File size: 2,649 Bytes
e99c165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""Gradio web application for ViTPose human pose estimation."""

import logging
import os

import gradio as gr
from PIL import Image

from src.predictor import Predictor

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

predictor = Predictor()


def estimate_pose(image: Image.Image | None) -> Image.Image | None:
    """Run pose estimation on the input image.

    Args:
        image: Input PIL image from Gradio upload.

    Returns:
        PIL image with skeleton overlay, or None if no image provided.
    """
    if image is None:
        return None

    if not predictor._initialized:
        predictor.initialize(device="cuda")

    return predictor.predict(image)


def build_app() -> gr.Blocks:
    """Build the Gradio Blocks application.

    Returns:
        Configured Gradio Blocks instance.
    """
    with gr.Blocks(
        title="ViTPose - Human Pose Estimation",
        theme=gr.themes.Soft(),
    ) as demo:
        gr.Markdown(
            """
            # ViTPose - Human Pose Estimation
            Upload an image to detect human poses and visualize the skeleton overlay.

            **Model**: [ViTPose-base-simple](https://huggingface.co/usyd-community/vitpose-base-simple)
            (Xu et al., 2022) with COCO 17 keypoints.
            """
        )

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(
                    label="Input Image",
                    type="pil",
                    sources=["upload", "clipboard"],
                )
                submit_button = gr.Button("Estimate Pose", variant="primary")

            with gr.Column():
                output_image = gr.Image(
                    label="Pose Estimation Result",
                    type="pil",
                )

        submit_button.click(
            fn=estimate_pose,
            inputs=[input_image],
            outputs=[output_image],
        )

        input_image.change(
            fn=estimate_pose,
            inputs=[input_image],
            outputs=[output_image],
        )

        gr.Markdown(
            """
            ---
            **Note**: This demo uses the full image as a bounding box for
            single-person pose estimation. For multi-person scenarios,
            an object detector (e.g., RT-DETR) would be used upstream.
            """
        )

    return demo


# Module-level demo for HF Spaces
demo = build_app()

if __name__ == "__main__":
    port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
    demo.launch(server_port=port)