dylanplummer commited on
Commit
f73de5f
·
1 Parent(s): 5a7cfc7

add tensorrt if available

Browse files
Files changed (1) hide show
  1. app.py +86 -122
app.py CHANGED
@@ -7,7 +7,6 @@ import math
7
  import matplotlib
8
  matplotlib.use('Agg')
9
  import matplotlib.pyplot as plt
10
- import multiprocessing as mp
11
  from scipy.signal import medfilt, find_peaks
12
  from functools import partial
13
  from passlib.hash import pbkdf2_sha256
@@ -26,16 +25,16 @@ from hls_download import download_clips
26
 
27
  plt.style.use('dark_background')
28
 
29
- onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
30
- #onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump_fp16.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
31
- # model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
32
- # hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
33
- #model_xml = "model_ir/model.xml"
34
 
35
  # ie = Core()
36
  # model_ir = ie.read_model(model=model_xml)
37
- # config = {"PERFORMANCE_HINT": "LATENCY"}
38
- # compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU", config=config)
39
 
40
 
41
  class SquarePad:
@@ -59,32 +58,18 @@ def create_transform(img_size):
59
  transforms.ToTensor(),
60
  ])
61
 
62
- def preprocess_frame(img, img_size):
63
- preprocess = create_transform(img_size)
64
- frameTensor = preprocess(img).unsqueeze(0)
65
- return frameTensor * 255
66
-
67
-
68
- def worker_function(frame_queue, batch_queue, img_size, seq_len):
69
- while True:
70
- frames = frame_queue.get()
71
- if frames is None: # Signal to exit
72
- break
73
- batch = torch.cat([preprocess_frame(img, img_size) for img in frames])
74
- batch_queue.put(batch)
75
-
76
 
77
  def inference(stream_url, start_time, end_time, count_only_api, api_key,
78
  img_size=288, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
79
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
80
  api_call=False,
81
  progress=gr.Progress()):
82
- progress(0, desc="Starting...")
83
  x = download_clips(stream_url, os.getcwd(), start_time, end_time)
84
  # check if GPU is available
85
  if torch.cuda.is_available():
86
- providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
87
- "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
88
  sess_options = ort.SessionOptions()
89
  sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
90
  ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
@@ -96,7 +81,7 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
96
  if api_call:
97
  has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
98
  if not has_access:
99
- return "Invalid API Key"
100
 
101
  cap = cv2.VideoCapture(x)
102
  length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
@@ -127,37 +112,26 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
127
  event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
128
  for _ in range(seq_len + stride_length): # pad full sequence
129
  all_frames.append(all_frames[-1])
130
-
131
- num_workers = mp.cpu_count() # Use all available CPU cores
132
- frame_queue = mp.Queue(maxsize=num_workers * 2)
133
- batch_queue = mp.Queue(maxsize=num_workers * 2)
134
-
135
- # Start worker processes
136
- processes = []
137
- for _ in range(num_workers):
138
- p = mp.Process(target=worker_function, args=(frame_queue, batch_queue, img_size, seq_len))
139
- p.start()
140
- processes.append(p)
141
-
142
- # Enqueue frame batches
143
- for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
144
- batch = all_frames[i:i + seq_len]
145
- if len(batch) < seq_len:
146
- batch.extend([batch[-1]] * (seq_len - len(batch)))
147
- frame_queue.put(batch)
148
-
149
- # Signal workers to exit after all frames are processed
150
- for _ in range(num_workers):
151
- frame_queue.put(None)
152
-
153
  batch_list = []
154
  idx_list = []
155
- #preprocess = create_transform(img_size)
156
  for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
157
- X = batch_queue.get()
 
 
 
 
 
 
 
 
 
 
 
 
158
  batch_list.append(X.unsqueeze(0))
159
  idx_list.append(i)
160
-
161
  if len(batch_list) == batch_size:
162
  batch_X = torch.cat(batch_list)
163
  outputs = ort_sess.run(None, {'video': batch_X.numpy()})
@@ -178,10 +152,7 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
178
  event_type_logit_overlaps[idx:idx+seq_len] += 1
179
  batch_list = []
180
  idx_list = []
