dylanplummer commited on
Commit
fd5ccd5
·
verified ·
1 Parent(s): c22a4b8

update for mobilenetv4

Browse files
Files changed (1) hide show
  1. app.py +381 -372
app.py CHANGED
@@ -1,373 +1,382 @@
1
- import gradio as gr
2
- import numpy as np
3
- from PIL import Image
4
- import os
5
- import cv2
6
- import math
7
- import matplotlib
8
- matplotlib.use('Agg')
9
- import matplotlib.pyplot as plt
10
- from scipy.signal import medfilt, find_peaks
11
- from functools import partial
12
- from passlib.hash import pbkdf2_sha256
13
- from tqdm import tqdm
14
- import pandas as pd
15
- import plotly.express as px
16
- import onnxruntime as ort
17
- import torch
18
- from torchvision import transforms
19
- import torchvision.transforms.functional as F
20
-
21
- from huggingface_hub import hf_hub_download
22
- from huggingface_hub import HfApi
23
-
24
- from hls_download import download_clips
25
-
26
- plt.style.use('dark_background')
27
-
28
- onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_mobilenetv3.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
29
- #onnx_file = hf_hub_download(repo_id='dylanplummer/ropenet', filename='nextjump_fp16.onnx', repo_type='model', token=os.environ['DATASET_SECRET'])
30
- # model_xml = hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.xml', repo_type='model', token=os.environ['DATASET_SECRET'])
31
- # hf_hub_download(repo_id='dylanplummer/ropenet', filename='model.mapping', repo_type='model', token=os.environ['DATASET_SECRET'])
32
- #model_xml = 'model_ir/model.xml'
33
-
34
- # ie = Core()
35
- # model_ir = ie.read_model(model=model_xml)
36
- # config = {'PERFORMANCE_HINT': 'LATENCY'}
37
- # compiled_model_ir = ie.compile_model(model=model_ir, device_name='CPU', config=config)
38
- # check if GPU is available
39
- if torch.cuda.is_available():
40
- providers = [('CUDAExecutionProvider', {'device_id': torch.cuda.current_device(),
41
- 'user_compute_stream': str(torch.cuda.current_stream().cuda_stream)})]
42
- sess_options = ort.SessionOptions()
43
- sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
44
- ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
45
- else:
46
- ort_sess = ort.InferenceSession(onnx_file)
47
-
48
- print('Warmup...')
49
- dummy_input = torch.randn(4, 64, 3, 224, 224)
50
- ort_sess.run(None, {'video': dummy_input.numpy()})
51
- print('Done!')
52
-
53
- class SquarePad:
54
- # https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
55
- def __call__(self, image):
56
- w, h = image.size
57
- max_wh = max(w, h)
58
- hp = int((max_wh - w) / 2)
59
- vp = int((max_wh - h) / 2)
60
- padding = (hp, vp, hp, vp)
61
- return F.pad(image, padding, 0, 'constant')
62
-
63
- def sigmoid(x):
64
- return 1 / (1 + np.exp(-x))
65
-
66
-
67
- def create_transform(img_size):
68
- return transforms.Compose([
69
- SquarePad(),
70
- transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC),
71
- transforms.ToTensor(),
72
- ])
73
-
74
-
75
- def inference(stream_url, start_time, end_time, count_only_api, api_key,
76
- img_size=224, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
77
- miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
78
- api_call=False,
79
- progress=gr.Progress()):
80
- progress(0, desc='Starting...')
81
- x = download_clips(stream_url, os.getcwd(), start_time, end_time)
82
-
83
- #api = HfApi(token=os.environ['DATASET_SECRET'])
84
- #out_file = str(uuid.uuid1())
85
- has_access = False
86
- if api_call:
87
- has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
88
- if not has_access:
89
- return 'Invalid API Key'
90
-
91
- cap = cv2.VideoCapture(x)
92
- length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
93
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
94
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
95
- period_length_overlaps = np.zeros(length + seq_len)
96
- fps = int(cap.get(cv2.CAP_PROP_FPS))
97
- seconds = length / fps
98
- all_frames = []
99
- frame_i = 1
100
- print('Reading frames...')
101
- while cap.isOpened():
102
- ret, frame = cap.read()
103
- if ret is False:
104
- frame = all_frames[-1] # padding will be with last frame
105
- break
106
- frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
107
- img = Image.fromarray(frame)
108
- all_frames.append(img)
109
- frame_i += 1
110
- cap.release()
111
- print('Done!')
112
-
113
- length = len(all_frames)
114
- period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
115
- periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
116
- full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
117
- event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
118
- period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
119
- event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
120
- for _ in range(seq_len + stride_length): # pad full sequence
121
- all_frames.append(all_frames[-1])
122
- batch_list = []
123
- idx_list = []
124
- preprocess = create_transform(img_size)
125
- for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
126
- batch = all_frames[i:i + seq_len]
127
- Xlist = []
128
- for img in batch:
129
- frameTensor = preprocess(img).unsqueeze(0)
130
- Xlist.append(frameTensor)
131
-
132
- if len(Xlist) < seq_len:
133
- for _ in range(seq_len - len(Xlist)):
134
- Xlist.append(Xlist[-1])
135
-
136
- X = torch.cat(Xlist)
137
- X *= 255
138
- batch_list.append(X.unsqueeze(0))
139
- idx_list.append(i)
140
- if len(batch_list) == batch_size:
141
- batch_X = torch.cat(batch_list)
142
- outputs = ort_sess.run(None, {'video': batch_X.numpy()})
143
- y1pred = outputs[0]
144
- y2pred = outputs[1]
145
- y3pred = outputs[2]
146
- y4pred = outputs[3]
147
- for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
148
- periodLength = y1.squeeze()
149
- periodicity = y2.squeeze()
150
- marks = y3.squeeze()
151
- event_type = y4.squeeze()
152
- period_lengths[idx:idx+seq_len] += periodLength
153
- periodicities[idx:idx+seq_len] += periodicity
154
- full_marks[idx:idx+seq_len] += marks
155
- event_type_logits[idx:idx+seq_len] += event_type
156
- period_length_overlaps[idx:idx+seq_len] += 1
157
- event_type_logit_overlaps[idx:idx+seq_len] += 1
158
- batch_list = []
159
- idx_list = []
160
- progress(i / (length + stride_length - stride_pad), desc='Processing...')
161
- if len(batch_list) != 0: # still some leftover frames
162
- while len(batch_list) != batch_size:
163
- batch_list.append(batch_list[-1])
164
- idx_list.append(idx_list[-1])
165
- batch_X = torch.cat(batch_list)
166
- outputs = ort_sess.run(None, {'video': batch_X.numpy()})
167
- y1pred = outputs[0]
168
- y2pred = outputs[1]
169
- y3pred = outputs[2]
170
- y4pred = outputs[3]
171
- for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
172
- periodLength = y1.squeeze()
173
- periodicity = y2.squeeze()
174
- marks = y3.squeeze()
175
- event_type = y4.squeeze()
176
- period_lengths[idx:idx+seq_len] += periodLength
177
- periodicities[idx:idx+seq_len] += periodicity
178
- full_marks[idx:idx+seq_len] += marks
179
- event_type_logits[idx:idx+seq_len] += event_type
180
- period_length_overlaps[idx:idx+seq_len] += 1
181
- event_type_logit_overlaps[idx:idx+seq_len] += 1
182
-
183
- periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
184
- periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
185
- full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
186
- per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
187
- event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
188
- # softmax of event type logits
189
- event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
190
- per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1)
191
-
192
- if median_pred_filter:
193
- periodicity = medfilt(periodicity, 5)
194
- periodLength = medfilt(periodLength, 5)
195
- periodicity = sigmoid(periodicity)
196
- full_marks = sigmoid(full_marks)
197
- #full_marks_mask = np.int32(full_marks > marks_threshold)
198
- pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
199
- full_marks_mask = np.zeros(len(full_marks))
200
- full_marks_mask[pred_marks_peaks] = 1
201
- periodicity_mask = np.int32(periodicity > miss_threshold)
202
- numofReps = 0
203
- count = []
204
- for i in range(len(periodLength)):
205
- if periodLength[i] < 2 or periodicity_mask[i] == 0:
206
- numofReps += 0
207
- elif full_marks_mask[i]: # high confidence mark detected
208
- if math.modf(numofReps)[0] < 0.2: # probably false positive/late detection
209
- numofReps = float(int(numofReps))
210
- else:
211
- numofReps = float(int(numofReps) + 1.01) # round up
212
- else:
213
- numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
214
- count.append(round(float(numofReps), 2))
215
- count_pred = count[-1]
216
- marks_count_pred = 0
217
- for i in range(len(full_marks) - 1):
218
- # if a jump was counted, and periodicity is high, and the next frame was not counted (to avoid double counting)
219
- if full_marks_mask[i] > 0 and periodicity_mask[i] > 0 and full_marks_mask[i + 1] == 0:
220
- marks_count_pred += 1
221
- if not both_feet:
222
- count_pred = count_pred / 2
223
- marks_count_pred = marks_count_pred / 2
224
- count = np.array(count) / 2
225
- try:
226
- confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
227
- except ZeroDivisionError:
228
- confidence = 0
229
- self_err = abs(count_pred - marks_count_pred)
230
- try:
231
- self_pct_err = self_err / count_pred
232
- except ZeroDivisionError:
233
- self_pct_err = 0
234
- total_confidence = confidence * (1 - self_pct_err)
235
-
236
- if both_feet:
237
- count_msg = f'## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}'
238
- else:
239
- count_msg = f'## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}'
240
-
241
- if api_call:
242
- if count_only_api:
243
- return f'{count_pred:.2f} (conf: {total_confidence:.2f})'
244
- else:
245
- return np.array2string(periodLength, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
246
- np.array2string(periodicity, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
247
- np.array2string(full_marks, formatter={'float_kind':lambda x: '%.2f' % x}).replace('\n', ''), \
248
- f'reps: {count_pred:.2f}, marks: {marks_count_pred:.1f}, confidence: {total_confidence:.2f}', \
249
- f'single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}'
250
-
251
-
252
- jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.01), 0, 10)
253
- jumping_speed = np.copy(jumps_per_second)
254
- misses = periodicity < miss_threshold
255
- jumps_per_second[misses] = 0
256
- frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
257
- frame_type[full_marks > marks_threshold] = 'jump'
258
- per_frame_event_types = np.clip(per_frame_event_types, 0, 6) / 6
259
- df = pd.DataFrame.from_dict({'period length': periodLength,
260
- 'jumping speed': jumping_speed,
261
- 'jumps per second': jumps_per_second,
262
- 'periodicity': periodicity,
263
- 'miss': misses,
264
- 'frame_type': frame_type,
265
- 'event_type': per_frame_event_types,
266
- 'jumps': full_marks,
267
- 'jumps_size': (full_marks + 0.05) * 10,
268
- 'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 8),
269
- 'seconds': np.linspace(0, seconds, num=len(periodLength))})
270
- event_type_tick_vals = np.linspace(0, 1, num=7)
271
- event_type_colors = ['red', 'orange', 'green', 'blue', 'purple', 'pink', 'black']
272
- fig = px.scatter(data_frame=df,
273
- x='seconds',
274
- y='jumps per second',
275
- #symbol='frame_type',
276
- #symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
277
- color='event_type',
278
- size='jumps_size',
279
- size_max=8,
280
- color_continuous_scale=[(t, c) for t, c in zip(event_type_tick_vals, event_type_colors)],
281
- range_color=(0,1),
282
- title='Jumping speed (jumps-per-second)',
283
- trendline='rolling',
284
- trendline_options=dict(window=16),
285
- trendline_color_override='goldenrod',
286
- trendline_scope='overall',
287
- template='plotly_dark')
288
-
289
- fig.update_layout(legend=dict(
290
- orientation='h',
291
- yanchor='bottom',
292
- y=0.98,
293
- xanchor='right',
294
- x=1,
295
- font=dict(
296
- family='Courier',
297
- size=12,
298
- color='black'
299
- ),
300
- bgcolor='AliceBlue',
301
- ),
302
- paper_bgcolor='rgba(0,0,0,0)',
303
- plot_bgcolor='rgba(0,0,0,0)'
304
- )
305
- # remove white outline from marks
306
- fig.update_traces(marker_line_width = 0)
307
- fig.update_layout(coloraxis_colorbar=dict(
308
- tickvals=event_type_tick_vals,
309
- ticktext=['single<br>rope', 'double<br>dutch', 'double<br>unders', 'single<br>bounces', 'double<br>bounces', 'triple<br>unders', 'other'],
310
- title='event type'
311
- ))
312
-
313
- hist = px.histogram(df,
314
- x='jumps per second',
315
- template='plotly_dark',
316
- marginal='box',
317
- histnorm='percent',
318
- title='Distribution of jumping speed (jumps-per-second)')
319
-
320
- # make a bar plot of the event type distribution
321
-
322
- bar = px.bar(x=['single rope', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
323
- y=event_type_probs,
324
- template='plotly_dark',
325
- title='Event Type Distribution',
326
- labels={'x': 'event type', 'y': 'probability'},
327
- range_y=[0, 1])
328
-
329
- return x, count_msg, fig, hist, bar
330
-
331
-
332
- if __name__ == '__main__':
333
- with gr.Blocks() as demo:
334
- # in_video = gr.PlayableVideo(label='Input Video', elem_id='input-video', format='mp4',
335
- # width=400, height=400, interactive=True, container=True,
336
- # max_length=150)
337
- with gr.Row():
338
- in_stream_url = gr.Textbox(label='Stream URL', elem_id='stream-url', visible=True)
339
- with gr.Column():
340
- in_stream_start = gr.Textbox(label='Start Time', elem_id='stream-start', visible=True)
341
- with gr.Column():
342
- in_stream_end = gr.Textbox(label='End Time', elem_id='stream-end', visible=True)
343
- with gr.Column(min_width=480):
344
- out_video = gr.PlayableVideo(label='Video Clip', elem_id='output-video', format='mp4', width=400, height=400)
345
-
346
- with gr.Row():
347
- run_button = gr.Button(value='Run', elem_id='run-button', scale=1)
348
- api_dummy_button = gr.Button(value='Run (No Viz)', elem_id='count-only', visible=False, scale=2)
349
- count_only = gr.Checkbox(label='Count Only', visible=False)
350
- api_token = gr.Textbox(label='API Key', elem_id='api-token', visible=False)
351
-
352
- with gr.Column(elem_id='output-video-container'):
353
- with gr.Row():
354
- with gr.Column():
355
- out_text = gr.Markdown(label='Predicted Count', elem_id='output-text')
356
- period_length = gr.Textbox(label='Period Length', elem_id='period-length', visible=False)
357
- periodicity = gr.Textbox(label='Periodicity', elem_id='periodicity', visible=False)
358
- with gr.Row():
359
- out_plot = gr.Plot(label='Jumping Speed', elem_id='output-plot')
360
- with gr.Row():
361
- with gr.Column():
362
- out_hist = gr.Plot(label='Speed Histogram', elem_id='output-hist')
363
- with gr.Column():
364
- out_event_type_dist = gr.Plot(label='Event Type Distribution', elem_id='output-event-type-dist')
365
-
366
-
367
- demo_inference = partial(inference, count_only_api=False, api_key=None)
368
-
369
- run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
370
- api_inference = partial(inference, api_call=True)
371
- api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
372
-
 
 
 
 
 
 
 
 
 
373
  demo.queue(api_open=True, max_size=15).launch(share=False)
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import os
5
+ import cv2
6
+ import math
7
+ import spaces
8
+ import matplotlib
9
+ matplotlib.use('Agg')
10
+ import matplotlib.pyplot as plt
11
+ from scipy.signal import medfilt, find_peaks
12
+ from functools import partial
13
+ from passlib.hash import pbkdf2_sha256
14
+ from tqdm import tqdm
15
+ import pandas as pd
16
+ import plotly.express as px
17
+ import onnxruntime as ort
18
+ import torch
19
+ from torchvision import transforms
20
+ import torchvision.transforms.functional as F
21
+
22
+ from huggingface_hub import hf_hub_download
23
+ from huggingface_hub import HfApi
24
+
25
+ from hls_download import download_clips
26
+
27
+ plt.style.use('dark_background')
28
+
29
+ onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
30
+ # model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
31
+ # hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
32
+ #model_xml = "model_ir/model.xml"
33
+
34
+ # ie = Core()
35
+ # model_ir = ie.read_model(model=model_xml)
36
+ # config = {"PERFORMANCE_HINT": "LATENCY"}
37
+ # compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU", config=config)
38
+
39
+
40
+ class SquarePad:
41
+ # https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
42
+ def __call__(self, image):
43
+ w, h = image.size
44
+ max_wh = max(w, h)
45
+ hp = int((max_wh - w) / 2)
46
+ vp = int((max_wh - h) / 2)
47
+ padding = (hp, vp, hp, vp)
48
+ return F.pad(image, padding, 0, 'constant')
49
+
50
+ def sigmoid(x):
51
+ return 1 / (1 + np.exp(-x))
52
+
53
+
54
+ @spaces.GPU()
55
+ def inference(stream_url, start_time, end_time, count_only_api, api_key,
56
+ img_size=224, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
57
+ miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
58
+ api_call=False,
59
+ progress=gr.Progress()):
60
+ progress(0, desc="Starting...")
61
+ x = download_clips(stream_url, os.getcwd(), start_time, end_time)
62
+ # check if GPU is available
63
+ if torch.cuda.is_available():
64
+ providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
65
+ "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
66
+ sess_options = ort.SessionOptions()
67
+ ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
68
+ else:
69
+ ort_sess = ort.InferenceSession(onnx_file)
70
+ #api = HfApi(token=os.environ['DATASET_SECRET'])
71
+ #out_file = str(uuid.uuid1())
72
+ has_access = False
73
+ if api_call:
74
+ has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
75
+ if not has_access:
76
+ return "Invalid API Key"
77
+
78
+ cap = cv2.VideoCapture(x)
79
+ length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
80
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
81
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
82
+ period_length_overlaps = np.zeros(length + seq_len)
83
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
84
+ seconds = length / fps
85
+ all_frames = []
86
+ frame_i = 1
87
+ while cap.isOpened():
88
+ ret, frame = cap.read()
89
+ if ret is False:
90
+ frame = all_frames[-1] # padding will be with last frame
91
+ break
92
+ frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
93
+ img = Image.fromarray(frame)
94
+ all_frames.append(img)
95
+ frame_i += 1
96
+ cap.release()
97
+
98
+ length = len(all_frames)
99
+ period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
100
+ periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
101
+ full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
102
+ event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
103
+ period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
104
+ event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
105
+ for _ in range(seq_len + stride_length): # pad full sequence
106
+ all_frames.append(all_frames[-1])
107
+ batch_list = []
108
+ idx_list = []
109
+ for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
110
+ batch = all_frames[i:i + seq_len]
111
+ Xlist = []
112
+ print('Preprocessing...')
113
+ for img in batch:
114
+ transforms_list = []
115
+ # if center_crop:
116
+ # if width > height:
117
+ # transforms_list.append(transforms.Resize((int(width / (height / img_size)), img_size)))
118
+ # else:
119
+ # transforms_list.append(transforms.Resize((img_size, int(height / (width / img_size)))))
120
+ # transforms_list.append(transforms.CenterCrop((img_size, img_size)))
121
+ # else:
122
+ transforms_list.append(SquarePad())
123
+ transforms_list.append(transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC))
124
+
125
+
126
+ transforms_list += [
127
+ transforms.ToTensor()]
128
+ #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
129
+ preprocess = transforms.Compose(transforms_list)
130
+ frameTensor = preprocess(img).unsqueeze(0)
131
+ Xlist.append(frameTensor)
132
+
133
+ if len(Xlist) < seq_len:
134
+ for _ in range(seq_len - len(Xlist)):
135
+ Xlist.append(Xlist[-1])
136
+
137
+ X = torch.cat(Xlist)
138
+ X *= 255
139
+ batch_list.append(X.unsqueeze(0))
140
+ idx_list.append(i)
141
+ print('Running inference...')
142
+ if len(batch_list) == batch_size:
143
+ batch_X = torch.cat(batch_list)
144
+ outputs = ort_sess.run(None, {'video': batch_X.numpy()})
145
+ y1pred = outputs[0]
146
+ y2pred = outputs[1]
147
+ y3pred = outputs[2]
148
+ y4pred = outputs[3]
149
+ for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
150
+ periodLength = y1.squeeze()
151
+ periodicity = y2.squeeze()
152
+ marks = y3.squeeze()
153
+ event_type = y4.squeeze()
154
+ period_lengths[idx:idx+seq_len] += periodLength
155
+ periodicities[idx:idx+seq_len] += periodicity
156
+ full_marks[idx:idx+seq_len] += marks
157
+ event_type_logits[idx:idx+seq_len] += event_type
158
+ period_length_overlaps[idx:idx+seq_len] += 1
159
+ event_type_logit_overlaps[idx:idx+seq_len] += 1
160
+ batch_list = []
161
+ idx_list = []
162
+ progress(i / (length + stride_length - stride_pad), desc="Processing...")
163
+ if len(batch_list) != 0: # still some leftover frames
164
+ while len(batch_list) != batch_size:
165
+ batch_list.append(batch_list[-1])
166
+ idx_list.append(idx_list[-1])
167
+ batch_X = torch.cat(batch_list)
168
+ outputs = ort_sess.run(None, {'video': batch_X.numpy()})
169
+ y1pred = outputs[0]
170
+ y2pred = outputs[1]
171
+ y3pred = outputs[2]
172
+ y4pred = outputs[3]
173
+ for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
174
+ periodLength = y1.squeeze()
175
+ periodicity = y2.squeeze()
176
+ marks = y3.squeeze()
177
+ event_type = y4.squeeze()
178
+ period_lengths[idx:idx+seq_len] += periodLength
179
+ periodicities[idx:idx+seq_len] += periodicity
180
+ full_marks[idx:idx+seq_len] += marks
181
+ event_type_logits[idx:idx+seq_len] += event_type
182
+ period_length_overlaps[idx:idx+seq_len] += 1
183
+ event_type_logit_overlaps[idx:idx+seq_len] += 1
184
+
185
+ periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
186
+ periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
187
+ full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
188
+ per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
189
+ event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
190
+ # softmax of event type logits
191
+ event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
192
+ per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1)
193
+
194
+ if median_pred_filter:
195
+ periodicity = medfilt(periodicity, 5)
196
+ periodLength = medfilt(periodLength, 5)
197
+ periodicity = sigmoid(periodicity)
198
+ full_marks = sigmoid(full_marks)
199
+ #full_marks_mask = np.int32(full_marks > marks_threshold)
200
+ pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
201
+ full_marks_mask = np.zeros(len(full_marks))
202
+ full_marks_mask[pred_marks_peaks] = 1
203
+ periodicity_mask = np.int32(periodicity > miss_threshold)
204
+ numofReps = 0
205
+ count = []
206
+ for i in range(len(periodLength)):
207
+ if periodLength[i] < 2 or periodicity_mask[i] == 0:
208
+ numofReps += 0
209
+ elif full_marks_mask[i]: # high confidence mark detected
210
+ if math.modf(numofReps)[0] < 0.2: # probably false positive/late detection
211
+ numofReps = float(int(numofReps))
212
+ else:
213
+ numofReps = float(int(numofReps) + 1.01) # round up
214
+ else:
215
+ numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
216
+ count.append(round(float(numofReps), 2))
217
+ count_pred = count[-1]
218
+ marks_count_pred = 0
219
+ for i in range(len(full_marks) - 1):
220
+ # if a jump was counted, and periodicity is high, and the next frame was not counted (to avoid double counting)
221
+ if full_marks_mask[i] > 0 and periodicity_mask[i] > 0 and full_marks_mask[i + 1] == 0:
222
+ marks_count_pred += 1
223
+ if not both_feet:
224
+ count_pred = count_pred / 2
225
+ marks_count_pred = marks_count_pred / 2
226
+ count = np.array(count) / 2
227
+ try:
228
+ confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
229
+ except ZeroDivisionError:
230
+ confidence = 0
231
+ self_err = abs(count_pred - marks_count_pred)
232
+ try:
233
+ self_pct_err = self_err / count_pred
234
+ except ZeroDivisionError:
235
+ self_pct_err = 0
236
+ total_confidence = confidence * (1 - self_pct_err)
237
+
238
+ if both_feet:
239
+ count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
240
+ else:
241
+ count_msg = f"## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
242
+
243
+ if api_call:
244
+ if count_only_api:
245
+ return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
246
+ else:
247
+ return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
248
+ np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
249
+ np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
250
+ f"reps: {count_pred:.2f}, marks: {marks_count_pred:.1f}, confidence: {total_confidence:.2f}", \
251
+ f"single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}"
252
+
253
+
254
+ jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.01), 0, 10)
255
+ jumping_speed = np.copy(jumps_per_second)
256
+ misses = periodicity < miss_threshold
257
+ jumps_per_second[misses] = 0
258
+ frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
259
+ frame_type[full_marks > marks_threshold] = 'jump'
260
+ per_frame_event_types = np.clip(per_frame_event_types, 0, 6) / 6
261
+ df = pd.DataFrame.from_dict({'period length': periodLength,
262
+ 'jumping speed': jumping_speed,
263
+ 'jumps per second': jumps_per_second,
264
+ 'periodicity': periodicity,
265
+ 'miss': misses,
266
+ 'frame_type': frame_type,
267
+ 'event_type': per_frame_event_types,
268
+ 'jumps': full_marks,
269
+ 'jumps_size': (full_marks + 0.05) * 10,
270
+ 'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 8),
271
+ 'seconds': np.linspace(0, seconds, num=len(periodLength))})
272
+ event_type_tick_vals = np.linspace(0, 1, num=7)
273
+ event_type_colors = ['red', 'orange', 'green', 'blue', 'purple', 'pink', 'black']
274
+ fig = px.scatter(data_frame=df,
275
+ x='seconds',
276
+ y='jumps per second',
277
+ #symbol='frame_type',
278
+ #symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
279
+ color='event_type',
280
+ size='jumps_size',
281
+ size_max=8,
282
+ color_continuous_scale=[(t, c) for t, c in zip(event_type_tick_vals, event_type_colors)],
283
+ range_color=(0,1),
284
+ title="Jumping speed (jumps-per-second)",
285
+ trendline='rolling',
286
+ trendline_options=dict(window=16),
287
+ trendline_color_override="goldenrod",
288
+ trendline_scope='overall',
289
+ template="plotly_dark")
290
+
291
+ fig.update_layout(legend=dict(
292
+ orientation="h",
293
+ yanchor="bottom",
294
+ y=0.98,
295
+ xanchor="right",
296
+ x=1,
297
+ font=dict(
298
+ family="Courier",
299
+ size=12,
300
+ color="black"
301
+ ),
302
+ bgcolor="AliceBlue",
303
+ ),
304
+ paper_bgcolor='rgba(0,0,0,0)',
305
+ plot_bgcolor='rgba(0,0,0,0)'
306
+ )
307
+ # remove white outline from marks
308
+ fig.update_traces(marker_line_width = 0)
309
+ fig.update_layout(coloraxis_colorbar=dict(
310
+ tickvals=event_type_tick_vals,
311
+ ticktext=['single<br>rope', 'double<br>dutch', 'double<br>unders', 'single<br>bounces', 'double<br>bounces', 'triple<br>unders', 'other'],
312
+ title='event type'
313
+ ))
314
+
315
+ hist = px.histogram(df,
316
+ x="jumps per second",
317
+ template="plotly_dark",
318
+ marginal="box",
319
+ histnorm='percent',
320
+ title="Distribution of jumping speed (jumps-per-second)")
321
+
322
+ # make a bar plot of the event type distribution
323
+
324
+ bar = px.bar(x=['single rope', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
325
+ y=event_type_probs,
326
+ template="plotly_dark",
327
+ title="Event Type Distribution",
328
+ labels={'x': 'event type', 'y': 'probability'},
329
+ range_y=[0, 1])
330
+
331
+ return x, count_msg, fig, hist, bar
332
+
333
+
334
+ DESCRIPTION = '# NextJump 🦘'
335
+ DESCRIPTION += '\n## AI Counting for Competitive Jump Rope'
336
+ DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
337
+
338
+
339
+ with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
340
+ gr.Markdown(DESCRIPTION)
341
+ # in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
342
+ # width=400, height=400, interactive=True, container=True,
343
+ # max_length=150)
344
+ with gr.Row():
345
+ in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
346
+ with gr.Column():
347
+ in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True)
348
+ with gr.Column():
349
+ in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
350
+ with gr.Column(min_width=480):
351
+ out_video = gr.PlayableVideo(label="Video Clip", elem_id='output-video', format='mp4', width=400, height=400)
352
+
353
+ with gr.Row():
354
+ run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
355
+ api_dummy_button = gr.Button(value="Run (No Viz)", elem_id='count-only', visible=False, scale=2)
356
+ count_only = gr.Checkbox(label="Count Only", visible=False)
357
+ api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False)
358
+
359
+ with gr.Column(elem_id='output-video-container'):
360
+ with gr.Row():
361
+ with gr.Column():
362
+ out_text = gr.Markdown(label="Predicted Count", elem_id='output-text')
363
+ period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False)
364
+ periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
365
+ with gr.Row():
366
+ out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
367
+ with gr.Row():
368
+ with gr.Column():
369
+ out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
370
+ with gr.Column():
371
+ out_event_type_dist = gr.Plot(label="Event Type Distribution", elem_id='output-event-type-dist')
372
+
373
+
374
+ demo_inference = partial(inference, count_only_api=False, api_key=None)
375
+
376
+ run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
377
+ api_inference = partial(inference, api_call=True)
378
+ api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
379
+
380
+
381
+ if __name__ == "__main__":
382
  demo.queue(api_open=True, max_size=15).launch(share=False)