File size: 20,499 Bytes
8fe9a5a
 
 
 
 
d744c30
1c4dfa7
8fe9a5a
 
 
d744c30
8fe9a5a
d192993
8fe9a5a
 
 
43ed042
8fe9a5a
 
 
 
 
 
 
 
 
bded07f
2fe3d36
 
57925f4
8fe9a5a
2fe3d36
 
 
 
 
8fe9a5a
43f7645
 
 
 
 
 
 
 
 
8fe9a5a
43f7645
 
8fe9a5a
 
1c4dfa7
9aece52
bded07f
cf90a22
9aece52
767d5c4
41e552b
741c26c
 
 
 
 
 
 
 
d192993
 
 
 
 
 
 
8fe9a5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43f7645
8fe9a5a
 
 
 
 
 
08cf343
d22ea40
8fe9a5a
d22ea40
8fe9a5a
 
 
 
 
 
 
87f81d9
8fe9a5a
 
43f7645
 
 
 
 
 
 
 
b5ca2ed
8fe9a5a
 
 
 
 
 
43f7645
8fe9a5a
 
 
 
 
 
 
 
 
 
87f81d9
8fe9a5a
 
bded07f
43ed042
 
 
 
08cf343
8fe9a5a
 
08cf343
43f7645
8fe9a5a
 
08cf343
43f7645
8fe9a5a
43f7645
8fe9a5a
 
767d5c4
8fe9a5a
 
 
 
 
bded07f
43ed042
 
 
 
08cf343
8fe9a5a
 
08cf343
43f7645
8fe9a5a
 
08cf343
43f7645
8fe9a5a
43f7645
8fe9a5a
 
 
08cf343
cf3f9fd
 
43f7645
 
cf3f9fd
98fee40
8fe9a5a
 
319f52e
8fe9a5a
08cf343
d744c30
 
 
 
8fe9a5a
 
 
 
 
 
d744c30
 
 
 
 
8fe9a5a
 
 
 
08cf343
 
 
 
 
8fe9a5a
 
08cf343
8fe9a5a
 
319f52e
 
 
 
 
8fe9a5a
319f52e
8fe9a5a
319f52e
8fe9a5a
98fee40
 
e3ec7f6
98fee40
 
43f7645
08cf343
e3ec7f6
43f7645
98fee40
 
08cf343
8fe9a5a
 
 
319f52e
 
43ed042
8fe9a5a
 
 
 
 
319f52e
cf3f9fd
08cf343
ab0b18f
cf90a22
8fe9a5a
c843211
 
8fe9a5a
 
 
cf3f9fd
 
 
08cf343
cf90a22
c843211
4b5adda
8fe9a5a
 
08cf343
8fe9a5a
 
 
ab0b18f
8fe9a5a
9aece52
 
 
 
 
 
 
 
 
 
 
 
 
64462ff
8fe9a5a
64462ff
 
547e0cb
c843211
767d5c4
547e0cb
 
8fe9a5a
 
 
 
 
 
3398cf4
43f7645
 
 
767d5c4
43f7645
 
 
 
 
8fe9a5a
43f7645
8fe9a5a
 
08cf343
8fe9a5a
 
 
 
319f52e
8fe9a5a
ef86f0f
 
 
98fee40
 
8494608
 
98fee40
d192993
8fe9a5a
 
 
 
 
 
 
0f5c629
8fe9a5a
43f7645
 
 
 
 
 
 
98fee40
8fe9a5a
 
ab0b18f
8fe9a5a
 
 
 
 
 
 
 
 
ab0b18f
 
8fe9a5a
 
b2bf29b
 
8fe9a5a
ab0b18f
b027d03
 
08775f1
 
b027d03
08775f1
b027d03
a852991
b027d03
 
 
8fe9a5a
98fee40
43f7645
b2bf29b
 
43f7645
8fe9a5a
d192993
8fe9a5a
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
import gradio as gr
import numpy as np
from PIL import Image
import os
import cv2
import math
import spaces
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from scipy.signal import medfilt, find_peaks
from functools import partial
from passlib.hash import pbkdf2_sha256
from tqdm import tqdm
import pandas as pd
import plotly.express as px
import onnxruntime as ort
import torch
from torchvision import transforms
import torchvision.transforms.functional as F

from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi

plt.style.use('dark_background')

onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
# model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
# hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
#model_xml = "model_ir/model.xml"

