YoungjaeDev Claude commited on
Commit
5b423bf
·
1 Parent(s): 0ea4706

feat(viz): 시각화 간소화 및 깜빡임 방지 구현

Browse files

변경 사항:
- visualization.py: visualize_fall_simple() 함수 추가
- Pose skeleton + FALL DETECTED 텍스트만 표시
- FPS, Latency, 정보 패널, 빨간 플래시 오버레이 제거

- app.py: 깜빡임 방지 로직 구현
- FALL_DISPLAY_DURATION = 2.0초 (첫 낙상 후 2초간 텍스트 유지)
- _visualize_single_frame() 워커 함수 간소화
- visualize_clip_parallel()에 first_fall_frame 파라미터 추가

- pose_estimator.py: extract_batch()가 numpy 배열 직접 입력 지원

- stgcn_classifier.py: predict_batch()가 fall_probs 별도 반환
- 100% 확률 버그 수정 (예측 클래스 확률 -> Fall 클래스 확률)

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +381 -228
app.py CHANGED
@@ -1,25 +1,28 @@
1
  #!/usr/bin/env python3
2
  """
3
- Fall Detection Gradio App
4
 
5
  YOLOv11-Pose + ST-GCN 2-stage 파이프라인을 사용한 낙상 감지 데모입니다.
6
- HF Spaces Zero GPU 환경에서 실행됩니다.
7
 
8
- 사용법 (로컬):
9
- python demo_gradio/app.py
 
 
 
10
 
11
- 사용법 (HF Spaces):
12
- 자동으로 app.py가 실행됩니다.
13
 
14
  작성자: Fall Detection Pipeline Team
15
- 작성일: 2025-11-26
16
  """
17
 
18
  import os
19
  import subprocess
20
  import sys
21
  import tempfile
22
- import time
23
  from pathlib import Path
24
  from typing import Iterable, Optional, Tuple
25
 
@@ -33,7 +36,6 @@ from gradio.themes.utils import colors, fonts, sizes
33
  from huggingface_hub import hf_hub_download
34
 
35
  # 프로젝트 루트를 Python path에 추가
36
- # pipeline/demo_gradio/app.py -> pipeline -> project_root
37
  PROJECT_ROOT = Path(__file__).parent.parent.parent
38
  sys.path.insert(0, str(PROJECT_ROOT))
39
 
@@ -150,15 +152,7 @@ HF_MODEL_REPO = "YoungjaeDev/fall-detection-models"
150
 
151
 
152
  def download_models() -> tuple[str, str]:
153
- """
154
- HuggingFace Hub에서 모델 다운로드 (캐시됨)
155
-
156
- Returns:
157
- tuple: (pose_model_path, stgcn_checkpoint_path)
158
-
159
- Raises:
160
- RuntimeError: 모델 다운로드 또는 검증 실패 시
161
- """
162
  # 로컬 경로 우선 확인 (개발 환경)
163
  local_pose = Path("yolo11m-pose.pt")
164
  local_stgcn = Path("runs/stgcn_binary_exp2_fixed_graph/best_acc.pth")
@@ -166,116 +160,309 @@ def download_models() -> tuple[str, str]:
166
  if local_pose.exists() and local_stgcn.exists():
167
  return str(local_pose), str(local_stgcn)
168
 
169
- # HuggingFace Hub에서 다운로드 (Private repo는 HF_TOKEN 환경변수 필요)
170
  token = os.environ.get("HF_TOKEN")
171
-
172
- # Private 저장소 접근을 위한 토큰 확인
173
  if token is None:
174
  raise RuntimeError(
175
  "HF_TOKEN 환경변수가 설정되지 않았습니다. "
176
- "Private 모델 저장소 접근을 위해 HF_TOKEN이 필요합니다. "
177
- "HF Spaces의 경우 Settings > Secrets에서 설정하세요."
178
  )
179
 
180
  try:
181
  pose_model_path = hf_hub_download(
182
- repo_id=HF_MODEL_REPO,
183
- filename="yolo11m-pose.pt",
184
- token=token
185
  )
186
-
187
  stgcn_checkpoint = hf_hub_download(
188
- repo_id=HF_MODEL_REPO,
189
- filename="best_acc.pth",
190
- token=token
191
  )
192
  except Exception as e:
193
- raise RuntimeError(
194
- f"모델 다운로드 실패: {e}\n"
195
- f"저장소: {HF_MODEL_REPO}\n"
196
- f"HF_TOKEN이 올바르게 설정되었는지 확인하세요."
197
- ) from e
198
-
199
- # 다운로드된 파일 검증
200
- pose_path = Path(pose_model_path)
201
- stgcn_path = Path(stgcn_checkpoint)
202
-
203
- if not pose_path.exists():
204
- raise RuntimeError(f"Pose 모델 파일이 존재하지 않습니다: {pose_model_path}")
205
- if not stgcn_path.exists():
206
- raise RuntimeError(f"ST-GCN 체크포인트 파일이 존재하지 않습니다: {stgcn_checkpoint}")
207
-
208
- # 파일 크기 검증 (너무 작으면 손상된 파일일 가능성)
209
- pose_size = pose_path.stat().st_size
210
- stgcn_size = stgcn_path.stat().st_size
211
- if pose_size < 1_000_000: # 1MB 미만
212
- raise RuntimeError(f"Pose 모델 파일이 너무 작습니다: {pose_size} bytes")
213
- if stgcn_size < 1_000_000: # 1MB 미만
214
- raise RuntimeError(f"ST-GCN 체크포인트 파일이 너무 작습니다: {stgcn_size} bytes")
215
 
