Spaces:
Build error
Build error
Commit ·
8585ab0
1
Parent(s): 823860d
test higher resolution
Browse files
app.py
CHANGED
|
@@ -27,6 +27,8 @@ from hls_download import download_clips
|
|
| 27 |
|
| 28 |
plt.style.use('dark_background')
|
| 29 |
|
|
|
|
|
|
|
| 30 |
onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
|
| 31 |
# model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
|
| 32 |
# hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
|
|
@@ -46,7 +48,7 @@ else:
|
|
| 46 |
ort_sess = ort.InferenceSession(onnx_file)
|
| 47 |
|
| 48 |
# warmup inference
|
| 49 |
-
ort_sess.run(None, {'video': np.zeros((4, 64, 3,
|
| 50 |
|
| 51 |
|
| 52 |
class SquarePad:
|
|
@@ -86,7 +88,7 @@ def run_inference(batch_X):
|
|
| 86 |
|
| 87 |
@spaces.GPU()
|
| 88 |
def inference(x, count_only_api, api_key,
|
| 89 |
-
img_size=
|
| 90 |
miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
|
| 91 |
api_call=False,
|
| 92 |
progress=gr.Progress()):
|
|
|
|
| 27 |
|
| 28 |
plt.style.use('dark_background')
|
| 29 |
|
| 30 |
+
IMG_SIZE = 256
|
| 31 |
+
|
| 32 |
onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
|
| 33 |
# model_xml = hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.xml", repo_type="model", token=os.environ['DATASET_SECRET'])
|
| 34 |
# hf_hub_download(repo_id="dylanplummer/ropenet", filename="model.mapping", repo_type="model", token=os.environ['DATASET_SECRET'])
|
|
|
|
| 48 |
ort_sess = ort.InferenceSession(onnx_file)
|
| 49 |
|
| 50 |
# warmup inference
|
| 51 |
+
ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)})
|
| 52 |
|
| 53 |
|
| 54 |
class SquarePad:
|
|
|
|
| 88 |
|
| 89 |
@spaces.GPU()
|
| 90 |
def inference(x, count_only_api, api_key,
|
| 91 |
+
img_size=IMG_SIZE, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
|
| 92 |
miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
|
| 93 |
api_call=False,
|
| 94 |
progress=gr.Progress()):
|