YoungjaeDev
feat: HF Space Gradio ์•ฑ ๊ตฌํ˜„ (Zero GPU)
cacaeb4
raw
history blame
16.2 kB
#!/usr/bin/env python3
"""
Fall Detection Gradio App
YOLOv11-Pose + ST-GCN 2-stage ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•œ ๋‚™์ƒ ๊ฐ์ง€ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
HF Spaces Zero GPU ํ™˜๊ฒฝ์—์„œ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
์‚ฌ์šฉ๋ฒ• (๋กœ์ปฌ):
python demo_gradio/app.py
์‚ฌ์šฉ๋ฒ• (HF Spaces):
์ž๋™์œผ๋กœ app.py๊ฐ€ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
์ž‘์„ฑ์ž: Fall Detection Pipeline Team
์ž‘์„ฑ์ผ: 2025-11-26
"""
import os
import sys
import tempfile
import time
from pathlib import Path
from typing import Iterable, Optional, Tuple
import cv2
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import torch
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
# ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ๋ฅผ Python path์— ์ถ”๊ฐ€
# pipeline/demo_gradio/app.py -> pipeline -> project_root
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
# Zero GPU ํ˜ธํ™˜ ์„ค์ •
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
# -----------------------------------------------------------------------------
# ์ปค์Šคํ…€ ํ…Œ๋งˆ (PRITHIVSAKTHIUR ์Šคํƒ€์ผ)
# -----------------------------------------------------------------------------
colors.custom_color = colors.Color(
name="custom_color",
c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1",
c300="#7DB3D2", c400="#529AC3", c500="#4682B4",
c600="#3E72A0", c700="#36638C", c800="#2E5378",
c900="#264364", c950="#1E3450",
)
class CustomTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.custom_color,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
button_primary_text_color="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
slider_color="*secondary_500",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
)
custom_theme = CustomTheme()
# -----------------------------------------------------------------------------
# CSS ์Šคํƒ€์ผ
# -----------------------------------------------------------------------------
css = """
#col-container { margin: 0 auto; max-width: 1200px; }
#main-title h1 { font-size: 2.3em !important; }
.submit-btn {
background-color: #4682B4 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #5A9BD4 !important;
}
.result-label {
font-size: 1.5em !important;
font-weight: bold !important;
padding: 10px !important;
border-radius: 8px !important;
}
.fall-detected {
background-color: #FF4444 !important;
color: white !important;
}
.non-fall {
background-color: #44BB44 !important;
color: white !important;
}
"""
# -----------------------------------------------------------------------------
# ๋””๋ฐ”์ด์Šค ์„ค์ •
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------------------------
# GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ (๋กœ์ปฌ/HF Spaces ํ˜ธํ™˜)
# -----------------------------------------------------------------------------
def gpu_decorator(duration: int = 120):
"""๋กœ์ปฌ์—์„œ๋Š” ๊ทธ๋ƒฅ ์‹คํ–‰, Spaces์—์„œ๋Š” GPU ํ• ๋‹น"""
def decorator(func):
if SPACES_AVAILABLE:
return spaces.GPU(duration=duration)(func)
return func
return decorator
# -----------------------------------------------------------------------------
# ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (์ง€์—ฐ ๋กœ๋”ฉ)
# -----------------------------------------------------------------------------
_pipeline = None
def get_pipeline():
"""ํŒŒ์ดํ”„๋ผ์ธ ์‹ฑ๊ธ€ํ†ค ๋ฐ˜ํ™˜ (์ง€์—ฐ ๋กœ๋”ฉ)"""
global _pipeline
if _pipeline is None:
from pipeline.core.pipeline import FallDetectionPipeline
# HF Spaces์—์„œ๋Š” models ํด๋”์—์„œ ๋กœ๋“œ
pose_model_path = "pipeline/demo_gradio/models/yolo11m-pose.pt"
stgcn_checkpoint = "pipeline/demo_gradio/models/best_acc.pth"
# ๋กœ์ปฌ ๊ฒฝ๋กœ ํด๋ฐฑ
if not Path(pose_model_path).exists():
pose_model_path = "yolo11m-pose.pt"
if not Path(stgcn_checkpoint).exists():
stgcn_checkpoint = "runs/stgcn_binary_exp2_fixed_graph/best_acc.pth"
_pipeline = FallDetectionPipeline(
pose_model_path=pose_model_path,
stgcn_checkpoint=stgcn_checkpoint,
window_size=60,
conf_threshold=0.5,
fall_threshold=0.7,
temporal_window=5,
stgcn_stride=5,
alert_duration=150,
post_fall_frames=3,
device=str(device),
debug=False,
headless=False,
viz_keypoints="all",
viz_scale=1.0,
viz_optimized=True
)
return _pipeline
# -----------------------------------------------------------------------------
# ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
# -----------------------------------------------------------------------------
def create_probability_graph(
frame_indices: list,
probabilities: list,
fall_threshold: float = 0.7
) -> go.Figure:
"""
๋‚™์ƒ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
Args:
frame_indices: ํ”„๋ ˆ์ž„ ์ธ๋ฑ์Šค ๋ฆฌ์ŠคํŠธ
probabilities: ๋‚™์ƒ ํ™•๋ฅ  ๋ฆฌ์ŠคํŠธ (0.0-1.0)
fall_threshold: ๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’
Returns:
Plotly Figure ๊ฐ์ฒด
"""
fig = go.Figure()
# ํ™•๋ฅ  ๋ผ์ธ
fig.add_trace(go.Scatter(
x=frame_indices,
y=probabilities,
mode='lines',
name='Fall Probability',
line=dict(color='#4682B4', width=2),
fill='tozeroy',
fillcolor='rgba(70, 130, 180, 0.3)'
))
# ์ž„๊ณ„๊ฐ’ ๋ผ์ธ
fig.add_hline(
y=fall_threshold,
line_dash="dash",
line_color="red",
annotation_text=f"Threshold ({fall_threshold})",
annotation_position="right"
)
# ๋ ˆ์ด์•„์›ƒ
fig.update_layout(
title="Fall Detection Probability Over Time",
xaxis_title="Frame",
yaxis_title="Probability",
yaxis=dict(range=[0, 1]),
template="plotly_white",
height=300,
margin=dict(l=50, r=50, t=50, b=50),
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
return fig
# -----------------------------------------------------------------------------
# ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
# -----------------------------------------------------------------------------
@gpu_decorator(duration=120)
def process_video(
video_path: str,
fall_threshold: float,
viz_keypoints: str,
progress: gr.Progress = gr.Progress()
) -> Tuple[Optional[str], Optional[go.Figure], str]:
"""
๋น„๋””์˜ค ์ฒ˜๋ฆฌ ๋ฐ ๋‚™์ƒ ๊ฐ์ง€
Args:
video_path: ์ž…๋ ฅ ๋น„๋””์˜ค ๊ฒฝ๋กœ
fall_threshold: ๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’ (0.0-1.0)
viz_keypoints: ํ‚คํฌ์ธํŠธ ํ‘œ์‹œ ๋ชจ๋“œ ('all' ๋˜๋Š” 'major')
progress: Gradio ์ง„ํ–‰๋ฅ  ํ‘œ์‹œ
Returns:
output_video_path: ๊ฒฐ๊ณผ ๋น„๋””์˜ค ๊ฒฝ๋กœ
probability_graph: ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„
result_text: ์ตœ์ข… ํŒ์ • ํ…์ŠคํŠธ
"""
if video_path is None:
return None, None, "๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
try:
# ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ
progress(0.1, desc="๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
pipeline = get_pipeline()
pipeline.fall_threshold = fall_threshold
pipeline.stgcn_classifier.fall_threshold = fall_threshold
pipeline.viz_keypoints = viz_keypoints
pipeline.reset()
# ๋น„๋””์˜ค ์—ด๊ธฐ
progress(0.2, desc="๋น„๋””์˜ค ์—ด๊ธฐ...")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, None, "๋น„๋””์˜ค๋ฅผ ์—ด ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
# ๋น„๋””์˜ค ์ •๋ณด
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# ์ถœ๋ ฅ ๋น„๋””์˜ค ์„ค์ •
output_path = tempfile.mktemp(suffix=".mp4")
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# Info panel ์ถ”๊ฐ€๋กœ ๋†’์ด 80px ์ฆ๊ฐ€
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height + 80))
# ์ฒ˜๋ฆฌ ๋ฃจํ”„
frame_idx = 0
frame_indices = []
probabilities = []
fall_detected = False
max_confidence = 0.0
while True:
ret, frame = cap.read()
if not ret:
break
# ํ”„๋ ˆ์ž„ ์ฒ˜๋ฆฌ
vis_frame, info = pipeline.process_frame(frame, frame_idx)
# ํ™•๋ฅ  ๊ธฐ๋ก
if info['confidence'] is not None:
frame_indices.append(frame_idx)
probabilities.append(info['confidence'])
max_confidence = max(max_confidence, info['confidence'])
# ๋‚™์ƒ ๊ฐ์ง€ ํ™•์ธ
if info['alert']:
fall_detected = True
# ์ถœ๋ ฅ ์ €์žฅ
out.write(vis_frame)
frame_idx += 1
# ์ง„ํ–‰๋ฅ  ์—…๋ฐ์ดํŠธ
if frame_idx % 10 == 0:
progress_val = 0.2 + 0.7 * (frame_idx / total_frames)
progress(progress_val, desc=f"์ฒ˜๋ฆฌ ์ค‘... ({frame_idx}/{total_frames})")
# ๋ฆฌ์†Œ์Šค ํ•ด์ œ
cap.release()
out.release()
# H.264 ์ฝ”๋ฑ์œผ๋กœ ์žฌ์ธ์ฝ”๋”ฉ (๋ธŒ๋ผ์šฐ์ € ํ˜ธํ™˜)
progress(0.9, desc="๋น„๋””์˜ค ์ธ์ฝ”๋”ฉ ์ค‘...")
output_h264 = tempfile.mktemp(suffix=".mp4")
os.system(f'ffmpeg -y -i "{output_path}" -c:v libx264 -preset fast -crf 23 "{output_h264}" -loglevel quiet')
# mp4v ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ
if os.path.exists(output_path):
os.remove(output_path)
# H.264 ๋ณ€ํ™˜ ์„ฑ๊ณต ์—ฌ๋ถ€ ํ™•์ธ
if os.path.exists(output_h264):
final_output = output_h264
else:
final_output = output_path # ํด๋ฐฑ
# ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
progress(0.95, desc="๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ์ค‘...")
if frame_indices and probabilities:
fig = create_probability_graph(frame_indices, probabilities, fall_threshold)
else:
fig = None
# ์ตœ์ข… ํŒ์ •
progress(1.0, desc="์™„๋ฃŒ!")
if fall_detected:
result_text = f"[FALL DETECTED] ๋‚™์ƒ์ด ๊ฐ์ง€๋˜์—ˆ์Šต๋‹ˆ๋‹ค! (์ตœ๋Œ€ ํ™•๋ฅ : {max_confidence:.1%})"
else:
result_text = f"[Non-Fall] ๋‚™์ƒ์ด ๊ฐ์ง€๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. (์ตœ๋Œ€ ํ™•๋ฅ : {max_confidence:.1%})"
return final_output, fig, result_text
except Exception as e:
import traceback
error_msg = f"์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
return None, None, error_msg
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
def create_demo() -> gr.Blocks:
"""Gradio ๋ฐ๋ชจ ์ƒ์„ฑ"""
with gr.Blocks() as demo:
gr.Markdown(
"""
# Fall Detection Demo
YOLOv11-Pose + ST-GCN 2-stage ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•œ ์‹ค์‹œ๊ฐ„ ๋‚™์ƒ ๊ฐ์ง€ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ๋‚™์ƒ ์—ฌ๋ถ€๋ฅผ ๋ถ„์„ํ•˜๊ณ , ๊ฒฐ๊ณผ ๋น„๋””์˜ค์™€ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
**ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์„ฑ:**
- Stage 1: YOLOv11m-pose (Pose Estimation)
- Stage 2: ST-GCN (Temporal Classification)
- Window Size: 60 frames (2์ดˆ @ 30fps)
""",
elem_id="main-title"
)
with gr.Row():
with gr.Column(scale=1):
# ์ž…๋ ฅ ์„น์…˜
gr.Markdown("### ์ž…๋ ฅ")
video_input = gr.Video(
label="๋น„๋””์˜ค ์—…๋กœ๋“œ",
sources=["upload"],
)
with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
fall_threshold = gr.Slider(
minimum=0.5,
maximum=0.95,
value=0.7,
step=0.05,
label="๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’",
info="์ด ๊ฐ’ ์ด์ƒ์˜ ํ™•๋ฅ ์ด๋ฉด ๋‚™์ƒ์œผ๋กœ ํŒ์ •ํ•ฉ๋‹ˆ๋‹ค."
)
viz_keypoints = gr.Radio(
choices=["all", "major"],
value="all",
label="ํ‚คํฌ์ธํŠธ ํ‘œ์‹œ",
info="all: ์ „์ฒด 17๊ฐœ, major: ์ฃผ์š” 9๊ฐœ"
)
submit_btn = gr.Button(
"๋ถ„์„ ์‹œ์ž‘",
variant="primary",
elem_classes="submit-btn"
)
with gr.Column(scale=1):
# ์ถœ๋ ฅ ์„น์…˜
gr.Markdown("### ๊ฒฐ๊ณผ")
result_text = gr.Textbox(
label="ํŒ์ • ๊ฒฐ๊ณผ",
lines=2,
interactive=False
)
video_output = gr.Video(
label="๊ฒฐ๊ณผ ๋น„๋””์˜ค",
)
prob_graph = gr.Plot(
label="๋‚™์ƒ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„",
)
# ์˜ˆ์ œ ๋น„๋””์˜ค
gr.Markdown("### ์˜ˆ์ œ ๋น„๋””์˜ค")
example_dir = Path(__file__).parent / "examples"
examples = []
if example_dir.exists():
for ext in ["*.mp4", "*.avi", "*.mov"]:
examples.extend([str(p) for p in example_dir.glob(ext)])
if examples:
gr.Examples(
examples=[[ex, 0.7, "all"] for ex in examples[:3]],
inputs=[video_input, fall_threshold, viz_keypoints],
outputs=[video_output, prob_graph, result_text],
fn=process_video,
cache_examples=False,
)
# ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
submit_btn.click(
fn=process_video,
inputs=[video_input, fall_threshold, viz_keypoints],
outputs=[video_output, prob_graph, result_text],
)
# ํ‘ธํ„ฐ
gr.Markdown(
"""
---
**References:**
- [YOLOv11](https://github.com/ultralytics/ultralytics) - Pose Estimation
- [ST-GCN](https://arxiv.org/abs/1801.07455) - Spatial Temporal Graph Convolutional Networks
- AI Hub Fall Detection Dataset
"""
)
return demo
# -----------------------------------------------------------------------------
# ๋ฉ”์ธ ์‹คํ–‰
# -----------------------------------------------------------------------------
if __name__ == "__main__":
demo = create_demo()
demo.queue(max_size=10).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
theme=custom_theme,
css=css,
)