dylanplummer commited on
Commit
08cf343
·
1 Parent(s): 6df803b

update to use marks

Browse files
Files changed (2) hide show
  1. app.py +36 -14
  2. requirements.txt +1 -0
app.py CHANGED
@@ -11,6 +11,7 @@ import matplotlib
11
  matplotlib.use('Agg')
12
  import matplotlib.pyplot as plt
13
  from scipy.signal import medfilt
 
14
  from functools import partial
15
  from passlib.hash import pbkdf2_sha256
16
  from tqdm import tqdm
@@ -28,7 +29,7 @@ plt.style.use('dark_background')
28
  hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.bin", repo_type="model", token=os.environ['DATASET_SECRET'])
29
  model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
30
  hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
31
- #model_xml = "model_ir/model.xml"
32
 
33
  ie = Core()
34
  model_ir = ie.read_model(model=model_xml)
@@ -53,7 +54,7 @@ def sigmoid(x):
53
  return 1 / (1 + np.exp(-x))
54
 
55
 
56
- def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, miss_threshold=0.8, median_pred_filter=True, center_crop=True, both_feet=True, api_call=False):
57
  print(x)
58
  #api = HfApi(token=os.environ['DATASET_SECRET'])
59
  #out_file = str(uuid.uuid1())
@@ -91,6 +92,7 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
91
  length = len(all_frames)
92
  period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
93
  periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
 
94
  event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 4))
95
  period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
96
  event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 4))
@@ -134,13 +136,16 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
134
  result = compiled_model_ir(batch_X)
135
  y1pred = result[output_layer_period_length]
136
  y2pred = result[output_layer_periodicity]
 
137
  y4pred = result[output_layer_event_type]
138
- for y1, y2, y4, idx in zip(y1pred, y2pred, y4pred, idx_list):
139
  periodLength = y1.squeeze()
140
  periodicity = y2.squeeze()
 
141
  event_type = y4.squeeze()
142
  period_lengths[idx:idx+seq_len] += periodLength
143
  periodicities[idx:idx+seq_len] += periodicity
 
144
  event_type_logits[idx:idx+seq_len] += event_type
145
  period_length_overlaps[idx:idx+seq_len] += 1
146
  event_type_logit_overlaps[idx:idx+seq_len] += 1
@@ -154,19 +159,23 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
154
  result = compiled_model_ir(batch_X)
155
  y1pred = result[output_layer_period_length]
156
  y2pred = result[output_layer_periodicity]
 
157
  y4pred = result[output_layer_event_type]
158
- for y1, y2, y4, idx in zip(y1pred, y2pred, y4pred, idx_list):
159
  periodLength = y1.squeeze()
160
  periodicity = y2.squeeze()
 
161
  event_type = y4.squeeze()
162
  period_lengths[idx:idx+seq_len] += periodLength
163
  periodicities[idx:idx+seq_len] += periodicity
 
164
  event_type_logits[idx:idx+seq_len] += event_type
165
  period_length_overlaps[idx:idx+seq_len] += 1
166
  event_type_logit_overlaps[idx:idx+seq_len] += 1
167
 
168
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
169
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
 
170
  event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
171
  event_type_logits = np.mean(event_type_logits, axis=0)
172
  # softmax of event type logits
@@ -174,9 +183,13 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
174
 
175
  if median_pred_filter:
176
  periodicity = medfilt(periodicity, 5)
177
- periodLength = medfilt(periodLength, 5)
178
  periodicity = sigmoid(periodicity)
 
 
 
179
  periodicity_mask = np.int32(periodicity > miss_threshold)
 
180
  numofReps = 0
181
  count = []
182
  for i in range(len(periodLength)):
@@ -186,14 +199,20 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
186
  numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
187
  count.append(round(float(numofReps), 2))
188
  count_pred = count[-1]
 
 
 
 
 
189
  if not both_feet:
190
  count_pred = count_pred / 2
 
191
  count = np.array(count) / 2
192
 
193
  if both_feet:
194
- count_msg = f"## Predicted Count (both feet): {count_pred:.1f}"
195
  else:
196
- count_msg = f"## Predicted Count (one foot): {count_pred:.1f}"
197
 