181
- progress(i / (length + stride_length - stride_pad), desc="Processing...")
182
- # Wait for all processes to finish
183
- for p in processes:
184
- p.join()
185
  if len(batch_list) != 0: # still some leftover frames
186
  while len(batch_list) != batch_size:
187
  batch_list.append(batch_list[-1])
@@ -258,19 +229,19 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
258
  total_confidence = confidence * (1 - self_pct_err)
259
 
260
  if both_feet:
261
- count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
262
  else:
263
- count_msg = f"## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
264
 
265
  if api_call:
266
  if count_only_api:
267
- return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
268
  else:
269
- return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
270
- np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
271
- np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
272
- f"reps: {count_pred:.2f}, marks: {marks_count_pred:.1f}, confidence: {total_confidence:.2f}", \
273
- f"single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}"
274
 
275
 
276
  jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.01), 0, 10)
@@ -303,25 +274,25 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
303
  size_max=8,
304
  color_continuous_scale=[(t, c) for t, c in zip(event_type_tick_vals, event_type_colors)],
305
  range_color=(0,1),
306
- title="Jumping speed (jumps-per-second)",
307
  trendline='rolling',
308
  trendline_options=dict(window=16),
309
- trendline_color_override="goldenrod",
310
  trendline_scope='overall',
311
- template="plotly_dark")
312
 
313
  fig.update_layout(legend=dict(
314
- orientation="h",
315
- yanchor="bottom",
316
  y=0.98,
317
- xanchor="right",
318
  x=1,
319
  font=dict(
320
- family="Courier",
321
  size=12,
322
- color="black"
323
  ),
324
- bgcolor="AliceBlue",
325
  ),
326
  paper_bgcolor='rgba(0,0,0,0)',
327
  plot_bgcolor='rgba(0,0,0,0)'
@@ -335,71 +306,64 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
335
  ))
336
 
337
  hist = px.histogram(df,
338
- x="jumps per second",
339
- template="plotly_dark",
340
- marginal="box",
341
  histnorm='percent',
342
- title="Distribution of jumping speed (jumps-per-second)")
343
 
344
  # make a bar plot of the event type distribution
345
 
346
  bar = px.bar(x=['single rope', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
347
  y=event_type_probs,
348
- template="plotly_dark",
349
- title="Event Type Distribution",
350
  labels={'x': 'event type', 'y': 'probability'},
351
  range_y=[0, 1])
352
 
353
  return x, count_msg, fig, hist, bar
354
 
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
-
358
-
359
- if __name__ == "__main__":
360
- DESCRIPTION = '# NextJump 🦘'
361
- DESCRIPTION += '\n## AI Counting for Competitive Jump Rope'
362
- DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
363
-
364
-
365
- with gr.Blocks() as demo:
366
- gr.Markdown(DESCRIPTION)
367
- # in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
368
- # width=400, height=400, interactive=True, container=True,
369
- # max_length=150)
370
  with gr.Row():
371
- in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
372
  with gr.Column():
373
- in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True)
374
- with gr.Column():
375
- in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
376
- with gr.Column(min_width=480):
377
- out_video = gr.PlayableVideo(label="Video Clip", elem_id='output-video', format='mp4', width=400, height=400)
378
-
379
  with gr.Row():
380
- run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
381
- api_dummy_button = gr.Button(value="Run (No Viz)", elem_id='count-only', visible=False, scale=2)
382
- count_only = gr.Checkbox(label="Count Only", visible=False)
383
- api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False)
 
384
 
385
- with gr.Column(elem_id='output-video-container'):
386
- with gr.Row():
387
- with gr.Column():
388
- out_text = gr.Markdown(label="Predicted Count", elem_id='output-text')
389
- period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False)
390
- periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
391
- with gr.Row():
392
- out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
393
- with gr.Row():
394
- with gr.Column():
395
- out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
396
- with gr.Column():
397
- out_event_type_dist = gr.Plot(label="Event Type Distribution", elem_id='output-event-type-dist')
398
-
399
 
