improve model uploader ui
Browse files
app.py
CHANGED
|
@@ -8,7 +8,6 @@ from PIL import Image
|
|
| 8 |
spec = importlib.util.find_spec('mmdet')
|
| 9 |
if spec and spec.origin:
|
| 10 |
src = open(spec.origin, encoding='utf-8').read()
|
| 11 |
-
# strip out the mmcv_minimum_version…assert… block up to __all__
|
| 12 |
patched = re.sub(r'(?ms)^[ \t]*mmcv_minimum_version.*?^__all__', '__all__', src)
|
| 13 |
m = importlib.util.module_from_spec(spec)
|
| 14 |
m.__loader__ = spec.loader
|
|
@@ -28,16 +27,18 @@ def load_inferencer(checkpoint_path=None, device=None):
|
|
| 28 |
|
| 29 |
# ——— Gradio prediction function ———
|
| 30 |
@spaces.GPU()
|
| 31 |
-
def predict(image: Image.Image, checkpoint
|
| 32 |
# save upload to temp file
|
| 33 |
inp_path = "/tmp/upload.jpg"
|
| 34 |
image.save(inp_path)
|
| 35 |
|
|
|
|
|
|
|
|
|
|
| 36 |
vis_dir = "/tmp/vis"
|
| 37 |
os.makedirs(vis_dir, exist_ok=True)
|
| 38 |
|
| 39 |
-
inferencer = load_inferencer(checkpoint_path=
|
| 40 |
-
# run inference & visualization
|
| 41 |
for result in inferencer(
|
| 42 |
inputs=inp_path,
|
| 43 |
bbox_thr=0.1,
|
|
@@ -48,7 +49,6 @@ def predict(image: Image.Image, checkpoint: str = None):
|
|
| 48 |
):
|
| 49 |
pass
|
| 50 |
|
| 51 |
-
# return the first visualization
|
| 52 |
out_files = sorted(os.listdir(vis_dir))
|
| 53 |
if out_files:
|
| 54 |
return Image.open(os.path.join(vis_dir, out_files[0]))
|
|
@@ -59,11 +59,11 @@ demo = gr.Interface(
|
|
| 59 |
fn=predict,
|
| 60 |
inputs=[
|
| 61 |
gr.Image(type="pil", label="Upload Image"),
|
| 62 |
-
gr.
|
| 63 |
],
|
| 64 |
outputs=gr.Image(type="pil", label="Annotated Image"),
|
| 65 |
title="RTMO Pose Demo",
|
| 66 |
-
description="Upload an image
|
| 67 |
)
|
| 68 |
|
| 69 |
def main():
|
|
|
|
| 8 |
spec = importlib.util.find_spec('mmdet')
|
| 9 |
if spec and spec.origin:
|
| 10 |
src = open(spec.origin, encoding='utf-8').read()
|
|
|
|
| 11 |
patched = re.sub(r'(?ms)^[ \t]*mmcv_minimum_version.*?^__all__', '__all__', src)
|
| 12 |
m = importlib.util.module_from_spec(spec)
|
| 13 |
m.__loader__ = spec.loader
|
|
|
|
| 27 |
|
| 28 |
# ——— Gradio prediction function ———
|
| 29 |
@spaces.GPU()
|
| 30 |
+
def predict(image: Image.Image, checkpoint):
|
| 31 |
# save upload to temp file
|
| 32 |
inp_path = "/tmp/upload.jpg"
|
| 33 |
image.save(inp_path)
|
| 34 |
|
| 35 |
+
# determine checkpoint path if user uploaded
|
| 36 |
+
ckpt_path = checkpoint.name if checkpoint else None
|
| 37 |
+
|
| 38 |
vis_dir = "/tmp/vis"
|
| 39 |
os.makedirs(vis_dir, exist_ok=True)
|
| 40 |
|
| 41 |
+
inferencer = load_inferencer(checkpoint_path=ckpt_path, device=None)
|
|
|
|
| 42 |
for result in inferencer(
|
| 43 |
inputs=inp_path,
|
| 44 |
bbox_thr=0.1,
|
|
|
|
| 49 |
):
|
| 50 |
pass
|
| 51 |
|
|
|
|
| 52 |
out_files = sorted(os.listdir(vis_dir))
|
| 53 |
if out_files:
|
| 54 |
return Image.open(os.path.join(vis_dir, out_files[0]))
|
|
|
|
| 59 |
fn=predict,
|
| 60 |
inputs=[
|
| 61 |
gr.Image(type="pil", label="Upload Image"),
|
| 62 |
+
gr.File(file_types=['.pth'], label="Upload RTMO .pth Checkpoint (optional)")
|
| 63 |
],
|
| 64 |
outputs=gr.Image(type="pil", label="Annotated Image"),
|
| 65 |
title="RTMO Pose Demo",
|
| 66 |
+
description="Upload an image and (optionally) a RTMO .pth checkpoint to get 2D pose annotation.",
|
| 67 |
)
|
| 68 |
|
| 69 |
def main():
|