216
  return pose_model_path, stgcn_checkpoint
217
 
218
 
219
  # -----------------------------------------------------------------------------
220
- # 파이프라인 초기화 (지연 로딩)
221
  # -----------------------------------------------------------------------------
222
- _pipeline = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
 
225
- def get_pipeline():
226
- """파이프라인 싱글톤 반환 (지연 로딩)"""
227
- global _pipeline
228
- if _pipeline is None:
229
- from pipeline.core.pipeline import FallDetectionPipeline
 
230
 
231
- # 모델 다운로드 (캐시됨)
232
- pose_model_path, stgcn_checkpoint = download_models()
 
 
 
 
 
233
 
234
- _pipeline = FallDetectionPipeline(
235
- pose_model_path=pose_model_path,
236
- stgcn_checkpoint=stgcn_checkpoint,
237
- window_size=60,
238
- conf_threshold=0.5,
239
- fall_threshold=0.85, # 가이드라인 권장: 0.8-0.9 (false positive <5%)
240
- temporal_window=5,
241
- stgcn_stride=5,
242
- alert_duration=150,
243
- post_fall_frames=15, # 2.5초 @ 30fps with stride=5 (가이드라인: 2-3초)
244
- device=str(device),
245
- debug=False,
246
- headless=False,
247
- viz_keypoints="all",
248
- viz_scale=1.0,
249
- viz_optimized=True
250
- )
251
- return _pipeline
252
 
253
 
254
  # -----------------------------------------------------------------------------
255
- # 확률 그래프 생성
256
  # -----------------------------------------------------------------------------
257
- def create_probability_graph(
258
- frame_indices: list,
259
- probabilities: list,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  fall_threshold: float = 0.7
261
- ) -> go.Figure:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
  """
263
- 낙상 확률 그래프 생성
264
 
265
  Args:
266
- frame_indices: 프레임 인덱스 리스트
267
- probabilities: 낙상 확률 리스트 (0.0-1.0)
 
 
 
 
 
 
268
  fall_threshold: 낙상 판정 임계값
 
 
 
269
 
270
  Returns:
271
- Plotly Figure 객체
272
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  fig = go.Figure()
274
 
275
  # 확률 라인
276
  fig.add_trace(go.Scatter(
277
- x=frame_indices,
278
- y=probabilities,
279
  mode='lines',
280
  name='Fall Probability',
281
  line=dict(color='#4682B4', width=2),
@@ -295,9 +482,9 @@ def create_probability_graph(
295
  # 레이아웃
296
  fig.update_layout(
297
  title="Fall Detection Probability Over Time",
298
- xaxis_title="Frame",
299
  yaxis_title="Probability",
300
- yaxis=dict(range=[0, 1]),
301
  template="plotly_white",
302
  height=300,
303
  margin=dict(l=50, r=50, t=50, b=50),
@@ -315,7 +502,7 @@ def create_probability_graph(
315
 
316
 
317
  # -----------------------------------------------------------------------------
318
- # 스마트 클립 추출 설정 (Issue #82)
319
  # -----------------------------------------------------------------------------
320
  CLIP_PRE_FALL_SECONDS = 1.0 # 낙상 전 1초
321
  CLIP_POST_FALL_SECONDS = 2.0 # 낙상 후 2초
@@ -332,11 +519,13 @@ def process_video(
332
  progress: gr.Progress = gr.Progress()
333
  ) -> Tuple[Optional[str], Optional[go.Figure], str]:
334
  """
335
- 비디오 처리 및 낙상 감지 (스마트 클립 추출)
336
 
337
- Issue #82: 낙상 감지 구간만 클립으로 추출하여 인코딩 시간 대폭 감소
338
- - 낙상 감지 시: 낙상 전 1 + 낙상 2초 구간만 추출
339
- - 비낙상 시: 낙상 미감지 메시지 반환
 
 
340
 
341
  Args:
342
  video_path: 입력 비디오 경로