198
  if api_call:
199
  if count_only_api:
@@ -201,11 +220,12 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
201
  else:
202
  return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
203
  np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
204
- f"{count_pred:.2f}", \
 
205
  f"single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}"
206
 
207
 
208
- jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.05), 0, 8)
209
  jumping_speed = np.copy(jumps_per_second)
210
  misses = periodicity < miss_threshold
211
  jumps_per_second[misses] = 0
@@ -214,20 +234,22 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
214
  'jumps per second': jumps_per_second,
215
  'periodicity': periodicity,
216
  'miss': misses,
 
 
217
  'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 10),
218
  'seconds': np.linspace(0, seconds, num=len(periodLength))})
219
  fig = px.scatter(data_frame=df,
220
  x='seconds',
221
  y='jumps per second',
222
  symbol='miss',
223
- symbol_map={False: 'triangle-down', True: 'circle-open'},
224
  color='periodicity',
225
- size='miss_size',
226
- size_max=8,
227
  color_continuous_scale='RdYlGn',
228
  title="Jumping speed (jumps-per-second)",
229
  trendline='rolling',
230
- trendline_options=dict(window=32),
231
  trendline_color_override="goldenrod",
232
  trendline_scope='overall',
233
  template="plotly_dark")
@@ -267,7 +289,7 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
267
  return count_msg, fig, hist, bar
268
 
269
 
270
- DESCRIPTION = '# NextJump'
271
  DESCRIPTION += '\n## AI Counting for Competitive Jump Rope'
272
  DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
273
 
 
11
  matplotlib.use('Agg')
12
  import matplotlib.pyplot as plt
13
  from scipy.signal import medfilt
14
+ from skimage.measure import block_reduce
15
  from functools import partial
16
  from passlib.hash import pbkdf2_sha256
17
  from tqdm import tqdm
 
29
  hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.bin", repo_type="model", token=os.environ['DATASET_SECRET'])
30
  model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
31
  hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
32
+ model_xml = "model_ir/model.xml"
33
 
34
  ie = Core()
35
  model_ir = ie.read_model(model=model_xml)
 
54
  return 1 / (1 + np.exp(-x))
55
 
56
 
57
+ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, miss_threshold=0.8, marks_threshold=0.6, median_pred_filter=True, center_crop=True, both_feet=True, api_call=False):
58
  print(x)
59
  #api = HfApi(token=os.environ['DATASET_SECRET'])
60
  #out_file = str(uuid.uuid1())
 
92
  length = len(all_frames)
93
  period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
94
  periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
95
+ full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
96
  event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 4))
97
  period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
98
  event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 4))
 
136
  result = compiled_model_ir(batch_X)
137
  y1pred = result[output_layer_period_length]
138
  y2pred = result[output_layer_periodicity]
139
+ y3pred = result[output_layer_marks]
140
  y4pred = result[output_layer_event_type]
141
+ for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
142
  periodLength = y1.squeeze()
143
  periodicity = y2.squeeze()
144
+ marks = y3.squeeze()
145
  event_type = y4.squeeze()
146
  period_lengths[idx:idx+seq_len] += periodLength
147
  periodicities[idx:idx+seq_len] += periodicity
148
+ full_marks[idx:idx+seq_len] += marks
149
  event_type_logits[idx:idx+seq_len] += event_type
150
  period_length_overlaps[idx:idx+seq_len] += 1
151
  event_type_logit_overlaps[idx:idx+seq_len] += 1
 
159
  result = compiled_model_ir(batch_X)
160
  y1pred = result[output_layer_period_length]
161
  y2pred = result[output_layer_periodicity]
162
+ y3pred = result[output_layer_marks]
163
  y4pred = result[output_layer_event_type]
164
+ for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
165
  periodLength = y1.squeeze()
166
  periodicity = y2.squeeze()
167
+ marks = y3.squeeze()
168
  event_type = y4.squeeze()
169
  period_lengths[idx:idx+seq_len] += periodLength
170
  periodicities[idx:idx+seq_len] += periodicity
171
+ full_marks[idx:idx+seq_len] += marks
172
  event_type_logits[idx:idx+seq_len] += event_type
