dylanplummer commited on
Commit
2e275e6
·
1 Parent(s): 4d03e81

update for srcvol3

Browse files
Files changed (2) hide show
  1. app.py +90 -42
  2. hls_download.py +2 -2
app.py CHANGED
@@ -34,9 +34,9 @@ IMG_SIZE = 256
34
  CACHE_API_CALLS = False
35
  os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
36
 
37
- onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename="nextjump_256.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
 
38
 
39
- #onnx_file = 'nextjump.onnx'
40
 
41
  if torch.cuda.is_available():
42
  print("Using CUDA")
@@ -236,16 +236,32 @@ def detect_relay_beeps(video_path, event_start, relay_length=30, n_jumpers=4, be
236
  return starts, ends
237
 
238
 
239
- def inference(in_video, stream_url, start_time, end_time, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay,
240
  count_only_api, api_key, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
241
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
242
  api_call=False,
243
  progress=gr.Progress()):
244
- progress(0, desc="Downloading clip...")
245
- if in_video is None:
246
- in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
247
- else: # local uploaded video (still resize with ffmpeg)
248
- in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  progress(0, desc="Running inference...")
250
  has_access = False
251
  if api_call:
@@ -350,14 +366,14 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
350
  y5_out = outputs[4]
351
  y6_out = outputs[5]
352
  for y1, y2, y3, y4, y5, y6, idx in zip(y1_out, y2_out, y3_out, y4_out, y5_out, y6_out, idx_list):
353
- periodLength = y1.squeeze()
354
  periodicity = y2.squeeze()
355
  marks = y3.squeeze()
356
  event_type = y4.squeeze()
357
  foot_type = y5.squeeze()
358
  phase = y6.squeeze()
359
  period_lengths[idx:idx+seq_len] += periodLength[:, 0]
360
- period_lengths_rope[idx:idx+seq_len] += periodLength[:, 1]
361
  periodicities[idx:idx+seq_len] += periodicity
362
  full_marks[idx:idx+seq_len] += marks
363
  event_type_logits[idx:idx+seq_len] += event_type
@@ -368,7 +384,7 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
368
  del y1_out, y2_out, y3_out, y4_out # free up memory
369
 
370
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
371
- periodLength_rope = np.divide(period_lengths_rope, period_length_overlaps, where=period_length_overlaps!=0)[:length]
372
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
373
  full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
374
  per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
@@ -442,14 +458,14 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
442
 
443
  if LOCAL:
444
  if both_feet:
445
- count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
446
  else:
447
- count_msg = f"## Reps Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
448
  else:
449
  if both_feet:
450
- count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Confidence: {total_confidence:.2f}"
451
  else:
452
- count_msg = f"## Reps Count (one foot): {count_pred:.1f}, Confidence: {total_confidence:.2f}"
453
 
454
  if api_call:
455
  if CACHE_API_CALLS:
@@ -458,7 +474,8 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
458
  periodicity_str = np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
459
  periodLength_str = np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
460
  full_marks_str = np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
461
- f.write(f"{stream_url}\t{start_time}\t{end_time}\t{beep_detection_on}\t{event_length}\t{periodicity_str}\t{periodLength_str}\t{full_marks_str}\t{count_pred}\t{total_confidence}\n")
 
462
  if count_only_api:
463
  return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
464
  else:
@@ -552,10 +569,10 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
552
  y='jumps per second',
553
  #symbol='frame_type',
554
  #symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
555
- color='event_type',
556
  size='jumps_size',
557
  size_max=8,
558
- color_continuous_scale=[(t, c) for t, c in zip(event_type_tick_vals, event_type_colors)],
559
  range_color=(0,1),
560
  title="Jumping speed (jumps-per-second)",
561
  trendline='rolling',
@@ -593,11 +610,11 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
593
  )
594
  # remove white outline from marks
595
  fig.update_traces(marker_line_width = 0)
