Commit ·
cf3f9fd
1
Parent(s): 57c7b88
add event type coloring
Browse files
app.py
CHANGED
|
@@ -172,10 +172,11 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
|
|
| 172 |
periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 173 |
periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 174 |
full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 175 |
-
|
| 176 |
-
event_type_logits = np.mean(
|
| 177 |
# softmax of event type logits
|
| 178 |
event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
|
|
|
|
| 179 |
|
| 180 |
if median_pred_filter:
|
| 181 |
periodicity = medfilt(periodicity, 5)
|
|
@@ -230,12 +231,14 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
|
|
| 230 |
jumps_per_second[misses] = 0
|
| 231 |
frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
|
| 232 |
frame_type[full_marks > marks_threshold] = 'jump'
|
|
|
|
| 233 |
df = pd.DataFrame.from_dict({'period length': periodLength,
|
| 234 |
'jumping speed': jumping_speed,
|
| 235 |
'jumps per second': jumps_per_second,
|
| 236 |
'periodicity': periodicity,
|
| 237 |
'miss': misses,
|
| 238 |
'frame_type': frame_type,
|
|
|
|
| 239 |
'jumps': full_marks,
|
| 240 |
'jumps_size': (full_marks + 0.05) * 10,
|
| 241 |
'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 10),
|
|
@@ -244,12 +247,12 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
|
|
| 244 |
fig = px.scatter(data_frame=df,
|
| 245 |
x='seconds',
|
| 246 |
y='jumps per second',
|
| 247 |
-
symbol='frame_type',
|
| 248 |
-
symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
|
| 249 |
-
color='
|
| 250 |
size='jumps_size',
|
| 251 |
size_max=10,
|
| 252 |
-
color_continuous_scale='
|
| 253 |
title="Jumping speed (jumps-per-second)",
|
| 254 |
trendline='rolling',
|
| 255 |
trendline_options=dict(window=16),
|
|
@@ -302,7 +305,7 @@ with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
|
|
| 302 |
gr.Markdown(DESCRIPTION)
|
| 303 |
with gr.Column():
|
| 304 |
with gr.Row():
|
| 305 |
-
in_video = gr.Video(label="Input Video", elem_id='input-video', format='mp4', width=400, height=400)
|
| 306 |
|
| 307 |
with gr.Row():
|
| 308 |
run_button = gr.Button(label="Run", elem_id='run-button', scale=1)
|
|
@@ -356,6 +359,7 @@ with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
|
|
| 356 |
[os.path.join(os.path.dirname(__file__), "files", "train_95.mp4")],
|
| 357 |
[os.path.join(os.path.dirname(__file__), "files", "train_253.mp4")],
|
| 358 |
[os.path.join(os.path.dirname(__file__), "files", "train_66.mp4")],
|
|
|
|
| 359 |
],
|
| 360 |
inputs=[in_video],
|
| 361 |
outputs=[out_text, out_plot, out_hist, out_event_type_dist],
|
|
|
|
| 172 |
periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 173 |
periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 174 |
full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 175 |
+
per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
|
| 176 |
+
event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
|
| 177 |
# softmax of event type logits
|
| 178 |
event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
|
| 179 |
+
per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1)
|
| 180 |
|
| 181 |
if median_pred_filter:
|
| 182 |
periodicity = medfilt(periodicity, 5)
|
|
|
|
| 231 |
jumps_per_second[misses] = 0
|
| 232 |
frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
|
| 233 |
frame_type[full_marks > marks_threshold] = 'jump'
|
| 234 |
+
per_frame_event_types = per_frame_event_types / 6
|
| 235 |
df = pd.DataFrame.from_dict({'period length': periodLength,
|
| 236 |
'jumping speed': jumping_speed,
|
| 237 |
'jumps per second': jumps_per_second,
|
| 238 |
'periodicity': periodicity,
|
| 239 |
'miss': misses,
|
| 240 |
'frame_type': frame_type,
|
| 241 |
+
'event_type': per_frame_event_types,
|
| 242 |
'jumps': full_marks,
|
| 243 |
'jumps_size': (full_marks + 0.05) * 10,
|
| 244 |
'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 10),
|
|
|
|
| 247 |
fig = px.scatter(data_frame=df,
|
| 248 |
x='seconds',
|
| 249 |
y='jumps per second',
|
| 250 |
+
#symbol='frame_type',
|
| 251 |
+
#symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
|
| 252 |
+
color='event_type',
|
| 253 |
size='jumps_size',
|
| 254 |
size_max=10,
|
| 255 |
+
color_continuous_scale='Rainbow',
|
| 256 |
title="Jumping speed (jumps-per-second)",
|
| 257 |
trendline='rolling',
|
| 258 |
trendline_options=dict(window=16),
|
|
|
|
| 305 |
gr.Markdown(DESCRIPTION)
|
| 306 |
with gr.Column():
|
| 307 |
with gr.Row():
|
| 308 |
+
in_video = gr.Video(label="Input Video", elem_id='input-video', format='mp4', width=400, height=400, interactive=True)
|
| 309 |
|
| 310 |
with gr.Row():
|
| 311 |
run_button = gr.Button(label="Run", elem_id='run-button', scale=1)
|
|
|
|
| 359 |
[os.path.join(os.path.dirname(__file__), "files", "train_95.mp4")],
|
| 360 |
[os.path.join(os.path.dirname(__file__), "files", "train_253.mp4")],
|
| 361 |
[os.path.join(os.path.dirname(__file__), "files", "train_66.mp4")],
|
| 362 |
+
[os.path.join(os.path.dirname(__file__), "files", "train_21.mp4")]
|
| 363 |
],
|
| 364 |
inputs=[in_video],
|
| 365 |
outputs=[out_text, out_plot, out_hist, out_event_type_dist],
|