dylanplummer commited on
Commit
319f52e
·
1 Parent(s): 57925f4

add conf and theme

Browse files
Files changed (2) hide show
  1. app.py +15 -10
  2. requirements.txt +0 -1
app.py CHANGED
@@ -11,7 +11,6 @@ import matplotlib
11
  matplotlib.use('Agg')
12
  import matplotlib.pyplot as plt
13
  from scipy.signal import medfilt
14
- from skimage.measure import block_reduce
15
  from functools import partial
16
  from passlib.hash import pbkdf2_sha256
17
  from tqdm import tqdm
@@ -183,13 +182,11 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
183
 
184
  if median_pred_filter:
185
  periodicity = medfilt(periodicity, 5)
186
- #periodLength = medfilt(periodLength, 5)
187
  periodicity = sigmoid(periodicity)
188
  full_marks = sigmoid(full_marks)
189
  full_marks_mask = np.int32(full_marks > marks_threshold)
190
- #full_marks_reduced = block_reduce(full_marks > marks_threshold, (3,), np.max)
191
  periodicity_mask = np.int32(periodicity > miss_threshold)
192
- #periodicity_mask_reduced = block_reduce(periodicity_mask, (3,), np.max)
193
  numofReps = 0
194
  count = []
195
  for i in range(len(periodLength)):
@@ -209,10 +206,15 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
209
  marks_count_pred = marks_count_pred / 2
210
  count = np.array(count) / 2
211
 
 
 
 
 
 
212
  if both_feet:
213
- count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}"
214
  else:
215
- count_msg = f"## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}"
216
 
217
  if api_call:
218
  if count_only_api:
@@ -229,20 +231,23 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
229
  jumping_speed = np.copy(jumps_per_second)
230
  misses = periodicity < miss_threshold
231
  jumps_per_second[misses] = 0
 
 
232
  df = pd.DataFrame.from_dict({'period length': periodLength,
233
  'jumping speed': jumping_speed,
234
  'jumps per second': jumps_per_second,
235
  'periodicity': periodicity,
236
  'miss': misses,
 
237
  'jumps': full_marks,
238
- 'jumps_size': (full_marks + 0.2) * 10,
239
  'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 10),
240
  'seconds': np.linspace(0, seconds, num=len(periodLength))})
241
  fig = px.scatter(data_frame=df,
242
  x='seconds',
243
  y='jumps per second',
244
- symbol='miss',
245
- symbol_map={False: 'circle', True: 'circle-open'},
246
  color='periodicity',
247
  size='jumps_size',
248
  size_max=10,
@@ -294,7 +299,7 @@ DESCRIPTION += '\n## AI Counting for Competitive Jump Rope'
294
  DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
295
 
296
 
297
- with gr.Blocks() as demo:
298
  gr.Markdown(DESCRIPTION)
299
  with gr.Column():
300
  with gr.Row():
 
11
  matplotlib.use('Agg')
12
  import matplotlib.pyplot as plt
13
  from scipy.signal import medfilt
 
14
  from functools import partial
15
  from passlib.hash import pbkdf2_sha256
16
  from tqdm import tqdm
 
182
 
183
  if median_pred_filter:
184
  periodicity = medfilt(periodicity, 5)
185
+ periodLength = medfilt(periodLength, 5)
186
  periodicity = sigmoid(periodicity)
187
  full_marks = sigmoid(full_marks)
188
  full_marks_mask = np.int32(full_marks > marks_threshold)
 
189
  periodicity_mask = np.int32(periodicity > miss_threshold)
 
190
  numofReps = 0
191
  count = []
192
  for i in range(len(periodLength)):
 
206
  marks_count_pred = marks_count_pred / 2
207
  count = np.array(count) / 2
208
 
209
+ confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
210
+ self_err = abs(count_pred - marks_count_pred)
211
+ self_pct_err = self_err / count_pred
212
+ total_confidence = confidence * (1 - self_pct_err)
213
+
214
  if both_feet:
215
+ count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
216
  else:
217
+ count_msg = f"## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
218
 
219
  if api_call:
220
  if count_only_api:
 
231
  jumping_speed = np.copy(jumps_per_second)
232
  misses = periodicity < miss_threshold
233
  jumps_per_second[misses] = 0
234
+ frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
235
+ frame_type[full_marks > marks_threshold] = 'jump'
236
  df = pd.DataFrame.from_dict({'period length': periodLength,
237
  'jumping speed': jumping_speed,
238
  'jumps per second': jumps_per_second,
239
  'periodicity': periodicity,
240
  'miss': misses,
241
+ 'frame_type': frame_type,
242
  'jumps': full_marks,
243
+ 'jumps_size': (full_marks + 0.1) * 8,
244
  'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 10),
245
  'seconds': np.linspace(0, seconds, num=len(periodLength))})
246
  fig = px.scatter(data_frame=df,
247
  x='seconds',
248
  y='jumps per second',
249
+ symbol='frame_type',
250
+ symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
251
  color='periodicity',
252
  size='jumps_size',
253
  size_max=10,
 
299
  DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
300
 
301
 
302
+ with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
303
  gr.Markdown(DESCRIPTION)
304
  with gr.Column():
305
  with gr.Row():
requirements.txt CHANGED
@@ -4,7 +4,6 @@ matplotlib
4
  plotly
5
  passlib
6
  scipy
7
- scikit-image
8
  --find-links https://download.pytorch.org/whl/torch_stable.html
9
  opencv-python-headless==4.7.0.68
10
  openvino-dev==2022.3.0
 
4
  plotly
5
  passlib
6
  scipy
 
7
  --find-links https://download.pytorch.org/whl/torch_stable.html
8
  opencv-python-headless==4.7.0.68
9
  openvino-dev==2022.3.0