YoungjaeDev
feat(batch): ๋ฐฐ์น˜ ์ถ”๋ก  ๋ฐ ์Šค๋งˆํŠธ ํด๋ฆฝ ์ถ”์ถœ ๊ตฌํ˜„ (Issue #77, #78, #82)
0ea4706
raw
history blame
22.5 kB
#!/usr/bin/env python3
"""
Fall Detection Gradio App
YOLOv11-Pose + ST-GCN 2-stage ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•œ ๋‚™์ƒ ๊ฐ์ง€ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
HF Spaces Zero GPU ํ™˜๊ฒฝ์—์„œ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
์‚ฌ์šฉ๋ฒ• (๋กœ์ปฌ):
python demo_gradio/app.py
์‚ฌ์šฉ๋ฒ• (HF Spaces):
์ž๋™์œผ๋กœ app.py๊ฐ€ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
์ž‘์„ฑ์ž: Fall Detection Pipeline Team
์ž‘์„ฑ์ผ: 2025-11-26
"""
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Iterable, Optional, Tuple
import cv2
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import torch
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
from huggingface_hub import hf_hub_download
# ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ๋ฅผ Python path์— ์ถ”๊ฐ€
# pipeline/demo_gradio/app.py -> pipeline -> project_root
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
# Zero GPU ํ˜ธํ™˜ ์„ค์ •
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
# -----------------------------------------------------------------------------
# ์ปค์Šคํ…€ ํ…Œ๋งˆ (PRITHIVSAKTHIUR ์Šคํƒ€์ผ)
# -----------------------------------------------------------------------------
colors.custom_color = colors.Color(
name="custom_color",
c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1",
c300="#7DB3D2", c400="#529AC3", c500="#4682B4",
c600="#3E72A0", c700="#36638C", c800="#2E5378",
c900="#264364", c950="#1E3450",
)
class CustomTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.custom_color,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
button_primary_text_color="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
slider_color="*secondary_500",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
)
custom_theme = CustomTheme()
# -----------------------------------------------------------------------------
# CSS ์Šคํƒ€์ผ
# -----------------------------------------------------------------------------
css = """
#col-container { margin: 0 auto; max-width: 1200px; }
#main-title h1 { font-size: 2.3em !important; }
.submit-btn {
background-color: #4682B4 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #5A9BD4 !important;
}
.result-label {
font-size: 1.5em !important;
font-weight: bold !important;
padding: 10px !important;
border-radius: 8px !important;
}
.fall-detected {
background-color: #FF4444 !important;
color: white !important;
}
.non-fall {
background-color: #44BB44 !important;
color: white !important;
}
"""
# -----------------------------------------------------------------------------
# ๋””๋ฐ”์ด์Šค ์„ค์ •
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------------------------
# GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ (๋กœ์ปฌ/HF Spaces ํ˜ธํ™˜)
# -----------------------------------------------------------------------------
def gpu_decorator(duration: int = 120):
"""๋กœ์ปฌ์—์„œ๋Š” ๊ทธ๋ƒฅ ์‹คํ–‰, Spaces์—์„œ๋Š” GPU ํ• ๋‹น"""
def decorator(func):
if SPACES_AVAILABLE:
return spaces.GPU(duration=duration)(func)
return func
return decorator
# -----------------------------------------------------------------------------
# ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ (HuggingFace Hub)
# -----------------------------------------------------------------------------
HF_MODEL_REPO = "YoungjaeDev/fall-detection-models"
def download_models() -> tuple[str, str]:
"""
HuggingFace Hub์—์„œ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ (์บ์‹œ๋จ)
Returns:
tuple: (pose_model_path, stgcn_checkpoint_path)
Raises:
RuntimeError: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๊ฒ€์ฆ ์‹คํŒจ ์‹œ
"""
# ๋กœ์ปฌ ๊ฒฝ๋กœ ์šฐ์„  ํ™•์ธ (๊ฐœ๋ฐœ ํ™˜๊ฒฝ)
local_pose = Path("yolo11m-pose.pt")
local_stgcn = Path("runs/stgcn_binary_exp2_fixed_graph/best_acc.pth")
if local_pose.exists() and local_stgcn.exists():
return str(local_pose), str(local_stgcn)
# HuggingFace Hub์—์„œ ๋‹ค์šด๋กœ๋“œ (Private repo๋Š” HF_TOKEN ํ™˜๊ฒฝ๋ณ€์ˆ˜ ํ•„์š”)
token = os.environ.get("HF_TOKEN")
# Private ์ €์žฅ์†Œ ์ ‘๊ทผ์„ ์œ„ํ•œ ํ† ํฐ ํ™•์ธ
if token is None:
raise RuntimeError(
"HF_TOKEN ํ™˜๊ฒฝ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. "
"Private ๋ชจ๋ธ ์ €์žฅ์†Œ ์ ‘๊ทผ์„ ์œ„ํ•ด HF_TOKEN์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. "
"HF Spaces์˜ ๊ฒฝ์šฐ Settings > Secrets์—์„œ ์„ค์ •ํ•˜์„ธ์š”."
)
try:
pose_model_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="yolo11m-pose.pt",
token=token
)
stgcn_checkpoint = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="best_acc.pth",
token=token
)
except Exception as e:
raise RuntimeError(
f"๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์‹คํŒจ: {e}\n"
f"์ €์žฅ์†Œ: {HF_MODEL_REPO}\n"
f"HF_TOKEN์ด ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์„ค์ •๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”."
) from e
# ๋‹ค์šด๋กœ๋“œ๋œ ํŒŒ์ผ ๊ฒ€์ฆ
pose_path = Path(pose_model_path)
stgcn_path = Path(stgcn_checkpoint)
if not pose_path.exists():
raise RuntimeError(f"Pose ๋ชจ๋ธ ํŒŒ์ผ์ด ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค: {pose_model_path}")
if not stgcn_path.exists():
raise RuntimeError(f"ST-GCN ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ์ด ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค: {stgcn_checkpoint}")
# ํŒŒ์ผ ํฌ๊ธฐ ๊ฒ€์ฆ (๋„ˆ๋ฌด ์ž‘์œผ๋ฉด ์†์ƒ๋œ ํŒŒ์ผ์ผ ๊ฐ€๋Šฅ์„ฑ)
pose_size = pose_path.stat().st_size
stgcn_size = stgcn_path.stat().st_size
if pose_size < 1_000_000: # 1MB ๋ฏธ๋งŒ
raise RuntimeError(f"Pose ๋ชจ๋ธ ํŒŒ์ผ์ด ๋„ˆ๋ฌด ์ž‘์Šต๋‹ˆ๋‹ค: {pose_size} bytes")
if stgcn_size < 1_000_000: # 1MB ๋ฏธ๋งŒ
raise RuntimeError(f"ST-GCN ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ์ด ๋„ˆ๋ฌด ์ž‘์Šต๋‹ˆ๋‹ค: {stgcn_size} bytes")
return pose_model_path, stgcn_checkpoint
# -----------------------------------------------------------------------------
# ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (์ง€์—ฐ ๋กœ๋”ฉ)
# -----------------------------------------------------------------------------
_pipeline = None
def get_pipeline():
"""ํŒŒ์ดํ”„๋ผ์ธ ์‹ฑ๊ธ€ํ†ค ๋ฐ˜ํ™˜ (์ง€์—ฐ ๋กœ๋”ฉ)"""
global _pipeline
if _pipeline is None:
from pipeline.core.pipeline import FallDetectionPipeline
# ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ (์บ์‹œ๋จ)
pose_model_path, stgcn_checkpoint = download_models()
_pipeline = FallDetectionPipeline(
pose_model_path=pose_model_path,
stgcn_checkpoint=stgcn_checkpoint,
window_size=60,
conf_threshold=0.5,
fall_threshold=0.85, # ๊ฐ€์ด๋“œ๋ผ์ธ ๊ถŒ์žฅ: 0.8-0.9 (false positive <5%)
temporal_window=5,
stgcn_stride=5,
alert_duration=150,
post_fall_frames=15, # 2.5์ดˆ @ 30fps with stride=5 (๊ฐ€์ด๋“œ๋ผ์ธ: 2-3์ดˆ)
device=str(device),
debug=False,
headless=False,
viz_keypoints="all",
viz_scale=1.0,
viz_optimized=True
)
return _pipeline
# -----------------------------------------------------------------------------
# ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
# -----------------------------------------------------------------------------
def create_probability_graph(
frame_indices: list,
probabilities: list,
fall_threshold: float = 0.7
) -> go.Figure:
"""
๋‚™์ƒ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
Args:
frame_indices: ํ”„๋ ˆ์ž„ ์ธ๋ฑ์Šค ๋ฆฌ์ŠคํŠธ
probabilities: ๋‚™์ƒ ํ™•๋ฅ  ๋ฆฌ์ŠคํŠธ (0.0-1.0)
fall_threshold: ๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’
Returns:
Plotly Figure ๊ฐ์ฒด
"""
fig = go.Figure()
# ํ™•๋ฅ  ๋ผ์ธ
fig.add_trace(go.Scatter(
x=frame_indices,
y=probabilities,
mode='lines',
name='Fall Probability',
line=dict(color='#4682B4', width=2),
fill='tozeroy',
fillcolor='rgba(70, 130, 180, 0.3)'
))
# ์ž„๊ณ„๊ฐ’ ๋ผ์ธ
fig.add_hline(
y=fall_threshold,
line_dash="dash",
line_color="red",
annotation_text=f"Threshold ({fall_threshold})",
annotation_position="right"
)
# ๋ ˆ์ด์•„์›ƒ
fig.update_layout(
title="Fall Detection Probability Over Time",
xaxis_title="Frame",
yaxis_title="Probability",
yaxis=dict(range=[0, 1]),
template="plotly_white",
height=300,
margin=dict(l=50, r=50, t=50, b=50),
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
return fig
# -----------------------------------------------------------------------------
# ์Šค๋งˆํŠธ ํด๋ฆฝ ์ถ”์ถœ ์„ค์ • (Issue #82)
# -----------------------------------------------------------------------------
CLIP_PRE_FALL_SECONDS = 1.0 # ๋‚™์ƒ ์ „ 1์ดˆ
CLIP_POST_FALL_SECONDS = 2.0 # ๋‚™์ƒ ํ›„ 2์ดˆ
# -----------------------------------------------------------------------------
# ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
# -----------------------------------------------------------------------------
@gpu_decorator(duration=120)
def process_video(
video_path: str,
fall_threshold: float,
viz_keypoints: str,
progress: gr.Progress = gr.Progress()
) -> Tuple[Optional[str], Optional[go.Figure], str]:
"""
๋น„๋””์˜ค ์ฒ˜๋ฆฌ ๋ฐ ๋‚™์ƒ ๊ฐ์ง€ (์Šค๋งˆํŠธ ํด๋ฆฝ ์ถ”์ถœ)
Issue #82: ๋‚™์ƒ ๊ฐ์ง€ ๊ตฌ๊ฐ„๋งŒ ํด๋ฆฝ์œผ๋กœ ์ถ”์ถœํ•˜์—ฌ ์ธ์ฝ”๋”ฉ ์‹œ๊ฐ„ ๋Œ€ํญ ๊ฐ์†Œ
- ๋‚™์ƒ ๊ฐ์ง€ ์‹œ: ๋‚™์ƒ ์ „ 1์ดˆ + ๋‚™์ƒ ํ›„ 2์ดˆ ๊ตฌ๊ฐ„๋งŒ ์ถ”์ถœ
- ๋น„๋‚™์ƒ ์‹œ: ๋‚™์ƒ ๋ฏธ๊ฐ์ง€ ๋ฉ”์‹œ์ง€ ๋ฐ˜ํ™˜
Args:
video_path: ์ž…๋ ฅ ๋น„๋””์˜ค ๊ฒฝ๋กœ
fall_threshold: ๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’ (0.0-1.0)
viz_keypoints: ํ‚คํฌ์ธํŠธ ํ‘œ์‹œ ๋ชจ๋“œ ('all' ๋˜๋Š” 'major')
progress: Gradio ์ง„ํ–‰๋ฅ  ํ‘œ์‹œ
Returns:
output_video_path: ๊ฒฐ๊ณผ ํด๋ฆฝ ๊ฒฝ๋กœ (๋‚™์ƒ ๊ฐ์ง€ ์‹œ) ๋˜๋Š” None (๋น„๋‚™์ƒ)
probability_graph: ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„
result_text: ์ตœ์ข… ํŒ์ • ํ…์ŠคํŠธ
"""
if video_path is None:
return None, None, "๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
try:
# ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ
progress(0.1, desc="๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
pipeline = get_pipeline()
pipeline.fall_threshold = fall_threshold
pipeline.stgcn_classifier.fall_threshold = fall_threshold
pipeline.viz_keypoints = viz_keypoints
pipeline.reset()
# ๋น„๋””์˜ค ์—ด๊ธฐ
progress(0.2, desc="๋น„๋””์˜ค ์—ด๊ธฐ...")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, None, "๋น„๋””์˜ค๋ฅผ ์—ด ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
# ๋น„๋””์˜ค ์ •๋ณด
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# ๋น„๋””์˜ค ๊ธธ์ด ๊ฒ€์ฆ (120s GPU ํƒ€์ž„์•„์›ƒ ๋Œ€๋น„)
if fps > 0:
video_duration = total_frames / fps
# ์ฒ˜๋ฆฌ ์‹œ๊ฐ„ ์ถ”์ •: ๋Œ€๋žต ์‹ค์‹œ๊ฐ„์˜ 1.5๋ฐฐ + ์ธ์ฝ”๋”ฉ 10์ดˆ
estimated_time = video_duration * 1.5 + 10
if estimated_time > 110: # 120s ํƒ€์ž„์•„์›ƒ์— ์—ฌ์œ  ๋‘๊ธฐ
cap.release()
return None, None, (
f"๋น„๋””์˜ค๊ฐ€ ๋„ˆ๋ฌด ๊น๋‹ˆ๋‹ค. "
f"๋น„๋””์˜ค ๊ธธ์ด: {video_duration:.1f}์ดˆ, "
f"์˜ˆ์ƒ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„: {estimated_time:.1f}์ดˆ (์ œํ•œ: 110์ดˆ). "
f"60์ดˆ ์ด๋‚ด์˜ ๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”."
)
# ํด๋ฆฝ ์ถ”์ถœ์„ ์œ„ํ•œ ํ”„๋ ˆ์ž„ ์ˆ˜ ๊ณ„์‚ฐ
pre_fall_frames = int(fps * CLIP_PRE_FALL_SECONDS)
post_fall_frames = int(fps * CLIP_POST_FALL_SECONDS)
# ์ฒ˜๋ฆฌ ๋ฃจํ”„ - ํ”„๋ ˆ์ž„ ๋ฒ„ํผ๋ง + ๋‚™์ƒ ๊ฐ์ง€
frame_idx = 0
frame_indices = []
probabilities = []
max_confidence = 0.0
# ๋‚™์ƒ ๊ฐ์ง€ ์ถ”์ 
first_fall_frame = None # ์ฒซ ๋‚™์ƒ ๊ฐ์ง€ ํ”„๋ ˆ์ž„
fall_detected = False
# ์‹œ๊ฐํ™” ํ”„๋ ˆ์ž„ ๋ฒ„ํผ (ํด๋ฆฝ ์ถ”์ถœ์šฉ)
vis_frame_buffer = []
raw_frame_buffer = [] # ์›๋ณธ ํ”„๋ ˆ์ž„ ๋ฒ„ํผ (์žฌ์ฒ˜๋ฆฌ์šฉ)
while True:
# ํ”„๋ ˆ์ž„ ์ฝ๊ธฐ
with pipeline.profiler.profile('video_read'):
ret, frame = cap.read()
if not ret:
break
# ์›๋ณธ ํ”„๋ ˆ์ž„ ๋ฒ„ํผ์— ์ €์žฅ (ํด๋ฆฝ ์ถ”์ถœ์— ํ•„์š”)
raw_frame_buffer.append(frame.copy())
# ํ”„๋ ˆ์ž„ ์ฒ˜๋ฆฌ
vis_frame, info = pipeline.process_frame(frame, frame_idx)
# ์‹œ๊ฐํ™” ํ”„๋ ˆ์ž„ ๋ฒ„ํผ์— ์ €์žฅ
vis_frame_buffer.append(vis_frame)
# ํ™•๋ฅ  ๊ธฐ๋ก
if info['confidence'] is not None:
frame_indices.append(frame_idx)
probabilities.append(info['confidence'])
max_confidence = max(max_confidence, info['confidence'])
# ์ฒซ ๋‚™์ƒ ๊ฐ์ง€ ์‹œ์  ๊ธฐ๋ก
if info['alert'] and first_fall_frame is None:
first_fall_frame = frame_idx
fall_detected = True
frame_idx += 1
# ์ง„ํ–‰๋ฅ  ์—…๋ฐ์ดํŠธ
if frame_idx % 10 == 0:
progress_val = 0.2 + 0.6 * (frame_idx / total_frames)
progress(progress_val, desc=f"๋ถ„์„ ์ค‘... ({frame_idx}/{total_frames})")
cap.release()
# ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ (ํ•ญ์ƒ ์ƒ์„ฑ)
progress(0.85, desc="๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ์ค‘...")
if frame_indices and probabilities:
fig = create_probability_graph(frame_indices, probabilities, fall_threshold)
else:
fig = None
# ๋‚™์ƒ ๋ฏธ๊ฐ์ง€ ์‹œ ํด๋ฆฝ ์—†์ด ๋ฐ˜ํ™˜
if not fall_detected or first_fall_frame is None:
progress(1.0, desc="์™„๋ฃŒ!")
result_text = (
f"[Non-Fall] ๋‚™์ƒ์ด ๊ฐ์ง€๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.\n"
f"์ตœ๋Œ€ ํ™•๋ฅ : {max_confidence:.1%}\n"
f"๋ถ„์„ ํ”„๋ ˆ์ž„: {total_frames}๊ฐœ"
)
return None, fig, result_text
# ํด๋ฆฝ ๊ตฌ๊ฐ„ ๊ณ„์‚ฐ
clip_start = max(0, first_fall_frame - pre_fall_frames)
clip_end = min(len(vis_frame_buffer), first_fall_frame + post_fall_frames)
clip_frames = vis_frame_buffer[clip_start:clip_end]
if not clip_frames:
progress(1.0, desc="์™„๋ฃŒ!")
return None, fig, "ํด๋ฆฝ ์ถ”์ถœ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค."
# ํด๋ฆฝ ๋น„๋””์˜ค ์ƒ์„ฑ (ํ”„๋ ˆ์ž„ ์ˆ˜ ๊ฐ์†Œ๋กœ ์ธ์ฝ”๋”ฉ ์‹œ๊ฐ„ ๋Œ€ํญ ๊ฐ์†Œ)
progress(0.9, desc="ํด๋ฆฝ ์ธ์ฝ”๋”ฉ ์ค‘...")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
output_path = tmp.name
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# Info panel ์ถ”๊ฐ€๋กœ ๋†’์ด 80px ์ฆ๊ฐ€
clip_height, clip_width = clip_frames[0].shape[:2]
out = cv2.VideoWriter(output_path, fourcc, fps, (clip_width, clip_height))
for vis_frame in clip_frames:
out.write(vis_frame)
out.release()
# H.264 ์ฝ”๋ฑ์œผ๋กœ ์žฌ์ธ์ฝ”๋”ฉ (๋ธŒ๋ผ์šฐ์ € ํ˜ธํ™˜)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
output_h264 = tmp.name
with pipeline.profiler.profile('ffmpeg_encode'):
subprocess.run(
[
'ffmpeg', '-y', '-i', output_path,
'-c:v', 'libx264', '-preset', 'fast', '-crf', '23',
output_h264, '-loglevel', 'quiet'
],
check=False,
capture_output=True
)
# mp4v ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ
if os.path.exists(output_path):
os.remove(output_path)
# H.264 ๋ณ€ํ™˜ ์„ฑ๊ณต ์—ฌ๋ถ€ ํ™•์ธ
if os.path.exists(output_h264):
final_output = output_h264
else:
final_output = output_path # ํด๋ฐฑ
# ์ตœ์ข… ํŒ์ •
progress(1.0, desc="์™„๋ฃŒ!")
fall_time = first_fall_frame / fps if fps > 0 else 0
clip_duration = len(clip_frames) / fps if fps > 0 else 0
result_text = (
f"[FALL DETECTED] ๋‚™์ƒ์ด ๊ฐ์ง€๋˜์—ˆ์Šต๋‹ˆ๋‹ค!\n"
f"๋‚™์ƒ ์‹œ์ : {fall_time:.2f}์ดˆ (ํ”„๋ ˆ์ž„ #{first_fall_frame})\n"
f"์ตœ๋Œ€ ํ™•๋ฅ : {max_confidence:.1%}\n"
f"ํด๋ฆฝ ๊ธธ์ด: {clip_duration:.1f}์ดˆ ({len(clip_frames)}ํ”„๋ ˆ์ž„)\n"
f"์›๋ณธ ๋Œ€๋น„: {len(clip_frames)}/{total_frames}ํ”„๋ ˆ์ž„ "
f"({len(clip_frames)/total_frames*100:.1f}% ์ธ์ฝ”๋”ฉ)"
)
return final_output, fig, result_text
except Exception as e:
import traceback
error_msg = f"์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
return None, None, error_msg
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
def create_demo() -> gr.Blocks:
"""Gradio ๋ฐ๋ชจ ์ƒ์„ฑ"""
with gr.Blocks(theme=custom_theme, css=css) as demo:
gr.Markdown(
"""
# Fall Detection Demo
YOLOv11-Pose + ST-GCN 2-stage ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•œ ์‹ค์‹œ๊ฐ„ ๋‚™์ƒ ๊ฐ์ง€ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ๋‚™์ƒ ์—ฌ๋ถ€๋ฅผ ๋ถ„์„ํ•˜๊ณ , ๊ฒฐ๊ณผ ๋น„๋””์˜ค์™€ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
**ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์„ฑ:**
- Stage 1: YOLOv11m-pose (Pose Estimation)
- Stage 2: ST-GCN (Temporal Classification)
- Window Size: 60 frames (2์ดˆ @ 30fps)
""",
elem_id="main-title"
)
with gr.Row():
with gr.Column(scale=1):
# ์ž…๋ ฅ ์„น์…˜
gr.Markdown("### ์ž…๋ ฅ")
video_input = gr.Video(
label="๋น„๋””์˜ค ์—…๋กœ๋“œ",
sources=["upload"],
)
with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
fall_threshold = gr.Slider(
minimum=0.7,
maximum=0.95,
value=0.85,
step=0.05,
label="๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’",
info="๊ถŒ์žฅ: 0.8-0.9 (false positive <5% ๋ชฉํ‘œ)"
)
viz_keypoints = gr.Radio(
choices=["all", "major"],
value="all",
label="ํ‚คํฌ์ธํŠธ ํ‘œ์‹œ",
info="all: ์ „์ฒด 17๊ฐœ, major: ์ฃผ์š” 9๊ฐœ"
)
submit_btn = gr.Button(
"๋ถ„์„ ์‹œ์ž‘",
variant="primary",
elem_classes="submit-btn"
)
with gr.Column(scale=1):
# ์ถœ๋ ฅ ์„น์…˜
gr.Markdown("### ๊ฒฐ๊ณผ")
result_text = gr.Textbox(
label="ํŒ์ • ๊ฒฐ๊ณผ",
lines=2,
interactive=False
)
video_output = gr.Video(
label="๊ฒฐ๊ณผ ๋น„๋””์˜ค",
)
prob_graph = gr.Plot(
label="๋‚™์ƒ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„",
)
# ์˜ˆ์ œ ๋น„๋””์˜ค
gr.Markdown("### ์˜ˆ์ œ ๋น„๋””์˜ค")
example_dir = Path(__file__).parent / "examples"
examples = []
if example_dir.exists():
for ext in ["*.mp4", "*.avi", "*.mov"]:
examples.extend([str(p) for p in example_dir.glob(ext)])
if examples:
gr.Examples(
examples=[[ex, 0.85, "all"] for ex in examples[:3]],
inputs=[video_input, fall_threshold, viz_keypoints],
outputs=[video_output, prob_graph, result_text],
fn=process_video,
cache_examples=False,
)
# ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
submit_btn.click(
fn=process_video,
inputs=[video_input, fall_threshold, viz_keypoints],
outputs=[video_output, prob_graph, result_text],
)
# ํ‘ธํ„ฐ
gr.Markdown(
"""
---
**References:**
- [YOLOv11](https://github.com/ultralytics/ultralytics) - Pose Estimation
- [ST-GCN](https://arxiv.org/abs/1801.07455) - Spatial Temporal Graph Convolutional Networks
- AI Hub Fall Detection Dataset
"""
)
return demo
# -----------------------------------------------------------------------------
# ๋ฉ”์ธ ์‹คํ–‰
# -----------------------------------------------------------------------------
if __name__ == "__main__":
demo = create_demo()
demo.queue(max_size=10).launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True,
)