dylanplummer commited on
Commit
d744c30
·
1 Parent(s): 1c4dfa7

add mark period reset

Browse files
Files changed (1) hide show
  1. app.py +11 -5
app.py CHANGED
@@ -3,14 +3,12 @@ import numpy as np
3
  from PIL import Image
4
  import os
5
  import cv2
6
- import uuid
7
- import time
8
  import spaces
9
- import subprocess
10
  import matplotlib
11
  matplotlib.use('Agg')
12
  import matplotlib.pyplot as plt
13
- from scipy.signal import medfilt
14
  from functools import partial
15
  from passlib.hash import pbkdf2_sha256
16
  from tqdm import tqdm
@@ -195,13 +193,21 @@ def inference(x, count_only_api, api_key,
195
  periodLength = medfilt(periodLength, 5)
196
  periodicity = sigmoid(periodicity)
197
  full_marks = sigmoid(full_marks)
198
- full_marks_mask = np.int32(full_marks > marks_threshold)
 
 
 
199
  periodicity_mask = np.int32(periodicity > miss_threshold)
200
  numofReps = 0
201
  count = []
202
  for i in range(len(periodLength)):
203
  if periodLength[i] < 2 or periodicity_mask[i] == 0:
204
  numofReps += 0
 
 
 
 
 
205
  else:
206
  numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
207
  count.append(round(float(numofReps), 2))
 
3
  from PIL import Image
4
  import os
5
  import cv2
6
+ import math
 
7
  import spaces
 
8
  import matplotlib
9
  matplotlib.use('Agg')
10
  import matplotlib.pyplot as plt
11
+ from scipy.signal import medfilt, find_peaks
12
  from functools import partial
13
  from passlib.hash import pbkdf2_sha256
14
  from tqdm import tqdm
 
193
  periodLength = medfilt(periodLength, 5)
194
  periodicity = sigmoid(periodicity)
195
  full_marks = sigmoid(full_marks)
196
+ #full_marks_mask = np.int32(full_marks > marks_threshold)
197
+ pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
198
+ full_marks_mask = np.zeros(len(full_marks))
199
+ full_marks_mask[pred_marks_peaks] = 1
200
  periodicity_mask = np.int32(periodicity > miss_threshold)
201
  numofReps = 0
202
  count = []
203
  for i in range(len(periodLength)):
204
  if periodLength[i] < 2 or periodicity_mask[i] == 0:
205
  numofReps += 0
206
+ elif full_marks_mask[i]: # high confidence mark detected
207
+ if math.modf(numofReps)[0] < 0.2: # probably false positive/late detection
208
+ numofReps = float(int(numofReps))
209
+ else:
210
+ numofReps = float(int(numofReps) + 1.01) # round up
211
  else:
212
  numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
213
  count.append(round(float(numofReps), 2))