400
- demo_inference = partial(inference, count_only_api=False, api_key=None)
401
-
402
- run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
403
- api_inference = partial(inference, api_call=True)
404
- api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
405
  demo.queue(api_open=True, max_size=15).launch(share=False)
 
7
  import matplotlib
8
  matplotlib.use('Agg')
9
  import matplotlib.pyplot as plt
 
10
  from scipy.signal import medfilt, find_peaks
11
  from functools import partial
12
  from passlib.hash import pbkdf2_sha256
 
25
 
26
  plt.style.use('dark_background')
27
 
28
+ onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
29
+ #onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_fp16.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
30
+ # model_xml = hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.xml', repo_type='model', token=os.environ['DATASET_SECRET'])
31
+ # hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.mapping', repo_type='model', token=os.environ['DATASET_SECRET'])
32
+ #model_xml = 'model_ir/model.xml'
33
 
34
  # ie = Core()
35
  # model_ir = ie.read_model(model=model_xml)
36
+ # config = {'PERFORMANCE_HINT': 'LATENCY'}
37
+ # compiled_model_ir = ie.compile_model(model=model_ir, device_name='CPU', config=config)
38
 
39
 
40
  class SquarePad:
 
58
  transforms.ToTensor(),
59
  ])
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def inference(stream_url, start_time, end_time, count_only_api, api_key,
63
  img_size=288, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
64
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
65
  api_call=False,
66
  progress=gr.Progress()):
67
+ progress(0, desc='Starting...')
68
  x = download_clips(stream_url, os.getcwd(), start_time, end_time)
69
  # check if GPU is available
70
  if torch.cuda.is_available():
71
+ providers = ['TensorrtExecutionProvider', ('CUDAExecutionProvider', {'device_id': torch.cuda.current_device(),
72
+ 'user_compute_stream': str(torch.cuda.current_stream().cuda_stream)})]
73
  sess_options = ort.SessionOptions()
74
  sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
75
  ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
 
81
  if api_call:
82
  has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
83
  if not has_access:
84
+ return 'Invalid API Key'
85
 
86
  cap = cv2.VideoCapture(x)
87
  length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
112
  event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
113
  for _ in range(seq_len + stride_length): # pad full sequence
