dylanplummer commited on
Commit
d22ea40
·
1 Parent(s): 0f7b211

check for gpu

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -35,10 +35,14 @@ onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.onnx
35
  # config = {"PERFORMANCE_HINT": "LATENCY"}
36
  # compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU", config=config)
37
 
38
- providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
39
- "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
40
- sess_options = ort.SessionOptions()
41
- ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
 
 
 
 
42
 
43
 
44
  class SquarePad:
@@ -55,7 +59,7 @@ def sigmoid(x):
55
  return 1 / (1 + np.exp(-x))
56
 
57
 
58
- def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, miss_threshold=0.8, marks_threshold=0.6, median_pred_filter=True, center_crop=True, both_feet=True, api_call=False):
59
  print(x)
60
  #api = HfApi(token=os.environ['DATASET_SECRET'])
61
  #out_file = str(uuid.uuid1())
@@ -89,9 +93,9 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
89
  period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
90
  periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
91
  full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
92
- event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 6))
93
  period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
94
- event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 6))
95
  for _ in range(seq_len + stride_length): # pad full sequence
96
  all_frames.append(all_frames[-1])
97
  batch_list = []
@@ -293,7 +297,7 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
293
 
294
  # make a bar plot of the event type distribution
295
 
296
- bar = px.bar(x=['single rope speed', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders'],
297
  y=event_type_probs,
298
  template="plotly_dark",
299
  title="Event Type Distribution",
 
35
  # config = {"PERFORMANCE_HINT": "LATENCY"}
36
  # compiled_model_ir = ie.compile_model(model=model_ir, device_name="CPU", config=config)
37
 
38
+ # check if GPU is available
39
+ if torch.cuda.is_available():
40
+ providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
41
+ "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
42
+ sess_options = ort.SessionOptions()
43
+ ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
44
+ else:
45
+ ort_sess = ort.InferenceSession(onnx_file)
46
 
47
 
48
  class SquarePad:
 
59
  return 1 / (1 + np.exp(-x))
60
 
61
 
62
+ def inference(x, count_only_api, api_key, img_size=224, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, miss_threshold=0.8, marks_threshold=0.6, median_pred_filter=True, center_crop=True, both_feet=True, api_call=False):
63
  print(x)
64
  #api = HfApi(token=os.environ['DATASET_SECRET'])
65
  #out_file = str(uuid.uuid1())
 
93
  period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
94
  periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
95
  full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
96
+ event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
97
  period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
98
+ event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
99
  for _ in range(seq_len + stride_length): # pad full sequence
100
  all_frames.append(all_frames[-1])
101
  batch_list = []
 
297
 
298
  # make a bar plot of the event type distribution
299
 
300
+ bar = px.bar(x=['single rope speed', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
301
  y=event_type_probs,
302
  template="plotly_dark",
303
  title="Event Type Distribution",