@@ -345,7 +534,7 @@ def process_video(
345
  progress: Gradio 진행률 표시
346
 
347
  Returns:
348
- output_video_path: 결과 클립 경로 (낙상 감지 시) 또는 None (비낙상)
349
  probability_graph: 확률 그래프
350
  result_text: 최종 판정 텍스트
351
  """
@@ -353,170 +542,134 @@ def process_video(
353
  return None, None, "비디오를 업로드해주세요."
354
 
355
  try:
356
- # 파이프라인 로드
357
- progress(0.1, desc="모델 로딩 중...")
358
- pipeline = get_pipeline()
359
- pipeline.fall_threshold = fall_threshold
360
- pipeline.stgcn_classifier.fall_threshold = fall_threshold
361
- pipeline.viz_keypoints = viz_keypoints
362
- pipeline.reset()
363
-
364
- # 비디오 열기
365
- progress(0.2, desc="비디오 열기...")
366
- cap = cv2.VideoCapture(video_path)
367
- if not cap.isOpened():
368
- return None, None, "비디오를 열 수 없습니다."
369
-
370
- # 비디오 정보
371
- fps = cap.get(cv2.CAP_PROP_FPS)
372
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
373
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
374
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
375
-
376
- # 비디오 길이 검증 (120s GPU 타임아웃 대비)
377
- if fps > 0:
378
- video_duration = total_frames / fps
379
- # 처리 시간 추정: 대략 실시간의 1.5배 + 인코딩 10초
380
- estimated_time = video_duration * 1.5 + 10
381
- if estimated_time > 110: # 120s 타임아웃에 여유 두기
382
- cap.release()
383
- return None, None, (
384
- f"비디오가 너무 깁니다. "
385
- f"비디오 길이: {video_duration:.1f}초, "
386
- f"예상 처리 시간: {estimated_time:.1f}초 (제한: 110초). "
387
- f"60초 이내의 비디오를 업로드하세요."
388
- )
389
-
390
- # 클립 추출을 위한 프레임 수 계산
391
- pre_fall_frames = int(fps * CLIP_PRE_FALL_SECONDS)
392
- post_fall_frames = int(fps * CLIP_POST_FALL_SECONDS)
393
-
394
- # 처리 루프 - 프레임 버퍼링 + 낙상 감지
395
- frame_idx = 0
396
- frame_indices = []
397
- probabilities = []
398
- max_confidence = 0.0
399
-
400
- # 낙상 감지 추적
401
- first_fall_frame = None # 첫 낙상 감지 프레임
402
- fall_detected = False
403
 
404
- # 시각화 프레임 버퍼 (클립 추출용)
405
- vis_frame_buffer = []
406
- raw_frame_buffer = [] # 원본 프레임 버퍼 (재처리용)
 
407
 
408
- while True:
409
- # 프레임 읽기
410
- with pipeline.profiler.profile('video_read'):
411
- ret, frame = cap.read()
412
- if not ret:
413
- break
414
 
415
- # 원본 프레임 버퍼에 저장 (클립 추출에 필요)
416
- raw_frame_buffer.append(frame.copy())
417
-
418
- # 프레임 처리
419
- vis_frame, info = pipeline.process_frame(frame, frame_idx)
420
-
421
- # 시각화 프레임 버퍼에 저장
422
- vis_frame_buffer.append(vis_frame)
423
-
424
- # 확률 기록
425
- if info['confidence'] is not None:
426
- frame_indices.append(frame_idx)
427
- probabilities.append(info['confidence'])
428
- max_confidence = max(max_confidence, info['confidence'])
429
-
430
- # 첫 낙상 감지 시점 기록
431
- if info['alert'] and first_fall_frame is None:
432
- first_fall_frame = frame_idx
433
- fall_detected = True
434
 
435
- frame_idx += 1
 
 
 
 
436
 
437
- # 진행률 업데이트
438
- if frame_idx % 10 == 0:
439
- progress_val = 0.2 + 0.6 * (frame_idx / total_frames)
440
- progress(progress_val, desc=f"분석 중... ({frame_idx}/{total_frames})")
 
441
 
442
- cap.release()
 
 
 
 
 
 
 
 
443
 
444
- # 확률 그래프 생성 (항상 생성)
445
- progress(0.85, desc="그래프 생성 중...")
446
- if frame_indices and probabilities:
447
- fig = create_probability_graph(frame_indices, probabilities, fall_threshold)
448
- else:
449
- fig = None
450
 
451
- # 낙상 미감지 시 클립 없이 반환
452
- if not fall_detected or first_fall_frame is None:
453
  progress(1.0, desc="완료!")
454
  result_text = (
455
  f"[Non-Fall] 낙상이 감지되지 않았습니다.\n"
456
- f"최대 확률: {max_confidence:.1%}\n"
457
- f"분석 프레임: {total_frames}개"
458
  )
459
  return None, fig, result_text
460
 
461
- # 클립 구간 계산
 
 
 
 
462
  clip_start = max(0, first_fall_frame - pre_fall_frames)
463
- clip_end = min(len(vis_frame_buffer), first_fall_frame + post_fall_frames)
464
- clip_frames = vis_frame_buffer[clip_start:clip_end]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
- if not clip_frames:
467
  progress(1.0, desc="완료!")
468
  return None, fig, "클립 추출에 실패했습니다."
469
 
470
- # 클립 비디오 생성 (프레임 수 감소로 인코딩 시간 대폭 감소)
471
  progress(0.9, desc="클립 인코딩 중...")
472
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
473
  output_path = tmp.name
474
 
475
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
476
- # Info panel 추가로 높이 80px 증가
477
- clip_height, clip_width = clip_frames[0].shape[:2]
478
  out = cv2.VideoWriter(output_path, fourcc, fps, (clip_width, clip_height))
479
 
480
- for vis_frame in clip_frames:
481
  out.write(vis_frame)
482
  out.release()
483
 
484
- # H.264 코덱으로 재인코딩 (브라우저 호환)
485
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
486
  output_h264 = tmp.name
487
 
488
- with pipeline.profiler.profile('ffmpeg_encode'):
489
- subprocess.run(
490
- [
491
- 'ffmpeg', '-y', '-i', output_path,
492
- '-c:v', 'libx264', '-preset', 'fast', '-crf', '23',
493
- output_h264, '-loglevel', 'quiet'
494
- ],
495
- check=False,
496
- capture_output=True
497
- )
498
 
499
- # mp4v 임시 파일 삭제
500
  if os.path.exists(output_path):
501
  os.remove(output_path)
502
 
503
- # H.264 변환 성공 여부 확인
504
- if os.path.exists(output_h264):
505
- final_output = output_h264
506
- else:
507
- final_output = output_path # 폴백
508
 
509
  # 최종 판정
510
  progress(1.0, desc="완료!")
511
- fall_time = first_fall_frame / fps if fps > 0 else 0
512
- clip_duration = len(clip_frames) / fps if fps > 0 else 0
513
  result_text = (
514
  f"[FALL DETECTED] 낙상이 감지되었습니다!\n"
515
  f"낙상 시점: {fall_time:.2f}초 (프레임 #{first_fall_frame})\n"
516
- f"최대 확률: {max_confidence:.1%}\n"
517
- f"클립 길이: {clip_duration:.1f}초 ({len(clip_frames)}프레임)\n"
518
- f"원본 대비: {len(clip_frames)}/{total_frames}프레임 "
519
- f"({len(clip_frames)/total_frames*100:.1f}% 인코딩)"
520
  )
521
 
522
  return final_output, fig, result_text
@@ -542,9 +695,9 @@ def create_demo() -> gr.Blocks:
542
  비디오를 업로드하면 낙상 여부를 분석하고, 결과 비디오와 확률 그래프를 제공합니다.
543
 
544
  **파이프라인 구성:**
545
- - Stage 1: YOLOv11m-pose (Pose Estimation)
546
- - Stage 2: ST-GCN (Temporal Classification)
547
- - Window Size: 60 frames (2초 @ 30fps)
548
  """,
549
  elem_id="main-title"
550
  )
