dylanplummer commited on
Commit
5a7cfc7
·
1 Parent(s): 4e386bd

multiprocessing

Browse files
Files changed (1) hide show
  1. app.py +91 -61
app.py CHANGED
@@ -7,6 +7,7 @@ import math
7
  import matplotlib
8
  matplotlib.use('Agg')
9
  import matplotlib.pyplot as plt
 
10
  from scipy.signal import medfilt, find_peaks
11
  from functools import partial
12
  from passlib.hash import pbkdf2_sha256
@@ -25,8 +26,8 @@ from hls_download import download_clips
25
 
26
  plt.style.use('dark_background')
27
 
28
- #onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
29
- onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump_fp16.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
30
  # model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
31
  # hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
32
  #model_xml = "model_ir/model.xml"
@@ -58,6 +59,20 @@ def create_transform(img_size):
58
  transforms.ToTensor(),
59
  ])
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  def inference(stream_url, start_time, end_time, count_only_api, api_key,
63
  img_size=288, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
@@ -112,29 +127,40 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
112
  event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
113
  for _ in range(seq_len + stride_length): # pad full sequence
114
  all_frames.append(all_frames[-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  batch_list = []
116
  idx_list = []
117
- preprocess = create_transform(img_size)
118
  for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
119
- batch = all_frames[i:i + seq_len]
120
- Xlist = []
121
- print('Preprocessing...')
122
- for img in batch:
123
- frameTensor = preprocess(img).unsqueeze(0)
124
- Xlist.append(frameTensor)
125
-
126
- if len(Xlist) < seq_len:
127
- for _ in range(seq_len - len(Xlist)):
128
- Xlist.append(Xlist[-1])
129
-
130
- X = torch.cat(Xlist)
131
- X *= 255
132
  batch_list.append(X.unsqueeze(0))
133
  idx_list.append(i)
134
- print('Running inference...')
135
  if len(batch_list) == batch_size:
136
  batch_X = torch.cat(batch_list)
137
- outputs = ort_sess.run(None, {'video': np.float16(batch_X.numpy())})
138
  y1pred = outputs[0]
139
  y2pred = outputs[1]
140
  y3pred = outputs[2]
@@ -153,12 +179,15 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
153
  batch_list = []
154
  idx_list = []
155
  progress(i / (length + stride_length - stride_pad), desc="Processing...")
 
 
 
156
  if len(batch_list) != 0: # still some leftover frames
157
  while len(batch_list) != batch_size:
158
  batch_list.append(batch_list[-1])
159
  idx_list.append(idx_list[-1])
160
  batch_X = torch.cat(batch_list)
161
- outputs = ort_sess.run(None, {'video': np.float16(batch_X.numpy())})
162
  y1pred = outputs[0]
163
  y2pred = outputs[1]
164
  y3pred = outputs[2]
@@ -324,52 +353,53 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
324
  return x, count_msg, fig, hist, bar
325
 
326
 
327
- DESCRIPTION = '# NextJump 🦘'
328
- DESCRIPTION += '\n## AI Counting for Competitive Jump Rope'
329
- DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
330
-
331
-
332
- with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
333
- gr.Markdown(DESCRIPTION)
334
- # in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
335
- # width=400, height=400, interactive=True, container=True,
336
- # max_length=150)
337
- with gr.Row():
338
- in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
339
- with gr.Column():
340
- in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True)
341
- with gr.Column():
342
- in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
343
- with gr.Column(min_width=480):
344
- out_video = gr.PlayableVideo(label="Video Clip", elem_id='output-video', format='mp4', width=400, height=400)
345
-
346
- with gr.Row():
347
- run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
348
- api_dummy_button = gr.Button(value="Run (No Viz)", elem_id='count-only', visible=False, scale=2)
349
- count_only = gr.Checkbox(label="Count Only", visible=False)
350
- api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False)
351
 
352
- with gr.Column(elem_id='output-video-container'):
353
- with gr.Row():
354
- with gr.Column():
355
- out_text = gr.Markdown(label="Predicted Count", elem_id='output-text')
356
- period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False)
357
- periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
358
- with gr.Row():
359
- out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
 
 
 
 
 
360
  with gr.Row():
 
361
  with gr.Column():
362
- out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
363
  with gr.Column():
364
- out_event_type_dist = gr.Plot(label="Event Type Distribution", elem_id='output-event-type-dist')
365
-
366
-
367
- demo_inference = partial(inference, count_only_api=False, api_key=None)
368
-
369
- run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
370
- api_inference = partial(inference, api_call=True)
371
- api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
 
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
 
374
- if __name__ == "__main__":
 
 
 
 
375
  demo.queue(api_open=True, max_size=15).launch(share=False)
 
7
  import matplotlib
8
  matplotlib.use('Agg')
9
  import matplotlib.pyplot as plt
10
+ import multiprocessing as mp
11
  from scipy.signal import medfilt, find_peaks
12
  from functools import partial
13
  from passlib.hash import pbkdf2_sha256
 
26
 
27
  plt.style.use('dark_background')
28
 
29
+ onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
30
+ #onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump_fp16.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
31
  # model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
