dylanplummer commited on
Commit
cf3f9fd
·
1 Parent(s): 57c7b88

add event type coloring

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -172,10 +172,11 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
172
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
173
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
174
  full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
175
- event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
176
- event_type_logits = np.mean(event_type_logits, axis=0)
177
  # softmax of event type logits
178
  event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
 
179
 
180
  if median_pred_filter:
181
  periodicity = medfilt(periodicity, 5)
@@ -230,12 +231,14 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
230
  jumps_per_second[misses] = 0
231
  frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
232
  frame_type[full_marks > marks_threshold] = 'jump'
 
233
  df = pd.DataFrame.from_dict({'period length': periodLength,
234
  'jumping speed': jumping_speed,
235
  'jumps per second': jumps_per_second,
236
  'periodicity': periodicity,
237
  'miss': misses,
238
  'frame_type': frame_type,
 
239
  'jumps': full_marks,
240
  'jumps_size': (full_marks + 0.05) * 10,
241
  'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 10),
@@ -244,12 +247,12 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
244
  fig = px.scatter(data_frame=df,
245
  x='seconds',
246
  y='jumps per second',
247
- symbol='frame_type',
248
- symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
249
- color='periodicity',
250
  size='jumps_size',
251
  size_max=10,
252
- color_continuous_scale='RdYlGn',
253
  title="Jumping speed (jumps-per-second)",
254
  trendline='rolling',
255
  trendline_options=dict(window=16),
@@ -302,7 +305,7 @@ with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
302
  gr.Markdown(DESCRIPTION)
303
  with gr.Column():
304
  with gr.Row():
305
- in_video = gr.Video(label="Input Video", elem_id='input-video', format='mp4', width=400, height=400)
306
 
307
  with gr.Row():
308
  run_button = gr.Button(label="Run", elem_id='run-button', scale=1)
@@ -356,6 +359,7 @@ with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
356
  [os.path.join(os.path.dirname(__file__), "files", "train_95.mp4")],
357
  [os.path.join(os.path.dirname(__file__), "files", "train_253.mp4")],
358
  [os.path.join(os.path.dirname(__file__), "files", "train_66.mp4")],
 
359
  ],
360
  inputs=[in_video],
361
  outputs=[out_text, out_plot, out_hist, out_event_type_dist],
 
172
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
173
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
174
  full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
175
+ per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
176
+ event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
177
  # softmax of event type logits
178
  event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
179
+ per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1)
180
 
181
  if median_pred_filter:
182
  periodicity = medfilt(periodicity, 5)
 
231
  jumps_per_second[misses] = 0
232
  frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
233
  frame_type[full_marks > marks_threshold] = 'jump'
234
+ per_frame_event_types = per_frame_event_types / 6
235
  df = pd.DataFrame.from_dict({'period length': periodLength,
236
  'jumping speed': jumping_speed,
237
  'jumps per second': jumps_per_second,
238
  'periodicity': periodicity,
239
  'miss': misses,
240
  'frame_type': frame_type,
241
+ 'event_type': per_frame_event_types,
242
  'jumps': full_marks,
243
  'jumps_size': (full_marks + 0.05) * 10,
244
  'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 10),
 
247
  fig = px.scatter(data_frame=df,
248
  x='seconds',
249
  y='jumps per second',
250
+ #symbol='frame_type',
251
+ #symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
252
+ color='event_type',
253
  size='jumps_size',
254
  size_max=10,
255
+ color_continuous_scale='Rainbow',
256
  title="Jumping speed (jumps-per-second)",
257
  trendline='rolling',
258
  trendline_options=dict(window=16),
 
305
  gr.Markdown(DESCRIPTION)
306
  with gr.Column():
307
  with gr.Row():
308
+ in_video = gr.Video(label="Input Video", elem_id='input-video', format='mp4', width=400, height=400, interactive=True)
309
 
310
  with gr.Row():
311
  run_button = gr.Button(label="Run", elem_id='run-button', scale=1)
 
359
  [os.path.join(os.path.dirname(__file__), "files", "train_95.mp4")],
360
  [os.path.join(os.path.dirname(__file__), "files", "train_253.mp4")],
361
  [os.path.join(os.path.dirname(__file__), "files", "train_66.mp4")],
362
+ [os.path.join(os.path.dirname(__file__), "files", "train_21.mp4")]
363
  ],
364
  inputs=[in_video],
365
  outputs=[out_text, out_plot, out_hist, out_event_type_dist],