114
  all_frames.append(all_frames[-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  batch_list = []
116
  idx_list = []
117
+ preprocess = create_transform(img_size)
118
  for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
119
+ batch = all_frames[i:i + seq_len]
120
+ Xlist = []
121
+ print('Preprocessing...')
122
+ for img in batch:
123
+ frameTensor = preprocess(img).unsqueeze(0)
124
+ Xlist.append(frameTensor)
125
+
126
+ if len(Xlist) < seq_len:
127
+ for _ in range(seq_len - len(Xlist)):
128
+ Xlist.append(Xlist[-1])
129
+
130
+ X = torch.cat(Xlist)
131
+ X *= 255
132
  batch_list.append(X.unsqueeze(0))
133
  idx_list.append(i)
134
+ print('Running inference...')
135
  if len(batch_list) == batch_size:
136
  batch_X = torch.cat(batch_list)
137
  outputs = ort_sess.run(None, {'video': batch_X.numpy()})
 
152
  event_type_logit_overlaps[idx:idx+seq_len] += 1
153
  batch_list = []
154
  idx_list = []
155
+ progress(i / (length + stride_length - stride_pad), desc='Processing...')
 
 
 
156
  if len(batch_list) != 0: # still some leftover frames
157
  while len(batch_list) != batch_size:
158
  batch_list.append(batch_list[-1])
 
229
  total_confidence = confidence * (1 - self_pct_err)
230
 
231
  if both_feet:
232
+ count_msg = f'## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}'
233
  else:
234
+ count_msg = f'## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}'
235
 
236
  if api_call:
237
  if count_only_api:
238
+ return f'{count_pred:.2f} (conf: {total_confidence:.2f})'
239
  else:
240
+ return np.array2string(periodLength, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
241
+ np.array2string(periodicity, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
242
+ np.array2string(full_marks, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
243
+ f'reps: {count_pred:.2f}, marks: {marks_count_pred:.1f}, confidence: {total_confidence:.2f}', \
244
+ f'single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}'
245
 
246
 
247
  jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.01), 0, 10)
 
274
  size_max=8,
275
  color_continuous_scale=[(t, c) for t, c in zip(event_type_tick_vals, event_type_colors)],
276
  range_color=(0,1),
277
+ title='Jumping speed (jumps-per-second)',
278
  trendline='rolling',
279
  trendline_options=dict(window=16),
280
+ trendline_color_override='goldenrod',
281
  trendline_scope='overall',
282
+ template='plotly_dark')
283
 
284
  fig.update_layout(legend=dict(
285
+ orientation='h',
286
+ yanchor='bottom',
287
  y=0.98,
288
+ xanchor='right',
289
  x=1,
290
  font=dict(
291
+ family='Courier',
292
  size=12,
293
+ color='black'
294
  ),
295
+ bgcolor='AliceBlue',
296
  ),
297
  paper_bgcolor='rgba(0,0,0,0)',
298
  plot_bgcolor='rgba(0,0,0,0)'
 
306
  ))
307
 
308
  hist = px.histogram(df,
309
+ x='jumps per second',
310
+ template='plotly_dark',
311
+ marginal='box',
312
  histnorm='percent',
313
+ title='Distribution of jumping speed (jumps-per-second)')
314
 
315
  # make a bar plot of the event type distribution
316
 
317
  bar = px.bar(x=['single rope', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
318
  y=event_type_probs,
319
+ template='plotly_dark',
320
+ title='Event Type Distribution',
321
  labels={'x': 'event type', 'y': 'probability'},
322
  range_y=[0, 1])
323
 
324
  return x, count_msg, fig, hist, bar
325
 
326
 
327
+ with gr.Blocks() as demo:
328
+ # in_video = gr.PlayableVideo(label='Input Video', elem_id='input-video', format='mp4',
329
+ # width=400, height=400, interactive=True, container=True,
330
+ # max_length=150)
331
+ with gr.Row():
332
+ in_stream_url = gr.Textbox(label='Stream URL', elem_id='stream-url', visible=True)
333
+ with gr.Column():
334
+ in_stream_start = gr.Textbox(label='Start Time', elem_id='stream-start', visible=True)
335
+ with gr.Column():
336
+ in_stream_end = gr.Textbox(label='End Time', elem_id='stream-end', visible=True)
337
+ with gr.Column(min_width=480):
338
+ out_video = gr.PlayableVideo(label='Video Clip', elem_id='output-video', format='mp4', width=400, height=400)
339
+
340
+ with gr.Row():
341
+ run_button = gr.Button(value='Run', elem_id='run-button', scale=1)
342
+ api_dummy_button = gr.Button(value='Run (No Viz)', elem_id='count-only', visible=False, scale=2)
343
+ count_only = gr.Checkbox(label='Count Only', visible=False)
344
+ api_token = gr.Textbox(label='API Key', elem_id='api-token', visible=False)
345
 
346
+ with gr.Column(elem_id='output-video-container'):
 
 
 
 
 
 
 
 
 
 
 
 
347
  with gr.Row():
 
348
  with gr.Column():
349
+ out_text = gr.Markdown(label='Predicted Count', elem_id='output-text')
350
+ period_length = gr.Textbox(label='Period Length', elem_id='period-length', visible=False)
351
+ periodicity = gr.Textbox(label='Periodicity', elem_id='periodicity', visible=False)
352
+ with gr.Row():
353
+ out_plot = gr.Plot(label='Jumping Speed', elem_id='output-plot')
 
354
  with gr.Row():
355
+ with gr.Column():
356
+ out_hist = gr.Plot(label='Speed Histogram', elem_id='output-hist')
357
+ with gr.Column():
358
+ out_event_type_dist = gr.Plot(label='Event Type Distribution', elem_id='output-event-type-dist')
359
+
360
 
361
+ demo_inference = partial(inference, count_only_api=False, api_key=None)
362
+
363
+ run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
364
+ api_inference = partial(inference, api_call=True)
365
+ api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
 
 
 
 
 
 
 
 
 
366
 
367
+
368
+ if __name__ == '__main__':
 
 
 
369
  demo.queue(api_open=True, max_size=15).launch(share=False)