dylanplummer commited on
Commit
98fee40
·
1 Parent(s): 0f5c629

add count only api option

Browse files
Files changed (1) hide show
  1. app.py +22 -15
app.py CHANGED
@@ -118,7 +118,7 @@ def confidence_analysis(periodicity, counts, frames, out_dir='confidence_animati
118
  plt.close(fig)
119
 
120
 
121
- def inference(x, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, miss_threshold=0.85, median_pred_filter=True, center_crop=True, both_feet=True, api_call=False):
122
  print(x)
123
  api = HfApi(token=os.environ['DATASET_SECRET'])
124
  out_file = str(uuid.uuid1())
@@ -217,9 +217,7 @@ def inference(x, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch
217
 
218
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
219
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
220
- if api_call:
221
- return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.3f" % x}).replace('\n', ''), np.array2string(periodicity, formatter={'float_kind':lambda x: "%.3f" % x}).replace('\n', '')
222
-
223
  if median_pred_filter:
224
  periodicity = medfilt(periodicity, 5)
225
  periodLength = medfilt(periodLength, 5)
@@ -243,6 +241,15 @@ def inference(x, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch
243
  else:
244
  count_msg = f"## Predicted Count (one foot): {count_pred:.1f}"
245
 
 
 
 
 
 
 
 
 
 
246
  jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.05), 0, 8)
247
  jumping_speed = np.copy(jumps_per_second)
248
  misses = periodicity < miss_threshold
@@ -306,11 +313,12 @@ with gr.Blocks() as demo:
306
  gr.Markdown(DESCRIPTION)
307
  with gr.Column():
308
  with gr.Row():
309
- in_video = gr.Video(type="file", label="Input Video", elem_id='input-video', format='mp4').style(width=400)
310
- with gr.Column():
311
- with gr.Row():
312
- run_button = gr.Button(label="Run", elem_id='run-button').style(full_width=False)
313
- count_only = gr.Button(label="Run (No Viz)", elem_id='count-only', visible=False).style(full_width=False)
 
314
 
315
  with gr.Column(elem_id='output-video-container'):
316
  with gr.Row():
@@ -322,8 +330,7 @@ with gr.Blocks() as demo:
322
  #out_video = gr.PlayableVideo(label="Output Video", elem_id='output-video', format='mp4')
323
  out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
324
  out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
325
-
326
- inputs = [in_video]
327
  with gr.Accordion(label="Instructions and more information", open=False):
328
  instructions = "## Instructions:"
329
  instructions += "\n* Upload a video and click 'Run' to get a prediction of the number of jumps (either one foot, or both). This could take a couple minutes! The model is trained on single rope and double dutch speed, but try out any videos you want."
@@ -346,13 +353,13 @@ with gr.Blocks() as demo:
346
  [a, True, False, -1, True, 1.0, 0.95],
347
  [b, False, True, -1, True, 1.0, 0.95],
348
  ],
349
- inputs=inputs,
350
  outputs=[out_text, out_plot, out_hist],
351
  fn=inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
352
-
353
- run_button.click(inference, inputs, outputs=[out_text, out_plot, out_hist])
354
  api_inference = partial(inference, api_call=True)
355
- count_only.click(api_inference, inputs, outputs=[period_length, periodicity], api_name='inference')
356
 
357
 
358
  if __name__ == "__main__":
 
118
  plt.close(fig)
119
 
120
 
121
+ def inference(x, count_only_api, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, miss_threshold=0.85, median_pred_filter=True, center_crop=True, both_feet=True, api_call=False):
122
  print(x)
123
  api = HfApi(token=os.environ['DATASET_SECRET'])
124
  out_file = str(uuid.uuid1())
 
217
 
218
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
219
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
220
+
 
 
221
  if median_pred_filter:
222
  periodicity = medfilt(periodicity, 5)
223
  periodLength = medfilt(periodLength, 5)
 
241
  else:
242
  count_msg = f"## Predicted Count (one foot): {count_pred:.1f}"
243
 
244
+ if api_call:
245
+ if count_only_api:
246
+ return f"{count_pred:.2f}"
247
+ else:
248
+ return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
249
+ np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
250
+ np.array2string(count_pred, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', '')
251
+
252
+
253
  jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.05), 0, 8)
254
  jumping_speed = np.copy(jumps_per_second)
255
  misses = periodicity < miss_threshold
 
313
  gr.Markdown(DESCRIPTION)
314
  with gr.Column():
315
  with gr.Row():
316
+ in_video = gr.Video(label="Input Video", elem_id='input-video', format='mp4', width=400, scale=2)
317
+
318
+ with gr.Row():
319
+ run_button = gr.Button(label="Run", elem_id='run-button', style=dict(full_width=False), scale=1)
320
+ api_dummy_button = gr.Button(label="Run (No Viz)", elem_id='count-only', visible=False, scale=2)
321
+ count_only = gr.Checkbox(label="Count Only", visible=False)
322
 
323
  with gr.Column(elem_id='output-video-container'):
324
  with gr.Row():
 
330
  #out_video = gr.PlayableVideo(label="Output Video", elem_id='output-video', format='mp4')
331
  out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
332
  out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
333
+
 
334
  with gr.Accordion(label="Instructions and more information", open=False):
335
  instructions = "## Instructions:"
336
  instructions += "\n* Upload a video and click 'Run' to get a prediction of the number of jumps (either one foot, or both). This could take a couple minutes! The model is trained on single rope and double dutch speed, but try out any videos you want."
 
353
  [a, True, False, -1, True, 1.0, 0.95],
354
  [b, False, True, -1, True, 1.0, 0.95],
355
  ],
356
+ inputs=[in_video],
357
  outputs=[out_text, out_plot, out_hist],
358
  fn=inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
359
+ demo_inference = partial(inference, count_only_api=False)
360
+ run_button.click(demo_inference, [in_video], outputs=[out_text, out_plot, out_hist])
361
  api_inference = partial(inference, api_call=True)
362
+ api_dummy_button.click(api_inference, [in_video, count_only], outputs=[period_length], api_name='inference')
363
 
364
 
365
  if __name__ == "__main__":