dylanplummer commited on
Commit
1431cde
·
1 Parent(s): 392e794

optimize inference

Browse files
Files changed (1) hide show
  1. app.py +100 -103
app.py CHANGED
@@ -4,10 +4,10 @@ from PIL import Image
4
  import os
5
  import cv2
6
  import math
7
- import spaces
8
  import matplotlib
9
  matplotlib.use('Agg')
10
  import matplotlib.pyplot as plt
 
11
  from scipy.signal import medfilt, find_peaks
12
  from functools import partial
13
  from passlib.hash import pbkdf2_sha256
@@ -26,15 +26,20 @@ from hls_download import download_clips
26
 
27
  plt.style.use('dark_background')
28
 
 
 
29
  onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
30
- # model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
31
- # hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
32
- #model_xml = "model_ir/model.xml"
 
 
 
 
 
 
 
33
 
34
- # ie = Core()
35
- # model_ir = ie.read_model(model=model_xml)
36
- # config = {"PERFORMANCE_HINT": "LATENCY"}
37
- # compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU", config=config)
38
 
39
 
40
  class SquarePad:
@@ -46,52 +51,72 @@ class SquarePad:
46
  vp = int((max_wh - h) / 2)
47
  padding = (hp, vp, hp, vp)
48
  return F.pad(image, padding, 0, 'constant')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  def sigmoid(x):
51
  return 1 / (1 + np.exp(-x))
52
 
53
 
54
- @spaces.GPU()
55
  def inference(stream_url, start_time, end_time, count_only_api, api_key,
56
  img_size=256, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
57
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
58
  api_call=False,
59
  progress=gr.Progress()):
60
- progress(0, desc="Starting...")
61
- x = download_clips(stream_url, os.getcwd(), start_time, end_time)
62
- # check if GPU is available
63
- if torch.cuda.is_available():
64
- providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
65
- "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
66
- sess_options = ort.SessionOptions()
67
- ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
68
- else:
69
- ort_sess = ort.InferenceSession(onnx_file)
70
- #api = HfApi(token=os.environ['DATASET_SECRET'])
71
- #out_file = str(uuid.uuid1())
72
  has_access = False
73
  if api_call:
74
  has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
75
  if not has_access:
76
  return "Invalid API Key"
77
 
78
- cap = cv2.VideoCapture(x)
79
  length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
80
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
81
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
82
  period_length_overlaps = np.zeros(length + seq_len)
83
  fps = int(cap.get(cv2.CAP_PROP_FPS))
84
  seconds = length / fps
85
  all_frames = []
86
  frame_i = 1
 
87
  while cap.isOpened():
88
  ret, frame = cap.read()
89
  if ret is False:
90
  frame = all_frames[-1] # padding will be with last frame
91
  break
92
  frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
93
- img = Image.fromarray(frame)
94
- all_frames.append(img)
 
 
 
 
 
 
 
95
  frame_i += 1
96
  cap.release()
97
 
@@ -106,47 +131,45 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
106
  all_frames.append(all_frames[-1])
107
  batch_list = []
108
  idx_list = []
109
- for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
110
- batch = all_frames[i:i + seq_len]
111
- Xlist = []
112
- print('Preprocessing...')
113
- for img in batch:
114
- transforms_list = []
115
- # if center_crop:
116
- # if width > height:
117
- # transforms_list.append(transforms.Resize((int(width / (height / img_size)), img_size)))
118
- # else:
119
- # transforms_list.append(transforms.Resize((img_size, int(height / (width / img_size)))))
120
- # transforms_list.append(transforms.CenterCrop((img_size, img_size)))
121
- # else:
122
- transforms_list.append(SquarePad())
123
- transforms_list.append(transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC))
124
 
125
-
126
- transforms_list += [
127
- transforms.ToTensor()]
128
- #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
129
- preprocess = transforms.Compose(transforms_list)
130
- frameTensor = preprocess(img).unsqueeze(0)
131
- Xlist.append(frameTensor)
132
-
133
- if len(Xlist) < seq_len:
134
- for _ in range(seq_len - len(Xlist)):
135
- Xlist.append(Xlist[-1])
 
 
 
 
 
 
 
 
 
 
136
 
137
- X = torch.cat(Xlist)
138
- X *= 255
139
- batch_list.append(X.unsqueeze(0))
140
- idx_list.append(i)
141
- print('Running inference...')
142
- if len(batch_list) == batch_size:
143
- batch_X = torch.cat(batch_list)
144
- outputs = ort_sess.run(None, {'video': batch_X.numpy()})
145
- y1pred = outputs[0]
146
- y2pred = outputs[1]
147
- y3pred = outputs[2]
148
- y4pred = outputs[3]
149
- for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
150
  periodLength = y1.squeeze()
151
  periodicity = y2.squeeze()
152
  marks = y3.squeeze()
@@ -157,30 +180,6 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
157
  event_type_logits[idx:idx+seq_len] += event_type
158
  period_length_overlaps[idx:idx+seq_len] += 1
159
  event_type_logit_overlaps[idx:idx+seq_len] += 1
