dylan-plummer commited on
Commit
ed4eebf
·
1 Parent(s): 6249b75

prepare for testing

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +15 -200
  3. hls_download.py +3 -11
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: NextJump
3
  emoji: 🦘
4
  colorFrom: blue
5
  colorTo: green
 
1
  ---
2
+ title: NextJump + Single Rope Contest
3
  emoji: 🦘
4
  colorFrom: blue
5
  colorTo: green
app.py CHANGED
@@ -46,9 +46,11 @@ if torch.cuda.is_available():
46
  sess_options = ort.SessionOptions()
47
  #sess_options.log_severity_level = 0
48
  ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
 
49
  else:
50
  print("Using CPU")
51
  ort_sess = ort.InferenceSession(onnx_file)
 
52
 
53
  # warmup inference
54
  ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)})
@@ -182,121 +184,6 @@ def detect_beeps(video_path, target_event_length=30, beep_height=0.8):
182
  return event_start, event_end
183
 
184
 
185
- def detect_relay_beeps(video_path, event_start, relay_length=30, n_jumpers=4, beep_height=0.8):
186
- reference_file = 'relay_beep.WAV'
187
- fs, beep = wavfile.read(reference_file)
188
- beep = beep[:, 0] + beep[:, 1] # combine stereo to mono
189
- video = cv2.VideoCapture(video_path)
190
- try:
191
- os.remove('temp.wav')
192
- except FileNotFoundError:
193
- pass
194
- audio_convert_command = f'ffmpeg -i {video_path} -vn -acodec pcm_s16le -ar {fs} -ac 2 temp.wav'
195
- subprocess.call(audio_convert_command, shell=True)
196
- length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
197
- fps = int(video.get(cv2.CAP_PROP_FPS))
198
- audio = wavfile.read('temp.wav')[1]
199
- audio = (audio[:, 0] + audio[:, 1]) / 2 # combine stereo to mono
200
- corr = correlate(audio, beep, mode='same') / audio.size
201
- # min max scale to -1, 1
202
- corr = 2 * (corr - np.min(corr)) / (np.max(corr) - np.min(corr)) - 1
203
-
204
- # Calculate total event length in frames
205
- total_event_length_frames = fps * relay_length * n_jumpers
206
- print(event_start, total_event_length_frames)
207
- expected_event_end = event_start + total_event_length_frames
208
-
209
- # Find all significant peaks in the correlation
210
- peaks, _ = find_peaks(corr, height=beep_height, distance=fs)
211
-
212
- # Convert peaks from sample indices to frame indices
213
- peak_frames = [int(peak / fs * fps) for peak in peaks]
214
-
215
- # For debugging
216
- plt.plot(corr)
217
- plt.plot(peaks, corr[peaks], "x")
218
- plt.savefig('beep.png')
219
- plt.close()
220
-
221
- starts = []
222
- ends = []
223
-
224
- # Add the event start for the first jumper
225
- starts.append(event_start)
226
-
227
- # Convert event_start back to sample index for comparison
228
- event_start_sample = int(event_start * fs / fps)
229
-
230
- # Find peaks that come after the event start but before the expected end
231
- # Convert expected_event_end to sample index
232
- expected_event_end_sample = int(expected_event_end * fs / fps)
233
- relevant_peaks = [p for p in peaks if event_start_sample < p < expected_event_end_sample]
234
-
235
- # If we don't have enough peaks, try lowering the threshold
236
- if len(relevant_peaks) < n_jumpers - 1: # We need n_jumpers-1 transitions
237
- for lower_height in [0.7, 0.6, 0.5, 0.4, 0.3]:
238
- peaks, _ = find_peaks(corr, height=lower_height, distance=fs)
239
- relevant_peaks = [p for p in peaks if event_start_sample < p < expected_event_end_sample]
240
- if len(relevant_peaks) >= n_jumpers - 1:
241
- break
242
-
243
- # If we still don't have enough peaks, we'll need to estimate some transitions
244
- relay_length_frames = fps * relay_length
245
-
246
- # Process peaks to identify jumper transitions
247
- if len(relevant_peaks) >= n_jumpers - 1:
248
- # Ideal case: we found enough beeps for transitions
249
- # Sort peaks by time to ensure correct order
250
- relevant_peaks.sort()
251
-
252
- # Use the first n_jumpers-1 peaks as transition points
253
- transition_frames = [int(p / fs * fps) for p in relevant_peaks[:n_jumpers-1]]
254
-
255
- # Set ends for jumpers based on transition points
256
- for i in range(n_jumpers - 1):
257
- ends.append(transition_frames[i])
258
- starts.append(transition_frames[i])
259
-
260
- # Add end for the last jumper
261
- ends.append(expected_event_end)
262
- else:
263
- # Not enough peaks detected, use expected relay_length to estimate
264
- for i in range(n_jumpers):
265
- if i == 0:
266
- # First jumper starts at event_start (already added to starts)
267
- jumper_end = event_start + relay_length_frames
268
- ends.append(jumper_end)
269
- if i < n_jumpers - 1:
270
- starts.append(jumper_end)
271
- elif i < n_jumpers - 1:
272
- jumper_end = starts[i] + relay_length_frames
273
- ends.append(jumper_end)
274
- starts.append(jumper_end)
275
- else:
276
- # Last jumper
277
- jumper_end = starts[i] + relay_length_frames
278
- ends.append(jumper_end)
279
-
280
- # Validate and adjust if necessary
281
- # Make sure all intervals are close to relay_length
282
- for i in range(n_jumpers):
283
- interval = ends[i] - starts[i]
284
- # If an interval is significantly different from relay_length, adjust it
285
- if abs(interval - relay_length_frames) > relay_length_frames * 0.2: # 20% tolerance
286
- # Adjust the end time to match expected relay_length
287
- ends[i] = starts[i] + relay_length_frames
288
- # If not the last jumper, adjust the next start time
289
- if i < n_jumpers - 1:
290
- starts[i + 1] = ends[i]
291
-
292
- # Final check: ensure the total length matches expected
293
- if ends[-1] != expected_event_end:
294
- # Adjust the last end to match the expected total event end
295
- ends[-1] = expected_event_end
296
-
297
- return starts, ends
298
-
299
-
300
  def upload_video(out_text, in_video):
