#!/usr/bin/env python3 """ Fall Detection Gradio App (Batch Processing Pipeline) YOLOv11-Pose + ST-GCN 2-stage 파이프라인을 사용한 낙상 감지 데모입니다. 배치 처리로 최적화되어 빠른 추론 속도를 제공합니다. Pipeline: 1. decord로 전체 프레임 배치 로드 2. YOLO Pose 배치 추론 → keypoints 누적 3. 윈도우 단위 ST-GCN 배치 추론 4. 낙상 시점 -1s ~ +2s 구간만 시각화 사용법 (로컬): python pipeline/demo_gradio/app.py 작성자: Fall Detection Pipeline Team 작성일: 2025-11-27 """ import json import os import subprocess import sys import tempfile from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Iterable, Optional, Tuple import cv2 import gradio as gr import numpy as np import plotly.graph_objects as go import torch from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes from huggingface_hub import hf_hub_download from visualization import visualize_fall_simple # HF Spaces 배포용: 프로젝트 루트 설정 불필요 (self-contained) # Zero GPU 호환 설정 try: import spaces SPACES_AVAILABLE = True except ImportError: SPACES_AVAILABLE = False # ----------------------------------------------------------------------------- # Authentication (multi-user support via environment variable) # ----------------------------------------------------------------------------- def get_auth_credentials(): """Load auth credentials from environment variable (multi-user support). Environment variable format: GRADIO_AUTH='[["user1","pass1"],["user2","pass2"]]' Returns None if not set (auth disabled for local development). """ auth_json = os.environ.get("GRADIO_AUTH") if auth_json: try: auth_list = json.loads(auth_json) # [["user1","pass1"],["user2","pass2"]] -> [("user1","pass1"),("user2","pass2")] return [tuple(pair) for pair in auth_list] except json.JSONDecodeError: print("Warning: Invalid GRADIO_AUTH format, auth disabled") return None return None # ----------------------------------------------------------------------------- # 커스텀 테마 (PRITHIVSAKTHIUR 스타일) # ----------------------------------------------------------------------------- colors.custom_color = colors.Color( name="custom_color", c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2", c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C", c800="#2E5378", c900="#264364", c950="#1E3450", ) class CustomTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.custom_color, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", button_primary_text_color="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", slider_color="*secondary_500", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", ) custom_theme = CustomTheme() # ----------------------------------------------------------------------------- # CSS 스타일 # ----------------------------------------------------------------------------- css = """ #col-container { margin: 0 auto; max-width: 1200px; } #main-title h1 { font-size: 2.3em !important; } .submit-btn { background-color: #4682B4 !important; color: white !important; } .submit-btn:hover { background-color: #5A9BD4 !important; } .result-label { font-size: 1.5em !important; font-weight: bold !important; padding: 10px !important; border-radius: 8px !important; } .fall-detected { background-color: #FF4444 !important; color: white !important; } .non-fall { background-color: #44BB44 !important; color: white !important; } """ # ----------------------------------------------------------------------------- # 디바이스 설정 # ----------------------------------------------------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------------------------------------------------------------------------- # GPU 데코레이터 (로컬/HF Spaces 호환) # ----------------------------------------------------------------------------- def gpu_decorator(duration: int = 120): """로컬에서는 그냥 실행, Spaces에서는 GPU 할당""" def decorator(func): if SPACES_AVAILABLE: return spaces.GPU(duration=duration)(func) return func return decorator # ----------------------------------------------------------------------------- # 모델 다운로드 (HuggingFace Hub) # ----------------------------------------------------------------------------- HF_MODEL_REPO = "YoungjaeDev/fall-detection-models" def download_models() -> tuple[str, str]: """HuggingFace Hub에서 모델 다운로드 (캐시됨)""" # 로컬 경로 우선 확인 (개발 환경) local_pose = Path("yolo11m-pose.pt") local_stgcn = Path("runs/stgcn_binary_exp2_fixed_graph/best_acc.pth") if local_pose.exists() and local_stgcn.exists(): return str(local_pose), str(local_stgcn) # HuggingFace Hub에서 다운로드 token = os.environ.get("HF_TOKEN") if token is None: raise RuntimeError( "HF_TOKEN 환경변수가 설정되지 않았습니다. " "Private 모델 저장소 접근을 위해 HF_TOKEN이 필요합니다." ) try: pose_model_path = hf_hub_download( repo_id=HF_MODEL_REPO, filename="yolo11m-pose.pt", token=token ) stgcn_checkpoint = hf_hub_download( repo_id=HF_MODEL_REPO, filename="best_acc.pth", token=token ) except Exception as e: raise RuntimeError(f"모델 다운로드 실패: {e}") from e return pose_model_path, stgcn_checkpoint # ----------------------------------------------------------------------------- # 모델 싱글톤 (지연 로딩) # ----------------------------------------------------------------------------- _pose_estimator = None _stgcn_classifier = None def get_pose_estimator(): """PoseEstimator 싱글톤 반환""" global _pose_estimator if _pose_estimator is None: from models.pose_estimator import PoseEstimator pose_model_path, _ = download_models() _pose_estimator = PoseEstimator( model_path=pose_model_path, conf_threshold=0.5, device=str(device) ) return _pose_estimator def get_stgcn_classifier(): """STGCNClassifier 싱글톤 반환""" global _stgcn_classifier if _stgcn_classifier is None: from models.stgcn_classifier import STGCNClassifier _, stgcn_checkpoint = download_models() _stgcn_classifier = STGCNClassifier( checkpoint_path=stgcn_checkpoint, fall_threshold=0.7, device=str(device) ) return _stgcn_classifier # ----------------------------------------------------------------------------- # 프레임 로드 (cv2 사용 - 대부분의 비디오에서 더 빠름) # ----------------------------------------------------------------------------- def load_video_frames(video_path: str) -> Tuple[np.ndarray, float]: """ 비디오에서 전체 프레임 로드 (cv2 사용) Returns: frames: (N, H, W, C) numpy array (BGR) fps: 프레임 레이트 """ cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) # 일부 비디오 컨테이너는 FPS 정보를 제공하지 않아 0을 반환할 수 있음 if not fps or fps <= 0: fps = 30.0 # 기본값 (ZeroDivisionError 방지) frames = [] while True: ret, frame = cap.read() if not ret: break frames.append(frame) cap.release() return np.array(frames), fps # ----------------------------------------------------------------------------- # 배치 Pose 추론 # ----------------------------------------------------------------------------- def extract_all_keypoints( frames: np.ndarray, pose_estimator, batch_size: int = 8, progress_callback=None ) -> list[Optional[np.ndarray]]: """ 전체 프레임에 대해 배치 Pose 추론 Args: frames: (N, H, W, C) 전체 비디오 프레임 pose_estimator: PoseEstimator 인스턴스 batch_size: 배치 크기 progress_callback: 진행률 콜백 함수 Returns: keypoints_list: [(17, 3) or None, ...] N개의 keypoints """ n_frames = len(frames) all_keypoints = [] for i in range(0, n_frames, batch_size): batch = list(frames[i:i+batch_size]) batch_keypoints = pose_estimator.extract_batch(batch) all_keypoints.extend(batch_keypoints) if progress_callback: progress_callback(min(i + batch_size, n_frames), n_frames) return all_keypoints # ----------------------------------------------------------------------------- # 윈도우 생성 및 ST-GCN 배치 추론 # ----------------------------------------------------------------------------- def create_windows_and_predict( keypoints_list: list[Optional[np.ndarray]], stgcn_classifier, window_size: int = 60, stride: int = 5, fall_threshold: float = 0.7 ) -> Tuple[list[int], list[float], Optional[int]]: """ keypoints에서 윈도우 생성 후 ST-GCN 배치 추론 Args: keypoints_list: 프레임별 keypoints 리스트 stgcn_classifier: STGCNClassifier 인스턴스 window_size: 윈도우 크기 (프레임 수) stride: 추론 간격 (N 프레임마다 1번) fall_threshold: 낙상 판정 임계값 Returns: frame_indices: ST-GCN 예측이 있는 프레임 인덱스 fall_probs: 각 프레임의 낙상 확률 (class 1 확률) first_fall_frame: 첫 낙상 감지 프레임 인덱스 (없으면 None) """ n_frames = len(keypoints_list) # None을 빈 keypoints로 대체 processed_keypoints = [] for kpts in keypoints_list: if kpts is None: processed_keypoints.append(np.zeros((17, 3), dtype=np.float32)) else: processed_keypoints.append(kpts) # 윈도우 생성 (stride 간격으로) frame_indices = [] windows = [] for frame_idx in range(window_size - 1, n_frames, stride): # 이전 window_size 프레임으로 윈도우 구성 window_keypoints = processed_keypoints[frame_idx - window_size + 1:frame_idx + 1] # (T, V, C) -> (C, T, V, M) 변환 window = np.array(window_keypoints) # (T=60, V=17, C=3) window = window.transpose(2, 0, 1) # (C=3, T=60, V=17) window = np.expand_dims(window, -1) # (C=3, T=60, V=17, M=1) frame_indices.append(frame_idx) windows.append(window.astype(np.float32)) if not windows: return [], [], None # ST-GCN 배치 추론 predictions, confidences, fall_probs = stgcn_classifier.predict_batch(windows) # 첫 낙상 감지 프레임 찾기 first_fall_frame = None for i, (pred, fall_prob) in enumerate(zip(predictions, fall_probs)): if pred == 1 and fall_prob >= fall_threshold: first_fall_frame = frame_indices[i] break return frame_indices, fall_probs.tolist(), first_fall_frame # ----------------------------------------------------------------------------- # 시각화 워커 함수 (ThreadPoolExecutor용 - HF Spaces daemon 프로세스 호환) # ----------------------------------------------------------------------------- # FALL DETECTED 텍스트 표시 지속 시간 (초) FALL_DISPLAY_DURATION = 2.0 def _visualize_single_frame(args: tuple) -> Tuple[int, np.ndarray]: """단일 프레임 시각화 워커 (간소화된 버전)""" (frame_idx, frame, keypoints, show_fall_text, viz_keypoints, viz_scale) = args vis_frame = visualize_fall_simple( frame=frame, keypoints=keypoints if keypoints is not None and keypoints.sum() > 0 else None, show_fall_text=show_fall_text, keypoint_mode=viz_keypoints, output_scale=viz_scale ) return frame_idx, vis_frame def visualize_clip_parallel( frames: np.ndarray, keypoints_list: list[Optional[np.ndarray]], frame_indices: list[int], fall_probs: list[float], clip_start: int, clip_end: int, fps: float, first_fall_frame: Optional[int] = None, fall_threshold: float = 0.7, viz_keypoints: str = "all", viz_scale: float = 1.0, num_workers: int = 4 ) -> list[np.ndarray]: """ 클립 구간 병렬 시각화 (간소화된 버전) Args: frames: 전체 프레임 keypoints_list: 전체 keypoints frame_indices: ST-GCN 예측 프레임 인덱스 fall_probs: 프레임별 낙상 확률 clip_start: 클립 시작 인덱스 clip_end: 클립 종료 인덱스 fps: 프레임 레이트 first_fall_frame: 첫 낙상 감지 프레임 (깜빡임 방지용) fall_threshold: 낙상 판정 임계값 viz_keypoints: 키포인트 표시 모드 viz_scale: 출력 스케일 num_workers: 병렬 워커 수 Returns: vis_frames: 시각화된 프레임 리스트 """ # 깜빡임 방지: 첫 낙상 후 N초간 FALL DETECTED 표시 fall_display_end_frame = None if first_fall_frame is not None: fall_display_end_frame = first_fall_frame + int(fps * FALL_DISPLAY_DURATION) # 시각화 인자 준비 viz_args = [] for i in range(clip_start, clip_end): frame = frames[i] keypoints = keypoints_list[i] # FALL DETECTED 텍스트 표시 여부 결정 (깜빡임 방지) show_fall_text = False if first_fall_frame is not None and fall_display_end_frame is not None: if first_fall_frame <= i <= fall_display_end_frame: show_fall_text = True args = ( i, # frame_idx frame, # frame keypoints, # keypoints show_fall_text, # show_fall_text (깜빡임 방지 적용) viz_keypoints, # viz_keypoints viz_scale # viz_scale ) viz_args.append(args) # 병렬 시각화 (ThreadPoolExecutor 사용 - HF Spaces daemon 프로세스 호환) with ThreadPoolExecutor(max_workers=num_workers) as executor: results = list(executor.map(_visualize_single_frame, viz_args)) # 순서대로 정렬 results.sort(key=lambda x: x[0]) vis_frames = [frame for _, frame in results] return vis_frames # ----------------------------------------------------------------------------- # 확률 그래프 생성 # ----------------------------------------------------------------------------- def create_probability_graph( frame_indices: list[int], fall_probs: list[float], fall_threshold: float = 0.7, fps: float = 30.0 ) -> go.Figure: """낙상 확률 그래프 생성 (X축: 시간)""" # 프레임 인덱스 -> 시간(초) 변환 time_seconds = [idx / fps for idx in frame_indices] fig = go.Figure() # 확률 라인 fig.add_trace(go.Scatter( x=time_seconds, y=fall_probs, mode='lines', name='Fall Probability', line=dict(color='#4682B4', width=2), fill='tozeroy', fillcolor='rgba(70, 130, 180, 0.3)' )) # 임계값 라인 fig.add_hline( y=fall_threshold, line_dash="dash", line_color="red", annotation_text=f"Threshold ({fall_threshold})", annotation_position="right" ) # 레이아웃 fig.update_layout( title="Fall Detection Probability Over Time", xaxis_title="Time (seconds)", yaxis_title="Probability", yaxis=dict(range=[0, 1.05]), template="plotly_white", height=300, margin=dict(l=50, r=50, t=50, b=50), showlegend=True, legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ) ) return fig # ----------------------------------------------------------------------------- # 스마트 클립 추출 설정 # ----------------------------------------------------------------------------- CLIP_PRE_FALL_SECONDS = 1.0 # 낙상 전 1초 CLIP_POST_FALL_SECONDS = 2.0 # 낙상 후 2초 # ----------------------------------------------------------------------------- # 메인 추론 함수 # ----------------------------------------------------------------------------- @gpu_decorator(duration=120) def process_video( video_path: str, fall_threshold: float, viz_keypoints: str, progress: gr.Progress = gr.Progress() ) -> Tuple[Optional[str], Optional[go.Figure], str]: """ 비디오 처리 및 낙상 감지 (배치 처리 파이프라인) Pipeline: 1. decord로 전체 프레임 배치 로드 2. YOLO Pose 배치 추론 → keypoints 누적 3. 윈도우 단위 ST-GCN 배치 추론 4. 낙상 시점 -1s ~ +2s 구간만 시각화 Args: video_path: 입력 비디오 경로 fall_threshold: 낙상 판정 임계값 (0.0-1.0) viz_keypoints: 키포인트 표시 모드 ('all' 또는 'major') progress: Gradio 진행률 표시 Returns: output_video_path: 결과 클립 경로 (낙상 감지 시) 또는 None probability_graph: 확률 그래프 result_text: 최종 판정 텍스트 """ if video_path is None: return None, None, "비디오를 업로드해주세요." try: # Stage 0: 모델 로드 progress(0.05, desc="모델 로딩 중...") pose_estimator = get_pose_estimator() stgcn_classifier = get_stgcn_classifier() stgcn_classifier.fall_threshold = fall_threshold # Stage 1: 프레임 로드 (decord) progress(0.1, desc="비디오 로딩 중...") frames, fps = load_video_frames(video_path) n_frames = len(frames) if n_frames == 0: return None, None, "비디오를 읽을 수 없습니다." # 비디오 길이 검증 (120s GPU 타임아웃 대비) video_duration = n_frames / fps if video_duration > 60: return None, None, ( f"비디오가 너무 깁니다. " f"비디오 길이: {video_duration:.1f}초 (제한: 60초). " f"60초 이내의 비디오를 업로드하세요." ) # Stage 2: 배치 Pose 추론 progress(0.15, desc="Pose 추출 중...") def pose_progress(current, total): pct = 0.15 + 0.35 * (current / total) progress(pct, desc=f"Pose 추출 중... ({current}/{total})") keypoints_list = extract_all_keypoints( frames, pose_estimator, batch_size=8, progress_callback=pose_progress ) # Stage 3: ST-GCN 배치 추론 progress(0.55, desc="낙상 분석 중...") frame_indices, fall_probs, first_fall_frame = create_windows_and_predict( keypoints_list, stgcn_classifier, window_size=60, stride=5, fall_threshold=fall_threshold ) # 확률 그래프 생성 progress(0.7, desc="그래프 생성 중...") fig = None if frame_indices and fall_probs: fig = create_probability_graph(frame_indices, fall_probs, fall_threshold, fps) # 낙상 미감지 시 if first_fall_frame is None: progress(1.0, desc="완료!") result_text = ( f"[Non-Fall] 낙상이 감지되지 않았습니다.\n" f"분석 프레임: {n_frames}개" ) return None, fig, result_text # Stage 4: 낙상 구간만 시각화 progress(0.75, desc="클립 시각화 중...") pre_fall_frames = int(fps * CLIP_PRE_FALL_SECONDS) post_fall_frames = int(fps * CLIP_POST_FALL_SECONDS) clip_start = max(0, first_fall_frame - pre_fall_frames) clip_end = min(n_frames, first_fall_frame + post_fall_frames) vis_frames = visualize_clip_parallel( frames=frames, keypoints_list=keypoints_list, frame_indices=frame_indices, fall_probs=fall_probs, clip_start=clip_start, clip_end=clip_end, fps=fps, first_fall_frame=first_fall_frame, # 깜빡임 방지용 fall_threshold=fall_threshold, viz_keypoints=viz_keypoints, viz_scale=1.0, num_workers=4 ) if not vis_frames: progress(1.0, desc="완료!") return None, fig, "클립 추출에 실패했습니다." # Stage 5: 비디오 인코딩 progress(0.9, desc="클립 인코딩 중...") with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: output_path = tmp.name fourcc = cv2.VideoWriter_fourcc(*'mp4v') clip_height, clip_width = vis_frames[0].shape[:2] out = cv2.VideoWriter(output_path, fourcc, fps, (clip_width, clip_height)) for vis_frame in vis_frames: out.write(vis_frame) out.release() # H.264 재인코딩 (브라우저 호환) with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: output_h264 = tmp.name subprocess.run( [ 'ffmpeg', '-y', '-i', output_path, '-c:v', 'libx264', '-preset', 'fast', '-crf', '23', output_h264, '-loglevel', 'quiet' ], check=False, capture_output=True ) # 임시 파일 정리 if os.path.exists(output_path): os.remove(output_path) final_output = output_h264 if os.path.exists(output_h264) else None # 최종 판정 progress(1.0, desc="완료!") fall_time = first_fall_frame / fps clip_duration = len(vis_frames) / fps result_text = ( f"[FALL DETECTED] 낙상이 감지되었습니다!\n" f"낙상 시점: {fall_time:.2f}초 (프레임 #{first_fall_frame})\n" f"클립 길이: {clip_duration:.1f}초 ({len(vis_frames)}프레임)" ) return final_output, fig, result_text except Exception as e: import traceback error_msg = f"처리 중 오류 발생: {str(e)}\n{traceback.format_exc()}" return None, None, error_msg # ----------------------------------------------------------------------------- # Gradio UI # ----------------------------------------------------------------------------- def create_demo() -> gr.Blocks: """Gradio 데모 생성""" with gr.Blocks(theme=custom_theme, css=css) as demo: gr.Markdown( """ # Fall Detection Demo YOLOv11-Pose + ST-GCN 2-stage 파이프라인을 사용한 실시간 낙상 감지 데모입니다. 비디오를 업로드하면 낙상 여부를 분석하고, 결과 비디오와 확률 그래프를 제공합니다. **파이프라인 구성:** - Stage 1: YOLOv11m-pose (Pose Estimation) - Batch Processing - Stage 2: ST-GCN (Temporal Classification) - Batch Processing - Window Size: 60 frames (2s @ 30fps) """, elem_id="main-title" ) with gr.Row(): with gr.Column(scale=1): # 입력 섹션 gr.Markdown("### 입력") video_input = gr.Video( label="비디오 업로드", sources=["upload"], ) with gr.Accordion("고급 설정", open=False): fall_threshold = gr.Slider( minimum=0.5, maximum=0.95, value=0.7, step=0.05, label="낙상 판정 임계값", info="권장: 0.7-0.85" ) viz_keypoints = gr.Radio( choices=["all", "major"], value="all", label="키포인트 표시", info="all: 전체 17개, major: 주요 9개" ) submit_btn = gr.Button( "분석 시작", variant="primary", elem_classes="submit-btn" ) with gr.Column(scale=1): # 출력 섹션 gr.Markdown("### 결과") result_text = gr.Textbox( label="판정 결과", lines=3, interactive=False ) video_output = gr.Video( label="결과 비디오", ) prob_graph = gr.Plot( label="낙상 확률 그래프", ) # 예제 비디오 gr.Markdown("### 예제 비디오") example_dir = Path(__file__).parent / "examples" examples = [] if example_dir.exists(): for ext in ["*.mp4", "*.avi", "*.mov"]: examples.extend([str(p) for p in example_dir.glob(ext)]) if examples: gr.Examples( examples=[[ex, 0.7, "all"] for ex in sorted(examples)], inputs=[video_input, fall_threshold, viz_keypoints], outputs=[video_output, prob_graph, result_text], fn=process_video, cache_examples=False, examples_per_page=4, label="예제 비디오", ) # 이벤트 연결 submit_btn.click( fn=process_video, inputs=[video_input, fall_threshold, viz_keypoints], outputs=[video_output, prob_graph, result_text], ) # 푸터 gr.Markdown( """ --- **References:** - [YOLOv11](https://github.com/ultralytics/ultralytics) - Pose Estimation - [ST-GCN](https://arxiv.org/abs/1801.07455) - Spatial Temporal Graph Convolutional Networks - AI Hub Fall Detection Dataset """ ) return demo # ----------------------------------------------------------------------------- # 메인 실행 # ----------------------------------------------------------------------------- if __name__ == "__main__": demo = create_demo() demo.queue(max_size=10).launch( server_name="0.0.0.0", server_port=7860, share=False, # HF Spaces에서는 이미 public URL 제공 show_error=True, auth=get_auth_credentials(), ssr_mode=False, # svelte-i18n locale 에러 방지 )