Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Fall Detection Gradio App | |
| YOLOv11-Pose + ST-GCN 2-stage ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ ๋์ ๊ฐ์ง ๋ฐ๋ชจ์ ๋๋ค. | |
| HF Spaces Zero GPU ํ๊ฒฝ์์ ์คํ๋ฉ๋๋ค. | |
| ์ฌ์ฉ๋ฒ (๋ก์ปฌ): | |
| python demo_gradio/app.py | |
| ์ฌ์ฉ๋ฒ (HF Spaces): | |
| ์๋์ผ๋ก app.py๊ฐ ์คํ๋ฉ๋๋ค. | |
| ์์ฑ์: Fall Detection Pipeline Team | |
| ์์ฑ์ผ: 2025-11-26 | |
| """ | |
| import os | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import Iterable, Optional, Tuple | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import torch | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| from huggingface_hub import hf_hub_download | |
| # ํ๋ก์ ํธ ๋ฃจํธ๋ฅผ Python path์ ์ถ๊ฐ | |
| # pipeline/demo_gradio/app.py -> pipeline -> project_root | |
| PROJECT_ROOT = Path(__file__).parent.parent.parent | |
| sys.path.insert(0, str(PROJECT_ROOT)) | |
| # Zero GPU ํธํ ์ค์ | |
| try: | |
| import spaces | |
| SPACES_AVAILABLE = True | |
| except ImportError: | |
| SPACES_AVAILABLE = False | |
| # ----------------------------------------------------------------------------- | |
| # ์ปค์คํ ํ ๋ง (PRITHIVSAKTHIUR ์คํ์ผ) | |
| # ----------------------------------------------------------------------------- | |
| colors.custom_color = colors.Color( | |
| name="custom_color", | |
| c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", | |
| c300="#7DB3D2", c400="#529AC3", c500="#4682B4", | |
| c600="#3E72A0", c700="#36638C", c800="#2E5378", | |
| c900="#264364", c950="#1E3450", | |
| ) | |
| class CustomTheme(Soft): | |
| def __init__( | |
| self, | |
| *, | |
| primary_hue: colors.Color | str = colors.gray, | |
| secondary_hue: colors.Color | str = colors.custom_color, | |
| neutral_hue: colors.Color | str = colors.slate, | |
| text_size: sizes.Size | str = sizes.text_lg, | |
| font: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("Outfit"), "Arial", "sans-serif", | |
| ), | |
| font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( | |
| fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", | |
| ), | |
| ): | |
| super().__init__( | |
| primary_hue=primary_hue, | |
| secondary_hue=secondary_hue, | |
| neutral_hue=neutral_hue, | |
| text_size=text_size, | |
| font=font, | |
| font_mono=font_mono, | |
| ) | |
| super().set( | |
| background_fill_primary="*primary_50", | |
| body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", | |
| button_primary_text_color="white", | |
| button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", | |
| button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", | |
| slider_color="*secondary_500", | |
| block_title_text_weight="600", | |
| block_border_width="3px", | |
| block_shadow="*shadow_drop_lg", | |
| button_primary_shadow="*shadow_drop_lg", | |
| ) | |
| custom_theme = CustomTheme() | |
| # ----------------------------------------------------------------------------- | |
| # CSS ์คํ์ผ | |
| # ----------------------------------------------------------------------------- | |
| css = """ | |
| #col-container { margin: 0 auto; max-width: 1200px; } | |
| #main-title h1 { font-size: 2.3em !important; } | |
| .submit-btn { | |
| background-color: #4682B4 !important; | |
| color: white !important; | |
| } | |
| .submit-btn:hover { | |
| background-color: #5A9BD4 !important; | |
| } | |
| .result-label { | |
| font-size: 1.5em !important; | |
| font-weight: bold !important; | |
| padding: 10px !important; | |
| border-radius: 8px !important; | |
| } | |
| .fall-detected { | |
| background-color: #FF4444 !important; | |
| color: white !important; | |
| } | |
| .non-fall { | |
| background-color: #44BB44 !important; | |
| color: white !important; | |
| } | |
| """ | |
| # ----------------------------------------------------------------------------- | |
| # ๋๋ฐ์ด์ค ์ค์ | |
| # ----------------------------------------------------------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # ----------------------------------------------------------------------------- | |
| # GPU ๋ฐ์ฝ๋ ์ดํฐ (๋ก์ปฌ/HF Spaces ํธํ) | |
| # ----------------------------------------------------------------------------- | |
| def gpu_decorator(duration: int = 120): | |
| """๋ก์ปฌ์์๋ ๊ทธ๋ฅ ์คํ, Spaces์์๋ GPU ํ ๋น""" | |
| def decorator(func): | |
| if SPACES_AVAILABLE: | |
| return spaces.GPU(duration=duration)(func) | |
| return func | |
| return decorator | |
| # ----------------------------------------------------------------------------- | |
| # ๋ชจ๋ธ ๋ค์ด๋ก๋ (HuggingFace Hub) | |
| # ----------------------------------------------------------------------------- | |
| HF_MODEL_REPO = "YoungjaeDev/fall-detection-models" | |
| def download_models() -> tuple[str, str]: | |
| """ | |
| HuggingFace Hub์์ ๋ชจ๋ธ ๋ค์ด๋ก๋ (์บ์๋จ) | |
| Returns: | |
| tuple: (pose_model_path, stgcn_checkpoint_path) | |
| Raises: | |
| RuntimeError: ๋ชจ๋ธ ๋ค์ด๋ก๋ ๋๋ ๊ฒ์ฆ ์คํจ ์ | |
| """ | |
| # ๋ก์ปฌ ๊ฒฝ๋ก ์ฐ์ ํ์ธ (๊ฐ๋ฐ ํ๊ฒฝ) | |
| local_pose = Path("yolo11m-pose.pt") | |
| local_stgcn = Path("runs/stgcn_binary_exp2_fixed_graph/best_acc.pth") | |
| if local_pose.exists() and local_stgcn.exists(): | |
| return str(local_pose), str(local_stgcn) | |
| # HuggingFace Hub์์ ๋ค์ด๋ก๋ (Private repo๋ HF_TOKEN ํ๊ฒฝ๋ณ์ ํ์) | |
| token = os.environ.get("HF_TOKEN") | |
| # Private ์ ์ฅ์ ์ ๊ทผ์ ์ํ ํ ํฐ ํ์ธ | |
| if token is None: | |
| raise RuntimeError( | |
| "HF_TOKEN ํ๊ฒฝ๋ณ์๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค. " | |
| "Private ๋ชจ๋ธ ์ ์ฅ์ ์ ๊ทผ์ ์ํด HF_TOKEN์ด ํ์ํฉ๋๋ค. " | |
| "HF Spaces์ ๊ฒฝ์ฐ Settings > Secrets์์ ์ค์ ํ์ธ์." | |
| ) | |
| try: | |
| pose_model_path = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename="yolo11m-pose.pt", | |
| token=token | |
| ) | |
| stgcn_checkpoint = hf_hub_download( | |
| repo_id=HF_MODEL_REPO, | |
| filename="best_acc.pth", | |
| token=token | |
| ) | |
| except Exception as e: | |
| raise RuntimeError( | |
| f"๋ชจ๋ธ ๋ค์ด๋ก๋ ์คํจ: {e}\n" | |
| f"์ ์ฅ์: {HF_MODEL_REPO}\n" | |
| f"HF_TOKEN์ด ์ฌ๋ฐ๋ฅด๊ฒ ์ค์ ๋์๋์ง ํ์ธํ์ธ์." | |
| ) from e | |
| # ๋ค์ด๋ก๋๋ ํ์ผ ๊ฒ์ฆ | |
| pose_path = Path(pose_model_path) | |
| stgcn_path = Path(stgcn_checkpoint) | |
| if not pose_path.exists(): | |
| raise RuntimeError(f"Pose ๋ชจ๋ธ ํ์ผ์ด ์กด์ฌํ์ง ์์ต๋๋ค: {pose_model_path}") | |
| if not stgcn_path.exists(): | |
| raise RuntimeError(f"ST-GCN ์ฒดํฌํฌ์ธํธ ํ์ผ์ด ์กด์ฌํ์ง ์์ต๋๋ค: {stgcn_checkpoint}") | |
| # ํ์ผ ํฌ๊ธฐ ๊ฒ์ฆ (๋๋ฌด ์์ผ๋ฉด ์์๋ ํ์ผ์ผ ๊ฐ๋ฅ์ฑ) | |
| pose_size = pose_path.stat().st_size | |
| stgcn_size = stgcn_path.stat().st_size | |
| if pose_size < 1_000_000: # 1MB ๋ฏธ๋ง | |
| raise RuntimeError(f"Pose ๋ชจ๋ธ ํ์ผ์ด ๋๋ฌด ์์ต๋๋ค: {pose_size} bytes") | |
| if stgcn_size < 1_000_000: # 1MB ๋ฏธ๋ง | |
| raise RuntimeError(f"ST-GCN ์ฒดํฌํฌ์ธํธ ํ์ผ์ด ๋๋ฌด ์์ต๋๋ค: {stgcn_size} bytes") | |
| return pose_model_path, stgcn_checkpoint | |
| # ----------------------------------------------------------------------------- | |
| # ํ์ดํ๋ผ์ธ ์ด๊ธฐํ (์ง์ฐ ๋ก๋ฉ) | |
| # ----------------------------------------------------------------------------- | |
| _pipeline = None | |
| def get_pipeline(): | |
| """ํ์ดํ๋ผ์ธ ์ฑ๊ธํค ๋ฐํ (์ง์ฐ ๋ก๋ฉ)""" | |
| global _pipeline | |
| if _pipeline is None: | |
| from pipeline.core.pipeline import FallDetectionPipeline | |
| # ๋ชจ๋ธ ๋ค์ด๋ก๋ (์บ์๋จ) | |
| pose_model_path, stgcn_checkpoint = download_models() | |
| _pipeline = FallDetectionPipeline( | |
| pose_model_path=pose_model_path, | |
| stgcn_checkpoint=stgcn_checkpoint, | |
| window_size=60, | |
| conf_threshold=0.5, | |
| fall_threshold=0.85, # ๊ฐ์ด๋๋ผ์ธ ๊ถ์ฅ: 0.8-0.9 (false positive <5%) | |
| temporal_window=5, | |
| stgcn_stride=5, | |
| alert_duration=150, | |
| post_fall_frames=15, # 2.5์ด @ 30fps with stride=5 (๊ฐ์ด๋๋ผ์ธ: 2-3์ด) | |
| device=str(device), | |
| debug=False, | |
| headless=False, | |
| viz_keypoints="all", | |
| viz_scale=1.0, | |
| viz_optimized=True | |
| ) | |
| return _pipeline | |
| # ----------------------------------------------------------------------------- | |
| # ํ๋ฅ ๊ทธ๋ํ ์์ฑ | |
| # ----------------------------------------------------------------------------- | |
| def create_probability_graph( | |
| frame_indices: list, | |
| probabilities: list, | |
| fall_threshold: float = 0.7 | |
| ) -> go.Figure: | |
| """ | |
| ๋์ ํ๋ฅ ๊ทธ๋ํ ์์ฑ | |
| Args: | |
| frame_indices: ํ๋ ์ ์ธ๋ฑ์ค ๋ฆฌ์คํธ | |
| probabilities: ๋์ ํ๋ฅ ๋ฆฌ์คํธ (0.0-1.0) | |
| fall_threshold: ๋์ ํ์ ์๊ณ๊ฐ | |
| Returns: | |
| Plotly Figure ๊ฐ์ฒด | |
| """ | |
| fig = go.Figure() | |
| # ํ๋ฅ ๋ผ์ธ | |
| fig.add_trace(go.Scatter( | |
| x=frame_indices, | |
| y=probabilities, | |
| mode='lines', | |
| name='Fall Probability', | |
| line=dict(color='#4682B4', width=2), | |
| fill='tozeroy', | |
| fillcolor='rgba(70, 130, 180, 0.3)' | |
| )) | |
| # ์๊ณ๊ฐ ๋ผ์ธ | |
| fig.add_hline( | |
| y=fall_threshold, | |
| line_dash="dash", | |
| line_color="red", | |
| annotation_text=f"Threshold ({fall_threshold})", | |
| annotation_position="right" | |
| ) | |
| # ๋ ์ด์์ | |
| fig.update_layout( | |
| title="Fall Detection Probability Over Time", | |
| xaxis_title="Frame", | |
| yaxis_title="Probability", | |
| yaxis=dict(range=[0, 1]), | |
| template="plotly_white", | |
| height=300, | |
| margin=dict(l=50, r=50, t=50, b=50), | |
| showlegend=True, | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=1.02, | |
| xanchor="right", | |
| x=1 | |
| ) | |
| ) | |
| return fig | |
| # ----------------------------------------------------------------------------- | |
| # ๋ฉ์ธ ์ถ๋ก ํจ์ | |
| # ----------------------------------------------------------------------------- | |
| def process_video( | |
| video_path: str, | |
| fall_threshold: float, | |
| viz_keypoints: str, | |
| progress: gr.Progress = gr.Progress() | |
| ) -> Tuple[Optional[str], Optional[go.Figure], str]: | |
| """ | |
| ๋น๋์ค ์ฒ๋ฆฌ ๋ฐ ๋์ ๊ฐ์ง | |
| Args: | |
| video_path: ์ ๋ ฅ ๋น๋์ค ๊ฒฝ๋ก | |
| fall_threshold: ๋์ ํ์ ์๊ณ๊ฐ (0.0-1.0) | |
| viz_keypoints: ํคํฌ์ธํธ ํ์ ๋ชจ๋ ('all' ๋๋ 'major') | |
| progress: Gradio ์งํ๋ฅ ํ์ | |
| Returns: | |
| output_video_path: ๊ฒฐ๊ณผ ๋น๋์ค ๊ฒฝ๋ก | |
| probability_graph: ํ๋ฅ ๊ทธ๋ํ | |
| result_text: ์ต์ข ํ์ ํ ์คํธ | |
| """ | |
| if video_path is None: | |
| return None, None, "๋น๋์ค๋ฅผ ์ ๋ก๋ํด์ฃผ์ธ์." | |
| try: | |
| # ํ์ดํ๋ผ์ธ ๋ก๋ | |
| progress(0.1, desc="๋ชจ๋ธ ๋ก๋ฉ ์ค...") | |
| pipeline = get_pipeline() | |
| pipeline.fall_threshold = fall_threshold | |
| pipeline.stgcn_classifier.fall_threshold = fall_threshold | |
| pipeline.viz_keypoints = viz_keypoints | |
| pipeline.reset() | |
| # ๋น๋์ค ์ด๊ธฐ | |
| progress(0.2, desc="๋น๋์ค ์ด๊ธฐ...") | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| return None, None, "๋น๋์ค๋ฅผ ์ด ์ ์์ต๋๋ค." | |
| # ๋น๋์ค ์ ๋ณด | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| # ๋น๋์ค ๊ธธ์ด ๊ฒ์ฆ (120s GPU ํ์์์ ๋๋น) | |
| if fps > 0: | |
| video_duration = total_frames / fps | |
| # ์ฒ๋ฆฌ ์๊ฐ ์ถ์ : ๋๋ต ์ค์๊ฐ์ 1.5๋ฐฐ + ์ธ์ฝ๋ฉ 10์ด | |
| estimated_time = video_duration * 1.5 + 10 | |
| if estimated_time > 110: # 120s ํ์์์์ ์ฌ์ ๋๊ธฐ | |
| cap.release() | |
| return None, None, ( | |
| f"๋น๋์ค๊ฐ ๋๋ฌด ๊น๋๋ค. " | |
| f"๋น๋์ค ๊ธธ์ด: {video_duration:.1f}์ด, " | |
| f"์์ ์ฒ๋ฆฌ ์๊ฐ: {estimated_time:.1f}์ด (์ ํ: 110์ด). " | |
| f"60์ด ์ด๋ด์ ๋น๋์ค๋ฅผ ์ ๋ก๋ํ์ธ์." | |
| ) | |
| # ์ถ๋ ฅ ๋น๋์ค ์ค์ (๋ณด์: NamedTemporaryFile ์ฌ์ฉ) | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
| output_path = tmp.name | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| # Info panel ์ถ๊ฐ๋ก ๋์ด 80px ์ฆ๊ฐ | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height + 80)) | |
| # ์ฒ๋ฆฌ ๋ฃจํ | |
| frame_idx = 0 | |
| frame_indices = [] | |
| probabilities = [] | |
| fall_detected = False | |
| max_confidence = 0.0 | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # ํ๋ ์ ์ฒ๋ฆฌ | |
| vis_frame, info = pipeline.process_frame(frame, frame_idx) | |
| # ํ๋ฅ ๊ธฐ๋ก | |
| if info['confidence'] is not None: | |
| frame_indices.append(frame_idx) | |
| probabilities.append(info['confidence']) | |
| max_confidence = max(max_confidence, info['confidence']) | |
| # ๋์ ๊ฐ์ง ํ์ธ | |
| if info['alert']: | |
| fall_detected = True | |
| # ์ถ๋ ฅ ์ ์ฅ | |
| out.write(vis_frame) | |
| frame_idx += 1 | |
| # ์งํ๋ฅ ์ ๋ฐ์ดํธ | |
| if frame_idx % 10 == 0: | |
| progress_val = 0.2 + 0.7 * (frame_idx / total_frames) | |
| progress(progress_val, desc=f"์ฒ๋ฆฌ ์ค... ({frame_idx}/{total_frames})") | |
| # ๋ฆฌ์์ค ํด์ | |
| cap.release() | |
| out.release() | |
| # H.264 ์ฝ๋ฑ์ผ๋ก ์ฌ์ธ์ฝ๋ฉ (๋ธ๋ผ์ฐ์ ํธํ) | |
| progress(0.9, desc="๋น๋์ค ์ธ์ฝ๋ฉ ์ค...") | |
| # ๋ณด์: NamedTemporaryFile ์ฌ์ฉ (CWE-377 ๋ฐฉ์ง) | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: | |
| output_h264 = tmp.name | |
| # ๋ณด์: subprocess.run ์ฌ์ฉ (shell injection ๋ฐฉ์ง) | |
| subprocess.run( | |
| [ | |
| 'ffmpeg', '-y', '-i', output_path, | |
| '-c:v', 'libx264', '-preset', 'fast', '-crf', '23', | |
| output_h264, '-loglevel', 'quiet' | |
| ], | |
| check=False, | |
| capture_output=True | |
| ) | |
| # mp4v ์์ ํ์ผ ์ญ์ | |
| if os.path.exists(output_path): | |
| os.remove(output_path) | |
| # H.264 ๋ณํ ์ฑ๊ณต ์ฌ๋ถ ํ์ธ | |
| if os.path.exists(output_h264): | |
| final_output = output_h264 | |
| else: | |
| final_output = output_path # ํด๋ฐฑ | |
| # ํ๋ฅ ๊ทธ๋ํ ์์ฑ | |
| progress(0.95, desc="๊ทธ๋ํ ์์ฑ ์ค...") | |
| if frame_indices and probabilities: | |
| fig = create_probability_graph(frame_indices, probabilities, fall_threshold) | |
| else: | |
| fig = None | |
| # ์ต์ข ํ์ | |
| progress(1.0, desc="์๋ฃ!") | |
| if fall_detected: | |
| result_text = f"[FALL DETECTED] ๋์์ด ๊ฐ์ง๋์์ต๋๋ค! (์ต๋ ํ๋ฅ : {max_confidence:.1%})" | |
| else: | |
| result_text = f"[Non-Fall] ๋์์ด ๊ฐ์ง๋์ง ์์์ต๋๋ค. (์ต๋ ํ๋ฅ : {max_confidence:.1%})" | |
| return final_output, fig, result_text | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"์ฒ๋ฆฌ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}\n{traceback.format_exc()}" | |
| return None, None, error_msg | |
| # ----------------------------------------------------------------------------- | |
| # Gradio UI | |
| # ----------------------------------------------------------------------------- | |
| def create_demo() -> gr.Blocks: | |
| """Gradio ๋ฐ๋ชจ ์์ฑ""" | |
| with gr.Blocks(theme=custom_theme, css=css) as demo: | |
| gr.Markdown( | |
| """ | |
| # Fall Detection Demo | |
| YOLOv11-Pose + ST-GCN 2-stage ํ์ดํ๋ผ์ธ์ ์ฌ์ฉํ ์ค์๊ฐ ๋์ ๊ฐ์ง ๋ฐ๋ชจ์ ๋๋ค. | |
| ๋น๋์ค๋ฅผ ์ ๋ก๋ํ๋ฉด ๋์ ์ฌ๋ถ๋ฅผ ๋ถ์ํ๊ณ , ๊ฒฐ๊ณผ ๋น๋์ค์ ํ๋ฅ ๊ทธ๋ํ๋ฅผ ์ ๊ณตํฉ๋๋ค. | |
| **ํ์ดํ๋ผ์ธ ๊ตฌ์ฑ:** | |
| - Stage 1: YOLOv11m-pose (Pose Estimation) | |
| - Stage 2: ST-GCN (Temporal Classification) | |
| - Window Size: 60 frames (2์ด @ 30fps) | |
| """, | |
| elem_id="main-title" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # ์ ๋ ฅ ์น์ | |
| gr.Markdown("### ์ ๋ ฅ") | |
| video_input = gr.Video( | |
| label="๋น๋์ค ์ ๋ก๋", | |
| sources=["upload"], | |
| ) | |
| with gr.Accordion("๊ณ ๊ธ ์ค์ ", open=False): | |
| fall_threshold = gr.Slider( | |
| minimum=0.7, | |
| maximum=0.95, | |
| value=0.85, | |
| step=0.05, | |
| label="๋์ ํ์ ์๊ณ๊ฐ", | |
| info="๊ถ์ฅ: 0.8-0.9 (false positive <5% ๋ชฉํ)" | |
| ) | |
| viz_keypoints = gr.Radio( | |
| choices=["all", "major"], | |
| value="all", | |
| label="ํคํฌ์ธํธ ํ์", | |
| info="all: ์ ์ฒด 17๊ฐ, major: ์ฃผ์ 9๊ฐ" | |
| ) | |
| submit_btn = gr.Button( | |
| "๋ถ์ ์์", | |
| variant="primary", | |
| elem_classes="submit-btn" | |
| ) | |
| with gr.Column(scale=1): | |
| # ์ถ๋ ฅ ์น์ | |
| gr.Markdown("### ๊ฒฐ๊ณผ") | |
| result_text = gr.Textbox( | |
| label="ํ์ ๊ฒฐ๊ณผ", | |
| lines=2, | |
| interactive=False | |
| ) | |
| video_output = gr.Video( | |
| label="๊ฒฐ๊ณผ ๋น๋์ค", | |
| ) | |
| prob_graph = gr.Plot( | |
| label="๋์ ํ๋ฅ ๊ทธ๋ํ", | |
| ) | |
| # ์์ ๋น๋์ค | |
| gr.Markdown("### ์์ ๋น๋์ค") | |
| example_dir = Path(__file__).parent / "examples" | |
| examples = [] | |
| if example_dir.exists(): | |
| for ext in ["*.mp4", "*.avi", "*.mov"]: | |
| examples.extend([str(p) for p in example_dir.glob(ext)]) | |
| if examples: | |
| gr.Examples( | |
| examples=[[ex, 0.85, "all"] for ex in examples[:3]], | |
| inputs=[video_input, fall_threshold, viz_keypoints], | |
| outputs=[video_output, prob_graph, result_text], | |
| fn=process_video, | |
| cache_examples=False, | |
| ) | |
| # ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
| submit_btn.click( | |
| fn=process_video, | |
| inputs=[video_input, fall_threshold, viz_keypoints], | |
| outputs=[video_output, prob_graph, result_text], | |
| ) | |
| # ํธํฐ | |
| gr.Markdown( | |
| """ | |
| --- | |
| **References:** | |
| - [YOLOv11](https://github.com/ultralytics/ultralytics) - Pose Estimation | |
| - [ST-GCN](https://arxiv.org/abs/1801.07455) - Spatial Temporal Graph Convolutional Networks | |
| - AI Hub Fall Detection Dataset | |
| """ | |
| ) | |
| return demo | |
| # ----------------------------------------------------------------------------- | |
| # ๋ฉ์ธ ์คํ | |
| # ----------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| demo = create_demo() | |
| demo.queue(max_size=10).launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| show_error=True, | |
| ) | |