160
- batch_list = []
161
- idx_list = []
162
- progress(i / (length + stride_length - stride_pad), desc="Processing...")
163
- if len(batch_list) != 0: # still some leftover frames
164
- while len(batch_list) != batch_size:
165
- batch_list.append(batch_list[-1])
166
- idx_list.append(idx_list[-1])
167
- batch_X = torch.cat(batch_list)
168
- outputs = ort_sess.run(None, {'video': batch_X.numpy()})
169
- y1pred = outputs[0]
170
- y2pred = outputs[1]
171
- y3pred = outputs[2]
172
- y4pred = outputs[3]
173
- for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
174
- periodLength = y1.squeeze()
175
- periodicity = y2.squeeze()
176
- marks = y3.squeeze()
177
- event_type = y4.squeeze()
178
- period_lengths[idx:idx+seq_len] += periodLength
179
- periodicities[idx:idx+seq_len] += periodicity
180
- full_marks[idx:idx+seq_len] += marks
181
- event_type_logits[idx:idx+seq_len] += event_type
182
- period_length_overlaps[idx:idx+seq_len] += 1
183
- event_type_logit_overlaps[idx:idx+seq_len] += 1
184
 
185
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
186
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
@@ -196,7 +195,6 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
196
  periodLength = medfilt(periodLength, 5)
197
  periodicity = sigmoid(periodicity)
198
  full_marks = sigmoid(full_marks)
199
- #full_marks_mask = np.int32(full_marks > marks_threshold)
200
  pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
201
  full_marks_mask = np.zeros(len(full_marks))
202
  full_marks_mask[pred_marks_peaks] = 1
@@ -328,24 +326,15 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
328
  labels={'x': 'event type', 'y': 'probability'},
329
  range_y=[0, 1])
330
 
331
- return x, count_msg, fig, hist, bar
332
 
333
 
334
- DESCRIPTION = '# NextJump 🦘'
335
- DESCRIPTION += '\n## AI Counting for Competitive Jump Rope'
336
- DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
337
-
338
-
339
- with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
340
- gr.Markdown(DESCRIPTION)
341
- # in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
342
- # width=400, height=400, interactive=True, container=True,
343
- # max_length=150)
344
  with gr.Row():
345
- in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
346
  with gr.Column():
347
- in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True)
348
  with gr.Column():
 
349
  in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
350
  with gr.Column(min_width=480):
351
  out_video = gr.PlayableVideo(label="Video Clip", elem_id='output-video', format='mp4', width=400, height=400)
@@ -376,7 +365,15 @@ with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
376
  run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
377
  api_inference = partial(inference, api_call=True)
378
  api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
 
 
 
 
 
 
 
379
 
380
 
381
  if __name__ == "__main__":
 
382
  demo.queue(api_open=True, max_size=15).launch(share=False)
 
4
  import os
5
  import cv2
6
  import math
 
7
  import matplotlib
8
  matplotlib.use('Agg')
9
  import matplotlib.pyplot as plt
10
+ import concurrent.futures
11
  from scipy.signal import medfilt, find_peaks
12
  from functools import partial
13
  from passlib.hash import pbkdf2_sha256
 
26
 
27
  plt.style.use('dark_background')
28
 
29
+ IMG_SIZE = 256
30
+
31
  onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
32
+ if torch.cuda.is_available():
33
+ providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
34
+ "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
35
+ sess_options = ort.SessionOptions()
36
+ ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
37
+ else:
38
+ ort_sess = ort.InferenceSession(onnx_file)
39
+
40
+ # warmup inference
41
+ ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)})
42
 
 
 
 
 
43
 
44
 
45
  class SquarePad:
 
51
  vp = int((max_wh - h) / 2)
52
  padding = (hp, vp, hp, vp)
53
  return F.pad(image, padding, 0, 'constant')
54
+
55
+ def square_pad_opencv(image):
56
+ h, w = image.shape[:2]
57
+ max_wh = max(w, h)
58
+ hp = int((max_wh - w) / 2)
59
+ vp = int((max_wh - h) / 2)
60
+ return cv2.copyMakeBorder(image, vp, vp, hp, hp, cv2.BORDER_CONSTANT, value=[0, 0, 0])
61
+
62
+
63
+ def preprocess_image(img, img_size):
64
+ #img = square_pad_opencv(img)
65
+ #img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC)
66
+ img = Image.fromarray(img)
67
+ transforms_list = []
68
+ transforms_list.append(transforms.ToTensor())
69
+ preprocess = transforms.Compose(transforms_list)
70
+ return preprocess(img).unsqueeze(0)
71
+
72
+ def run_inference(batch_X):
73
+ batch_X = torch.cat(batch_X)
74
+ return ort_sess.run(None, {'video': batch_X.numpy()})
75
+
76
 
77
  def sigmoid(x):
78
  return 1 / (1 + np.exp(-x))
79
 
80
 
 
81
  def inference(stream_url, start_time, end_time, count_only_api, api_key,
82
  img_size=256, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
83
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
84
  api_call=False,
85
  progress=gr.Progress()):