596
- fig.update_layout(coloraxis_colorbar=dict(
597
- tickvals=event_type_tick_vals,
598
- ticktext=['single<br>rope', 'double<br>dutch', 'double<br>unders', 'single<br>bounces', 'double<br>bounces', 'triple<br>unders', 'other'],
599
- title='event type'
600
- ))
601
 
602
 
603
  # -pi/2 phase offset to make the bottom of the plot the start of the jump
@@ -674,20 +691,53 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
674
  histnorm='percent',
675
  title="Distribution of jumping speed (jumps-per-second)")
676
 
677
- # make a bar plot of the event type distribution
678
-
679
- bar = px.bar(x=['single rope', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
680
- y=event_type_probs,
681
- template="plotly_dark",
682
- title="Event Type Distribution",
683
- labels={'x': 'event type', 'y': 'probability'},
684
- range_y=[0, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
685
  try:
686
  os.remove('temp.wav')
687
  except FileNotFoundError:
688
  pass
689
 
690
- return in_video, count_msg, fig, fig_phase_spiral, fig_phase_spiral_marks, hist, bar
691
 
692
  #css = '#phase-spiral {transform: rotate(0.25turn);}\n#phase-spiral-marks {transform: rotate(0.25turn);}'
693
  with gr.Blocks() as demo:
@@ -697,12 +747,10 @@ with gr.Blocks() as demo:
697
  max_length=300)
698
  with gr.Row():
699
  with gr.Column():
700
- in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
701
-
702
- in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True, value='00:00:00')
703
-
704
- in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True, value='00:00:00')
705
- in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
706
  with gr.Column():
707
  beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True)
708
  event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True)
@@ -740,19 +788,19 @@ with gr.Blocks() as demo:
740
 
741
  demo_inference = partial(inference, count_only_api=False, api_key=None)
742
 
743
- run_button.click(demo_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
744
  outputs=[out_video, out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist])
745
  api_inference = partial(inference, api_call=True)
746
- api_dummy_button.click(api_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only, api_token],
747
  outputs=[period_length], api_name='inference')
748
  examples = [
749
  #['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60],
750
- ['files/wc2023.mp4', '', '00:00:00', '', True, 30, False, '30', '0.2'],
751
  #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60]
752
  #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120]
753
  ]
754
  gr.Examples(examples,
755
- inputs=[in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
756
  outputs=[out_video, out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist],
757
  fn=demo_inference, cache_examples=False)
758
 
 
34
  CACHE_API_CALLS = False
35
  os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
36
 
37
+ current_model = 'nextjump_speed'
38
+ onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename=f"{current_model}.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
39
 
 
40
 
41
  if torch.cuda.is_available():
42
  print("Using CUDA")
 
236
  return starts, ends
237
 
238
 
239
+ def inference(in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay,
240
  count_only_api, api_key, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
241
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
242
  api_call=False,
243
  progress=gr.Progress()):
244
+ global current_model
245
+ if model_choice != current_model:
246
+ current_model = model_choice
247
+ onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename=f"{current_model}.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
248
+
249
+
250
+ if torch.cuda.is_available():
251
+ print("Using CUDA")
252
+ providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
253
+ "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
254
+ sess_options = ort.SessionOptions()
255
+ #sess_options.log_severity_level = 0
256
+ ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
257
+ else:
258
+ print("Using CPU")
259
+ ort_sess = ort.InferenceSession(onnx_file)
260
+
261
+ # warmup inference
262
+ ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)})
263
+
264
+ in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), '00:00:00', '', use_60fps=use_60fps)
265
  progress(0, desc="Running inference...")
266
  has_access = False
267
  if api_call:
 
366
  y5_out = outputs[4]
367
  y6_out = outputs[5]
368
  for y1, y2, y3, y4, y5, y6, idx in zip(y1_out, y2_out, y3_out, y4_out, y5_out, y6_out, idx_list):
