dylanplummer commited on
Commit
580b186
·
verified ·
1 Parent(s): d0e184f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -25
app.py CHANGED
@@ -4,6 +4,7 @@ from PIL import Image
4
  import os
5
  import cv2
6
  import math
 
7
  import subprocess
8
  import matplotlib
9
  matplotlib.use('Agg')
@@ -26,28 +27,22 @@ from huggingface_hub import HfApi
26
 
27
  from hls_download import download_clips
28
 
29
- plt.style.use('dark_background')
30
 
31
- LOCAL = False
32
  LOCAL = False
33
  IMG_SIZE = 256
34
  CACHE_API_CALLS = True
35
  os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
36
- CACHE_API_CALLS = True
37
- os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
38
 
39
  onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
40
  if torch.cuda.is_available():
41
- print("Using CUDA")
42
  print("Using CUDA")
43
  providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
44
  "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
45
  sess_options = ort.SessionOptions()
46
  #sess_options.log_severity_level = 0
47
- #sess_options.log_severity_level = 0
48
  ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
49
  else:
50
- print("Using CPU")
51
  print("Using CPU")
52
  ort_sess = ort.InferenceSession(onnx_file)
53
 
@@ -245,10 +240,6 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
245
  api_call=False,
246
  progress=gr.Progress()):
247
  progress(0, desc="Downloading clip...")
248
- if in_video is None:
249
- in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
250
- else: # local uploaded video (still resize with ffmpeg)
251
- in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
252
  if in_video is None:
253
  in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
254
  else: # local uploaded video (still resize with ffmpeg)
@@ -448,13 +439,58 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
448
  if count_only_api:
449
  return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
450
  else:
451
- return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), \
452
- np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), \
453
- np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), \
454
- f"reps: {count_pred:.2f}, marks: {marks_count_pred:.1f}, confidence: {total_confidence:.2f}", \
455
- f"single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456
 
457
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
  jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10)
459
  jumping_speed = np.copy(jumps_per_second)
460
  misses = periodicity < miss_threshold
@@ -602,7 +638,7 @@ with gr.Blocks() as demo:
602
  outputs=[period_length], api_name='inference')
603
  examples = [
604
  #['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60],
605
- [None, 'https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_UGEhqlMh/vod', '00:00:18', '00:00:55', True, 30],
606
  #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60]
607
  #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120]
608
  ]
@@ -613,14 +649,6 @@ with gr.Blocks() as demo:
613
 
614
 
615
  if __name__ == "__main__":
616
- if LOCAL:
617
- demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
618
- server_port=7860,
619
- debug=False,
620
- ssl_verify=False,
621
- share=False)
622
- else:
623
- demo.queue(api_open=True, max_size=15).launch(share=False)
624
  if LOCAL:
625
  demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
626
  server_port=7860,
 
4
  import os
5
  import cv2
6
  import math
7
+ import json
8
  import subprocess
9
  import matplotlib
10
  matplotlib.use('Agg')
 
27
 
28
  from hls_download import download_clips
29
 
30
+ #plt.style.use('dark_background')
31
 
 
32
  LOCAL = False
33
  IMG_SIZE = 256
34
  CACHE_API_CALLS = True
35
  os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
 
 
36
 
37
  onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
38
  if torch.cuda.is_available():
 
39
  print("Using CUDA")
40
  providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
41
  "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
42
  sess_options = ort.SessionOptions()
43
  #sess_options.log_severity_level = 0
 
44
  ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
45
  else:
 
46
  print("Using CPU")
47
  ort_sess = ort.InferenceSession(onnx_file)
48
 
 
240
  api_call=False,
241
  progress=gr.Progress()):
242
  progress(0, desc="Downloading clip...")
 
 
 
 
243
  if in_video is None:
244
  in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
245
  else: # local uploaded video (still resize with ffmpeg)
 
439
  if count_only_api:
440
  return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
441
  else:
442
+ # create a nice json object to return
443
+ results_dict = {
444
+ "periodLength": np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
445
+ "periodicity": np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
446
+ "full_marks": np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
447
+ "cum_count": np.array2string(np.array(count), formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
448
+ "count": f"{count_pred:.2f}",
449
+ "marks": f"{marks_count_pred:.1f}",
450
+ "confidence": f"{total_confidence:.2f}",
451
+ "single_rope_speed": f"{event_type_probs[0]:.3f}",
452
+ "double_dutch": f"{event_type_probs[1]:.3f}",
453
+ "double_unders": f"{event_type_probs[2]:.3f}",
454
+ "single_bounce": f"{event_type_probs[3]:.3f}"
455
+ }
456
+ if beep_detection_on:
457
+ results_dict['event_start'] = event_start
458
+ results_dict['event_end'] = event_end
459
+ if relay_detection_on:
460
+ results_dict['relay_starts'] = relay_starts
461
+ results_dict['relay_ends'] = relay_ends
462
+ return json.dumps(results_dict)
463
+
464
 
465
 
466
+ fig, axs = plt.subplots(4, 1, figsize=(12, 10)) # Added a plot for count
467
+
468
+ # Ensure data exists before plotting
469
+ axs[0].plot(periodLength)
470
+ axs[0].set_title(f"Stream 0 - Period Length")
471
+
472
+ axs[1].plot(periodicity)
473
+ axs[1].set_title("Stream 0 - Periodicity")
474
+ axs[1].set_ylim(0, 1)
475
+ axs[1].axhline(miss_threshold, color='r', linestyle=':', label=f'Miss Thresh ({miss_threshold})')
476
+
477
+
478
+ axs[2].plot(full_marks, label='Raw Marks')
479
+ marks_peaks_vis, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
480
+ axs[2].plot(marks_peaks_vis, np.array(full_marks)[marks_peaks_vis], "x", label='Detected Peaks')
481
+ axs[2].set_title("Stream 0 - Marks")
482
+ axs[2].set_ylim(0, 1)
483
+ axs[2].axhline(marks_threshold, color='r', linestyle=':', label=f'Mark Thresh ({marks_threshold})')
484
+
485
+
486
+ axs[3].plot(count)
487
+ axs[3].set_title("Stream 0 - Calculated Count")
488
+
489
+ plt.tight_layout()
490
+
491
+ plt.savefig('plot.png')
492
+ plt.close()
493
+
494
  jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10)
495
  jumping_speed = np.copy(jumps_per_second)
496
  misses = periodicity < miss_threshold
 
638
  outputs=[period_length], api_name='inference')
639
  examples = [
640
  #['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60],
641
+ ['files/wc2023.mp4', '', '00:00:00', '', True, 30, False, '', '0.2'],
642
  #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60]
643
  #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120]
644
  ]
 
649
 
650
 
651
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
652
  if LOCAL:
653
  demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
654
  server_port=7860,