dylanplummer commited on
Commit
0f5c629
·
1 Parent(s): 0cdfae9

simplify for api

Browse files
Files changed (1) hide show
  1. app.py +7 -167
app.py CHANGED
@@ -118,22 +118,10 @@ def confidence_analysis(periodicity, counts, frames, out_dir='confidence_animati
118
  plt.close(fig)
119
 
120
 
121
- def inference(x, both_feet, has_misses, true_count, center_crop, img_resize, miss_threshold, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, median_pred_filter=True, api_call=False):
122
  print(x)
123
- img_size = int((img_size - 64) * img_resize + 64)
124
  api = HfApi(token=os.environ['DATASET_SECRET'])
125
  out_file = str(uuid.uuid1())
126
- if has_misses:
127
- out_file = "misses_" + out_file
128
- if true_count != -1:
129
- out_file += '_' + str(true_count)
130
- out_file = f"labeled_videos/{out_file}.mp4"
131
- api.upload_file(
132
- path_or_fileobj=x,
133
- path_in_repo=out_file,
134
- repo_id="dylanplummer/jumprope",
135
- repo_type="dataset",
136
- )
137
 
138
 
139
  cap = cv2.VideoCapture(x)
@@ -231,6 +219,7 @@ def inference(x, both_feet, has_misses, true_count, center_crop, img_resize, mis
231
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
232
  if api_call:
233
  return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.3f" % x}).replace('\n', ''), np.array2string(periodicity, formatter={'float_kind':lambda x: "%.3f" % x}).replace('\n', '')
 
234
  if median_pred_filter:
235
  periodicity = medfilt(periodicity, 5)
236
  periodLength = medfilt(periodLength, 5)
@@ -249,130 +238,6 @@ def inference(x, both_feet, has_misses, true_count, center_crop, img_resize, mis
249
  count_pred = count_pred / 2
250
  count = np.array(count) / 2
251
 
252
- animate_frames = len(count)
253
- anim_subsample = max(1, int(fps / 24))
254
- fig, ax = plt.subplots(figsize = (3, 3))
255
- canvas_width, canvas_height = fig.canvas.get_width_height()
256
- img_ax = ax
257
-
258
- # imgs = []
259
- # for img in all_frames:
260
- # img = Image.fromarray(img)
261
- # width, height = img.size
262
- # if width > height:
263
- # img = img.resize((int(width / (height / img_size)), img_size))
264
- # else:
265
- # img = img.resize((img_size, int(height / (width / img_size))))
266
- # h_center, w_center = height / 2, width / 2
267
- # h_start, w_start = int(h_center - img_size / 2), int(w_center - img_size / 2)
268
- # cropped = img.crop((w_start, h_start, w_start + img_size, h_start + img_size))
269
- # imgs.append(cropped)
270
-
271
- # confidence_analysis(periodicity, count, imgs)
272
-
273
- # alpha=1.0
274
- # colormap=plt.cm.OrRd
275
- # h, w, _ = np.shape(imgs[0])
276
- # wedge_x = 34 / canvas_width * w
277
- # wedge_y = 34 / canvas_height * h
278
- # wedge_r = 30 / canvas_height * h
279
- # txt_x = 34 / canvas_width * w
280
- # txt_y = 36 / canvas_height * h
281
- # otxt_size = 25 / canvas_height * h
282
- # wedge1 = matplotlib.patches.Wedge(
283
- # center=(wedge_x, wedge_y),
284
- # r=wedge_r,
285
- # theta1=0,
286
- # theta2=0,
287
- # color=colormap(1.),
288
- # alpha=alpha)
289
- # wedge2 = matplotlib.patches.Wedge(
290
- # center=(wedge_x, wedge_y),
291
- # r=wedge_r,
292
- # theta1=0,
293
- # theta2=0,
294
- # color=colormap(0.5),
295
- # alpha=alpha)
296
-
297
- # im = img_ax.imshow(cropped)
298
-
299
- # img_ax.add_patch(wedge1)
300
- # img_ax.add_patch(wedge2)
301
- # txt = img_ax.text(
302
- # txt_x,
303
- # txt_y,
304
- # '0',
305
- # size=otxt_size,
306
- # ha='center',
307
- # va='center',
308
- # alpha=0.9,
309
- # color='white',
310
- # )
311
-
312
- # def animate_fn(i):
313
- # if anim_subsample:
314
- # i *= anim_subsample
315
- # cropped = imgs[i + stride_pad]
316
- # current_count = count[i]
317
- # if current_count % 2 == 0:
318
- # wedge1.set_color(colormap(1.0))
319
- # wedge2.set_color(colormap(0.5))
320
- # else:
321
- # wedge1.set_color(colormap(0.5))
322
- # wedge2.set_color(colormap(1.0))
323
- # txt.set_text(int(current_count))
324
- # wedge1.set_theta1(-90)
325
- # wedge1.set_theta2(-90 - 360 * (1 - current_count % 1.0))
326
- # wedge2.set_theta1(-90 - 360 * (1 - current_count % 1.0))
327
- # wedge2.set_theta2(-90)
328
-
329
- # im.set_data(cropped)
330
- # img_ax.set_title(f"Time: {i / fps:.1f}s, {current_count:.1f}/{count_pred:.1f} jumps")
331
- # img_ax.set_xticks([])
332
- # img_ax.set_yticks([])
333
- # img_ax.spines['top'].set_visible(False)
334
- # img_ax.spines['right'].set_visible(False)
335
- # img_ax.spines['bottom'].set_visible(False)
336
- # img_ax.spines['left'].set_visible(False)
337
-
338
- outf = x
339
- # anim_start_time = time.time()
340
- # # Open an ffmpeg process
341
- # outf = x.replace('.mp4', '_jump.mp4')
342
- # cmdstring = ('ffmpeg',
343
- # '-y', '-r', f'{30 if anim_subsample != 1 else int(fps)}', # overwrite, 24fps
344
- # '-s', f'{canvas_width}x{canvas_height}',
345
- # '-pix_fmt', 'argb',
346
- # '-hide_banner', '-loglevel', 'error',
347
- # '-f', 'rawvideo', '-i', '-', # tell ffmpeg to expect raw video from the pipe
348
- # '-i', x, '-map', '0:v', '-map', '1:a', # map video from the pipe, audio from the input file
349
- # '-c:v', 'libx264', # https://trac.ffmpeg.org/wiki/Encode/H.264
350
- # outf) # output encoding
351
- # try:
352
- # p = subprocess.Popen(cmdstring, stdin=subprocess.PIPE)
353
- # except FileNotFoundError as e:
354
- # print(e)
355
- # print('Trying to install ffmpeg...')
356
- # os.system("apt install ffmpeg -y")
357
- # p = subprocess.Popen(cmdstring, stdin=subprocess.PIPE)
358
-
359
- # # Draw frames and write to the pipe
360
- # anim_length = int(animate_frames / anim_subsample) - 1
361
- # for frame in range(anim_length):
362
- # # draw the frame
363
- # animate_fn(frame)
364
- # fig.canvas.draw()
365
- # # extract the image as an ARGB string
366
- # string = fig.canvas.tostring_argb()
367
- # # write to pipe
368
- # p.stdin.write(string)
369
-
370
- # # Finish up
371
- # p.communicate()
372
- # print(f"Animation done in {time.time() - anim_start_time:.2f} seconds")
373
- #cvrt_string = f'ffmpeg -hide_banner -loglevel error -i "{outf}" -i "{x}" -map 0:v -map 1:a -y -r 24 -s {canvas_width}x{canvas_height} -c:v libx264 "{outf.replace(".mp4", "_audio.mp4")}"'
374
- #os.system(cvrt_string)
375
-
376
  if both_feet:
377
  count_msg = f"## Predicted Count (both feet): {count_pred:.1f}"
378
  else:
@@ -428,13 +293,8 @@ def inference(x, both_feet, has_misses, true_count, center_crop, img_resize, mis
428
  histnorm='percent',
429
  title="Distribution of jumping speed (jumps-per-second)",
430
  range_x=[np.min(jumps_per_second[jumps_per_second > 0]) - 0.5, np.max(jumps_per_second) + 0.5])
431
-
432
- vid = px.imshow(np.uint8(all_frames)[:128], animation_frame=0, binary_string=True, binary_compression_level=5, binary_format='jpg', template="plotly_dark")
433
- vid.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)
434
- vid.layout.updatemenus[0].buttons[0].args[1]['frame']['duration'] = 13
435
- vid.layout.updatemenus[0].buttons[0].args[1]['transition']['duration'] = 5
436
 
437
- return count_msg, vid, fig, hist, periodLength
438
 
439
 
440
  DESCRIPTION = '# NextJump'
@@ -448,13 +308,6 @@ with gr.Blocks() as demo:
448
  with gr.Row():
449
  in_video = gr.Video(type="file", label="Input Video", elem_id='input-video', format='mp4').style(width=400)
450
  with gr.Column():
451
- gr.Markdown(label="Optional settings and parameters:")
452
- true_count = gr.Number(label="True Count (optional)", info="Provide a true count if you are ok with us using your video for training", elem_id='true-count', value=-1)
453
- both_feet = gr.Checkbox(label="Both feet", info="Count both feet rather than only one", elem_id='both-feet', value=False)
454
- misses = gr.Checkbox(label="Contains misses", info="Only necessary if providing a true count for training", elem_id='both-feet', value=False)
455
- center_crop = gr.Checkbox(label="Center crop square", info="Either crop a square out of the center or stretch to a square", elem_id='center-crop', value=True)
456
- image_size = gr.Slider(label="Image size", info="Lower image size is faster but less accurate", elem_id='miss-thresh', minimum=0.0, maximum=1.0, step=0.01, value=1.0)
457
- miss_thresh = gr.Slider(label="Miss threshold", info="Lower values are more sensitive to misses", elem_id='miss-thresh', minimum=0.0, maximum=0.99, step=0.01, value=0.95)
458
  with gr.Row():
459
  run_button = gr.Button(label="Run", elem_id='run-button').style(full_width=False)
460
  count_only = gr.Button(label="Run (No Viz)", elem_id='count-only', visible=False).style(full_width=False)
@@ -465,13 +318,12 @@ with gr.Blocks() as demo:
465
  out_text = gr.Markdown(label="Predicted Count", elem_id='output-text')
466
  period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False)
467
  periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
468
- with gr.Column(min_width=480):
469
  #out_video = gr.PlayableVideo(label="Output Video", elem_id='output-video', format='mp4')
470
- out_video = gr.Plot(label="Output Video", elem_id='output-video')
471
  out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
472
  out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
473
 
474
- inputs = [in_video, both_feet, misses, true_count, center_crop, image_size, miss_thresh]
475
  with gr.Accordion(label="Instructions and more information", open=False):
476
  instructions = "## Instructions:"
477
  instructions += "\n* Upload a video and click 'Run' to get a prediction of the number of jumps (either one foot, or both). This could take a couple minutes! The model is trained on single rope and double dutch speed, but try out any videos you want."
@@ -495,22 +347,10 @@ with gr.Blocks() as demo:
495
  [b, False, True, -1, True, 1.0, 0.95],
496
  ],