173
  period_length_overlaps[idx:idx+seq_len] += 1
174
  event_type_logit_overlaps[idx:idx+seq_len] += 1
175
 
176
  periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
177
  periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
178
+ full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
179
  event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
180
  event_type_logits = np.mean(event_type_logits, axis=0)
181
  # softmax of event type logits
 
183
 
184
  if median_pred_filter:
185
  periodicity = medfilt(periodicity, 5)
186
+ #periodLength = medfilt(periodLength, 5)
187
  periodicity = sigmoid(periodicity)
188
+ full_marks = sigmoid(full_marks)
189
+ full_marks_mask = np.int32(full_marks > marks_threshold)
190
+ #full_marks_reduced = block_reduce(full_marks > marks_threshold, (3,), np.max)
191
  periodicity_mask = np.int32(periodicity > miss_threshold)
192
+ #periodicity_mask_reduced = block_reduce(periodicity_mask, (3,), np.max)
193
  numofReps = 0
194
  count = []
195
  for i in range(len(periodLength)):
 
199
  numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
200
  count.append(round(float(numofReps), 2))
201
  count_pred = count[-1]
202
+ marks_count_pred = 0
203
+ for i in range(len(full_marks) - 1):
204
+ # if a jump was counted, and periodicity is high, and the next frame was not counted (to avoid double counting)
205
+ if full_marks_mask[i] > 0 and periodicity_mask[i] > 0 and full_marks_mask[i + 1] == 0:
206
+ marks_count_pred += 1
207
  if not both_feet:
208
  count_pred = count_pred / 2
209
+ marks_count_pred = marks_count_pred / 2
210
  count = np.array(count) / 2
211
 
212
  if both_feet:
213
+ count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}"
214
  else:
215
+ count_msg = f"## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}"
216
 
217
  if api_call:
218
  if count_only_api:
 
220
  else:
221
  return np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
222
  np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
223
+ np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}).replace('\n', ''), \
224
+ f"reps: {count_pred:.2f}, marks: {marks_count_pred:.1f}", \
225
  f"single_rope_speed: {event_type_probs[0]:.3f}, double_dutch: {event_type_probs[1]:.3f}, double_unders: {event_type_probs[2]:.3f}, single_bounce: {event_type_probs[3]:.3f}"
226
 
227
 
228
+ jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.01), 0, 10)
229
  jumping_speed = np.copy(jumps_per_second)
230
  misses = periodicity < miss_threshold
231
  jumps_per_second[misses] = 0
 
234
  'jumps per second': jumps_per_second,
235
  'periodicity': periodicity,
236
  'miss': misses,
237
+ 'jumps': full_marks,
238
+ 'jumps_size': (full_marks + 0.2) * 10,
239
  'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 10),
240
  'seconds': np.linspace(0, seconds, num=len(periodLength))})
241
  fig = px.scatter(data_frame=df,
242
  x='seconds',
243
  y='jumps per second',
244
  symbol='miss',
245
+ symbol_map={False: 'circle', True: 'circle-open'},
246
  color='periodicity',
247
+ size='jumps_size',
248
+ size_max=10,
249
  color_continuous_scale='RdYlGn',
250
  title="Jumping speed (jumps-per-second)",
251
  trendline='rolling',
252
+ trendline_options=dict(window=16),
253
  trendline_color_override="goldenrod",
254
  trendline_scope='overall',
255
  template="plotly_dark")
 
289
  return count_msg, fig, hist, bar
290
 
291
 
292
+ DESCRIPTION = '# NextJump 🦘'
293
  DESCRIPTION += '\n## AI Counting for Competitive Jump Rope'
294
  DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
295
 
requirements.txt CHANGED
@@ -4,6 +4,7 @@ matplotlib
4
  plotly
5
  passlib
6
  scipy
 
7
  --find-links https://download.pytorch.org/whl/torch_stable.html
8
  opencv-python-headless==4.7.0.68
9
  openvino-dev==2022.3.0
 
4
  plotly
5
  passlib
6
  scipy
7
+ scikit-image
8
  --find-links https://download.pytorch.org/whl/torch_stable.html
9
  opencv-python-headless==4.7.0.68
10
  openvino-dev==2022.3.0