dylanplummer commited on
Commit
0340538
·
1 Parent(s): 35202d4

add warmup

Browse files
Files changed (2) hide show
  1. app.py +373 -368
  2. requirements.txt +1 -2
app.py CHANGED
@@ -1,369 +1,374 @@
1
- import gradio as gr
2
- import numpy as np
3
- from PIL import Image
4
- import os
5
- import cv2
6
- import math
7
- import matplotlib
8
- matplotlib.use('Agg')
9
- import matplotlib.pyplot as plt
10
- from scipy.signal import medfilt, find_peaks
11
- from functools import partial
12
- from passlib.hash import pbkdf2_sha256
13
- from tqdm import tqdm
14
- import pandas as pd
15
- import plotly.express as px
16
- import onnxruntime as ort
17
- import torch
18
- from torchvision import transforms
19
- import torchvision.transforms.functional as F
20
-
21
- from huggingface_hub import hf_hub_download
22
- from huggingface_hub import HfApi
23
-
24
- from hls_download import download_clips
25
-
26
- plt.style.use('dark_background')
27
-
28
- onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
29
- #onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_fp16.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
30
- # model_xml = hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.xml', repo_type='model', token=os.environ['DATASET_SECRET'])
31
- # hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.mapping', repo_type='model', token=os.environ['DATASET_SECRET'])
32
- #model_xml = 'model_ir/model.xml'
33
-
34
- # ie = Core()
35
- # model_ir = ie.read_model(model=model_xml)
36
- # config = {'PERFORMANCE_HINT': 'LATENCY'}
37
- # compiled_model_ir = ie.compile_model(model=model_ir, device_name='CPU', config=config)
38
-
39
-
40
- class SquarePad:
41
- # https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
42
- def __call__(self, image):
43
- w, h = image.size
44
- max_wh = max(w, h)
45
- hp = int((max_wh - w) / 2)
46
- vp = int((max_wh - h) / 2)
47
- padding = (hp, vp, hp, vp)
48
- return F.pad(image, padding, 0, 'constant')
49
-
50
- def sigmoid(x):
51
- return 1 / (1 + np.exp(-x))
52
-
53
-
54
- def create_transform(img_size):
55
- return transforms.Compose([
56
- SquarePad(),
57
- transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC),
58
- transforms.ToTensor(),
59
- ])
60
-
61
-
62
- def inference(stream_url, start_time, end_time, count_only_api, api_key,
63
- img_size=288, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
64
- miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
65
- api_call=False,
66
- progress=gr.Progress()):
67
- progress(0, desc='Starting...')
68
- x = download_clips(stream_url, os.getcwd(), start_time, end_time)
69
- # check if GPU is available
70
- if torch.cuda.is_available():
71
- providers = ['TensorrtExecutionProvider', ('CUDAExecutionProvider', {'device_id': torch.cuda.current_device(),
72
- 'user_compute_stream': str(torch.cuda.current_stream().cuda_stream)})]
73
- sess_options = ort.SessionOptions()
74
- sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
75
- ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
76
- else:
77
- ort_sess = ort.InferenceSession(onnx_file)
78
- #api = HfApi(token=os.environ['DATASET_SECRET'])
79
- #out_file = str(uuid.uuid1())
80
- has_access = False
81
- if api_call:
82
- has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
83
- if not has_access:
84
- return 'Invalid API Key'
85
-
86
- cap = cv2.VideoCapture(x)
87
- length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
88
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
89
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
90
- period_length_overlaps = np.zeros(length + seq_len)
91
- fps = int(cap.get(cv2.CAP_PROP_FPS))
92
- seconds = length / fps
93
- all_frames = []
94
- frame_i = 1
95
- while cap.isOpened():
96
- ret, frame = cap.read()
97
- if ret is False:
98
- frame = all_frames[-1] # padding will be with last frame
99
- break
100
- frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
101
- img = Image.fromarray(frame)
102
- all_frames.append(img)
103
- frame_i += 1
104
- cap.release()
105
-
106
- length = len(all_frames)
107
- period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
108
- periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
109
- full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
110
- event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
111
- period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
112
- event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
113
- for _ in range(seq_len + stride_length): # pad full sequence
114
- all_frames.append(all_frames[-1])
115
- batch_list = []
116
- idx_list = []
117
- preprocess = create_transform(img_size)
118
- for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
119
- batch = all_frames[i:i + seq_len]
120
- Xlist = []
121
- print('Preprocessing...')
122
- for img in batch:
123
- frameTensor = preprocess(img).unsqueeze(0)
124
- Xlist.append(frameTensor)
125
-
126
- if len(Xlist) < seq_len:
127
- for _ in range(seq_len - len(Xlist)):
128
- Xlist.append(Xlist[-1])
129
-
130
- X = torch.cat(Xlist)
131
- X *= 255
132
- batch_list.append(X.unsqueeze(0))
133
- idx_list.append(i)
134
- print('Running inference...')
135
- if len(batch_list) == batch_size:
136
- batch_X = torch.cat(batch_list)
137
- outputs = ort_sess.run(None, {'video': batch_X.numpy()})
138
- y1pred = outputs[0]
139
- y2pred = outputs[1]
140
- y3pred = outputs[2]
141
- y4pred = outputs[3]
142
- for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
143
- periodLength = y1.squeeze()
144
- periodicity = y2.squeeze()
145
- marks = y3.squeeze()
146
- event_type = y4.squeeze()
147
- period_lengths[idx:idx+seq_len] += periodLength
148
- periodicities[idx:idx+seq_len] += periodicity
149
- full_marks[idx:idx+seq_len] += marks
150
- event_type_logits[idx:idx+seq_len] += event_type
151
- period_length_overlaps[idx:idx+seq_len] += 1
152
- event_type_logit_overlaps[idx:idx+seq_len] += 1
153
- batch_list = []
154
- idx_list = []
155
- progress(i / (length + stride_length - stride_pad), desc='Processing...')
156
- if len(batch_list) != 0: # still some leftover frames
157
- while len(batch_list) != batch_size:
158
- batch_list.append(batch_list[-1])
159
- idx_list.append(idx_list[-1])
160
- batch_X = torch.cat(batch_list)
161
- outputs = ort_sess.run(None, {'video': batch_X.numpy()})
162
- y1pred = outputs[0]
163
- y2pred = outputs[1]
164
- y3pred = outputs[2]
165
- y4pred = outputs[3]
166
- for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
167
- periodLength = y1.squeeze()
168
- periodicity = y2.squeeze()
169
- marks = y3.squeeze()
170
- event_type = y4.squeeze()
171
- period_lengths[idx:idx+seq_len] += periodLength
172
- periodicities[idx:idx+seq_len] += periodicity
173
- full_marks[idx:idx+seq_len] += marks
174
- event_type_logits[idx:idx+seq_len] += event_type
175
- period_length_overlaps[idx:idx+seq_len] += 1
176
- event_type_logit_overlaps[idx:idx+seq_len] += 1
177
-
178
- periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
179
- periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
180
- full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
181
- per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
182
- event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
183
- # softmax of event type logits
184
- event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
185
- per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1)
186
-
187
- if median_pred_filter:
188
- periodicity = medfilt(periodicity, 5)
189
- periodLength = medfilt(periodLength, 5)
190
- periodicity = sigmoid(periodicity)
191
- full_marks = sigmoid(full_marks)
192
- #full_marks_mask = np.int32(full_marks > marks_threshold)
193
- pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
194
- full_marks_mask = np.zeros(len(full_marks))
195
- full_marks_mask[pred_marks_peaks] = 1
196
- periodicity_mask = np.int32(periodicity > miss_threshold)
197
- numofReps = 0
198
- count = []
199
- for i in range(len(periodLength)):
200
- if periodLength[i] < 2 or periodicity_mask[i] == 0:
201
- numofReps += 0
202
- elif full_marks_mask[i]: # high confidence mark detected
203
- if math.modf(numofReps)[0] < 0.2: # probably false positive/late detection
204
- numofReps = float(int(numofReps))
205
- else:
206
- numofReps = float(int(numofReps) + 1.01) # round up
207
- else:
208
- numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
209
- count.append(round(float(numofReps), 2))
210
- count_pred = count[-1]
211
- marks_count_pred = 0
212
- for i in range(len(full_marks) - 1):
213
- # if a jump was counted, and periodicity is high, and the next frame was not counted (to avoid double counting)
214
- if full_marks_mask[i] > 0 and periodicity_mask[i] > 0 and full_marks_mask[i + 1] == 0:
215
- marks_count_pred += 1
216
- if not both_feet:
217
- count_pred = count_pred / 2
218
- marks_count_pred = marks_count_pred / 2
219
- count = np.array(count) / 2
220
- try:
221
- confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
222
- except ZeroDivisionError:
223
- confidence = 0
224
- self_err = abs(count_pred - marks_count_pred)
225
- try:
226
- self_pct_err = self_err / count_pred
227
- except ZeroDivisionError:
228
- self_pct_err = 0
229
- total_confidence = confidence * (1 - self_pct_err)
230
-
231
- if both_feet:
232
- count_msg = f'## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}'
233
- else:
234
- count_msg = f'## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}'
235
-
236
- if api_call:
237
- if count_only_api:
238
- return f'{count_pred:.2f} (conf: {total_confidence:.2f})'
239
- else:
240
- return np.array2string(periodLength, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
241
- np.array2string(periodicity, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
242
- np.array2string(full_marks, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
243
- f'reps: {count_pred:.2f}, marks: {marks_count_pred:.1f}, confidence: {total_confidence:.2f}', \
244
- f'single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}'
245
-
246
-
247
- jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.01), 0, 10)
248
- jumping_speed = np.copy(jumps_per_second)
249
- misses = periodicity < miss_threshold
250
- jumps_per_second[misses] = 0
251
- frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
252
- frame_type[full_marks > marks_threshold] = 'jump'
253
- per_frame_event_types = np.clip(per_frame_event_types, 0, 6) / 6
254
- df = pd.DataFrame.from_dict({'period length': periodLength,
255
- 'jumping speed': jumping_speed,
256
- 'jumps per second': jumps_per_second,
257
- 'periodicity': periodicity,
258
- 'miss': misses,
259
- 'frame_type': frame_type,
260
- 'event_type': per_frame_event_types,
261
- 'jumps': full_marks,
262
- 'jumps_size': (full_marks + 0.05) * 10,
263
- 'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 8),
264
- 'seconds': np.linspace(0, seconds, num=len(periodLength))})
265
- event_type_tick_vals = np.linspace(0, 1, num=7)
266
- event_type_colors = ['red', 'orange', 'green', 'blue', 'purple', 'pink', 'black']
267
- fig = px.scatter(data_frame=df,
268
- x='seconds',
269
- y='jumps per second',
270
- #symbol='frame_type',
271
- #symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
272
- color='event_type',
273
- size='jumps_size',
274
- size_max=8,
275
- color_continuous_scale=[(t, c) for t, c in zip(event_type_tick_vals, event_type_colors)],
276
- range_color=(0,1),
277
- title='Jumping speed (jumps-per-second)',
278
- trendline='rolling',
279
- trendline_options=dict(window=16),
280
- trendline_color_override='goldenrod',
281
- trendline_scope='overall',
282
- template='plotly_dark')
283
-
284
- fig.update_layout(legend=dict(
285
- orientation='h',
286
- yanchor='bottom',
287
- y=0.98,
288
- xanchor='right',
289
- x=1,
290
- font=dict(
291
- family='Courier',
292
- size=12,
293
- color='black'
294
- ),
295
- bgcolor='AliceBlue',
296
- ),
297
- paper_bgcolor='rgba(0,0,0,0)',
298
- plot_bgcolor='rgba(0,0,0,0)'
299
- )
300
- # remove white outline from marks
301
- fig.update_traces(marker_line_width = 0)
302
- fig.update_layout(coloraxis_colorbar=dict(
303
- tickvals=event_type_tick_vals,
304
- ticktext=['single<br>rope', 'double<br>dutch', 'double<br>unders', 'single<br>bounces', 'double<br>bounces', 'triple<br>unders', 'other'],
305
- title='event type'
306
- ))
307
-
308
- hist = px.histogram(df,
309
- x='jumps per second',
310
- template='plotly_dark',
311
- marginal='box',
312
- histnorm='percent',
313
- title='Distribution of jumping speed (jumps-per-second)')
314
-
315
- # make a bar plot of the event type distribution
316
-
317
- bar = px.bar(x=['single rope', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
318
- y=event_type_probs,
319
- template='plotly_dark',
320
- title='Event Type Distribution',
321
- labels={'x': 'event type', 'y': 'probability'},
322
- range_y=[0, 1])
323
-
324
- return x, count_msg, fig, hist, bar
325
-
326
-
327
- with gr.Blocks() as demo:
328
- # in_video = gr.PlayableVideo(label='Input Video', elem_id='input-video', format='mp4',
329
- # width=400, height=400, interactive=True, container=True,
330
- # max_length=150)
331
- with gr.Row():
332
- in_stream_url = gr.Textbox(label='Stream URL', elem_id='stream-url', visible=True)
333
- with gr.Column():
334
- in_stream_start = gr.Textbox(label='Start Time', elem_id='stream-start', visible=True)
335
- with gr.Column():
336
- in_stream_end = gr.Textbox(label='End Time', elem_id='stream-end', visible=True)
337
- with gr.Column(min_width=480):
338
- out_video = gr.PlayableVideo(label='Video Clip', elem_id='output-video', format='mp4', width=400, height=400)
339
-
340
- with gr.Row():
341
- run_button = gr.Button(value='Run', elem_id='run-button', scale=1)
342
- api_dummy_button = gr.Button(value='Run (No Viz)', elem_id='count-only', visible=False, scale=2)
343
- count_only = gr.Checkbox(label='Count Only', visible=False)
344
- api_token = gr.Textbox(label='API Key', elem_id='api-token', visible=False)
345
-
346
- with gr.Column(elem_id='output-video-container'):
347
- with gr.Row():
348
- with gr.Column():
349
- out_text = gr.Markdown(label='Predicted Count', elem_id='output-text')
350
- period_length = gr.Textbox(label='Period Length', elem_id='period-length', visible=False)
351
- periodicity = gr.Textbox(label='Periodicity', elem_id='periodicity', visible=False)
352
- with gr.Row():
353
- out_plot = gr.Plot(label='Jumping Speed', elem_id='output-plot')
354
- with gr.Row():
355
- with gr.Column():
356
- out_hist = gr.Plot(label='Speed Histogram', elem_id='output-hist')
357
- with gr.Column():
358
- out_event_type_dist = gr.Plot(label='Event Type Distribution', elem_id='output-event-type-dist')
359
-
360
-
361
- demo_inference = partial(inference, count_only_api=False, api_key=None)
362
-
363
- run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
364
- api_inference = partial(inference, api_call=True)
365
- api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
366
-
367
-
368
- if __name__ == '__main__':
 
 
 
 
 
369
  demo.queue(api_open=True, max_size=15).launch(share=False)
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import os
5
+ import cv2
6
+ import math
7
+ import matplotlib
8
+ matplotlib.use('Agg')
9
+ import matplotlib.pyplot as plt
10
+ from scipy.signal import medfilt, find_peaks
11
+ from functools import partial
12
+ from passlib.hash import pbkdf2_sha256
13
+ from tqdm import tqdm
14
+ import pandas as pd
15
+ import plotly.express as px
16
+ import onnxruntime as ort
17
+ import torch
18
+ from torchvision import transforms
19
+ import torchvision.transforms.functional as F
20
+
21
+ from huggingface_hub import hf_hub_download
22
+ from huggingface_hub import HfApi
23
+
24
+ from hls_download import download_clips
25
+
26
+ plt.style.use('dark_background')
27
+
28
+ onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
29
+ #onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_fp16.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
30
+ # model_xml = hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.xml', repo_type='model', token=os.environ['DATASET_SECRET'])
31
+ # hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.mapping', repo_type='model', token=os.environ['DATASET_SECRET'])
32
+ #model_xml = 'model_ir/model.xml'
33
+
34
+ # ie = Core()
35
+ # model_ir = ie.read_model(model=model_xml)
36
+ # config = {'PERFORMANCE_HINT': 'LATENCY'}
37
+ # compiled_model_ir = ie.compile_model(model=model_ir, device_name='CPU', config=config)
38
+ # check if GPU is available
39
+ if torch.cuda.is_available():
40
+ providers = [('CUDAExecutionProvider', {'device_id': torch.cuda.current_device(),
41
+ 'user_compute_stream': str(torch.cuda.current_stream().cuda_stream)})]
42
+ sess_options = ort.SessionOptions()
43
+ sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
44
+ ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
45
+ else:
46
+ ort_sess = ort.InferenceSession(onnx_file)
47
+
48
+ print('Warmup...')
49
+ dummy_input = torch.randn(1, 64, 3, 288, 288)
50
+ ort_sess.run(None, {'video': dummy_input.numpy()})
51
+ print('Done!')
52
+
53
+ class SquarePad:
54
+ # https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
55
+ def __call__(self, image):
56
+ w, h = image.size
57
+ max_wh = max(w, h)
58
+ hp = int((max_wh - w) / 2)
59
+ vp = int((max_wh - h) / 2)
60
+ padding = (hp, vp, hp, vp)
61
+ return F.pad(image, padding, 0, 'constant')
62
+
63
+ def sigmoid(x):
64
+ return 1 / (1 + np.exp(-x))
65
+
66
+
67
+ def create_transform(img_size):
68
+ return transforms.Compose([
69
+ SquarePad(),
70
+ transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC),
71
+ transforms.ToTensor(),
72
+ ])
73
+
74
+
75
+ def inference(stream_url, start_time, end_time, count_only_api, api_key,
76
+ img_size=288, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
77
+ miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
78
+ api_call=False,
79
+ progress=gr.Progress()):
80
+ progress(0, desc='Starting...')
81
+ x = download_clips(stream_url, os.getcwd(), start_time, end_time)
82
+
83
+ #api = HfApi(token=os.environ['DATASET_SECRET'])
84
+ #out_file = str(uuid.uuid1())
85
+ has_access = False
86
+ if api_call:
87
+ has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
88
+ if not has_access:
89
+ return 'Invalid API Key'
90
+
91
+ cap = cv2.VideoCapture(x)
92
+ length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
93
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
94
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
95
+ period_length_overlaps = np.zeros(length + seq_len)
96
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
97
+ seconds = length / fps
98
+ all_frames = []
99
+ frame_i = 1
100
+ while cap.isOpened():
101
+ ret, frame = cap.read()
102
+ if ret is False:
103
+ frame = all_frames[-1] # padding will be with last frame
104
+ break
105
+ frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
106
+ img = Image.fromarray(frame)
107
+ all_frames.append(img)
108
+ frame_i += 1
109
+ cap.release()
110
+
111
+ length = len(all_frames)
112
+ period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
113
+ periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
114
+ full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
115
+ event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
116
+ period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
117
+ event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
118
+ for _ in range(seq_len + stride_length): # pad full sequence
119
+ all_frames.append(all_frames[-1])
120
+ batch_list = []
121
+ idx_list = []
122
+ preprocess = create_transform(img_size)
123
+ for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
124
+ batch = all_frames[i:i + seq_len]
125
+ Xlist = []
126
+ print('Preprocessing...')
127
+ for img in batch:
128
+ frameTensor = preprocess(img).unsqueeze(0)
129
+ Xlist.append(frameTensor)
130
+
131
+ if len(Xlist) < seq_len:
132
+ for _ in range(seq_len - len(Xlist)):
133
+ Xlist.append(Xlist[-1])
134
+
135
+ X = torch.cat(Xlist)
136
+ X *= 255
137
+ batch_list.append(X.unsqueeze(0))
138
+ idx_list.append(i)
139
+ print('Running inference...')
140
+ if len(batch_list) == batch_size:
141
+ batch_X = torch.cat(batch_list)
142
+ outputs = ort_sess.run(None, {'video': batch_X.numpy()})
143
+ y1pred = outputs[0]
144
+ y2pred = outputs[1]
145
+ y3pred = outputs[2]
146
+ y4pred = outputs[3]
147
+ for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
148
+ periodLength = y1.squeeze()
149
+ periodicity = y2.squeeze()
150
+ marks = y3.squeeze()
151
+ event_type = y4.squeeze()
152
+ period_lengths[idx:idx+seq_len] += periodLength
153
+ periodicities[idx:idx+seq_len] += periodicity
154
+ full_marks[idx:idx+seq_len] += marks
155
+ event_type_logits[idx:idx+seq_len] += event_type
156
+ period_length_overlaps[idx:idx+seq_len] += 1
157
+ event_type_logit_overlaps[idx:idx+seq_len] += 1
158
+ batch_list = []
159
+ idx_list = []
160
+ progress(i / (length + stride_length - stride_pad), desc='Processing...')
161
+ if len(batch_list) != 0: # still some leftover frames
162
+ while len(batch_list) != batch_size:
163
+ batch_list.append(batch_list[-1])
164
+ idx_list.append(idx_list[-1])
165
+ batch_X = torch.cat(batch_list)
166
+ outputs = ort_sess.run(None, {'video': batch_X.numpy()})
167
+ y1pred = outputs[0]
168
+ y2pred = outputs[1]
169
+ y3pred = outputs[2]
170
+ y4pred = outputs[3]
171
+ for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
172
+ periodLength = y1.squeeze()
173
+ periodicity = y2.squeeze()
174
+ marks = y3.squeeze()
175
+ event_type = y4.squeeze()
176
+ period_lengths[idx:idx+seq_len] += periodLength
177
+ periodicities[idx:idx+seq_len] += periodicity
178
+ full_marks[idx:idx+seq_len] += marks
179
+ event_type_logits[idx:idx+seq_len] += event_type
180
+ period_length_overlaps[idx:idx+seq_len] += 1
181
+ event_type_logit_overlaps[idx:idx+seq_len] += 1
182
+
183
+ periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
184
+ periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
185
+ full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
186
+ per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
187
+ event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
188
+ # softmax of event type logits
189
+ event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
190
+ per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1)
191
+
192
+ if median_pred_filter:
193
+ periodicity = medfilt(periodicity, 5)
194
+ periodLength = medfilt(periodLength, 5)
195
+ periodicity = sigmoid(periodicity)
196
+ full_marks = sigmoid(full_marks)
197
+ #full_marks_mask = np.int32(full_marks > marks_threshold)
198
+ pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
199
+ full_marks_mask = np.zeros(len(full_marks))
200
+ full_marks_mask[pred_marks_peaks] = 1
201
+ periodicity_mask = np.int32(periodicity > miss_threshold)
202
+ numofReps = 0
203
+ count = []
204
+ for i in range(len(periodLength)):
205
+ if periodLength[i] < 2 or periodicity_mask[i] == 0:
206
+ numofReps += 0
207
+ elif full_marks_mask[i]: # high confidence mark detected
208
+ if math.modf(numofReps)[0] < 0.2: # probably false positive/late detection
209
+ numofReps = float(int(numofReps))
210
+ else:
211
+ numofReps = float(int(numofReps) + 1.01) # round up
212
+ else:
213
+ numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
214
+ count.append(round(float(numofReps), 2))
215
+ count_pred = count[-1]
216
+ marks_count_pred = 0
217
+ for i in range(len(full_marks) - 1):
218
+ # if a jump was counted, and periodicity is high, and the next frame was not counted (to avoid double counting)
219
+ if full_marks_mask[i] > 0 and periodicity_mask[i] > 0 and full_marks_mask[i + 1] == 0:
220
+ marks_count_pred += 1
221
+ if not both_feet:
222
+ count_pred = count_pred / 2
223
+ marks_count_pred = marks_count_pred / 2
224
+ count = np.array(count) / 2
225
+ try:
226
+ confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
227
+ except ZeroDivisionError:
228
+ confidence = 0
229
+ self_err = abs(count_pred - marks_count_pred)
230
+ try:
231
+ self_pct_err = self_err / count_pred
232
+ except ZeroDivisionError:
233
+ self_pct_err = 0
234
+ total_confidence = confidence * (1 - self_pct_err)
235
+
236
+ if both_feet:
237
+ count_msg = f'## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}'
238
+ else:
239
+ count_msg = f'## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}'
240
+
241
+ if api_call:
242
+ if count_only_api:
243
+ return f'{count_pred:.2f} (conf: {total_confidence:.2f})'
244
+ else:
245
+ return np.array2string(periodLength, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
246
+ np.array2string(periodicity, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
247
+ np.array2string(full_marks, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
248
+ f'reps: {count_pred:.2f}, marks: {marks_count_pred:.1f}, confidence: {total_confidence:.2f}', \
249
+ f'single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}'
250
+
251
+
252
+ jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.01), 0, 10)
253
+ jumping_speed = np.copy(jumps_per_second)
254
+ misses = periodicity < miss_threshold
255
+ jumps_per_second[misses] = 0
256
+ frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
257
+ frame_type[full_marks > marks_threshold] = 'jump'
258
+ per_frame_event_types = np.clip(per_frame_event_types, 0, 6) / 6
259
+ df = pd.DataFrame.from_dict({'period length': periodLength,
260
+ 'jumping speed': jumping_speed,
261
+ 'jumps per second': jumps_per_second,
262
+ 'periodicity': periodicity,
263
+ 'miss': misses,
264
+ 'frame_type': frame_type,
265
+ 'event_type': per_frame_event_types,
266
+ 'jumps': full_marks,
267
+ 'jumps_size': (full_marks + 0.05) * 10,
268
+ 'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 8),
269
+ 'seconds': np.linspace(0, seconds, num=len(periodLength))})
270
+ event_type_tick_vals = np.linspace(0, 1, num=7)
271
+ event_type_colors = ['red', 'orange', 'green', 'blue', 'purple', 'pink', 'black']
272
+ fig = px.scatter(data_frame=df,
273
+ x='seconds',
274
+ y='jumps per second',
275
+ #symbol='frame_type',
276
+ #symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
277
+ color='event_type',
278
+ size='jumps_size',
279
+ size_max=8,
280
+ color_continuous_scale=[(t, c) for t, c in zip(event_type_tick_vals, event_type_colors)],
281
+ range_color=(0,1),
282
+ title='Jumping speed (jumps-per-second)',
283
+ trendline='rolling',
284
+ trendline_options=dict(window=16),
285
+ trendline_color_override='goldenrod',
286
+ trendline_scope='overall',
287
+ template='plotly_dark')
288
+
289
+ fig.update_layout(legend=dict(
290
+ orientation='h',
291
+ yanchor='bottom',
292
+ y=0.98,
293
+ xanchor='right',
294
+ x=1,
295
+ font=dict(
296
+ family='Courier',
297
+ size=12,
298
+ color='black'
299
+ ),
300
+ bgcolor='AliceBlue',
301
+ ),
302
+ paper_bgcolor='rgba(0,0,0,0)',
303
+ plot_bgcolor='rgba(0,0,0,0)'
304
+ )
305
+ # remove white outline from marks
306
+ fig.update_traces(marker_line_width = 0)
307
+ fig.update_layout(coloraxis_colorbar=dict(
308
+ tickvals=event_type_tick_vals,
309
+ ticktext=['single<br>rope', 'double<br>dutch', 'double<br>unders', 'single<br>bounces', 'double<br>bounces', 'triple<br>unders', 'other'],
310
+ title='event type'
311
+ ))
312
+
313
+ hist = px.histogram(df,
314
+ x='jumps per second',
315
+ template='plotly_dark',
316
+ marginal='box',
317
+ histnorm='percent',
318
+ title='Distribution of jumping speed (jumps-per-second)')
319
+
320
+ # make a bar plot of the event type distribution
321
+
322
+ bar = px.bar(x=['single rope', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
323
+ y=event_type_probs,
324
+ template='plotly_dark',
325
+ title='Event Type Distribution',
326
+ labels={'x': 'event type', 'y': 'probability'},
327
+ range_y=[0, 1])
328
+
329
+ return x, count_msg, fig, hist, bar
330
+
331
+
332
+ with gr.Blocks() as demo:
333
+ # in_video = gr.PlayableVideo(label='Input Video', elem_id='input-video', format='mp4',
334
+ # width=400, height=400, interactive=True, container=True,
335
+ # max_length=150)
336
+ with gr.Row():
337
+ in_stream_url = gr.Textbox(label='Stream URL', elem_id='stream-url', visible=True)
338
+ with gr.Column():
339
+ in_stream_start = gr.Textbox(label='Start Time', elem_id='stream-start', visible=True)
340
+ with gr.Column():
341
+ in_stream_end = gr.Textbox(label='End Time', elem_id='stream-end', visible=True)
342
+ with gr.Column(min_width=480):
343
+ out_video = gr.PlayableVideo(label='Video Clip', elem_id='output-video', format='mp4', width=400, height=400)
344
+
345
+ with gr.Row():
346
+ run_button = gr.Button(value='Run', elem_id='run-button', scale=1)
347
+ api_dummy_button = gr.Button(value='Run (No Viz)', elem_id='count-only', visible=False, scale=2)
348
+ count_only = gr.Checkbox(label='Count Only', visible=False)
349
+ api_token = gr.Textbox(label='API Key', elem_id='api-token', visible=False)
350
+
351
+ with gr.Column(elem_id='output-video-container'):
352
+ with gr.Row():
353
+ with gr.Column():
354
+ out_text = gr.Markdown(label='Predicted Count', elem_id='output-text')
355
+ period_length = gr.Textbox(label='Period Length', elem_id='period-length', visible=False)
356
+ periodicity = gr.Textbox(label='Periodicity', elem_id='periodicity', visible=False)
357
+ with gr.Row():
358
+ out_plot = gr.Plot(label='Jumping Speed', elem_id='output-plot')
359
+ with gr.Row():
360
+ with gr.Column():
361
+ out_hist = gr.Plot(label='Speed Histogram', elem_id='output-hist')
362
+ with gr.Column():
363
+ out_event_type_dist = gr.Plot(label='Event Type Distribution', elem_id='output-event-type-dist')
364
+
365
+
366
+ demo_inference = partial(inference, count_only_api=False, api_key=None)
367
+
368
+ run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
369
+ api_inference = partial(inference, api_call=True)
370
+ api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
371
+
372
+
373
+ if __name__ == '__main__':
374
  demo.queue(api_open=True, max_size=15).launch(share=False)
requirements.txt CHANGED
@@ -10,5 +10,4 @@ opencv-python-headless==4.7.0.68
10
  torch
11
  torchvision
12
  onnxruntime-gpu
13
- yt-dlp
14
- nvidia-tensorrt
 
10
  torch
11
  torchvision
12
  onnxruntime-gpu
13
+ yt-dlp