497
  inputs=inputs,
498
- outputs=[out_text, out_video, out_plot, out_hist],
499
  fn=inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
500
- with gr.Accordion(label="Data usage and disclaimer", open=False):
501
- data_usage = "## Data usage:"
502
- data_usage += "\n* By default, no data submitted to this demo is stored by us."
503
- data_usage += "\n* If you would like to contribute your video for further model improvements please provide the true count (either one or both feet) and specify if the video contains any misses."
504
- data_usage += "\n* The video will be uploaded to a private dataset repository here on HuggingFace"
505
- gr.Markdown(data_usage)
506
-
507
- disclaimer = "## Disclaimer:"
508
- disclaimer += "\n* This model was trained on a small dataset of videos (~20 hours). It is not guaranteed to work on all videos and should not be used yet in real competitive settings."
509
- disclaimer += "\n* Deep learning models such as this one are susceptible to biases in the training data."
510
- disclaimer += " We are aware of the potential for bias and expect quantifiable differences in performance on counting videos from different demographics not represented in the training data. We are working to improve the model and mitigate these biases. If you notice any issues, please let us know."
511
- gr.Markdown(disclaimer)
512
 
513
- run_button.click(inference, inputs, outputs=[out_text, out_video, out_plot, out_hist])
514
  api_inference = partial(inference, api_call=True)