@@ -560,12 +713,12 @@ def create_demo() -> gr.Blocks:
560
 
561
  with gr.Accordion("고급 설정", open=False):
562
  fall_threshold = gr.Slider(
563
- minimum=0.7,
564
  maximum=0.95,
565
- value=0.85,
566
  step=0.05,
567
  label="낙상 판정 임계값",
568
- info="권장: 0.8-0.9 (false positive <5% 목표)"
569
  )
570
  viz_keypoints = gr.Radio(
571
  choices=["all", "major"],
@@ -585,7 +738,7 @@ def create_demo() -> gr.Blocks:
585
  gr.Markdown("### 결과")
586
  result_text = gr.Textbox(
587
  label="판정 결과",
588
- lines=2,
589
  interactive=False
590
  )
591
  video_output = gr.Video(
@@ -605,7 +758,7 @@ def create_demo() -> gr.Blocks:
605
 
606
  if examples:
607
  gr.Examples(
608
- examples=[[ex, 0.85, "all"] for ex in examples[:3]],
609
  inputs=[video_input, fall_threshold, viz_keypoints],
610
  outputs=[video_output, prob_graph, result_text],
611
  fn=process_video,
 
1
  #!/usr/bin/env python3
2
  """
3
+ Fall Detection Gradio App (Batch Processing Pipeline)
4
 
5
  YOLOv11-Pose + ST-GCN 2-stage 파이프라인을 사용한 낙상 감지 데모입니다.
6
+ 배치 처리로 최적화되어 빠른 추론 속도를 제공합니다.
7
 
8
+ Pipeline:
9
+ 1. decord로 전체 프레임 배치 로드
10
+ 2. YOLO Pose 배치 추론 → keypoints 누적
11
+ 3. 윈도우 단위 ST-GCN 배치 추론
12
+ 4. 낙상 시점 -1s ~ +2s 구간만 시각화
13
 
14
+ 사용법 (로컬):
15
+ python pipeline/demo_gradio/app.py
16
 
17
  작성자: Fall Detection Pipeline Team
18
+ 작성일: 2025-11-27
19
  """
20
 
21
  import os
22
  import subprocess
23
  import sys
24
  import tempfile
25
+ from concurrent.futures import ProcessPoolExecutor
26
  from pathlib import Path
27
  from typing import Iterable, Optional, Tuple
28
 
 
36
  from huggingface_hub import hf_hub_download
37
 
38
  # 프로젝트 루트를 Python path에 추가
 
39
  PROJECT_ROOT = Path(__file__).parent.parent.parent
40
  sys.path.insert(0, str(PROJECT_ROOT))
41
 
 
152
 
153
 
154
  def download_models() -> tuple[str, str]:
155
+ """HuggingFace Hub에서 모델 다운로드 (캐시됨)"""
 
 
 
 
 
 
 
 
156
  # 로컬 경로 우선 확인 (개발 환경)
157
  local_pose = Path("yolo11m-pose.pt")
158
  local_stgcn = Path("runs/stgcn_binary_exp2_fixed_graph/best_acc.pth")
 
160
  if local_pose.exists() and local_stgcn.exists():
161
  return str(local_pose), str(local_stgcn)
162
 
163
+ # HuggingFace Hub에서 다운로드
164
  token = os.environ.get("HF_TOKEN")
 
 
165
  if token is None:
166
  raise RuntimeError(
167
  "HF_TOKEN 환경변수가 설정되지 않았습니다. "
168
+ "Private 모델 저장소 접근을 위해 HF_TOKEN이 필요합니다."
 
169
  )
170
 
171
  try:
172
  pose_model_path = hf_hub_download(
173
+ repo_id=HF_MODEL_REPO, filename="yolo11m-pose.pt", token=token
 
 
174
  )
 
175
  stgcn_checkpoint = hf_hub_download(
176
+ repo_id=HF_MODEL_REPO, filename="best_acc.pth", token=token
 
 
177
  )
178
  except Exception as e:
179
+ raise RuntimeError(f"모델 다운로드 실패: {e}") from e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  return pose_model_path, stgcn_checkpoint
182
 
183
 
184
  # -----------------------------------------------------------------------------
185
+ # 모델 싱글톤 (지연 로딩)
186
  # -----------------------------------------------------------------------------
187
+ _pose_estimator = None
188
+ _stgcn_classifier = None
189
+
190
+
191
+ def get_pose_estimator():
192
+ """PoseEstimator 싱글톤 반환"""
193
+ global _pose_estimator
194
+ if _pose_estimator is None:
195
+ from pipeline.models.pose_estimator import PoseEstimator
196
+ pose_model_path, _ = download_models()
197
+ _pose_estimator = PoseEstimator(
198
+ model_path=pose_model_path,
199
+ conf_threshold=0.5,
200
+ device=str(device)
201
+ )
202
+ return _pose_estimator
203
+
204
+
205
+ def get_stgcn_classifier():
206
+ """STGCNClassifier 싱글톤 반환"""
207
+ global _stgcn_classifier
208
+ if _stgcn_classifier is None:
209
+ from pipeline.models.stgcn_classifier import STGCNClassifier
210
+ _, stgcn_checkpoint = download_models()
211
+ _stgcn_classifier = STGCNClassifier(
212
+ checkpoint_path=stgcn_checkpoint,
213
+ fall_threshold=0.7,
214
+ device=str(device)
215
+ )
216
+ return _stgcn_classifier
217
 
218
 
219
+ # -----------------------------------------------------------------------------
220
+ # 프레임 로드 (cv2 사용 - 대부분의 비디오에서 더 빠름)
221
+ # -----------------------------------------------------------------------------
222
+ def load_video_frames(video_path: str) -> Tuple[np.ndarray, float]:
223
+ """
224
+ 비디오에서 전체 프레임 로드 (cv2 사용)
225
 
226
+ Returns:
227
+ frames: (N, H, W, C) numpy array (BGR)
228
+ fps: 프레임 레이트
229
+ """
230
+ cap = cv2.VideoCapture(video_path)
231
+ fps = cap.get(cv2.CAP_PROP_FPS)
232
+ frames = []
233
 
234
+ while True:
235
+ ret, frame = cap.read()
236
+ if not ret:
237
+ break
238
+ frames.append(frame)
239
+
240
+ cap.release()
241
+ return np.array(frames), fps
 
 
 
 
 
 
 
 
 
 
242
 
243
 
244
  # -----------------------------------------------------------------------------
245
+ # 배치 Pose 추론
246
  # -----------------------------------------------------------------------------
247
+ def extract_all_keypoints(
248
+ frames: np.ndarray,
249
+ pose_estimator,
250
+ batch_size: int = 8,
251
+ progress_callback=None
252
+ ) -> list[Optional[np.ndarray]]:
253
+ """
254
+ 전체 프레임에 대해 배치 Pose 추론
255
+
256
+ Args:
257
+ frames: (N, H, W, C) 전체 비디오 프레임
258
+ pose_estimator: PoseEstimator 인스턴스
259
+ batch_size: 배치 크기
260
+ progress_callback: 진행률 콜백 함수
261
+
262
+ Returns:
263
+ keypoints_list: [(17, 3) or None, ...] N개의 keypoints
264
+ """
265
+ n_frames = len(frames)
266
+ all_keypoints = []
267
+
268
+ for i in range(0, n_frames, batch_size):
269
+ batch = list(frames[i:i+batch_size])
270
+ batch_keypoints = pose_estimator.extract_batch(batch)
271
+ all_keypoints.extend(batch_keypoints)
272
+
273
+ if progress_callback:
274
+ progress_callback(min(i + batch_size, n_frames), n_frames)
275
+
276
+ return all_keypoints
277
+
278
+
279
+ # -----------------------------------------------------------------------------
280
+ # 윈도우 생성 및 ST-GCN 배치 추론
281
+ # -----------------------------------------------------------------------------
282
+ def create_windows_and_predict(
283
+ keypoints_list: list[Optional[np.ndarray]],
284
+ stgcn_classifier,
285
+ window_size: int = 60,
286
+ stride: int = 5,
287
  fall_threshold: float = 0.7
288
+ ) -> Tuple[list[int], list[float], Optional[int]]:
289
+ """
290
+ keypoints에서 윈도우 생성 후 ST-GCN 배치 추론
291
+
292
+ Args:
293
+ keypoints_list: 프레임별 keypoints 리스트
294
+ stgcn_classifier: STGCNClassifier 인스턴스
295
+ window_size: 윈도우 크기 (프레임 수)
296
+ stride: 추론 간격 (N 프레임마다 1번)
297
+ fall_threshold: 낙상 판정 임계값
298
+
299
+ Returns:
300
+ frame_indices: ST-GCN 예측이 있는 프레임 인덱스
301
+ fall_probs: 각 프레임의 낙상 확률 (class 1 확률)
302
+ first_fall_frame: 첫 낙상 감지 프레임 인덱스 (없으면 None)
303
+ """
304
+ n_frames = len(keypoints_list)
305
+
306
+ # None을 빈 keypoints로 대체
307
+ processed_keypoints = []
308
+ for kpts in keypoints_list:
309
+ if kpts is None:
310
+ processed_keypoints.append(np.zeros((17, 3), dtype=np.float32))
311
+ else:
312
+ processed_keypoints.append(kpts)
313
+
314
+ # 윈도우 생성 (stride 간격으로)
315
+ frame_indices = []
316
+ windows = []
317
+
318
+ for frame_idx in range(window_size - 1, n_frames, stride):
319
+ # 이전 window_size 프레임으로 윈도우 구성
320
+ window_keypoints = processed_keypoints[frame_idx - window_size + 1:frame_idx + 1]
321
+
322
+ # (T, V, C) -> (C, T, V, M) 변환
323
+ window = np.array(window_keypoints) # (T=60, V=17, C=3)
324
+ window = window.transpose(2, 0, 1) # (C=3, T=60, V=17)
325
+ window = np.expand_dims(window, -1) # (C=3, T=60, V=17, M=1)
326
+
327
+ frame_indices.append(frame_idx)
328
+ windows.append(window.astype(np.float32))
329
+
330
+ if not windows:
331
+ return [], [], None
332
+
333
+ # ST-GCN 배치 추론
334
+ predictions, confidences, fall_probs = stgcn_classifier.predict_batch(windows)
335
+
336
+ # 첫 낙상 감지 프레임 찾기
337
+ first_fall_frame = None
338
+ for i, (pred, fall_prob) in enumerate(zip(predictions, fall_probs)):
339
+ if pred == 1 and fall_prob >= fall_threshold:
340
+ first_fall_frame = frame_indices[i]
341
+ break
342
+
343
+ return frame_indices, fall_probs.tolist(), first_fall_frame
344
+
345
+
346
+ # -----------------------------------------------------------------------------
347
+ # 시각화 워커 함수 (ProcessPoolExecutor용)
348
+ # -----------------------------------------------------------------------------
349
+ # FALL DETECTED 텍스트 표시 지속 시간 (초)
350
+ FALL_DISPLAY_DURATION = 2.0
351
+
352
+
353
+ def _visualize_single_frame(args: tuple) -> Tuple[int, np.ndarray]:
354
+ """단일 프레임 시각화 워커 (간소화된 버전)"""
355
+ (frame_idx, frame, keypoints, show_fall_text,
356
+ viz_keypoints, viz_scale) = args
357
+
358
+ # 프로젝트 import (워커 프로세스에서)
359
+ import sys
360
+ from pathlib import Path
361
+ sys.path.insert(0, str(Path(__file__).parent.parent.parent))
362
+ from pipeline.visualization import visualize_fall_simple
363
+
364
+ vis_frame = visualize_fall_simple(
365
+ frame=frame,
366
+ keypoints=keypoints if keypoints is not None and keypoints.sum() > 0 else None,
367
+ show_fall_text=show_fall_text,
368
+ keypoint_mode=viz_keypoints,
369
+ output_scale=viz_scale
370
+ )
371
+
372
+ return frame_idx, vis_frame
373
+
374
+
375
+ def visualize_clip_parallel(
376
+ frames: np.ndarray,
377
+ keypoints_list: list[Optional[np.ndarray]],
378
+ frame_indices: list[int],
379
+ fall_probs: list[float],
380
+ clip_start: int,
381
+ clip_end: int,
382
+ fps: float,
383
+ first_fall_frame: Optional[int] = None,
384
+ fall_threshold: float = 0.7,
385
+ viz_keypoints: str = "all",
386
+ viz_scale: float = 1.0,
387
+ num_workers: int = 4
388
+ ) -> list[np.ndarray]:
389
  """
390
+ 클립 구간 병렬 시각화 (간소화된 버전)
391
 
392
  Args:
393
+ frames: 전체 프레임
394
+ keypoints_list: 전체 keypoints
395
+ frame_indices: ST-GCN 예측 프레임 인덱스
396
+ fall_probs: 프레임별 낙상 확률
397
+ clip_start: 클립 시작 인덱스
398
+ clip_end: 클립 종료 인덱스
399
+ fps: 프레임 레이트
400
+ first_fall_frame: 첫 낙상 감지 프레임 (깜빡임 방지용)
401
  fall_threshold: 낙상 판정 임계값
402
+ viz_keypoints: 키포인트 표시 모드
403
+ viz_scale: 출력 스케일
404
+ num_workers: 병렬 워커 수
405
 
406
  Returns:
407
+ vis_frames: 시각화된 프레임 리스트
408
  """
409
+ # 깜빡임 방지: 첫 낙상 후 N초간 FALL DETECTED 표시
410
+ fall_display_end_frame = None
411
+ if first_fall_frame is not None:
412
+ fall_display_end_frame = first_fall_frame + int(fps * FALL_DISPLAY_DURATION)
413
+
414
+ # 시각화 인자 준비
415
+ viz_args = []
416
+ for i in range(clip_start, clip_end):
417
+ frame = frames[i]
418
+ keypoints = keypoints_list[i]
419
+
420
+ # FALL DETECTED 텍스트 표시 여부 결정 (깜빡임 방지)
421
+ show_fall_text = False
422
+ if first_fall_frame is not None and fall_display_end_frame is not None:
423
+ if first_fall_frame <= i <= fall_display_end_frame:
424
+ show_fall_text = True
425
+
426
+ args = (
427
+ i, # frame_idx
428
+ frame, # frame
429
+ keypoints, # keypoints
430
+ show_fall_text, # show_fall_text (깜빡임 방지 적용)
431
+ viz_keypoints, # viz_keypoints
432
+ viz_scale # viz_scale
433
+ )
434
+ viz_args.append(args)
435
+
436
+ # 병렬 시각화
437
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
438
+ results = list(executor.map(_visualize_single_frame, viz_args))
439
+
440
+ # 순서대로 정렬
441
+ results.sort(key=lambda x: x[0])
442
+ vis_frames = [frame for _, frame in results]
443
+
444
+ return vis_frames
445
+
446
+
447
+ # -----------------------------------------------------------------------------
448
+ # 확률 그래프 생성
449
+ # -----------------------------------------------------------------------------
450
+ def create_probability_graph(
451
+ frame_indices: list[int],
452
+ fall_probs: list[float],
453
+ fall_threshold: float = 0.7,
454
+ fps: float = 30.0
455
+ ) -> go.Figure:
456
+ """낙상 확률 그래프 생성 (X축: 시간)"""
457
+ # 프레임 인덱스 -> 시간(초) 변환
458
+ time_seconds = [idx / fps for idx in frame_indices]
459
+
460
  fig = go.Figure()
461
 
462
  # 확률 라인
463
  fig.add_trace(go.Scatter(
464
+ x=time_seconds,
465
+ y=fall_probs,
466
  mode='lines',
467
  name='Fall Probability',
468
  line=dict(color='#4682B4', width=2),
 
482
  # 레이아웃
483
  fig.update_layout(
484
  title="Fall Detection Probability Over Time",
485
+ xaxis_title="Time (seconds)",
486
  yaxis_title="Probability",
487
+ yaxis=dict(range=[0, 1.05]),
488
  template="plotly_white",
489
  height=300,
490
  margin=dict(l=50, r=50, t=50, b=50),
 
502
 
503
 
504
  # -----------------------------------------------------------------------------
505
+ # 스마트 클립 추출 설정
506
  # -----------------------------------------------------------------------------
507
  CLIP_PRE_FALL_SECONDS = 1.0 # 낙상 전 1초
508
  CLIP_POST_FALL_SECONDS = 2.0 # 낙상 후 2초
 
519
  progress: gr.Progress = gr.Progress()
520
  ) -> Tuple[Optional[str], Optional[go.Figure], str]:
521
  """
522
+ 비디오 처리 및 낙상 감지 (배치 처리 파이프라인)
523
 
524
+ Pipeline:
525
+ 1. decord로 전체 프레임 배치 로드
526
+ 2. YOLO Pose 배치 추론 keypoints 누적
527
+ 3. 윈도우 단위 ST-GCN 배치 추론
528
+ 4. 낙상 시점 -1s ~ +2s 구간만 시각화
529
 
530
  Args:
531
  video_path: 입력 비디오 경로
 
534
  progress: Gradio 진행률 표시
535
 
536
  Returns:
537
+ output_video_path: 결과 클립 경로 (낙상 감지 시) 또는 None
538
  probability_graph: 확률 그래프
539
  result_text: 최종 판정 텍스트
540
  """
 
542
  return None, None, "비디오를 업로드해주세요."
543
 
544
  try:
545
+ # Stage 0: 모델 로드
546
+ progress(0.05, desc="모델 로딩 중...")
547
+ pose_estimator = get_pose_estimator()
548
+ stgcn_classifier = get_stgcn_classifier()
549
+ stgcn_classifier.fall_threshold = fall_threshold
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
+ # Stage 1: 프레임 로드 (decord)
552
+ progress(0.1, desc="비디오 로딩 중...")
553
+ frames, fps = load_video_frames(video_path)
554
+ n_frames = len(frames)
555
 
556
+ if n_frames == 0:
557
+ return None, None, "비디오를 읽을 수 없습니다."
 
 
 
 
558
 
559
+ # 비디오 길이 검증 (120s GPU 타임아웃 대비)
560
+ video_duration = n_frames / fps
561
+ if video_duration > 60:
562
+ return None, None, (
563
+ f"비디오가 너무 깁니다. "
564
+ f"비디오 길이: {video_duration:.1f}초 (제한: 60초). "
565
+ f"60초 이내의 비디오를 업로드하세요."
566
+ )
 
 
 
 
 
 
 
 
 
 
 
567
 
568
+ # Stage 2: 배치 Pose 추론
569
+ progress(0.15, desc="Pose 추출 중...")
570
+ def pose_progress(current, total):
571
+ pct = 0.15 + 0.35 * (current / total)
572
+ progress(pct, desc=f"Pose 추출 중... ({current}/{total})")
573
 
574
+ keypoints_list = extract_all_keypoints(
575
+ frames, pose_estimator,
576
+ batch_size=8,
577
+ progress_callback=pose_progress
578
+ )
579
 
580
+ # Stage 3: ST-GCN 배치 추론
581
+ progress(0.55, desc="낙상 분석 중...")
582
+ frame_indices, fall_probs, first_fall_frame = create_windows_and_predict(
583
+ keypoints_list,
584
+ stgcn_classifier,
585
+ window_size=60,
586
+ stride=5,
587
+ fall_threshold=fall_threshold
588
+ )
589
 
590
+ # 확률 그래프 생성
591
+ progress(0.7, desc="그래프 생성 중...")
592
+ fig = None
593
+ if frame_indices and fall_probs:
594
+ fig = create_probability_graph(frame_indices, fall_probs, fall_threshold, fps)
 
595
 
596
+ # 낙상 미감지 시
597
+ if first_fall_frame is None:
598
  progress(1.0, desc="완료!")
599
  result_text = (
600
  f"[Non-Fall] 낙상이 감지되지 않았습니다.\n"
601
+ f"분석 프레임: {n_frames}"
 
602
  )
603
  return None, fig, result_text
604
 
605
+ # Stage 4: 낙상 구간만 시각화
606
+ progress(0.75, desc="클립 시각화 중...")
607
+ pre_fall_frames = int(fps * CLIP_PRE_FALL_SECONDS)
608
+ post_fall_frames = int(fps * CLIP_POST_FALL_SECONDS)
609
+
610
  clip_start = max(0, first_fall_frame - pre_fall_frames)
611
+ clip_end = min(n_frames, first_fall_frame + post_fall_frames)
612
+
613
+ vis_frames = visualize_clip_parallel(
614
+ frames=frames,
615
+ keypoints_list=keypoints_list,
616
+ frame_indices=frame_indices,
617
+ fall_probs=fall_probs,
618
+ clip_start=clip_start,
619
+ clip_end=clip_end,
620
+ fps=fps,
621
+ first_fall_frame=first_fall_frame, # 깜빡임 방지용
622
+ fall_threshold=fall_threshold,
623
+ viz_keypoints=viz_keypoints,
624
+ viz_scale=1.0,
625
+ num_workers=4
626
+ )
627
 
628
+ if not vis_frames:
629
  progress(1.0, desc="완료!")
630
  return None, fig, "클립 추출에 실패했습니다."
631
 
632
+ # Stage 5: 비디오 인코딩
633
  progress(0.9, desc="클립 인코딩 중...")
634
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
635
  output_path = tmp.name
636
 
637
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
638
+ clip_height, clip_width = vis_frames[0].shape[:2]
 
639
  out = cv2.VideoWriter(output_path, fourcc, fps, (clip_width, clip_height))
640
 
641
+ for vis_frame in vis_frames:
642
  out.write(vis_frame)
643
  out.release()
644
 
645
+ # H.264 재인코딩 (브라우저 호환)
646
  with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
647
  output_h264 = tmp.name
648
 
649
+ subprocess.run(
650
+ [
651
+ 'ffmpeg', '-y', '-i', output_path,
652
+ '-c:v', 'libx264', '-preset', 'fast', '-crf', '23',
653
+ output_h264, '-loglevel', 'quiet'
654
+ ],
655
+ check=False,
656
+ capture_output=True
657
+ )
 
658
 
659
+ # 임시 파일 정리
660
  if os.path.exists(output_path):
661
  os.remove(output_path)
662
 
663
+ final_output = output_h264 if os.path.exists(output_h264) else None
 
 
 
 
664
 
665
  # 최종 판정
666
  progress(1.0, desc="완료!")
667
+ fall_time = first_fall_frame / fps
668
+ clip_duration = len(vis_frames) / fps
669
  result_text = (
670
  f"[FALL DETECTED] 낙상이 감지되었습니다!\n"
671
  f"낙상 시점: {fall_time:.2f}초 (프레임 #{first_fall_frame})\n"
672
+ f"클립 길이: {clip_duration:.1f}초 ({len(vis_frames)}프레임)"
 
 
 
673
  )
674
 
675
  return final_output, fig, result_text
 
695
  비디오를 업로드하면 낙상 여부를 분석하고, 결과 비디오와 확률 그래프를 제공합니다.
696
 
697
  **파이프라인 구성:**
698
+ - Stage 1: YOLOv11m-pose (Pose Estimation) - Batch Processing
699
+ - Stage 2: ST-GCN (Temporal Classification) - Batch Processing
700
+ - Window Size: 60 frames (2s @ 30fps)
701
  """,
702
  elem_id="main-title"
703
  )
 
713
 
714
  with gr.Accordion("고급 설정", open=False):
715
  fall_threshold = gr.Slider(
716
+ minimum=0.5,
717
  maximum=0.95,
718
+ value=0.7,
719
  step=0.05,
720
  label="낙상 판정 임계값",
721
+ info="권장: 0.7-0.85"
722
  )
723
  viz_keypoints = gr.Radio(
724
  choices=["all", "major"],
 
738
  gr.Markdown("### 결과")
739
  result_text = gr.Textbox(
740
  label="판정 결과",
741
+ lines=3,
742
  interactive=False
743
  )
744
  video_output = gr.Video(
 
758
 
759
  if examples:
760
  gr.Examples(
761
+ examples=[[ex, 0.7, "all"] for ex in examples[:3]],
762
  inputs=[video_input, fall_threshold, viz_keypoints],
763
  outputs=[video_output, prob_graph, result_text],
764
  fn=process_video,