dylanplummer commited on
Commit
43ed042
·
1 Parent(s): 94e4fda

test onnx

Browse files
Files changed (2) hide show
  1. app.py +21 -27
  2. requirements.txt +1 -1
app.py CHANGED
@@ -16,6 +16,7 @@ from passlib.hash import pbkdf2_sha256
16
  from tqdm import tqdm
17
  import pandas as pd
18
  import plotly.express as px
 
19
  import torch
20
  from torchvision import transforms
21
  import torchvision.transforms.functional as F
@@ -23,14 +24,11 @@ import torchvision.transforms.functional as F
23
  from huggingface_hub import hf_hub_download
24
  from huggingface_hub import HfApi
25
 
26
-
27
 
28
  plt.style.use('dark_background')
29
 
30
- checkpoint = hf_hub_download(repo_id="dylanplummer/ropenet", filename="ropenet_keypoint_0.pt", repo_type="model", token=os.environ['DATASET_SECRET'])
31
- model_file = checkpoint = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.py", repo_type="model", token=os.environ['DATASET_SECRET'])
32
- os.rename(model_file, "model.py")
33
- from model import RepNet
34
  # model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
35
  # hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
36
  #model_xml = "model_ir/model.xml"
@@ -40,14 +38,10 @@ from model import RepNet
40
  # config = {"PERFORMANCE_HINT": "LATENCY"}
41
  # compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU", config=config)
42
 
43
- img_size = 224
44
- backbone = 'mobilenetv3'
45
- embedding_size = 196
46
- n_layers_lstm = 1
47
- separate_rope = False
48
- save_realtime = False
49
- model = RepNet(64, backbone=backbone, backbone_scale='0', trainable_backbone=False, distill_frame_model=save_realtime, img_size=img_size, embedding_size=embedding_size, separate_rope=separate_rope)
50
-
51
 
52
 
53
  class SquarePad:
@@ -98,9 +92,9 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
98
  period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
99
  periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
100
  full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
101
- event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
102
  period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
103
- event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
104
  for _ in range(seq_len + stride_length): # pad full sequence
105
  all_frames.append(all_frames[-1])
106
  batch_list = []
@@ -138,11 +132,11 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
138
  idx_list.append(i)
139
  if len(batch_list) == batch_size:
140
  batch_X = torch.cat(batch_list)
141
- result = model(batch_X)
142
- y1pred = result[0]
143
- y2pred = result[1]
144
- y3pred = result[2]
145
- y4pred = result[3]
146
  for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
147
  periodLength = y1.squeeze()
148
  periodicity = y2.squeeze()
@@ -161,11 +155,11 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
161
  batch_list.append(batch_list[-1])
162
  idx_list.append(idx_list[-1])
163
  batch_X = torch.cat(batch_list)
164
- result = model(batch_X)
165
- y1pred = result[0]
166
- y2pred = result[1]
167
- y3pred = result[2]
168
- y4pred = result[3]
169
  for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
170
  periodLength = y1.squeeze()
171
  periodicity = y2.squeeze()
@@ -240,7 +234,7 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
240
  jumps_per_second[misses] = 0
241
  frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
242
  frame_type[full_marks > marks_threshold] = 'jump'
243
- per_frame_event_types = np.clip(per_frame_event_types, 0, 7) / 7
244
  df = pd.DataFrame.from_dict({'period length': periodLength,
245
  'jumping speed': jumping_speed,
246
  'jumps per second': jumps_per_second,
@@ -302,7 +296,7 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
302
 
303
  # make a bar plot of the event type distribution
304
 
305
- bar = px.bar(x=['single rope speed', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
306
  y=event_type_probs,
307
  template="plotly_dark",
308
  title="Event Type Distribution",
 
16
  from tqdm import tqdm
17
  import pandas as pd
18
  import plotly.express as px
19
+ import onnxruntime as ort
20
  import torch
21
  from torchvision import transforms
22
  import torchvision.transforms.functional as F
 
24
  from huggingface_hub import hf_hub_download
25
  from huggingface_hub import HfApi
26
 
27
+ from model import RepNet
28
 
29
  plt.style.use('dark_background')
30
 
31
+ onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
 
 
 
32
  # model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
33
  # hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
34
  #model_xml = "model_ir/model.xml"
 
38
  # config = {"PERFORMANCE_HINT": "LATENCY"}
39
  # compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU", config=config)
40
 
41
+ providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
42
+ "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
43
+ sess_options = ort.SessionOptions()
44
+ ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
 
 
 
 
45
 
46
 
47
  class SquarePad:
 
92
  period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
93
  periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
94
  full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
95
+ event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 6))
96
  period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
97
+ event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 6))
98
  for _ in range(seq_len + stride_length): # pad full sequence
99
  all_frames.append(all_frames[-1])
100
  batch_list = []
 
132
  idx_list.append(i)
133
  if len(batch_list) == batch_size:
134
  batch_X = torch.cat(batch_list)
135
+ outputs = ort_sess.run(None, {'frames': batch_X.numpy()})
136
+ y1pred = outputs[0]
137
+ y2pred = outputs[1]
138
+ y3pred = outputs[2]
139
+ y4pred = outputs[3]
140
  for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
141
  periodLength = y1.squeeze()
142
  periodicity = y2.squeeze()
 
155
  batch_list.append(batch_list[-1])
156
  idx_list.append(idx_list[-1])
157
  batch_X = torch.cat(batch_list)
158
+ outputs = ort_sess.run(None, {'frames': batch_X.numpy()})
159
+ y1pred = outputs[0]
160
+ y2pred = outputs[1]
161
+ y3pred = outputs[2]
162
+ y4pred = outputs[3]
163
  for y1, y2, y3, y4, idx in zip(y1pred, y2pred, y3pred, y4pred, idx_list):
164
  periodLength = y1.squeeze()
165
  periodicity = y2.squeeze()
 
234
  jumps_per_second[misses] = 0
235
  frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
236
  frame_type[full_marks > marks_threshold] = 'jump'
237
+ per_frame_event_types = np.clip(per_frame_event_types, 0, 6) / 6
238
  df = pd.DataFrame.from_dict({'period length': periodLength,
239
  'jumping speed': jumping_speed,
240
  'jumps per second': jumps_per_second,
 
296
 
297
  # make a bar plot of the event type distribution
298
 
299
+ bar = px.bar(x=['single rope speed', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders'],
300
  y=event_type_probs,
301
  template="plotly_dark",
302
  title="Event Type Distribution",
requirements.txt CHANGED
@@ -9,4 +9,4 @@ opencv-python-headless==4.7.0.68
9
  # openvino-dev==2022.3.0
10
  torch
11
  torchvision
12
- timm
 
9
  # openvino-dev==2022.3.0
10
  torch
11
  torchvision
12
+ onnxruntime-gpu