dylanplummer commited on
Commit
8585ab0
·
1 Parent(s): 823860d

test higher resolution

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -27,6 +27,8 @@ from hls_download import download_clips
27
 
28
  plt.style.use('dark_background')
29
 
 
 
30
  onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
31
  # model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
32
  # hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
@@ -46,7 +48,7 @@ else:
46
  ort_sess = ort.InferenceSession(onnx_file)
47
 
48
  # warmup inference
49
- ort_sess.run(None, {'video': np.zeros((4, 64, 3, 224, 224), dtype=np.float32)})
50
 
51
 
52
  class SquarePad:
@@ -86,7 +88,7 @@ def run_inference(batch_X):
86
 
87
  @spaces.GPU()
88
  def inference(x, count_only_api, api_key,
89
- img_size=224, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
90
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
91
  api_call=False,
92
  progress=gr.Progress()):
 
27
 
28
  plt.style.use('dark_background')
29
 
30
+ IMG_SIZE = 256
31
+
32
  onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
33
  # model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
34
  # hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
 
48
  ort_sess = ort.InferenceSession(onnx_file)
49
 
50
  # warmup inference
51
+ ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)})
52
 
53
 
54
  class SquarePad:
 
88
 
89
  @spaces.GPU()
90
  def inference(x, count_only_api, api_key,
91
+ img_size=IMG_SIZE, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
92
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
93
  api_call=False,
94
  progress=gr.Progress()):