301
  if out_text != '':
302
  # generate a timestamp name for the video
@@ -336,38 +223,16 @@ def count_phases(phase_sin, phase_cos, threshold=0.5):
336
 
337
 
338
 
339
- def inference(in_video, stream_url, start_time, end_time, use_60fps, model_choice,
340
- beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay,
341
  count_only_api, api_key, seq_len=64, stride_length=32, stride_pad=3, batch_size=2,
342
  miss_threshold=0.5, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
343
  api_call=False,
344
  progress=gr.Progress()):
345
- global current_model, ort_sess
346
  print(in_video)
347
- if model_choice != current_model:
348
- current_model = model_choice
349
- onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename=f"{current_model}.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
350
- #onnx_file = f'{current_model}.onnx'
351
-
352
- if torch.cuda.is_available():
353
- print("Using CUDA")
354
- providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
355
- "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
356
- sess_options = ort.SessionOptions()
357
- #sess_options.log_severity_level = 0
358
- ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
359
- else:
360
- print("Using CPU")
361
- ort_sess = ort.InferenceSession(onnx_file)
362
-
363
- # warmup inference
364
- ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)})
365
  if in_video is None:
366
- print("No video input provided.")
367
- in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time, use_60fps=use_60fps)
368
- else: # local uploaded video (still resize with ffmpeg)
369
- print("Using uploaded video input.")
370
- in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), start_time, end_time, use_60fps=use_60fps)
371
  progress(0, desc="Running inference...")
372
  has_access = False
373
  if api_call:
@@ -379,10 +244,6 @@ def inference(in_video, stream_url, start_time, end_time, use_60fps, model_choic
379
  event_length = int(event_length)
380
  event_start, event_end = detect_beeps(in_video, event_length)
381
  print(event_start, event_end)
382
- if relay_detection_on:
383
- n_jumpers = int(int(event_length) / int(relay_length))
384
- relay_starts, relay_ends = detect_relay_beeps(in_video, event_start, int(relay_length), n_jumpers)
385
- print(relay_starts, relay_ends)
386
 
387
  cap = cv2.VideoCapture(in_video)
388
  length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -521,12 +382,6 @@ def inference(in_video, stream_url, start_time, end_time, use_60fps, model_choic
521
  print(f"Event detected: {event_start} - {event_end}")
522
  periodicity[:event_start] = 0
523
  periodicity[event_end:] = 0
524
- if relay_detection_on:
525
- for start, end in zip(relay_starts, relay_ends):
526
- if start > 0 and end > 0:
527
- print(f"Relay Event detected: {start} - {end}")
528
- # immediately after the beep set periodicity to 0 for switch_delay seconds
529
- periodicity[start:start + int(float(switch_delay) * fps)] = 0
530
  pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
531
  full_marks_mask = np.zeros(len(full_marks))
532
  full_marks_mask[pred_marks_peaks] = 1
@@ -672,7 +527,7 @@ def inference(in_video, stream_url, start_time, end_time, use_60fps, model_choic
672
  periodicity_str = np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
673
  periodLength_str = np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
674
  full_marks_str = np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
675
- f.write(f"{stream_url}\t{start_time}\t{end_time}\t{beep_detection_on}\t{event_length}\t{periodicity_str}\t{periodLength_str}\t{full_marks_str}\t{count_pred}\t{total_confidence}\n")
676
  if count_only_api:
677
  return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
678
  else:
@@ -707,9 +562,6 @@ def inference(in_video, stream_url, start_time, end_time, use_60fps, model_choic
707
  if beep_detection_on:
708
  results_dict['event_start'] = event_start
709
  results_dict['event_end'] = event_end
710
- if relay_detection_on:
711
- results_dict['relay_starts'] = relay_starts
712
- results_dict['relay_ends'] = relay_ends
713
  return json.dumps(results_dict)
714
 
715
 
@@ -795,12 +647,6 @@ def inference(in_video, stream_url, start_time, end_time, use_60fps, model_choic
795
  if beep_detection_on:
796
  # add vertical lines for beep event
797
  fig.add_vrect(x0=event_start / fps, x1=event_end / fps, fillcolor="LightSalmon", opacity=0.25, layer="below", line_width=0)
798
- if relay_detection_on:
799
- for start, end in zip(relay_starts, relay_ends):
800
- start += 10 # add some padding
801
- end -= 10
802
- fig.add_vrect(x0=start / fps, x1=end / fps, fillcolor="LightGreen", opacity=0.25, layer="below",
803
- line_width=0)
804
 
805
 
806
  fig.update_layout(legend=dict(
@@ -904,20 +750,12 @@ def inference(in_video, stream_url, start_time, end_time, use_60fps, model_choic
904
  histnorm='percent',
905
  title="Distribution of jumping speed (jumps-per-second)")
906
 
907
- # make a bar plot of the event type distribution
908
-
909
- bar = px.bar(x=['single rope', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
910
- y=event_type_probs,
911
- template="plotly_dark",
912
- title="Event Type Distribution",
913
- labels={'x': 'event type', 'y': 'probability'},
914
- range_y=[0, 1])
915
  try:
916
  os.remove('temp.wav')
917
  except FileNotFoundError:
918
  pass
919
 
920
- return count_msg, fig, fig_phase_spiral, fig_phase_spiral_marks, hist, bar
921
 
922
  #css = '#phase-spiral {transform: rotate(0.25turn);}\n#phase-spiral-marks {transform: rotate(0.25turn);}'
923
  with gr.Blocks() as demo:
@@ -926,30 +764,15 @@ with gr.Blocks() as demo:
926
  width=400, height=400, interactive=True, container=True,
927
  max_length=300)
928
  with gr.Row():
929
- with gr.Column():
930
- gr.Markdown(
931
- """
932
- ### Stream Input Options
933
- Either upload a video file above, or provide a stream URL below.
934
- """,
935
- elem_id='stream-input-options',
936
- )
937
- in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
938
- in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True, value='00:00:00')
939
- in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
940
  with gr.Column():
941
  gr.Markdown(
942
  """
943
  ### Inference Options
944
- Select the model and framerate for inference.
945
  """,
946
  elem_id='inference-options',
947
  )
948
  use_60fps = gr.Checkbox(label="Use 60 FPS", elem_id='use-60fps', visible=True)
949
- model_choice = gr.Dropdown(
950
- ["nextjump_speed", "nextjump_all", "nextjump_both_feet"], label="Model Choice", info="For now just speed-only or general model",
951
- value="nextjump_speed", elem_id='model-choice'
952
- )
953
  with gr.Column():
954
  gr.Markdown(
955
  """
@@ -960,9 +783,6 @@ with gr.Blocks() as demo:
960
  )
961
  beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True)
962
  event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True)
963
- relay_detection_on = gr.Checkbox(label="Relay Event", elem_id='relay-beeps', visible=True)
964
- relay_length = gr.Textbox(label="Relay Length (s)", elem_id='relay-length', visible=True, value='30')
965
- switch_delay = gr.Textbox(label="Expected Switch Delay (s)", elem_id='event-length', visible=True, value='0.2')
966
 
967
  with gr.Row():
968
  run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
@@ -986,26 +806,21 @@ with gr.Blocks() as demo:
986
  with gr.Row():
987
  with gr.Column():
988
  out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
989
- with gr.Column():
990
- out_event_type_dist = gr.Plot(label="Event Type Distribution", elem_id='output-event-type-dist')
991
 
992
 
993
  demo_inference = partial(inference, count_only_api=False, api_key=None)
994
 
995
- run_button.click(demo_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
996
- outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist])
997
  api_inference = partial(inference, api_call=True)
998
- api_dummy_button.click(api_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only, api_token],
999
  outputs=[period_length], api_name='inference')
1000
  examples = [
1001
- #['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60],
1002
- ['files/wc2023.mp4', '', '00:00:00', '', True, 'nextjump_speed', True, 30, False, '30', '0.2'],
1003
- #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60]
1004
- #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120]
1005
  ]
1006
  gr.Examples(examples,
1007
- inputs=[in_video, in_stream_url, in_stream_start, in_stream_end, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
1008
- outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist],
1009
  fn=demo_inference, cache_examples=False)
1010
 
1011
 
 
46
  sess_options = ort.SessionOptions()
47
  #sess_options.log_severity_level = 0
48
  ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
49
+ use_cuda = True
50
  else:
51
  print("Using CPU")
52
  ort_sess = ort.InferenceSession(onnx_file)
53
+ use_cuda = False
54
 
55
  # warmup inference
56
  ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)})
 
184
  return event_start, event_end
185
 
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def upload_video(out_text, in_video):
188
  if out_text != '':
189
  # generate a timestamp name for the video
 
223
 
224
 
225
 
226
+ def inference(in_video, use_60fps,
227
+ beep_detection_on, event_length,
228
  count_only_api, api_key, seq_len=64, stride_length=32, stride_pad=3, batch_size=2,
229
  miss_threshold=0.5, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
230
  api_call=False,
231
  progress=gr.Progress()):
 
232
  print(in_video)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  if in_video is None:
234
+ return "No video input provided."
235
+ in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), '00:00:00', '', use_60fps=use_60fps, use_cuda=use_cuda)
 
 
 
236
  progress(0, desc="Running inference...")
237
  has_access = False
238
  if api_call:
 
244
  event_length = int(event_length)
245
  event_start, event_end = detect_beeps(in_video, event_length)
246
  print(event_start, event_end)
 
 
 
 
247
 
248
  cap = cv2.VideoCapture(in_video)
249
  length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
382
  print(f"Event detected: {event_start} - {event_end}")
383
  periodicity[:event_start] = 0
384
  periodicity[event_end:] = 0
 
 
 
 
 
 
385
  pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
386
  full_marks_mask = np.zeros(len(full_marks))
387
  full_marks_mask[pred_marks_peaks] = 1
 
527
  periodicity_str = np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
528
  periodLength_str = np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
529
  full_marks_str = np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
530
+ f.write(f"{in_video}\t{beep_detection_on}\t{event_length}\t{periodicity_str}\t{periodLength_str}\t{full_marks_str}\t{count_pred}\t{total_confidence}\n")
531
  if count_only_api:
532
  return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
533
  else:
 
562
  if beep_detection_on:
563
  results_dict['event_start'] = event_start
564
  results_dict['event_end'] = event_end
 
 
 
565
  return json.dumps(results_dict)
566
 
567
 
 
647
  if beep_detection_on:
648
  # add vertical lines for beep event
649
  fig.add_vrect(x0=event_start / fps, x1=event_end / fps, fillcolor="LightSalmon", opacity=0.25, layer="below", line_width=0)
 
 
 
 
 
 
650
 
651
 
652
  fig.update_layout(legend=dict(
 
750
  histnorm='percent',
751
  title="Distribution of jumping speed (jumps-per-second)")
752
 
 
 
 
 
 
 
 
 
753
  try:
754
  os.remove('temp.wav')
755
  except FileNotFoundError:
756
  pass
757
 
758
+ return count_msg, fig, fig_phase_spiral, fig_phase_spiral_marks, hist
759
 
760
  #css = '#phase-spiral {transform: rotate(0.25turn);}\n#phase-spiral-marks {transform: rotate(0.25turn);}'
761
  with gr.Blocks() as demo:
 
764
  width=400, height=400, interactive=True, container=True,
765
  max_length=300)
766
  with gr.Row():
 
 
 
 
 
 
 
 
 
 
 
767
  with gr.Column():
768
  gr.Markdown(
769
  """
770
  ### Inference Options
771
+ Select the framerate for inference.
772
  """,
773
  elem_id='inference-options',
774
  )
775
  use_60fps = gr.Checkbox(label="Use 60 FPS", elem_id='use-60fps', visible=True)
 
 
 
 
776
  with gr.Column():
777
  gr.Markdown(
778
  """
 
783
  )
784
  beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True)
785
  event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True)
 
 
 
786
 
787
  with gr.Row():
788
  run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
 
806
  with gr.Row():
807
  with gr.Column():
808
  out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
 
 
809
 
810
 
811
  demo_inference = partial(inference, count_only_api=False, api_key=None)
812
 
813
+ run_button.click(demo_inference, [in_video, use_60fps, beep_detection_on, event_length],
814
+ outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist])
815
  api_inference = partial(inference, api_call=True)
816
+ api_dummy_button.click(api_inference, [in_video, use_60fps, beep_detection_on, event_length, count_only, api_token],
817
  outputs=[period_length], api_name='inference')
818
  examples = [
819
+ ['files/wc2023.mp4', True, True, 30],
 
 
 
820
  ]
821
  gr.Examples(examples,
822
+ inputs=[in_video, use_60fps, beep_detection_on, event_length],
823
+ outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist],
824
  fn=demo_inference, cache_examples=False)
825
 
826
 
hls_download.py CHANGED
@@ -1,26 +1,19 @@
1
  import subprocess
2
  import os
3
 
4
- def download_clips(stream_url, out_dir, start_time, end_time, resize=True, use_60fps=False):
5
  # remove all .mp4 files in out_dir to avoid confusion
6
  if len(os.listdir(out_dir)) > 5:
7
  for f in os.listdir(out_dir):
8
  if f.endswith('.mp4'):
9
  os.remove(os.path.join(out_dir, f))
10
- os.makedirs(out_dir, exist_ok=True)
11
  output_file = os.path.join(out_dir, f"train_{len(os.listdir(out_dir))}.mp4")
12
- try:
13
- os.remove(output_file)
14
- except FileNotFoundError:
15
- pass
16
  if resize: # resize and convert to 30 fps
17
  ffmpeg_cmd = [
18
  'ffmpeg',
19
- #'-hwaccel', 'cuda',
20
- '-allowed_extensions', 'ALL',
21
- '-extension_picky', '0', # https://github.com/yt-dlp/yt-dlp/issues/12700#issuecomment-2745400091
22
  '-i', stream_url,
23
- '-c:v', 'libx264',
24
  '-crf', '23',
25
  '-r', '30' if not use_60fps else '60',
26
  '-maxrate', '2M',
@@ -38,7 +31,6 @@ def download_clips(stream_url, out_dir, start_time, end_time, resize=True, use_6
38
  except subprocess.CalledProcessError as e:
39
  print(f"Error occurred: {e}")
40
  print(f"ffmpeg output: {e.output}")
41
- print(f"ffmpeg stderr: {e.stderr}")
42
  return output_file
43
  # else:
44
  # os.rename(tmp_file, output_file)
 
1
  import subprocess
2
  import os
3
 
4
+ def download_clips(stream_url, out_dir, start_time, end_time, resize=True, use_60fps=False, use_cuda=True):
5
  # remove all .mp4 files in out_dir to avoid confusion
6
  if len(os.listdir(out_dir)) > 5:
7
  for f in os.listdir(out_dir):
8
  if f.endswith('.mp4'):
9
  os.remove(os.path.join(out_dir, f))
 
10
  output_file = os.path.join(out_dir, f"train_{len(os.listdir(out_dir))}.mp4")
 
 
 
 
11
  if resize: # resize and convert to 30 fps
12
  ffmpeg_cmd = [
13
  'ffmpeg',
14
+ '-hwaccel', 'cuda' if use_cuda else 'none',
 
 
15
  '-i', stream_url,
16
+ #'-c:v', 'libx264',
17
  '-crf', '23',
18
  '-r', '30' if not use_60fps else '60',
19
  '-maxrate', '2M',
 
31
  except subprocess.CalledProcessError as e:
32
  print(f"Error occurred: {e}")
33
  print(f"ffmpeg output: {e.output}")
 
34
  return output_file
35
  # else:
36
  # os.rename(tmp_file, output_file)