dylanplummer commited on
Commit
41e552b
·
1 Parent(s): 64462ff

init progress

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -57,14 +57,13 @@ def inference(x, count_only_api, api_key,
57
  miss_threshold=0.8, marks_threshold=0.6, median_pred_filter=True, center_crop=True, both_feet=True,
58
  api_call=False,
59
  progress=gr.Progress(track_tqdm=True)):
60
- print(x)
61
  # check if GPU is available
62
  if torch.cuda.is_available():
63
  providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
64
  "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
65
  sess_options = ort.SessionOptions()
66
  ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
67
- print("Using GPU")
68
  else:
69
  ort_sess = ort.InferenceSession(onnx_file)
70
  #api = HfApi(token=os.environ['DATASET_SECRET'])
@@ -292,7 +291,7 @@ def inference(x, count_only_api, api_key,
292
  fig.update_traces(marker_line_width = 0)
293
  fig.update_layout(coloraxis_colorbar=dict(
294
  tickvals=event_type_tick_vals,
295
- ticktext=['single rope speed', 'double dutch', 'double unders', 'single bounces', 'double bounces', 'triple unders', 'other'],
296
  title='event type'
297
  ))
298
 
 
57
  miss_threshold=0.8, marks_threshold=0.6, median_pred_filter=True, center_crop=True, both_feet=True,
58
  api_call=False,
59
  progress=gr.Progress(track_tqdm=True)):
60
+ progress(0, desc="Starting...")
61
  # check if GPU is available
62
  if torch.cuda.is_available():
63
  providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
64
  "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
65
  sess_options = ort.SessionOptions()
66
  ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
 
67
  else:
68
  ort_sess = ort.InferenceSession(onnx_file)
69
  #api = HfApi(token=os.environ['DATASET_SECRET'])
 
291
  fig.update_traces(marker_line_width = 0)
292
  fig.update_layout(coloraxis_colorbar=dict(
293
  tickvals=event_type_tick_vals,
294
+ ticktext=['single\nrope', 'double\ndutch', 'double\nunders', 'single\nbounces', 'double\nbounces', 'triple\nunders', 'other'],
295
  title='event type'
296
  ))
297