#!/usr/bin/env python3 """ Fall Detection Gradio App YOLOv11-Pose + ST-GCN 2-stage 파이프라인을 사용한 낙상 감지 데모입니다. HF Spaces Zero GPU 환경에서 실행됩니다. 사용법 (로컬): python demo_gradio/app.py 사용법 (HF Spaces): 자동으로 app.py가 실행됩니다. 작성자: Fall Detection Pipeline Team 작성일: 2025-11-26 """ import os import sys import tempfile import time from pathlib import Path from typing import Iterable, Optional, Tuple import cv2 import gradio as gr import numpy as np import plotly.graph_objects as go import torch from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes # 프로젝트 루트를 Python path에 추가 # pipeline/demo_gradio/app.py -> pipeline -> project_root PROJECT_ROOT = Path(__file__).parent.parent.parent sys.path.insert(0, str(PROJECT_ROOT)) # Zero GPU 호환 설정 try: import spaces SPACES_AVAILABLE = True except ImportError: SPACES_AVAILABLE = False # ----------------------------------------------------------------------------- # 커스텀 테마 (PRITHIVSAKTHIUR 스타일) # ----------------------------------------------------------------------------- colors.custom_color = colors.Color( name="custom_color", c50="#EBF3F8", c100="#D3E5F0", c200="#A8CCE1", c300="#7DB3D2", c400="#529AC3", c500="#4682B4", c600="#3E72A0", c700="#36638C", c800="#2E5378", c900="#264364", c950="#1E3450", ) class CustomTheme(Soft): def __init__( self, *, primary_hue: colors.Color | str = colors.gray, secondary_hue: colors.Color | str = colors.custom_color, neutral_hue: colors.Color | str = colors.slate, text_size: sizes.Size | str = sizes.text_lg, font: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("Outfit"), "Arial", "sans-serif", ), font_mono: fonts.Font | str | Iterable[fonts.Font | str] = ( fonts.GoogleFont("IBM Plex Mono"), "ui-monospace", "monospace", ), ): super().__init__( primary_hue=primary_hue, secondary_hue=secondary_hue, neutral_hue=neutral_hue, text_size=text_size, font=font, font_mono=font_mono, ) super().set( background_fill_primary="*primary_50", body_background_fill="linear-gradient(135deg, *primary_200, *primary_100)", button_primary_text_color="white", button_primary_background_fill="linear-gradient(90deg, *secondary_500, *secondary_600)", button_primary_background_fill_hover="linear-gradient(90deg, *secondary_600, *secondary_700)", slider_color="*secondary_500", block_title_text_weight="600", block_border_width="3px", block_shadow="*shadow_drop_lg", button_primary_shadow="*shadow_drop_lg", ) custom_theme = CustomTheme() # ----------------------------------------------------------------------------- # CSS 스타일 # ----------------------------------------------------------------------------- css = """ #col-container { margin: 0 auto; max-width: 1200px; } #main-title h1 { font-size: 2.3em !important; } .submit-btn { background-color: #4682B4 !important; color: white !important; } .submit-btn:hover { background-color: #5A9BD4 !important; } .result-label { font-size: 1.5em !important; font-weight: bold !important; padding: 10px !important; border-radius: 8px !important; } .fall-detected { background-color: #FF4444 !important; color: white !important; } .non-fall { background-color: #44BB44 !important; color: white !important; } """ # ----------------------------------------------------------------------------- # 디바이스 설정 # ----------------------------------------------------------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------------------------------------------------------------------------- # GPU 데코레이터 (로컬/HF Spaces 호환) # ----------------------------------------------------------------------------- def gpu_decorator(duration: int = 120): """로컬에서는 그냥 실행, Spaces에서는 GPU 할당""" def decorator(func): if SPACES_AVAILABLE: return spaces.GPU(duration=duration)(func) return func return decorator # ----------------------------------------------------------------------------- # 파이프라인 초기화 (지연 로딩) # ----------------------------------------------------------------------------- _pipeline = None def get_pipeline(): """파이프라인 싱글톤 반환 (지연 로딩)""" global _pipeline if _pipeline is None: from pipeline.core.pipeline import FallDetectionPipeline # HF Spaces에서는 models 폴더에서 로드 pose_model_path = "pipeline/demo_gradio/models/yolo11m-pose.pt" stgcn_checkpoint = "pipeline/demo_gradio/models/best_acc.pth" # 로컬 경로 폴백 if not Path(pose_model_path).exists(): pose_model_path = "yolo11m-pose.pt" if not Path(stgcn_checkpoint).exists(): stgcn_checkpoint = "runs/stgcn_binary_exp2_fixed_graph/best_acc.pth" _pipeline = FallDetectionPipeline( pose_model_path=pose_model_path, stgcn_checkpoint=stgcn_checkpoint, window_size=60, conf_threshold=0.5, fall_threshold=0.7, temporal_window=5, stgcn_stride=5, alert_duration=150, post_fall_frames=3, device=str(device), debug=False, headless=False, viz_keypoints="all", viz_scale=1.0, viz_optimized=True ) return _pipeline # ----------------------------------------------------------------------------- # 확률 그래프 생성 # ----------------------------------------------------------------------------- def create_probability_graph( frame_indices: list, probabilities: list, fall_threshold: float = 0.7 ) -> go.Figure: """ 낙상 확률 그래프 생성 Args: frame_indices: 프레임 인덱스 리스트 probabilities: 낙상 확률 리스트 (0.0-1.0) fall_threshold: 낙상 판정 임계값 Returns: Plotly Figure 객체 """ fig = go.Figure() # 확률 라인 fig.add_trace(go.Scatter( x=frame_indices, y=probabilities, mode='lines', name='Fall Probability', line=dict(color='#4682B4', width=2), fill='tozeroy', fillcolor='rgba(70, 130, 180, 0.3)' )) # 임계값 라인 fig.add_hline( y=fall_threshold, line_dash="dash", line_color="red", annotation_text=f"Threshold ({fall_threshold})", annotation_position="right" ) # 레이아웃 fig.update_layout( title="Fall Detection Probability Over Time", xaxis_title="Frame", yaxis_title="Probability", yaxis=dict(range=[0, 1]), template="plotly_white", height=300, margin=dict(l=50, r=50, t=50, b=50), showlegend=True, legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 ) ) return fig # ----------------------------------------------------------------------------- # 메인 추론 함수 # ----------------------------------------------------------------------------- @gpu_decorator(duration=120) def process_video( video_path: str, fall_threshold: float, viz_keypoints: str, progress: gr.Progress = gr.Progress() ) -> Tuple[Optional[str], Optional[go.Figure], str]: """ 비디오 처리 및 낙상 감지 Args: video_path: 입력 비디오 경로 fall_threshold: 낙상 판정 임계값 (0.0-1.0) viz_keypoints: 키포인트 표시 모드 ('all' 또는 'major') progress: Gradio 진행률 표시 Returns: output_video_path: 결과 비디오 경로 probability_graph: 확률 그래프 result_text: 최종 판정 텍스트 """ if video_path is None: return None, None, "비디오를 업로드해주세요." try: # 파이프라인 로드 progress(0.1, desc="모델 로딩 중...") pipeline = get_pipeline() pipeline.fall_threshold = fall_threshold pipeline.stgcn_classifier.fall_threshold = fall_threshold pipeline.viz_keypoints = viz_keypoints pipeline.reset() # 비디오 열기 progress(0.2, desc="비디오 열기...") cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return None, None, "비디오를 열 수 없습니다." # 비디오 정보 fps = cap.get(cv2.CAP_PROP_FPS) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # 출력 비디오 설정 output_path = tempfile.mktemp(suffix=".mp4") fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Info panel 추가로 높이 80px 증가 out = cv2.VideoWriter(output_path, fourcc, fps, (width, height + 80)) # 처리 루프 frame_idx = 0 frame_indices = [] probabilities = [] fall_detected = False max_confidence = 0.0 while True: ret, frame = cap.read() if not ret: break # 프레임 처리 vis_frame, info = pipeline.process_frame(frame, frame_idx) # 확률 기록 if info['confidence'] is not None: frame_indices.append(frame_idx) probabilities.append(info['confidence']) max_confidence = max(max_confidence, info['confidence']) # 낙상 감지 확인 if info['alert']: fall_detected = True # 출력 저장 out.write(vis_frame) frame_idx += 1 # 진행률 업데이트 if frame_idx % 10 == 0: progress_val = 0.2 + 0.7 * (frame_idx / total_frames) progress(progress_val, desc=f"처리 중... ({frame_idx}/{total_frames})") # 리소스 해제 cap.release() out.release() # H.264 코덱으로 재인코딩 (브라우저 호환) progress(0.9, desc="비디오 인코딩 중...") output_h264 = tempfile.mktemp(suffix=".mp4") os.system(f'ffmpeg -y -i "{output_path}" -c:v libx264 -preset fast -crf 23 "{output_h264}" -loglevel quiet') # mp4v 임시 파일 삭제 if os.path.exists(output_path): os.remove(output_path) # H.264 변환 성공 여부 확인 if os.path.exists(output_h264): final_output = output_h264 else: final_output = output_path # 폴백 # 확률 그래프 생성 progress(0.95, desc="그래프 생성 중...") if frame_indices and probabilities: fig = create_probability_graph(frame_indices, probabilities, fall_threshold) else: fig = None # 최종 판정 progress(1.0, desc="완료!") if fall_detected: result_text = f"[FALL DETECTED] 낙상이 감지되었습니다! (최대 확률: {max_confidence:.1%})" else: result_text = f"[Non-Fall] 낙상이 감지되지 않았습니다. (최대 확률: {max_confidence:.1%})" return final_output, fig, result_text except Exception as e: import traceback error_msg = f"처리 중 오류 발생: {str(e)}\n{traceback.format_exc()}" return None, None, error_msg # ----------------------------------------------------------------------------- # Gradio UI # ----------------------------------------------------------------------------- def create_demo() -> gr.Blocks: """Gradio 데모 생성""" with gr.Blocks() as demo: gr.Markdown( """ # Fall Detection Demo YOLOv11-Pose + ST-GCN 2-stage 파이프라인을 사용한 실시간 낙상 감지 데모입니다. 비디오를 업로드하면 낙상 여부를 분석하고, 결과 비디오와 확률 그래프를 제공합니다. **파이프라인 구성:** - Stage 1: YOLOv11m-pose (Pose Estimation) - Stage 2: ST-GCN (Temporal Classification) - Window Size: 60 frames (2초 @ 30fps) """, elem_id="main-title" ) with gr.Row(): with gr.Column(scale=1): # 입력 섹션 gr.Markdown("### 입력") video_input = gr.Video( label="비디오 업로드", sources=["upload"], ) with gr.Accordion("고급 설정", open=False): fall_threshold = gr.Slider( minimum=0.5, maximum=0.95, value=0.7, step=0.05, label="낙상 판정 임계값", info="이 값 이상의 확률이면 낙상으로 판정합니다." ) viz_keypoints = gr.Radio( choices=["all", "major"], value="all", label="키포인트 표시", info="all: 전체 17개, major: 주요 9개" ) submit_btn = gr.Button( "분석 시작", variant="primary", elem_classes="submit-btn" ) with gr.Column(scale=1): # 출력 섹션 gr.Markdown("### 결과") result_text = gr.Textbox( label="판정 결과", lines=2, interactive=False ) video_output = gr.Video( label="결과 비디오", ) prob_graph = gr.Plot( label="낙상 확률 그래프", ) # 예제 비디오 gr.Markdown("### 예제 비디오") example_dir = Path(__file__).parent / "examples" examples = [] if example_dir.exists(): for ext in ["*.mp4", "*.avi", "*.mov"]: examples.extend([str(p) for p in example_dir.glob(ext)]) if examples: gr.Examples( examples=[[ex, 0.7, "all"] for ex in examples[:3]], inputs=[video_input, fall_threshold, viz_keypoints], outputs=[video_output, prob_graph, result_text], fn=process_video, cache_examples=False, ) # 이벤트 연결 submit_btn.click( fn=process_video, inputs=[video_input, fall_threshold, viz_keypoints], outputs=[video_output, prob_graph, result_text], ) # 푸터 gr.Markdown( """ --- **References:** - [YOLOv11](https://github.com/ultralytics/ultralytics) - Pose Estimation - [ST-GCN](https://arxiv.org/abs/1801.07455) - Spatial Temporal Graph Convolutional Networks - AI Hub Fall Detection Dataset """ ) return demo # ----------------------------------------------------------------------------- # 메인 실행 # ----------------------------------------------------------------------------- if __name__ == "__main__": demo = create_demo() demo.queue(max_size=10).launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True, theme=custom_theme, css=css, )