515
  count_only.click(api_inference, inputs, outputs=[period_length, periodicity], api_name='inference')
516
 
 
118
  plt.close(fig)
119
 
120
 
121
+ def inference(x, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, miss_threshold=0.85, median_pred_filter=True, center_crop=True, both_feet=True, api_call=False):
122
  print(x)
 
123
  api = HfApi(token=os.environ['DATASET_SECRET'])
124
  out_file = str(uuid.uuid1())
 
 
 
 
 
 
 
 
 
 
 
125
 
126
 
127
  cap = cv2.VideoCapture(x)
 
219
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
220
  if api_call:
221
  return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.3f" % x}).replace('\n', ''), np.array2string(periodicity, formatter={'float_kind':lambda x: "%.3f" % x}).replace('\n', '')
222
+
223
  if median_pred_filter:
224
  periodicity = medfilt(periodicity, 5)
225
  periodLength = medfilt(periodLength, 5)
 
238
  count_pred = count_pred / 2
239
  count = np.array(count) / 2
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  if both_feet:
242
  count_msg = f"## Predicted Count (both feet): {count_pred:.1f}"
243
  else:
 
293
  histnorm='percent',
294
  title="Distribution of jumping speed (jumps-per-second)",
295
  range_x=[np.min(jumps_per_second[jumps_per_second > 0]) - 0.5, np.max(jumps_per_second) + 0.5])
 
 
 
 
 
