Adjust layout
Browse files
app.py
CHANGED
|
@@ -89,7 +89,7 @@ def load_inferencer(checkpoint_path=None, device=None):
|
|
| 89 |
kwargs['pose2d_weights'] = checkpoint_path
|
| 90 |
else:
|
| 91 |
# default to rtmo-s
|
| 92 |
-
kwargs['pose2d'] = 'rtmo
|
| 93 |
return MMPoseInferencer(**kwargs)
|
| 94 |
|
| 95 |
# βββ Gradio prediction function βββ
|
|
@@ -132,31 +132,30 @@ def predict(image: Image.Image,
|
|
| 132 |
vis_img = Image.open(os.path.join(vis_dir, out_files[0])) if out_files else None
|
| 133 |
return vis_img, active
|
| 134 |
|
| 135 |
-
#
|
| 136 |
-
inputs = [
|
| 137 |
-
gr.Image(type="pil",
|
| 138 |
-
label="Upload Image"),
|
| 139 |
-
gr.Dropdown(label="Select Remote Checkpoint",
|
| 140 |
-
choices=list(REMOTE_CHECKPOINTS.keys()),
|
| 141 |
-
value=list(REMOTE_CHECKPOINTS.keys())[0]),
|
| 142 |
-
gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)"),
|
| 143 |
-
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.1, label="Bounding Box Threshold (bbox_thr)"),
|
| 144 |
-
gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.65, label="NMS Threshold (nms_thr)"),
|
| 145 |
-
]
|
| 146 |
-
outputs = [
|
| 147 |
-
gr.Image(type="pil", label="Annotated Image"),
|
| 148 |
-
gr.Textbox(label="Active Checkpoint", interactive=False),
|
| 149 |
-
]
|
| 150 |
-
|
| 151 |
-
demo = gr.Interface(
|
| 152 |
-
fn=predict,
|
| 153 |
-
inputs=inputs,
|
| 154 |
-
outputs=outputs,
|
| 155 |
-
title="RTMO Pose Demo",
|
| 156 |
-
description="Upload an image and select or upload a RTMO .pth checkpoint to get 2D pose annotation.",
|
| 157 |
-
)
|
| 158 |
|
| 159 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
demo.launch()
|
| 161 |
|
| 162 |
if __name__ == "__main__":
|
|
|
|
| 89 |
kwargs['pose2d_weights'] = checkpoint_path
|
| 90 |
else:
|
| 91 |
# default to rtmo-s
|
| 92 |
+
kwargs['pose2d'] = 'rtmo'
|
| 93 |
return MMPoseInferencer(**kwargs)
|
| 94 |
|
| 95 |
# βββ Gradio prediction function βββ
|
|
|
|
| 132 |
vis_img = Image.open(os.path.join(vis_dir, out_files[0])) if out_files else None
|
| 133 |
return vis_img, active
|
| 134 |
|
| 135 |
+
# Build Gradio UI with Blocks for improved layout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
def main():
|
| 138 |
+
with gr.Blocks() as demo:
|
| 139 |
+
gr.Markdown("## RTMO Pose Demo")
|
| 140 |
+
with gr.Row():
|
| 141 |
+
with gr.Column(scale=1, min_width=300):
|
| 142 |
+
img_input = gr.Image(type="pil", label="Upload Image")
|
| 143 |
+
remote_dd = gr.Dropdown(label="Select Remote Checkpoint",
|
| 144 |
+
choices=list(REMOTE_CHECKPOINTS.keys()),
|
| 145 |
+
value=list(REMOTE_CHECKPOINTS.keys())[0])
|
| 146 |
+
upload_ckpt = gr.File(file_types=['.pth'], label="Or Upload Your Own Checkpoint (optional)")
|
| 147 |
+
bbox_thr = gr.Slider(minimum=0.0, maximum=1.0, step=0.01,
|
| 148 |
+
value=0.1, label="Bounding Box Threshold")
|
| 149 |
+
nms_thr = gr.Slider(minimum=0.0, maximum=1.0, step=0.01,
|
| 150 |
+
value=0.65, label="NMS Threshold")
|
| 151 |
+
run_btn = gr.Button("Run Inference")
|
| 152 |
+
with gr.Column(scale=2):
|
| 153 |
+
output_img = gr.Image(type="pil", label="Annotated Image",
|
| 154 |
+
elem_id="output_image", interactive=False)
|
| 155 |
+
active_tb = gr.Textbox(label="Active Checkpoint", interactive=False)
|
| 156 |
+
run_btn.click(predict,
|
| 157 |
+
inputs=[img_input, remote_dd, upload_ckpt, bbox_thr, nms_thr],
|
| 158 |
+
outputs=[output_img, active_tb])
|
| 159 |
demo.launch()
|
| 160 |
|
| 161 |
if __name__ == "__main__":
|