# ie = Core()
# model_ir = ie.read_model(model=model_xml)
# config = {"PERFORMANCE_HINT": "LATENCY"}
# compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU", config=config)


class SquarePad:
    # https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
	def __call__(self, image):
		w, h = image.size
		max_wh = max(w, h)
		hp = int((max_wh - w) / 2)
		vp = int((max_wh - h) / 2)
		padding = (hp, vp, hp, vp)
		return F.pad(image, padding, 0, 'constant')

def sigmoid(x):
    return 1 / (1 + np.exp(-x))


@spaces.GPU()
def inference(x, count_only_api, api_key, 

              img_size=288, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, 

              miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True, 

              api_call=False,

              progress=gr.Progress()):
    progress(0, desc="Starting...")
    # check if GPU is available
    if torch.cuda.is_available():
        providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
                                                "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
        sess_options = ort.SessionOptions()
        ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
    else:
        ort_sess = ort.InferenceSession(onnx_file)
    #api = HfApi(token=os.environ['DATASET_SECRET'])
    #out_file = str(uuid.uuid1())
    has_access = False
    if api_call:
        has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
        if not has_access:
            return "Invalid API Key"
        
    cap = cv2.VideoCapture(x)
    length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    period_length_overlaps = np.zeros(length + seq_len)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    seconds = length / fps
    all_frames = []
    frame_i = 1
    while cap.isOpened():
        ret, frame = cap.read()
        if ret is False:
            frame = all_frames[-1]  # padding will be with last frame
            break
        frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
        img = Image.fromarray(frame)
        all_frames.append(img)
        frame_i += 1
    cap.release()

    length = len(all_frames)
    period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
    periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
    full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
    event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
    period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
    event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
    for _ in range(seq_len + stride_length):  # pad full sequence
        all_frames.append(all_frames[-1])
    batch_list = []
    idx_list = []
    for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
        batch = all_frames[i:i + seq_len]
        Xlist = []
        print('Preprocessing...')
        for img in batch:
            transforms_list = []
            # if center_crop:
            #     if width > height:
            #         transforms_list.append(transforms.Resize((int(width / (height / img_size)), img_size)))
            #     else:
            #         transforms_list.append(transforms.Resize((img_size, int(height / (width / img_size)))))
            #     transforms_list.append(transforms.CenterCrop((img_size, img_size)))
            # else:
            transforms_list.append(SquarePad())
            transforms_list.append(transforms.Resize((img_size, img_size), interpolation=Image.BICUBIC))
            

            transforms_list += [
                transforms.ToTensor()]
                #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
            preprocess = transforms.Compose(transforms_list)
            frameTensor = preprocess(img).unsqueeze(0)
            Xlist.append(frameTensor)

        if len(Xlist) < seq_len:
            for _ in range(seq_len - len(Xlist)):
                Xlist.append(Xlist[-1])
        
        X = torch.cat(Xlist)
        X *= 255
        batch_list.append(X.unsqueeze(0))
        idx_list.append(i)
        print('Running inference...')
        if len(batch_list) == batch_size:
            batch_X = torch.cat(batch_list)
            outputs = ort_sess.run(None, {'video': batch_X.numpy()})
            y1pred = outputs[0]
            y2pred = outputs[1]
            y3pred = outputs[2]
            y4pred = outputs[3]
            for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
                periodLength = y1.squeeze()
                periodicity = y2.squeeze()
                marks = y3.squeeze()
                event_type = y4.squeeze()
                period_lengths[idx:idx+seq_len] += periodLength
                periodicities[idx:idx+seq_len] += periodicity
                full_marks[idx:idx+seq_len] += marks
                event_type_logits[idx:idx+seq_len] += event_type
                period_length_overlaps[idx:idx+seq_len] += 1
                event_type_logit_overlaps[idx:idx+seq_len] += 1
            batch_list = []
            idx_list = []
        progress(i / (length + stride_length - stride_pad), desc="Processing...")
    if len(batch_list) != 0:  # still some leftover frames
        while len(batch_list) != batch_size:
            batch_list.append(batch_list[-1])
            idx_list.append(idx_list[-1])
        batch_X = torch.cat(batch_list)
        outputs = ort_sess.run(None, {'video': batch_X.numpy()})
        y1pred = outputs[0]
        y2pred = outputs[1]
        y3pred = outputs[2]
        y4pred = outputs[3]
        for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
            periodLength = y1.squeeze()
            periodicity = y2.squeeze()
            marks = y3.squeeze()
            event_type = y4.squeeze()
            period_lengths[idx:idx+seq_len] += periodLength
            periodicities[idx:idx+seq_len] += periodicity
            full_marks[idx:idx+seq_len] += marks
            event_type_logits[idx:idx+seq_len] += event_type
            period_length_overlaps[idx:idx+seq_len] += 1
            event_type_logit_overlaps[idx:idx+seq_len] += 1
            
    periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
    periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
    full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
    per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
    event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
    # softmax of event type logits  
    event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
    per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1)
    
    if median_pred_filter:
        periodicity = medfilt(periodicity, 5)
        periodLength = medfilt(periodLength, 5)
    periodicity = sigmoid(periodicity)
    full_marks = sigmoid(full_marks)
    #full_marks_mask = np.int32(full_marks > marks_threshold)
    pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
    full_marks_mask = np.zeros(len(full_marks))
    full_marks_mask[pred_marks_peaks] = 1
    periodicity_mask = np.int32(periodicity > miss_threshold)
    numofReps = 0
    count = []
    for i in range(len(periodLength)):
        if periodLength[i] < 2 or periodicity_mask[i] == 0:
            numofReps += 0
        elif full_marks_mask[i]:  # high confidence mark detected
            if math.modf(numofReps)[0] < 0.2:  # probably false positive/late detection
                numofReps = float(int(numofReps))
            else:
                numofReps = float(int(numofReps) + 1.01)  # round up
        else:
            numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
        count.append(round(float(numofReps), 2))
    count_pred = count[-1]
    marks_count_pred = 0
    for i in range(len(full_marks) - 1):
        # if a jump was counted, and periodicity is high, and the next frame was not counted (to avoid double counting)
        if full_marks_mask[i] > 0 and periodicity_mask[i] > 0 and full_marks_mask[i + 1] == 0:
            marks_count_pred += 1
    if not both_feet:
        count_pred = count_pred / 2
        marks_count_pred = marks_count_pred / 2
        count = np.array(count) / 2

    confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
    self_err = abs(count_pred - marks_count_pred)
    self_pct_err = self_err / count_pred
    total_confidence = confidence * (1 - self_pct_err)

    if both_feet:
        count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
    else:
        count_msg = f"## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"

    if api_call:
        if count_only_api:
            return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
        else:
            return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
                np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
                np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
                f"reps: {count_pred:.2f}, marks: {marks_count_pred:.1f}, confidence: {total_confidence:.2f}", \
                f"single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}"
   

    jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.01), 0, 10)
    jumping_speed = np.copy(jumps_per_second)
    misses = periodicity < miss_threshold
    jumps_per_second[misses] = 0
    frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
    frame_type[full_marks > marks_threshold] = 'jump'
    per_frame_event_types = np.clip(per_frame_event_types, 0, 6) / 6
    df = pd.DataFrame.from_dict({'period length': periodLength, 
                                 'jumping speed': jumping_speed,
                                'jumps per second': jumps_per_second,
                                'periodicity': periodicity,
                                'miss': misses,
                                'frame_type': frame_type,
                                'event_type': per_frame_event_types,
                                'jumps': full_marks,
                                'jumps_size': (full_marks + 0.05) * 10,
                                'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 8),
                                'seconds': np.linspace(0, seconds, num=len(periodLength))})
    event_type_tick_vals = np.linspace(0, 1, num=7)
    event_type_colors = ['red', 'orange', 'green', 'blue', 'purple', 'pink', 'black']
    fig = px.scatter(data_frame=df,
                    x='seconds', 
                    y='jumps per second',
                    #symbol='frame_type',
                    #symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
                    color='event_type',
                    size='jumps_size',
                    size_max=8,
                    color_continuous_scale=[(t, c) for t, c in zip(event_type_tick_vals, event_type_colors)],
                    range_color=(0,1),
                    title="Jumping speed (jumps-per-second)",
                    trendline='rolling',
                    trendline_options=dict(window=16),
                    trendline_color_override="goldenrod",
                    trendline_scope='overall',
                    template="plotly_dark")
    
    fig.update_layout(legend=dict(
            orientation="h",
            yanchor="bottom",
            y=0.98,
            xanchor="right",
            x=1,
            font=dict(
                family="Courier",
                size=12,
                color="black"
                ),
            bgcolor="AliceBlue",
        ),
        paper_bgcolor='rgba(0,0,0,0)',
        plot_bgcolor='rgba(0,0,0,0)'
    )
    # remove white outline from marks
    fig.update_traces(marker_line_width = 0)
    fig.update_layout(coloraxis_colorbar=dict(
        tickvals=event_type_tick_vals,
        ticktext=['single<br>rope', 'double<br>dutch', 'double<br>unders', 'single<br>bounces', 'double<br>bounces', 'triple<br>unders', 'other'],
        title='event type'
    ))

    hist = px.histogram(df, 
                        x="jumps per second", 
                        template="plotly_dark", 
                        marginal="box",
                        histnorm='percent',
                        title="Distribution of jumping speed (jumps-per-second)")
    
    # make a bar plot of the event type distribution

    bar = px.bar(x=['single rope', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'], 
                 y=event_type_probs,
                 template="plotly_dark",
                 title="Event Type Distribution",
                 labels={'x': 'event type', 'y': 'probability'},
                 range_y=[0, 1])

    return count_msg, fig, hist, bar
        

DESCRIPTION = '# NextJump 🦘'
DESCRIPTION += '\n## AI Counting for Competitive Jump Rope'
DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'


with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
    gr.Markdown(DESCRIPTION)
    in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4', 
                                width=400, height=400, interactive=True, container=True,
                                max_length=150)
            
    with gr.Row():
        run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
        api_dummy_button = gr.Button(value="Run (No Viz)", elem_id='count-only', visible=False, scale=2)
        count_only = gr.Checkbox(label="Count Only", visible=False)
        api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False)

    with gr.Column(elem_id='output-video-container'):
        with gr.Row():
            with gr.Column():
                out_text = gr.Markdown(label="Predicted Count", elem_id='output-text')
                period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False)
                periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
            #with gr.Column(min_width=480):
                #out_video = gr.PlayableVideo(label="Output Video", elem_id='output-video', format='mp4')
        with gr.Row():
            out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
        with gr.Row():
            with gr.Column():
                out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
            with gr.Column():
                out_event_type_dist = gr.Plot(label="Event Type Distribution", elem_id='output-event-type-dist')
              
    with gr.Accordion(label="Instructions and more information", open=False):
        instructions = "## Instructions:"
        instructions += "\n* Upload a video and click 'Run' to get a prediction of the number of jumps (either one foot, or both). This could take a couple minutes!"
        instructions += "\n\n## Tips (optional):"
        instructions += "\n* Trim the video to start and end of the event"
        instructions += "\n* Frame the jumper fully, in the center of the frame"
        instructions += "\n* Videos are automatically resized, so higher resolution will not help, but a closer framing of the jumper might help. Try cropping the video differently."
        gr.Markdown(instructions)

        faq = "## FAQ:"
        faq += "\n* **Q:** Does the model recognize misses?\n    * **A:** Yes, but if it fails, you can try tuning the miss threshold slider to make it more sensitive."
        faq += "\n* **Q:** Does the model recognize double dutch?\n    * **A:** Yes, but it is trained on a smaller set of double dutch videos, so it may not work perfectly."
        faq += "\n* **Q:** Does the model recognize double unders\n    * **A:** Yes, but it is trained on a smaller set of double under videos, so it may not work perfectly. It is also trained to count the rope, not the jumps so you will need to divide the count by 2 to get the traditional double under count."
        faq += "\n* **Q:** Does the model count both feet?\n    * **A:** Yes, it counts every time the rope goes around no matter the event."
        gr.Markdown(faq)

    demo_inference = partial(inference, count_only_api=False, api_key=None)

    gr.Examples(examples=[
                        [os.path.join(os.path.dirname(__file__), "files", "dylan.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train14.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train_17.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train13.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train_213.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train_156.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train_202.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train_57.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train_95.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train_253.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train_66.mp4")],
                        #[os.path.join(os.path.dirname(__file__), "files", "train_21.mp4")]
                    ],
                inputs=[in_video],
                outputs=[out_text, out_plot, out_hist, out_event_type_dist],
                fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
    
    run_button.click(demo_inference, [in_video], outputs=[out_text, out_plot, out_hist, out_event_type_dist])
    api_inference = partial(inference, api_call=True)
    api_dummy_button.click(api_inference, [in_video, count_only, api_token], outputs=[period_length], api_name='inference')


if __name__ == "__main__":
    demo.queue(api_open=True, max_size=15).launch(share=False)