296
 
297
+ return count_msg, fig, hist, periodLength
298
 
299
 
300
  DESCRIPTION = '# NextJump'
 
308
  with gr.Row():
309
  in_video = gr.Video(type="file", label="Input Video", elem_id='input-video', format='mp4').style(width=400)
310
  with gr.Column():
 
 
 
 
 
 
 
311
  with gr.Row():
312
  run_button = gr.Button(label="Run", elem_id='run-button').style(full_width=False)
313
  count_only = gr.Button(label="Run (No Viz)", elem_id='count-only', visible=False).style(full_width=False)
 
318
  out_text = gr.Markdown(label="Predicted Count", elem_id='output-text')
319
  period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False)
320
  periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
321
+ #with gr.Column(min_width=480):
322
  #out_video = gr.PlayableVideo(label="Output Video", elem_id='output-video', format='mp4')
 
323
  out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
324
  out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
325
 
326
+ inputs = [in_video]
327
  with gr.Accordion(label="Instructions and more information", open=False):
328
  instructions = "## Instructions:"
329
  instructions += "\n* Upload a video and click 'Run' to get a prediction of the number of jumps (either one foot, or both). This could take a couple minutes! The model is trained on single rope and double dutch speed, but try out any videos you want."
 
347
  [b, False, True, -1, True, 1.0, 0.95],
348
  ],
349
  inputs=inputs,
350
+ outputs=[out_text, out_plot, out_hist],
351
  fn=inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
 
 
 
 
 
 
 
 
 
 
 
 
352
 
353
+ run_button.click(inference, inputs, outputs=[out_text, out_plot, out_hist])
354
  api_inference = partial(inference, api_call=True)
355
  count_only.click(api_inference, inputs, outputs=[period_length, periodicity], api_name='inference')
356