YoungjaeDev
fix: Private HF repo ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์—๋Ÿฌ ์ฒ˜๋ฆฌ ๊ฐ•ํ™”
169758c
raw
history blame
19.8 kB
#!/usr/bin/env python3
"""
Fall Detection Gradio App
YOLOv11-Pose + ST-GCN 2-stage ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•œ ๋‚™์ƒ ๊ฐ์ง€ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
HF Spaces Zero GPU ํ™˜๊ฒฝ์—์„œ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
์‚ฌ์šฉ๋ฒ• (๋กœ์ปฌ):
python demo_gradio/app.py
์‚ฌ์šฉ๋ฒ• (HF Spaces):
์ž๋™์œผ๋กœ app.py๊ฐ€ ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค.
์ž‘์„ฑ์ž: Fall Detection Pipeline Team
์ž‘์„ฑ์ผ: 2025-11-26
"""
import os
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import Iterable, Optional, Tuple
import cv2
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import torch
from gradio.themes import Soft
from gradio.themes.utils import colors, fonts, sizes
from huggingface_hub import hf_hub_download
# ํ”„๋กœ์ ํŠธ ๋ฃจํŠธ๋ฅผ Python path์— ์ถ”๊ฐ€
# pipeline/demo_gradio/app.py -> pipeline -> project_root
PROJECT_ROOT = Path(__file__).parent.parent.parent
sys.path.insert(0, str(PROJECT_ROOT))
# Zero GPU ํ˜ธํ™˜ ์„ค์ •
try:
import spaces
SPACES_AVAILABLE = True
except ImportError:
SPACES_AVAILABLE = False
# -----------------------------------------------------------------------------
# ์ปค์Šคํ…€ ํ…Œ๋งˆ (PRITHIVSAKTHIUR ์Šคํƒ€์ผ)
# -----------------------------------------------------------------------------
colors.custom_color = colors.Color(
name="custom_color",
c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1",
c300="#7DB3D2", c400="#529AC3", c500="#4682B4",
c600="#3E72A0", c700="#36638C", c800="#2E5378",
c900="#264364", c950="#1E3450",
)
class CustomTheme(Soft):
def __init__(
self,
*,
primary_hue: colors.Color | str = colors.gray,
secondary_hue: colors.Color | str = colors.custom_color,
neutral_hue: colors.Color | str = colors.slate,
text_size: sizes.Size | str = sizes.text_lg,
font: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("Outfit"), "Arial", "sans-serif",
),
font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace",
),
):
super().__init__(
primary_hue=primary_hue,
secondary_hue=secondary_hue,
neutral_hue=neutral_hue,
text_size=text_size,
font=font,
font_mono=font_mono,
)
super().set(
background_fill_primary="*primary_50",
body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)",
button_primary_text_color="white",
button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)",
button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)",
slider_color="*secondary_500",
block_title_text_weight="600",
block_border_width="3px",
block_shadow="*shadow_drop_lg",
button_primary_shadow="*shadow_drop_lg",
)
custom_theme = CustomTheme()
# -----------------------------------------------------------------------------
# CSS ์Šคํƒ€์ผ
# -----------------------------------------------------------------------------
css = """
#col-container { margin: 0 auto; max-width: 1200px; }
#main-title h1 { font-size: 2.3em !important; }
.submit-btn {
background-color: #4682B4 !important;
color: white !important;
}
.submit-btn:hover {
background-color: #5A9BD4 !important;
}
.result-label {
font-size: 1.5em !important;
font-weight: bold !important;
padding: 10px !important;
border-radius: 8px !important;
}
.fall-detected {
background-color: #FF4444 !important;
color: white !important;
}
.non-fall {
background-color: #44BB44 !important;
color: white !important;
}
"""
# -----------------------------------------------------------------------------
# ๋””๋ฐ”์ด์Šค ์„ค์ •
# -----------------------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# -----------------------------------------------------------------------------
# GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ (๋กœ์ปฌ/HF Spaces ํ˜ธํ™˜)
# -----------------------------------------------------------------------------
def gpu_decorator(duration: int = 120):
"""๋กœ์ปฌ์—์„œ๋Š” ๊ทธ๋ƒฅ ์‹คํ–‰, Spaces์—์„œ๋Š” GPU ํ• ๋‹น"""
def decorator(func):
if SPACES_AVAILABLE:
return spaces.GPU(duration=duration)(func)
return func
return decorator
# -----------------------------------------------------------------------------
# ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ (HuggingFace Hub)
# -----------------------------------------------------------------------------
HF_MODEL_REPO = "YoungjaeDev/fall-detection-models"
def download_models() -> tuple[str, str]:
"""
HuggingFace Hub์—์„œ ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ (์บ์‹œ๋จ)
Returns:
tuple: (pose_model_path, stgcn_checkpoint_path)
Raises:
RuntimeError: ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ๋˜๋Š” ๊ฒ€์ฆ ์‹คํŒจ ์‹œ
"""
# ๋กœ์ปฌ ๊ฒฝ๋กœ ์šฐ์„  ํ™•์ธ (๊ฐœ๋ฐœ ํ™˜๊ฒฝ)
local_pose = Path("yolo11m-pose.pt")
local_stgcn = Path("runs/stgcn_binary_exp2_fixed_graph/best_acc.pth")
if local_pose.exists() and local_stgcn.exists():
return str(local_pose), str(local_stgcn)
# HuggingFace Hub์—์„œ ๋‹ค์šด๋กœ๋“œ (Private repo๋Š” HF_TOKEN ํ™˜๊ฒฝ๋ณ€์ˆ˜ ํ•„์š”)
token = os.environ.get("HF_TOKEN")
# Private ์ €์žฅ์†Œ ์ ‘๊ทผ์„ ์œ„ํ•œ ํ† ํฐ ํ™•์ธ
if token is None:
raise RuntimeError(
"HF_TOKEN ํ™˜๊ฒฝ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. "
"Private ๋ชจ๋ธ ์ €์žฅ์†Œ ์ ‘๊ทผ์„ ์œ„ํ•ด HF_TOKEN์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. "
"HF Spaces์˜ ๊ฒฝ์šฐ Settings > Secrets์—์„œ ์„ค์ •ํ•˜์„ธ์š”."
)
try:
pose_model_path = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="yolo11m-pose.pt",
token=token
)
stgcn_checkpoint = hf_hub_download(
repo_id=HF_MODEL_REPO,
filename="best_acc.pth",
token=token
)
except Exception as e:
raise RuntimeError(
f"๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ ์‹คํŒจ: {e}\n"
f"์ €์žฅ์†Œ: {HF_MODEL_REPO}\n"
f"HF_TOKEN์ด ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์„ค์ •๋˜์—ˆ๋Š”์ง€ ํ™•์ธํ•˜์„ธ์š”."
) from e
# ๋‹ค์šด๋กœ๋“œ๋œ ํŒŒ์ผ ๊ฒ€์ฆ
pose_path = Path(pose_model_path)
stgcn_path = Path(stgcn_checkpoint)
if not pose_path.exists():
raise RuntimeError(f"Pose ๋ชจ๋ธ ํŒŒ์ผ์ด ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค: {pose_model_path}")
if not stgcn_path.exists():
raise RuntimeError(f"ST-GCN ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ์ด ์กด์žฌํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค: {stgcn_checkpoint}")
# ํŒŒ์ผ ํฌ๊ธฐ ๊ฒ€์ฆ (๋„ˆ๋ฌด ์ž‘์œผ๋ฉด ์†์ƒ๋œ ํŒŒ์ผ์ผ ๊ฐ€๋Šฅ์„ฑ)
pose_size = pose_path.stat().st_size
stgcn_size = stgcn_path.stat().st_size
if pose_size < 1_000_000: # 1MB ๋ฏธ๋งŒ
raise RuntimeError(f"Pose ๋ชจ๋ธ ํŒŒ์ผ์ด ๋„ˆ๋ฌด ์ž‘์Šต๋‹ˆ๋‹ค: {pose_size} bytes")
if stgcn_size < 1_000_000: # 1MB ๋ฏธ๋งŒ
raise RuntimeError(f"ST-GCN ์ฒดํฌํฌ์ธํŠธ ํŒŒ์ผ์ด ๋„ˆ๋ฌด ์ž‘์Šต๋‹ˆ๋‹ค: {stgcn_size} bytes")
return pose_model_path, stgcn_checkpoint
# -----------------------------------------------------------------------------
# ํŒŒ์ดํ”„๋ผ์ธ ์ดˆ๊ธฐํ™” (์ง€์—ฐ ๋กœ๋”ฉ)
# -----------------------------------------------------------------------------
_pipeline = None
def get_pipeline():
"""ํŒŒ์ดํ”„๋ผ์ธ ์‹ฑ๊ธ€ํ†ค ๋ฐ˜ํ™˜ (์ง€์—ฐ ๋กœ๋”ฉ)"""
global _pipeline
if _pipeline is None:
from pipeline.core.pipeline import FallDetectionPipeline
# ๋ชจ๋ธ ๋‹ค์šด๋กœ๋“œ (์บ์‹œ๋จ)
pose_model_path, stgcn_checkpoint = download_models()
_pipeline = FallDetectionPipeline(
pose_model_path=pose_model_path,
stgcn_checkpoint=stgcn_checkpoint,
window_size=60,
conf_threshold=0.5,
fall_threshold=0.85, # ๊ฐ€์ด๋“œ๋ผ์ธ ๊ถŒ์žฅ: 0.8-0.9 (false positive <5%)
temporal_window=5,
stgcn_stride=5,
alert_duration=150,
post_fall_frames=15, # 2.5์ดˆ @ 30fps with stride=5 (๊ฐ€์ด๋“œ๋ผ์ธ: 2-3์ดˆ)
device=str(device),
debug=False,
headless=False,
viz_keypoints="all",
viz_scale=1.0,
viz_optimized=True
)
return _pipeline
# -----------------------------------------------------------------------------
# ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
# -----------------------------------------------------------------------------
def create_probability_graph(
frame_indices: list,
probabilities: list,
fall_threshold: float = 0.7
) -> go.Figure:
"""
๋‚™์ƒ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
Args:
frame_indices: ํ”„๋ ˆ์ž„ ์ธ๋ฑ์Šค ๋ฆฌ์ŠคํŠธ
probabilities: ๋‚™์ƒ ํ™•๋ฅ  ๋ฆฌ์ŠคํŠธ (0.0-1.0)
fall_threshold: ๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’
Returns:
Plotly Figure ๊ฐ์ฒด
"""
fig = go.Figure()
# ํ™•๋ฅ  ๋ผ์ธ
fig.add_trace(go.Scatter(
x=frame_indices,
y=probabilities,
mode='lines',
name='Fall Probability',
line=dict(color='#4682B4', width=2),
fill='tozeroy',
fillcolor='rgba(70, 130, 180, 0.3)'
))
# ์ž„๊ณ„๊ฐ’ ๋ผ์ธ
fig.add_hline(
y=fall_threshold,
line_dash="dash",
line_color="red",
annotation_text=f"Threshold ({fall_threshold})",
annotation_position="right"
)
# ๋ ˆ์ด์•„์›ƒ
fig.update_layout(
title="Fall Detection Probability Over Time",
xaxis_title="Frame",
yaxis_title="Probability",
yaxis=dict(range=[0, 1]),
template="plotly_white",
height=300,
margin=dict(l=50, r=50, t=50, b=50),
showlegend=True,
legend=dict(
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1
)
)
return fig
# -----------------------------------------------------------------------------
# ๋ฉ”์ธ ์ถ”๋ก  ํ•จ์ˆ˜
# -----------------------------------------------------------------------------
@gpu_decorator(duration=120)
def process_video(
video_path: str,
fall_threshold: float,
viz_keypoints: str,
progress: gr.Progress = gr.Progress()
) -> Tuple[Optional[str], Optional[go.Figure], str]:
"""
๋น„๋””์˜ค ์ฒ˜๋ฆฌ ๋ฐ ๋‚™์ƒ ๊ฐ์ง€
Args:
video_path: ์ž…๋ ฅ ๋น„๋””์˜ค ๊ฒฝ๋กœ
fall_threshold: ๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’ (0.0-1.0)
viz_keypoints: ํ‚คํฌ์ธํŠธ ํ‘œ์‹œ ๋ชจ๋“œ ('all' ๋˜๋Š” 'major')
progress: Gradio ์ง„ํ–‰๋ฅ  ํ‘œ์‹œ
Returns:
output_video_path: ๊ฒฐ๊ณผ ๋น„๋””์˜ค ๊ฒฝ๋กœ
probability_graph: ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„
result_text: ์ตœ์ข… ํŒ์ • ํ…์ŠคํŠธ
"""
if video_path is None:
return None, None, "๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
try:
# ํŒŒ์ดํ”„๋ผ์ธ ๋กœ๋“œ
progress(0.1, desc="๋ชจ๋ธ ๋กœ๋”ฉ ์ค‘...")
pipeline = get_pipeline()
pipeline.fall_threshold = fall_threshold
pipeline.stgcn_classifier.fall_threshold = fall_threshold
pipeline.viz_keypoints = viz_keypoints
pipeline.reset()
# ๋น„๋””์˜ค ์—ด๊ธฐ
progress(0.2, desc="๋น„๋””์˜ค ์—ด๊ธฐ...")
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return None, None, "๋น„๋””์˜ค๋ฅผ ์—ด ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
# ๋น„๋””์˜ค ์ •๋ณด
fps = cap.get(cv2.CAP_PROP_FPS)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# ๋น„๋””์˜ค ๊ธธ์ด ๊ฒ€์ฆ (120s GPU ํƒ€์ž„์•„์›ƒ ๋Œ€๋น„)
if fps > 0:
video_duration = total_frames / fps
# ์ฒ˜๋ฆฌ ์‹œ๊ฐ„ ์ถ”์ •: ๋Œ€๋žต ์‹ค์‹œ๊ฐ„์˜ 1.5๋ฐฐ + ์ธ์ฝ”๋”ฉ 10์ดˆ
estimated_time = video_duration * 1.5 + 10
if estimated_time > 110: # 120s ํƒ€์ž„์•„์›ƒ์— ์—ฌ์œ  ๋‘๊ธฐ
cap.release()
return None, None, (
f"๋น„๋””์˜ค๊ฐ€ ๋„ˆ๋ฌด ๊น๋‹ˆ๋‹ค. "
f"๋น„๋””์˜ค ๊ธธ์ด: {video_duration:.1f}์ดˆ, "
f"์˜ˆ์ƒ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„: {estimated_time:.1f}์ดˆ (์ œํ•œ: 110์ดˆ). "
f"60์ดˆ ์ด๋‚ด์˜ ๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”."
)
# ์ถœ๋ ฅ ๋น„๋””์˜ค ์„ค์ • (๋ณด์•ˆ: NamedTemporaryFile ์‚ฌ์šฉ)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
output_path = tmp.name
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# Info panel ์ถ”๊ฐ€๋กœ ๋†’์ด 80px ์ฆ๊ฐ€
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height + 80))
# ์ฒ˜๋ฆฌ ๋ฃจํ”„
frame_idx = 0
frame_indices = []
probabilities = []
fall_detected = False
max_confidence = 0.0
while True:
ret, frame = cap.read()
if not ret:
break
# ํ”„๋ ˆ์ž„ ์ฒ˜๋ฆฌ
vis_frame, info = pipeline.process_frame(frame, frame_idx)
# ํ™•๋ฅ  ๊ธฐ๋ก
if info['confidence'] is not None:
frame_indices.append(frame_idx)
probabilities.append(info['confidence'])
max_confidence = max(max_confidence, info['confidence'])
# ๋‚™์ƒ ๊ฐ์ง€ ํ™•์ธ
if info['alert']:
fall_detected = True
# ์ถœ๋ ฅ ์ €์žฅ
out.write(vis_frame)
frame_idx += 1
# ์ง„ํ–‰๋ฅ  ์—…๋ฐ์ดํŠธ
if frame_idx % 10 == 0:
progress_val = 0.2 + 0.7 * (frame_idx / total_frames)
progress(progress_val, desc=f"์ฒ˜๋ฆฌ ์ค‘... ({frame_idx}/{total_frames})")
# ๋ฆฌ์†Œ์Šค ํ•ด์ œ
cap.release()
out.release()
# H.264 ์ฝ”๋ฑ์œผ๋กœ ์žฌ์ธ์ฝ”๋”ฉ (๋ธŒ๋ผ์šฐ์ € ํ˜ธํ™˜)
progress(0.9, desc="๋น„๋””์˜ค ์ธ์ฝ”๋”ฉ ์ค‘...")
# ๋ณด์•ˆ: NamedTemporaryFile ์‚ฌ์šฉ (CWE-377 ๋ฐฉ์ง€)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
output_h264 = tmp.name
# ๋ณด์•ˆ: subprocess.run ์‚ฌ์šฉ (shell injection ๋ฐฉ์ง€)
subprocess.run(
[
'ffmpeg', '-y', '-i', output_path,
'-c:v', 'libx264', '-preset', 'fast', '-crf', '23',
output_h264, '-loglevel', 'quiet'
],
check=False,
capture_output=True
)
# mp4v ์ž„์‹œ ํŒŒ์ผ ์‚ญ์ œ
if os.path.exists(output_path):
os.remove(output_path)
# H.264 ๋ณ€ํ™˜ ์„ฑ๊ณต ์—ฌ๋ถ€ ํ™•์ธ
if os.path.exists(output_h264):
final_output = output_h264
else:
final_output = output_path # ํด๋ฐฑ
# ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ
progress(0.95, desc="๊ทธ๋ž˜ํ”„ ์ƒ์„ฑ ์ค‘...")
if frame_indices and probabilities:
fig = create_probability_graph(frame_indices, probabilities, fall_threshold)
else:
fig = None
# ์ตœ์ข… ํŒ์ •
progress(1.0, desc="์™„๋ฃŒ!")
if fall_detected:
result_text = f"[FALL DETECTED] ๋‚™์ƒ์ด ๊ฐ์ง€๋˜์—ˆ์Šต๋‹ˆ๋‹ค! (์ตœ๋Œ€ ํ™•๋ฅ : {max_confidence:.1%})"
else:
result_text = f"[Non-Fall] ๋‚™์ƒ์ด ๊ฐ์ง€๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. (์ตœ๋Œ€ ํ™•๋ฅ : {max_confidence:.1%})"
return final_output, fig, result_text
except Exception as e:
import traceback
error_msg = f"์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}\n{traceback.format_exc()}"
return None, None, error_msg
# -----------------------------------------------------------------------------
# Gradio UI
# -----------------------------------------------------------------------------
def create_demo() -> gr.Blocks:
"""Gradio ๋ฐ๋ชจ ์ƒ์„ฑ"""
with gr.Blocks(theme=custom_theme, css=css) as demo:
gr.Markdown(
"""
# Fall Detection Demo
YOLOv11-Pose + ST-GCN 2-stage ํŒŒ์ดํ”„๋ผ์ธ์„ ์‚ฌ์šฉํ•œ ์‹ค์‹œ๊ฐ„ ๋‚™์ƒ ๊ฐ์ง€ ๋ฐ๋ชจ์ž…๋‹ˆ๋‹ค.
๋น„๋””์˜ค๋ฅผ ์—…๋กœ๋“œํ•˜๋ฉด ๋‚™์ƒ ์—ฌ๋ถ€๋ฅผ ๋ถ„์„ํ•˜๊ณ , ๊ฒฐ๊ณผ ๋น„๋””์˜ค์™€ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.
**ํŒŒ์ดํ”„๋ผ์ธ ๊ตฌ์„ฑ:**
- Stage 1: YOLOv11m-pose (Pose Estimation)
- Stage 2: ST-GCN (Temporal Classification)
- Window Size: 60 frames (2์ดˆ @ 30fps)
""",
elem_id="main-title"
)
with gr.Row():
with gr.Column(scale=1):
# ์ž…๋ ฅ ์„น์…˜
gr.Markdown("### ์ž…๋ ฅ")
video_input = gr.Video(
label="๋น„๋””์˜ค ์—…๋กœ๋“œ",
sources=["upload"],
)
with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
fall_threshold = gr.Slider(
minimum=0.7,
maximum=0.95,
value=0.85,
step=0.05,
label="๋‚™์ƒ ํŒ์ • ์ž„๊ณ„๊ฐ’",
info="๊ถŒ์žฅ: 0.8-0.9 (false positive <5% ๋ชฉํ‘œ)"
)
viz_keypoints = gr.Radio(
choices=["all", "major"],
value="all",
label="ํ‚คํฌ์ธํŠธ ํ‘œ์‹œ",
info="all: ์ „์ฒด 17๊ฐœ, major: ์ฃผ์š” 9๊ฐœ"
)
submit_btn = gr.Button(
"๋ถ„์„ ์‹œ์ž‘",
variant="primary",
elem_classes="submit-btn"
)
with gr.Column(scale=1):
# ์ถœ๋ ฅ ์„น์…˜
gr.Markdown("### ๊ฒฐ๊ณผ")
result_text = gr.Textbox(
label="ํŒ์ • ๊ฒฐ๊ณผ",
lines=2,
interactive=False
)
video_output = gr.Video(
label="๊ฒฐ๊ณผ ๋น„๋””์˜ค",
)
prob_graph = gr.Plot(
label="๋‚™์ƒ ํ™•๋ฅ  ๊ทธ๋ž˜ํ”„",
)
# ์˜ˆ์ œ ๋น„๋””์˜ค
gr.Markdown("### ์˜ˆ์ œ ๋น„๋””์˜ค")
example_dir = Path(__file__).parent / "examples"
examples = []
if example_dir.exists():
for ext in ["*.mp4", "*.avi", "*.mov"]:
examples.extend([str(p) for p in example_dir.glob(ext)])
if examples:
gr.Examples(
examples=[[ex, 0.85, "all"] for ex in examples[:3]],
inputs=[video_input, fall_threshold, viz_keypoints],
outputs=[video_output, prob_graph, result_text],
fn=process_video,
cache_examples=False,
)
# ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
submit_btn.click(
fn=process_video,
inputs=[video_input, fall_threshold, viz_keypoints],
outputs=[video_output, prob_graph, result_text],
)
# ํ‘ธํ„ฐ
gr.Markdown(
"""
---
**References:**
- [YOLOv11](https://github.com/ultralytics/ultralytics) - Pose Estimation
- [ST-GCN](https://arxiv.org/abs/1801.07455) - Spatial Temporal Graph Convolutional Networks
- AI Hub Fall Detection Dataset
"""
)
return demo
# -----------------------------------------------------------------------------
# ๋ฉ”์ธ ์‹คํ–‰
# -----------------------------------------------------------------------------
if __name__ == "__main__":
demo = create_demo()
demo.queue(max_size=10).launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
show_error=True,
)