86
+ progress(0, desc="Downloading clip...")
87
+ in_video = download_clips(stream_url, os.getcwd(), start_time, end_time)
88
+ progress(0, desc="Running inference...")
 
 
 
 
 
 
 
 
 
89
  has_access = False
90
  if api_call:
91
  has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
92
  if not has_access:
93
  return "Invalid API Key"
94
 
95
+ cap = cv2.VideoCapture(in_video)
96
  length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
97
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
98
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
99
  period_length_overlaps = np.zeros(length + seq_len)
100
  fps = int(cap.get(cv2.CAP_PROP_FPS))
101
  seconds = length / fps
102
  all_frames = []
103
  frame_i = 1
104
+ resize_size = max(frame_width, frame_height)
105
  while cap.isOpened():
106
  ret, frame = cap.read()
107
  if ret is False:
108
  frame = all_frames[-1] # padding will be with last frame
109
  break
110
  frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
111
+ # add square padding with opencv
112
+ #frame = square_pad_opencv(frame)
113
+ frame = cv2.resize(frame, (resize_size, resize_size), interpolation=cv2.INTER_CUBIC)
114
+ frame_center_x = frame.shape[1] // 2
115
+ frame_center_y = frame.shape[0] // 2
116
+ crop_x = frame_center_x - img_size // 2
117
+ crop_y = frame_center_y - img_size // 2
118
+ frame = frame[crop_y:crop_y+img_size, crop_x:crop_x+img_size]
119
+ all_frames.append(frame)
120
  frame_i += 1
121
  cap.release()
122
 
 
131
  all_frames.append(all_frames[-1])
132
  batch_list = []
133
  idx_list = []
134
+ inference_futures = []
135
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
136
+ for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
137
+ batch = all_frames[i:i + seq_len]
138
+ Xlist = []
139
+ preprocess_tasks = [(idx, executor.submit(preprocess_image, img, img_size)) for idx, img in enumerate(batch)]
140
+ for idx, future in sorted(preprocess_tasks, key=lambda x: x[0]):
141
+ Xlist.append(future.result())
 
 
 
 
 
 
 
142
 
143
+ if len(Xlist) < seq_len:
144
+ for _ in range(seq_len - len(Xlist)):
145
+ Xlist.append(Xlist[-1])
146
+
147
+ X = torch.cat(Xlist)
148
+ X *= 255
149
+ batch_list.append(X.unsqueeze(0))
150
+ idx_list.append(i)
151
+
152
+ if len(batch_list) == batch_size:
153
+ future = executor.submit(run_inference, batch_list)
154
+ inference_futures.append((batch_list, idx_list, future))
155
+ batch_list = []
156
+ idx_list = []
157
+ # Process any remaining batches
158
+ if batch_list:
159
+ while len(batch_list) != batch_size:
160
+ batch_list.append(batch_list[-1])
161
+ idx_list.append(idx_list[-1])
162
+ future = executor.submit(run_inference, batch_list)
163
+ inference_futures.append((batch_list, idx_list, future))
164
 
165
+ # Collect and process the inference results
166
+ for batch_list, idx_list, future in inference_futures:
167
+ outputs = future.result()
168
+ y1_out = outputs[0]
169
+ y2_out = outputs[1]
170
+ y3_out = outputs[2]
171
+ y4_out = outputs[3]
172
+ for y1, y2, y3, y4, idx in zip(y1_out, y2_out, y3_out, y4_out, idx_list):
 
 
 
 
 
173
  periodLength = y1.squeeze()
174
  periodicity = y2.squeeze()
175
  marks = y3.squeeze()
 
180
  event_type_logits[idx:idx+seq_len] += event_type
181
  period_length_overlaps[idx:idx+seq_len] += 1
182
  event_type_logit_overlaps[idx:idx+seq_len] += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
185
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
 
195
  periodLength = medfilt(periodLength, 5)
196
  periodicity = sigmoid(periodicity)
197
  full_marks = sigmoid(full_marks)
 
198
  pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
199
  full_marks_mask = np.zeros(len(full_marks))
200
  full_marks_mask[pred_marks_peaks] = 1
 
326
  labels={'x': 'event type', 'y': 'probability'},
327
  range_y=[0, 1])
328
 
329
+ return in_video, count_msg, fig, hist, bar
330
 
331
 
332
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
333
  with gr.Row():
 
334
  with gr.Column():
335
+ in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
336
  with gr.Column():
337
+ in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True)
338
  in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
339
  with gr.Column(min_width=480):
340
  out_video = gr.PlayableVideo(label="Video Clip", elem_id='output-video', format='mp4', width=400, height=400)
 
365
  run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
366
  api_inference = partial(inference, api_call=True)
367
  api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
368
+ examples = [
369
+ ['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '00:43:10', '00:43:40'],
370
+ ]
371
+ gr.Examples(examples,
372
+ inputs=[in_stream_url, in_stream_start, in_stream_end],
373
+ outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist],
374
+ fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
375
 
376
 
377
  if __name__ == "__main__":
378
+
379
  demo.queue(api_open=True, max_size=15).launch(share=False)