32
  # hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
33
  #model_xml = "model_ir/model.xml"
 
59
  transforms.ToTensor(),
60
  ])
61
 
62
+ def preprocess_frame(img, img_size):
63
+ preprocess = create_transform(img_size)
64
+ frameTensor = preprocess(img).unsqueeze(0)
65
+ return frameTensor * 255
66
+
67
+
68
+ def worker_function(frame_queue, batch_queue, img_size, seq_len):
69
+ while True:
70
+ frames = frame_queue.get()
71
+ if frames is None: # Signal to exit
72
+ break
73
+ batch = torch.cat([preprocess_frame(img, img_size) for img in frames])
74
+ batch_queue.put(batch)
75
+
76
 
77
  def inference(stream_url, start_time, end_time, count_only_api, api_key,
78
  img_size=288, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
 
127
  event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
128
  for _ in range(seq_len + stride_length): # pad full sequence
129
  all_frames.append(all_frames[-1])
130
+
131
+ num_workers = mp.cpu_count() # Use all available CPU cores
132
+ frame_queue = mp.Queue(maxsize=num_workers * 2)
133
+ batch_queue = mp.Queue(maxsize=num_workers * 2)
134
+
135
+ # Start worker processes
136
+ processes = []
137
+ for _ in range(num_workers):
138
+ p = mp.Process(target=worker_function, args=(frame_queue, batch_queue, img_size, seq_len))
139
+ p.start()
140
+ processes.append(p)
141
+
142
+ # Enqueue frame batches
143
+ for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
144
+ batch = all_frames[i:i + seq_len]
145
+ if len(batch) < seq_len:
146
+ batch.extend([batch[-1]] * (seq_len - len(batch)))
147
+ frame_queue.put(batch)
148
+
149
+ # Signal workers to exit after all frames are processed
150
+ for _ in range(num_workers):
151
+ frame_queue.put(None)
152
+
153
  batch_list = []
154
  idx_list = []
155
+ #preprocess = create_transform(img_size)
156
  for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
157
+ X = batch_queue.get()
 
 
 
 
 
 
 
 
 
 
 
 
158
  batch_list.append(X.unsqueeze(0))
159
  idx_list.append(i)
160
+
161
  if len(batch_list) == batch_size:
162
  batch_X = torch.cat(batch_list)
163
+ outputs = ort_sess.run(None, {'video': batch_X.numpy()})
164
  y1pred = outputs[0]
165
  y2pred = outputs[1]
166
  y3pred = outputs[2]
 
179
  batch_list = []
180
  idx_list = []
181
  progress(i / (length + stride_length - stride_pad), desc="Processing...")
182
+ # Wait for all processes to finish
183
+ for p in processes:
184
+ p.join()
185
  if len(batch_list) != 0: # still some leftover frames
186
  while len(batch_list) != batch_size:
187
  batch_list.append(batch_list[-1])
188
  idx_list.append(idx_list[-1])
189
  batch_X = torch.cat(batch_list)
190
+ outputs = ort_sess.run(None, {'video': batch_X.numpy()})
191
  y1pred = outputs[0]
192
  y2pred = outputs[1]
193
  y3pred = outputs[2]
 
353
  return x, count_msg, fig, hist, bar
354
 
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
+
358
+
359
+ if __name__ == "__main__":
360
+ DESCRIPTION = '# NextJump 🦘'
361
+ DESCRIPTION += '\n## AI Counting for Competitive Jump Rope'
362
+ DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
363
+
364
+
365
+ with gr.Blocks() as demo:
366
+ gr.Markdown(DESCRIPTION)
367
+ # in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
368
+ # width=400, height=400, interactive=True, container=True,
369
+ # max_length=150)
370
  with gr.Row():
371
+ in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
372
  with gr.Column():
373
+ in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True)
374
  with gr.Column():
375
+ in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
376
+ with gr.Column(min_width=480):
377
+ out_video = gr.PlayableVideo(label="Video Clip", elem_id='output-video', format='mp4', width=400, height=400)
378
+
379
+ with gr.Row():
380
+ run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
381
+ api_dummy_button = gr.Button(value="Run (No Viz)", elem_id='count-only', visible=False, scale=2)
382
+ count_only = gr.Checkbox(label="Count Only", visible=False)
383
+ api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False)
384
 
385
+ with gr.Column(elem_id='output-video-container'):
386
+ with gr.Row():
387
+ with gr.Column():
388
+ out_text = gr.Markdown(label="Predicted Count", elem_id='output-text')
389
+ period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False)
390
+ periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
391
+ with gr.Row():
392
+ out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
393
+ with gr.Row():
394
+ with gr.Column():
395
+ out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
396
+ with gr.Column():
397
+ out_event_type_dist = gr.Plot(label="Event Type Distribution", elem_id='output-event-type-dist')
398
+
399
 
400
+ demo_inference = partial(inference, count_only_api=False, api_key=None)
401
+
402
+ run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
403
+ api_inference = partial(inference, api_call=True)
404
+ api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
405
  demo.queue(api_open=True, max_size=15).launch(share=False)