369
+ periodLength = y1
370
  periodicity = y2.squeeze()
371
  marks = y3.squeeze()
372
  event_type = y4.squeeze()
373
  foot_type = y5.squeeze()
374
  phase = y6.squeeze()
375
  period_lengths[idx:idx+seq_len] += periodLength[:, 0]
376
+ #period_lengths_rope[idx:idx+seq_len] += periodLength[:, 1]
377
  periodicities[idx:idx+seq_len] += periodicity
378
  full_marks[idx:idx+seq_len] += marks
379
  event_type_logits[idx:idx+seq_len] += event_type
 
384
  del y1_out, y2_out, y3_out, y4_out # free up memory
385
 
386
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
387
+ #periodLength_rope = np.divide(period_lengths_rope, period_length_overlaps, where=period_length_overlaps!=0)[:length]
388
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
389
  full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
390
  per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
 
458
 
459
  if LOCAL:
460
  if both_feet:
461
+ count_msg = f"## Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
462
  else:
463
+ count_msg = f"## Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
464
  else:
465
  if both_feet:
466
+ count_msg = f"## Count (both feet): {count_pred:.1f}, Confidence: {total_confidence:.2f}"
467
  else:
468
+ count_msg = f"## Count (one foot): {count_pred:.1f}, Confidence: {total_confidence:.2f}"
469
 
470
  if api_call:
471
  if CACHE_API_CALLS:
 
474
  periodicity_str = np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
475
  periodLength_str = np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
476
  full_marks_str = np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
477
+ f.write(f"{beep_detection_on}\t{event_length}\t{periodicity_str}\t{periodLength_str}\t{full_marks_str}\t{count_pred}\t{total_confidence}\n")
478
+
479
  if count_only_api:
480
  return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
481
  else:
 
569
  y='jumps per second',
570
  #symbol='frame_type',
571
  #symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
572
+ color='periodicity',
573
  size='jumps_size',
574
  size_max=8,
575
+ color_continuous_scale='rainbow',
576
  range_color=(0,1),
577
  title="Jumping speed (jumps-per-second)",
578
  trendline='rolling',
 
610
  )
611
  # remove white outline from marks
612
  fig.update_traces(marker_line_width = 0)
613
+ # fig.update_layout(coloraxis_colorbar=dict(
614
+ # tickvals=event_type_tick_vals,
615
+ # ticktext=['single<br>rope', 'double<br>dutch', 'double<br>unders', 'single<br>bounces', 'double<br>bounces', 'triple<br>unders', 'other'],
616
+ # title='event type'
617
+ # ))
618
 
619
 
620
  # -pi/2 phase offset to make the bottom of the plot the start of the jump
 
691
  histnorm='percent',
692
  title="Distribution of jumping speed (jumps-per-second)")
693
 
694
+ # plot the full count and predict a count for 30s, 60s, and 180s if the video is shorter than that
695
+ count = np.array(count)
696
+ regression_plot = px.scatter(x=np.arange(len(count)), y=count,
697
+ color=periodicity,
698
+ color_continuous_scale='rainbow',
699
+ title="Count Prediction (Perfect Run)",
700
+ template="plotly_dark")
701
+ regression_plot.update_coloraxes(colorbar=dict(
702
+ title="Periodicity"))
703
+ regression_plot.update_traces(marker=dict(size=6, opacity=0.5))
704
+ regression_plot.update_layout(
705
+ xaxis_title="Frame",
706
+ yaxis_title="Count",
707
+ xaxis=dict(range=[0, len(count)]),
708
+ yaxis=dict(range=[0, max(count) * 1.2]),
709
+ showlegend=False,
710
+ )
711
+
712
+ # add 30s, 60s, and 180s predictions
713
+ pred_count_30s = int(np.median(jumps_per_second[~misses]) * 30)
714
+ pred_count_60s = int(np.median(jumps_per_second[~misses]) * 60)
715
+ pred_count_180s = int(np.median(jumps_per_second[~misses]) * 180)
716
+ # add text to the plot
717
+ regression_plot.add_annotation(
718
+ x=0.5,
719
+ y=0.95,
720
+ xref="paper",
721
+ yref="paper",
722
+ text=f"No-Miss Count (30s): {pred_count_30s}<br>No-Miss Count (60s): {pred_count_60s}<br>No-Miss Count (180s): {pred_count_180s}",
723
+ showarrow=False,
724
+ font=dict(
725
+ size=16,
726
+ color="white"
727
+ ),
728
+ align="center",
729
+ bgcolor="rgba(0, 0, 0, 0.5)",
730
+ bordercolor="white",
731
+ borderwidth=2,
732
+ borderpad=4,
733
+ opacity=0.8
734
+ )
735
  try:
736
  os.remove('temp.wav')
737
  except FileNotFoundError:
738
  pass
739
 
740
+ return in_video, count_msg, fig, fig_phase_spiral, fig_phase_spiral_marks, hist, regression_plot
741
 
742
  #css = '#phase-spiral {transform: rotate(0.25turn);}\n#phase-spiral-marks {transform: rotate(0.25turn);}'
743
  with gr.Blocks() as demo:
 
747
  max_length=300)
748
  with gr.Row():
749
  with gr.Column():
750
+ use_60fps = gr.Checkbox(label="Use 60 FPS", elem_id='use-60fps', visible=True)
751
+ model_choice = gr.Dropdown(
752
+ ["nextjump_speed", "nextjump_all"], label="Model Choice", info="For now just speed-only or general model",
753
+ )
 
 
754
  with gr.Column():
755
  beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True)
756
  event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True)
 
788
 
789
  demo_inference = partial(inference, count_only_api=False, api_key=None)
790
 
791
+ run_button.click(demo_inference, [in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
792
  outputs=[out_video, out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist])
793
  api_inference = partial(inference, api_call=True)
794
+ api_dummy_button.click(api_inference, [in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only, api_token],
795
  outputs=[period_length], api_name='inference')
796
  examples = [
797
  #['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60],
798
+ ['files/wc2023.mp4', True, 'nextjump_speed', True, 30, False, '30', '0.2'],
799
  #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60]
800
  #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120]
801
  ]
802
  gr.Examples(examples,
803
+ inputs=[in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
804
  outputs=[out_video, out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist],
805
  fn=demo_inference, cache_examples=False)
806
 
hls_download.py CHANGED
@@ -1,7 +1,7 @@
1
  import subprocess
2
  import os
3
 
4
- def download_clips(stream_url, out_dir, start_time, end_time, resize=True):
5
  # remove all .mp4 files in out_dir to avoid confusion
6
  if len(os.listdir(out_dir)) > 5:
7
  for f in os.listdir(out_dir):
@@ -14,7 +14,7 @@ def download_clips(stream_url, out_dir, start_time, end_time, resize=True):
14
  '-i', stream_url,
15
  '-c:v', 'libx264',
16
  '-crf', '23',
17
- '-r', '30',
18
  '-maxrate', '2M',
19
  '-bufsize', '4M',
20
  '-vf', f"scale=-2:300",
 
1
  import subprocess
2
  import os
3
 
4
+ def download_clips(stream_url, out_dir, start_time, end_time, resize=True, use_60fps=False):
5
  # remove all .mp4 files in out_dir to avoid confusion
6
  if len(os.listdir(out_dir)) > 5:
7
  for f in os.listdir(out_dir):
 
14
  '-i', stream_url,
15
  '-c:v', 'libx264',
16
  '-crf', '23',
17
+ '-r', '30' if not use_60fps else '60',
18
  '-maxrate', '2M',
19
  '-bufsize', '4M',
20
  '-vf', f"scale=-2:300",