YoungjaeDev
fix: CodeRabbit ๋ฆฌ๋ทฐ ๋ฐ˜์˜ - ๋ณด์•ˆ ์ทจ์•ฝ์  ์ˆ˜์ • ๋ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ์กฐ์ •
b95781b
raw
history blame
17.4 kB
#!/usr/bin/env python3
"""
Fall Detection Gradio App
YOLOv11-Pose + ST-GCN 2-stage ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•œ ๋‚™์ƒ ๊ฐ์ง€ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
HF Spaces Zero GPU ํ™˜๊ฒฝ์—์„œ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
์‚ฌ์šฉ๋ฒ• (๋กœ์ปฌ):
python demo_gradio/app.py
์‚ฌ์šฉ๋ฒ• (HF Spaces):
์ž๋™์œผ๋กœ app.py๊ฐ€ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
์ž‘์„ฑ์ž: Fall Detection Pipeline Team
์ž‘์„ฑ์ผ: 2025-11-26
"""
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Iterable, Optional, Tuple
import cv2
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import torch
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
# ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ๋ฅผ Python path์— ์ถ”๊ฐ€
# pipeline/demo_gradio/app.py -> pipeline -> project_root
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
# Zero GPU ํ˜ธํ™˜ ์„ค์ •
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
# -----------------------------------------------------------------------------
# ์ปค์Šคํ…€ ํ…Œ๋งˆ (PRITHIVSAKTHIUR ์Šคํƒ€์ผ)
# -----------------------------------------------------------------------------
colors.custom_color = colors.Color(
name="custom_color",
c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1",
c300="#7DB3D2", c400="#529AC3", c500="#4682B4",
c600="#3E72A0", c700="#36638C", c800="#2E5378",
c900="#264364", c950="#1E3450",
)
class CustomTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.custom_color,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
button_primary_text_color="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
slider_color="*secondary_500",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
)
custom_theme = CustomTheme()
# -----------------------------------------------------------------------------
# CSS ์Šคํƒ€์ผ
# -----------------------------------------------------------------------------
css = """
#col-container { margin: 0 auto; max-width: 1200px; }
#main-title h1 { font-size: 2.3em !important; }
.submit-btn {
background-color: #4682B4 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #5A9BD4 !important;
}
.result-label {
font-size: 1.5em !important;
font-weight: bold !important;
padding: 10px !important;
border-radius: 8px !important;
}
.fall-detected {
background-color: #FF4444 !important;
color: white !important;
}
.non-fall {
background-color: #44BB44 !important;
color: white !important;
}
"""
# -----------------------------------------------------------------------------
# ๋””๋ฐ”์ด์Šค ์„ค์ •
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------------------------
# GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ (๋กœ์ปฌ/HF Spaces ํ˜ธํ™˜)
# -----------------------------------------------------------------------------
def gpu_decorator(duration: int = 120):
"""๋กœ์ปฌ์—์„œ๋Š” ๊ทธ๋ƒฅ ์‹คํ–‰, Spaces์—์„œ๋Š” GPU ํ• ๋‹น"""
def decorator(func):
if SPACES_AVAILABLE:
return spaces.GPU(duration=duration)(func)
return func
return decorator
# -----------------------------------------------------------------------------
# ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (์ง€์—ฐ ๋กœ๋”ฉ)
# -----------------------------------------------------------------------------
_pipeline = None
def get_pipeline():
"""ํŒŒ์ดํ”„๋ผ์ธ ์‹ฑ๊ธ€ํ†ค ๋ฐ˜ํ™˜ (์ง€์—ฐ ๋กœ๋”ฉ)"""
global _pipeline
if _pipeline is None:
from pipeline.core.pipeline import FallDetectionPipeline
# HF Spaces์—์„œ๋Š” models ํด๋”์—์„œ ๋กœ๋“œ
pose_model_path = "pipeline/demo_gradio/models/yolo11m-pose.pt"
stgcn_checkpoint = "pipeline/demo_gradio/models/best_acc.pth"
# ๋กœ์ปฌ ๊ฒฝ๋กœ ํด๋ฐฑ
if not Path(pose_model_path).exists():
pose_model_path = "yolo11m-pose.pt"
if not Path(stgcn_checkpoint).exists():
stgcn_checkpoint = "runs/stgcn_binary_exp2_fixed_graph/best_acc.pth"
_pipeline = FallDetectionPipeline(
pose_model_path=pose_model_path,
stgcn_checkpoint=stgcn_checkpoint,
window_size=60,
conf_threshold=0.5,
fall_threshold=0.85, # ๊ฐ€์ด๋“œ๋ผ์ธ ๊ถŒ์žฅ: 0.8-0.9 (false positive <5%)
temporal_window=5,
stgcn_stride=5,
alert_duration=150,
post_fall_frames=15, # 2.5์ดˆ @ 30fps with stride=5 (๊ฐ€์ด๋“œ๋ผ์ธ: 2-3์ดˆ)
device=str(device),
debug=False,
headless=False,
viz_keypoints="all",
viz_scale=1.0,
viz_optimized=True
)
return _pipeline
# -----------------------------------------------------------------------------
# ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
# -----------------------------------------------------------------------------
def create_probability_graph(
frame_indices: list,
probabilities: list,
fall_threshold: float = 0.7
) -> go.Figure:
"""
๋‚™์ƒ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
Args:
frame_indices: ํ”„๋ ˆ์ž„ ์ธ๋ฑ์Šค ๋ฆฌ์ŠคํŠธ
probabilities: ๋‚™์ƒ ํ™•๋ฅ  ๋ฆฌ์ŠคํŠธ (0.0-1.0)
fall_threshold: ๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’
Returns:
Plotly Figure ๊ฐ์ฒด
"""
fig = go.Figure()
# ํ™•๋ฅ  ๋ผ์ธ
fig.add_trace(go.Scatter(
x=frame_indices,
y=probabilities,
mode='lines',
name='Fall Probability',
line=dict(color='#4682B4', width=2),
fill='tozeroy',
fillcolor='rgba(70, 130, 180, 0.3)'
))
# ์ž„๊ณ„๊ฐ’ ๋ผ์ธ
fig.add_hline(
y=fall_threshold,
line_dash="dash",
line_color="red",
annotation_text=f"Threshold ({fall_threshold})",
annotation_position="right"
)
# ๋ ˆ์ด์•„์›ƒ
fig.update_layout(
title="Fall Detection Probability Over Time",
xaxis_title="Frame",
yaxis_title="Probability",
yaxis=dict(range=[0, 1]),
template="plotly_white",
height=300,
margin=dict(l=50, r=50, t=50, b=50),
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
return fig
# -----------------------------------------------------------------------------
# ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
# -----------------------------------------------------------------------------
@gpu_decorator(duration=120)
def process_video(
video_path: str,
fall_threshold: float,
viz_keypoints: str,
progress: gr.Progress = gr.Progress()
) -> Tuple[Optional[str], Optional[go.Figure], str]:
"""
๋น„๋””์˜ค ์ฒ˜๋ฆฌ ๋ฐ ๋‚™์ƒ ๊ฐ์ง€
Args:
video_path: ์ž…๋ ฅ ๋น„๋””์˜ค ๊ฒฝ๋กœ
fall_threshold: ๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’ (0.0-1.0)
viz_keypoints: ํ‚คํฌ์ธํŠธ ํ‘œ์‹œ ๋ชจ๋“œ ('all' ๋˜๋Š” 'major')
progress: Gradio ์ง„ํ–‰๋ฅ  ํ‘œ์‹œ
Returns:
output_video_path: ๊ฒฐ๊ณผ ๋น„๋””์˜ค ๊ฒฝ๋กœ
probability_graph: ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„
result_text: ์ตœ์ข… ํŒ์ • ํ…์ŠคํŠธ
"""
if video_path is None:
return None, None, "๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
try:
# ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ
progress(0.1, desc="๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
pipeline = get_pipeline()
pipeline.fall_threshold = fall_threshold
pipeline.stgcn_classifier.fall_threshold = fall_threshold
pipeline.viz_keypoints = viz_keypoints
pipeline.reset()
# ๋น„๋””์˜ค ์—ด๊ธฐ
progress(0.2, desc="๋น„๋””์˜ค ์—ด๊ธฐ...")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, None, "๋น„๋””์˜ค๋ฅผ ์—ด ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
# ๋น„๋””์˜ค ์ •๋ณด
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# ๋น„๋””์˜ค ๊ธธ์ด ๊ฒ€์ฆ (120s GPU ํƒ€์ž„์•„์›ƒ ๋Œ€๋น„)
if fps > 0:
video_duration = total_frames / fps
# ์ฒ˜๋ฆฌ ์‹œ๊ฐ„ ์ถ”์ •: ๋Œ€๋žต ์‹ค์‹œ๊ฐ„์˜ 1.5๋ฐฐ + ์ธ์ฝ”๋”ฉ 10์ดˆ
estimated_time = video_duration * 1.5 + 10
if estimated_time > 110: # 120s ํƒ€์ž„์•„์›ƒ์— ์—ฌ์œ  ๋‘๊ธฐ
cap.release()
return None, None, (
f"๋น„๋””์˜ค๊ฐ€ ๋„ˆ๋ฌด ๊น๋‹ˆ๋‹ค. "
f"๋น„๋””์˜ค ๊ธธ์ด: {video_duration:.1f}์ดˆ, "
f"์˜ˆ์ƒ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„: {estimated_time:.1f}์ดˆ (์ œํ•œ: 110์ดˆ). "
f"60์ดˆ ์ด๋‚ด์˜ ๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”."
)
# ์ถœ๋ ฅ ๋น„๋””์˜ค ์„ค์ • (๋ณด์•ˆ: NamedTemporaryFile ์‚ฌ์šฉ)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
output_path = tmp.name
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# Info panel ์ถ”๊ฐ€๋กœ ๋†’์ด 80px ์ฆ๊ฐ€
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height + 80))
# ์ฒ˜๋ฆฌ ๋ฃจํ”„
frame_idx = 0
frame_indices = []
probabilities = []
fall_detected = False
max_confidence = 0.0
while True:
ret, frame = cap.read()
if not ret:
break
# ํ”„๋ ˆ์ž„ ์ฒ˜๋ฆฌ
vis_frame, info = pipeline.process_frame(frame, frame_idx)
# ํ™•๋ฅ  ๊ธฐ๋ก
if info['confidence'] is not None:
frame_indices.append(frame_idx)
probabilities.append(info['confidence'])
max_confidence = max(max_confidence, info['confidence'])
# ๋‚™์ƒ ๊ฐ์ง€ ํ™•์ธ
if info['alert']:
fall_detected = True
# ์ถœ๋ ฅ ์ €์žฅ
out.write(vis_frame)
frame_idx += 1
# ์ง„ํ–‰๋ฅ  ์—…๋ฐ์ดํŠธ
if frame_idx % 10 == 0:
progress_val = 0.2 + 0.7 * (frame_idx / total_frames)
progress(progress_val, desc=f"์ฒ˜๋ฆฌ ์ค‘... ({frame_idx}/{total_frames})")
# ๋ฆฌ์†Œ์Šค ํ•ด์ œ
cap.release()
out.release()
# H.264 ์ฝ”๋ฑ์œผ๋กœ ์žฌ์ธ์ฝ”๋”ฉ (๋ธŒ๋ผ์šฐ์ € ํ˜ธํ™˜)
progress(0.9, desc="๋น„๋””์˜ค ์ธ์ฝ”๋”ฉ ์ค‘...")
# ๋ณด์•ˆ: NamedTemporaryFile ์‚ฌ์šฉ (CWE-377 ๋ฐฉ์ง€)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
output_h264 = tmp.name
# ๋ณด์•ˆ: subprocess.run ์‚ฌ์šฉ (shell injection ๋ฐฉ์ง€)
subprocess.run(
[
'ffmpeg', '-y', '-i', output_path,
'-c:v', 'libx264', '-preset', 'fast', '-crf', '23',
output_h264, '-loglevel', 'quiet'
],
check=False,
capture_output=True
)
# mp4v ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ
if os.path.exists(output_path):
os.remove(output_path)
# H.264 ๋ณ€ํ™˜ ์„ฑ๊ณต ์—ฌ๋ถ€ ํ™•์ธ
if os.path.exists(output_h264):
final_output = output_h264
else:
final_output = output_path # ํด๋ฐฑ
# ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
progress(0.95, desc="๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ์ค‘...")
if frame_indices and probabilities:
fig = create_probability_graph(frame_indices, probabilities, fall_threshold)
else:
fig = None
# ์ตœ์ข… ํŒ์ •
progress(1.0, desc="์™„๋ฃŒ!")
if fall_detected:
result_text = f"[FALL DETECTED] ๋‚™์ƒ์ด ๊ฐ์ง€๋˜์—ˆ์Šต๋‹ˆ๋‹ค! (์ตœ๋Œ€ ํ™•๋ฅ : {max_confidence:.1%})"
else:
result_text = f"[Non-Fall] ๋‚™์ƒ์ด ๊ฐ์ง€๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. (์ตœ๋Œ€ ํ™•๋ฅ : {max_confidence:.1%})"
return final_output, fig, result_text
except Exception as e:
import traceback
error_msg = f"์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
return None, None, error_msg
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
def create_demo() -> gr.Blocks:
"""Gradio ๋ฐ๋ชจ ์ƒ์„ฑ"""
with gr.Blocks() as demo:
gr.Markdown(
"""
# Fall Detection Demo
YOLOv11-Pose + ST-GCN 2-stage ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•œ ์‹ค์‹œ๊ฐ„ ๋‚™์ƒ ๊ฐ์ง€ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ๋‚™์ƒ ์—ฌ๋ถ€๋ฅผ ๋ถ„์„ํ•˜๊ณ , ๊ฒฐ๊ณผ ๋น„๋””์˜ค์™€ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
**ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์„ฑ:**
- Stage 1: YOLOv11m-pose (Pose Estimation)
- Stage 2: ST-GCN (Temporal Classification)
- Window Size: 60 frames (2์ดˆ @ 30fps)
""",
elem_id="main-title"
)
with gr.Row():
with gr.Column(scale=1):
# ์ž…๋ ฅ ์„น์…˜
gr.Markdown("### ์ž…๋ ฅ")
video_input = gr.Video(
label="๋น„๋””์˜ค ์—…๋กœ๋“œ",
sources=["upload"],
)
with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
fall_threshold = gr.Slider(
minimum=0.7,
maximum=0.95,
value=0.85,
step=0.05,
label="๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’",
info="๊ถŒ์žฅ: 0.8-0.9 (false positive <5% ๋ชฉํ‘œ)"
)
viz_keypoints = gr.Radio(
choices=["all", "major"],
value="all",
label="ํ‚คํฌ์ธํŠธ ํ‘œ์‹œ",
info="all: ์ „์ฒด 17๊ฐœ, major: ์ฃผ์š” 9๊ฐœ"
)
submit_btn = gr.Button(
"๋ถ„์„ ์‹œ์ž‘",
variant="primary",
elem_classes="submit-btn"
)
with gr.Column(scale=1):
# ์ถœ๋ ฅ ์„น์…˜
gr.Markdown("### ๊ฒฐ๊ณผ")
result_text = gr.Textbox(
label="ํŒ์ • ๊ฒฐ๊ณผ",
lines=2,
interactive=False
)
video_output = gr.Video(
label="๊ฒฐ๊ณผ ๋น„๋””์˜ค",
)
prob_graph = gr.Plot(
label="๋‚™์ƒ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„",
)
# ์˜ˆ์ œ ๋น„๋””์˜ค
gr.Markdown("### ์˜ˆ์ œ ๋น„๋””์˜ค")
example_dir = Path(__file__).parent / "examples"
examples = []
if example_dir.exists():
for ext in ["*.mp4", "*.avi", "*.mov"]:
examples.extend([str(p) for p in example_dir.glob(ext)])
if examples:
gr.Examples(
examples=[[ex, 0.85, "all"] for ex in examples[:3]],
inputs=[video_input, fall_threshold, viz_keypoints],
outputs=[video_output, prob_graph, result_text],
fn=process_video,
cache_examples=False,
)
# ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
submit_btn.click(
fn=process_video,
inputs=[video_input, fall_threshold, viz_keypoints],
outputs=[video_output, prob_graph, result_text],
)
# ํ‘ธํ„ฐ
gr.Markdown(
"""
---
**References:**
- [YOLOv11](https://github.com/ultralytics/ultralytics) - Pose Estimation
- [ST-GCN](https://arxiv.org/abs/1801.07455) - Spatial Temporal Graph Convolutional Networks
- AI Hub Fall Detection Dataset
"""
)
return demo
# -----------------------------------------------------------------------------
# ๋ฉ”์ธ ์‹คํ–‰
# -----------------------------------------------------------------------------
if __name__ == "__main__":
demo = create_demo()
demo.queue(max_size=10).launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True,
theme=custom_theme,
css=css,
)