Spaces:
Sleeping
Sleeping
YoungjaeDev
Claude
commited on
Commit
·
8133f1d
1
Parent(s):
f09549f
fix: HF Spaces import 에러 해결 - self-contained 구조로 변경
Browse files- demo_gradio/models/ 디렉토리에 pose_estimator.py, stgcn_classifier.py 복사
- demo_gradio/stgcn/ 디렉토리에 model.py, graph.py 복사
- augmentation.py, visualization.py 복사
- app.py 및 stgcn_classifier.py의 import 경로를 상대 import로 수정
- .gitignore에 pipeline/demo_gradio/models/ 예외 추가
이 변경으로 HF Space에서 pipeline 모듈 없이 독립 실행 가능
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude <noreply@anthropic.com>
- app.py +5 -10
- augmentation.py +725 -0
- models/__init__.py +1 -0
- models/pose_estimator.py +150 -0
- models/stgcn_classifier.py +183 -0
- stgcn/__init__.py +1 -0
- stgcn/graph.py +291 -0
- stgcn/model.py +391 -0
- visualization.py +973 -0
app.py
CHANGED
|
@@ -35,9 +35,7 @@ from gradio.themes import Soft
|
|
| 35 |
from gradio.themes.utils import colors, fonts, sizes
|
| 36 |
from huggingface_hub import hf_hub_download
|
| 37 |
|
| 38 |
-
# 프로젝트
|
| 39 |
-
PROJECT_ROOT = Path(__file__).parent.parent.parent
|
| 40 |
-
sys.path.insert(0, str(PROJECT_ROOT))
|
| 41 |
|
| 42 |
# Zero GPU 호환 설정
|
| 43 |
try:
|
|
@@ -192,7 +190,7 @@ def get_pose_estimator():
|
|
| 192 |
"""PoseEstimator 싱글톤 반환"""
|
| 193 |
global _pose_estimator
|
| 194 |
if _pose_estimator is None:
|
| 195 |
-
from
|
| 196 |
pose_model_path, _ = download_models()
|
| 197 |
_pose_estimator = PoseEstimator(
|
| 198 |
model_path=pose_model_path,
|
|
@@ -206,7 +204,7 @@ def get_stgcn_classifier():
|
|
| 206 |
"""STGCNClassifier 싱글톤 반환"""
|
| 207 |
global _stgcn_classifier
|
| 208 |
if _stgcn_classifier is None:
|
| 209 |
-
from
|
| 210 |
_, stgcn_checkpoint = download_models()
|
| 211 |
_stgcn_classifier = STGCNClassifier(
|
| 212 |
checkpoint_path=stgcn_checkpoint,
|
|
@@ -355,11 +353,8 @@ def _visualize_single_frame(args: tuple) -> Tuple[int, np.ndarray]:
|
|
| 355 |
(frame_idx, frame, keypoints, show_fall_text,
|
| 356 |
viz_keypoints, viz_scale) = args
|
| 357 |
|
| 358 |
-
#
|
| 359 |
-
import
|
| 360 |
-
from pathlib import Path
|
| 361 |
-
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
| 362 |
-
from pipeline.visualization import visualize_fall_simple
|
| 363 |
|
| 364 |
vis_frame = visualize_fall_simple(
|
| 365 |
frame=frame,
|
|
|
|
| 35 |
from gradio.themes.utils import colors, fonts, sizes
|
| 36 |
from huggingface_hub import hf_hub_download
|
| 37 |
|
| 38 |
+
# HF Spaces 배포용: 프로젝트 루트 설정 불필요 (self-contained)
|
|
|
|
|
|
|
| 39 |
|
| 40 |
# Zero GPU 호환 설정
|
| 41 |
try:
|
|
|
|
| 190 |
"""PoseEstimator 싱글톤 반환"""
|
| 191 |
global _pose_estimator
|
| 192 |
if _pose_estimator is None:
|
| 193 |
+
from models.pose_estimator import PoseEstimator
|
| 194 |
pose_model_path, _ = download_models()
|
| 195 |
_pose_estimator = PoseEstimator(
|
| 196 |
model_path=pose_model_path,
|
|
|
|
| 204 |
"""STGCNClassifier 싱글톤 반환"""
|
| 205 |
global _stgcn_classifier
|
| 206 |
if _stgcn_classifier is None:
|
| 207 |
+
from models.stgcn_classifier import STGCNClassifier
|
| 208 |
_, stgcn_checkpoint = download_models()
|
| 209 |
_stgcn_classifier = STGCNClassifier(
|
| 210 |
checkpoint_path=stgcn_checkpoint,
|
|
|
|
| 353 |
(frame_idx, frame, keypoints, show_fall_text,
|
| 354 |
viz_keypoints, viz_scale) = args
|
| 355 |
|
| 356 |
+
# HF Spaces 배포용 상대 import (워커 프로세스에서)
|
| 357 |
+
from visualization import visualize_fall_simple
|
|
|
|
|
|
|
|
|
|
| 358 |
|
| 359 |
vis_frame = visualize_fall_simple(
|
| 360 |
frame=frame,
|
augmentation.py
ADDED
|
@@ -0,0 +1,725 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Skeleton Data Augmentation for ST-GCN Fall Detection
|
| 4 |
+
|
| 5 |
+
This module provides augmentation strategies for skeleton sequence data to improve
|
| 6 |
+
model generalization and robustness. All augmentations preserve the spatial-temporal
|
| 7 |
+
structure required by ST-GCN while introducing controlled variations.
|
| 8 |
+
|
| 9 |
+
Input Format: (C, T, V, M) where
|
| 10 |
+
C = 3 channels (x, y, confidence)
|
| 11 |
+
T = 60 frames (temporal window)
|
| 12 |
+
V = 17 keypoints (COCO skeleton)
|
| 13 |
+
M = 1 person (max persons tracked)
|
| 14 |
+
|
| 15 |
+
Augmentation Strategies:
|
| 16 |
+
1. Horizontal Flip: Mirror skeleton across vertical axis with keypoint swapping
|
| 17 |
+
2. Gaussian Noise: Add random noise to x,y coordinates (preserves confidence)
|
| 18 |
+
3. Temporal Crop: Random crop + resize to simulate variable fall speeds
|
| 19 |
+
|
| 20 |
+
Reference: Issue #34 - ST-GCN Training Dataset Creation
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import numpy as np
|
| 24 |
+
from typing import Tuple, Optional
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# COCO 17-keypoint left/right pairs for horizontal flip
|
| 28 |
+
# Format: (left_index, right_index)
|
| 29 |
+
COCO_LEFT_RIGHT_PAIRS = [
|
| 30 |
+
(1, 2), # left_eye <-> right_eye
|
| 31 |
+
(3, 4), # left_ear <-> right_ear
|
| 32 |
+
(5, 6), # left_shoulder <-> right_shoulder
|
| 33 |
+
(7, 8), # left_elbow <-> right_elbow
|
| 34 |
+
(9, 10), # left_wrist <-> right_wrist
|
| 35 |
+
(11, 12), # left_hip <-> right_hip
|
| 36 |
+
(13, 14), # left_knee <-> right_knee
|
| 37 |
+
(15, 16), # left_ankle <-> right_ankle
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def augment_skeleton(data: np.ndarray, prob: float = 0.5) -> np.ndarray:
|
| 42 |
+
"""
|
| 43 |
+
Apply random augmentations to skeleton sequence data.
|
| 44 |
+
|
| 45 |
+
This function applies three augmentation strategies with probability `prob`:
|
| 46 |
+
1. Horizontal flip with keypoint swapping
|
| 47 |
+
2. Gaussian noise injection to x,y coordinates
|
| 48 |
+
3. Temporal crop and resize
|
| 49 |
+
|
| 50 |
+
Mathematical Formulations:
|
| 51 |
+
-------------------------
|
| 52 |
+
1. Horizontal Flip:
|
| 53 |
+
x' = -x
|
| 54 |
+
For each (left, right) keypoint pair: swap(left, right)
|
| 55 |
+
|
| 56 |
+
2. Gaussian Noise:
|
| 57 |
+
x' = x + N(0, sigma^2)
|
| 58 |
+
y' = y + N(0, sigma^2)
|
| 59 |
+
where N(0, sigma^2) ~ Normal(mean=0, std=0.01)
|
| 60 |
+
|
| 61 |
+
3. Temporal Crop & Resize:
|
| 62 |
+
T_crop ~ Uniform(0.8 * T, 1.0 * T)
|
| 63 |
+
start_frame ~ Uniform(0, T - T_crop)
|
| 64 |
+
cropped = data[:, start:start+T_crop, :, :]
|
| 65 |
+
resized = interpolate(cropped, T)
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
data: Skeleton data with shape (C, T, V, M) where
|
| 69 |
+
C = 3 (x, y, confidence)
|
| 70 |
+
T = 60 (number of frames)
|
| 71 |
+
V = 17 (number of keypoints)
|
| 72 |
+
M = 1 (number of persons)
|
| 73 |
+
prob: Probability of applying each augmentation (default: 0.5)
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
augmented_data: Augmented skeleton data with same shape (C, T, V, M)
|
| 77 |
+
|
| 78 |
+
Example:
|
| 79 |
+
>>> data = np.random.rand(3, 60, 17, 1)
|
| 80 |
+
>>> augmented = augment_skeleton(data, prob=0.5)
|
| 81 |
+
>>> augmented.shape
|
| 82 |
+
(3, 60, 17, 1)
|
| 83 |
+
"""
|
| 84 |
+
C, T, V, M = data.shape
|
| 85 |
+
assert C == 3, f"Expected 3 channels (x, y, conf), got {C}"
|
| 86 |
+
assert V == 17, f"Expected 17 COCO keypoints, got {V}"
|
| 87 |
+
assert M == 1, f"Expected max 1 person, got {M}"
|
| 88 |
+
|
| 89 |
+
# Create a copy to avoid modifying original data
|
| 90 |
+
augmented_data = data.copy()
|
| 91 |
+
|
| 92 |
+
# 1. Horizontal Flip (flip x-coordinate + swap left/right keypoints)
|
| 93 |
+
if np.random.rand() < prob:
|
| 94 |
+
augmented_data = _horizontal_flip(augmented_data)
|
| 95 |
+
|
| 96 |
+
# 2. Random Noise Injection (add Gaussian noise to x,y only)
|
| 97 |
+
if np.random.rand() < prob:
|
| 98 |
+
augmented_data = _add_gaussian_noise(augmented_data)
|
| 99 |
+
|
| 100 |
+
# 3. Temporal Crop and Resize (crop 0.8-1.0 of length, resize back)
|
| 101 |
+
if np.random.rand() < prob:
|
| 102 |
+
augmented_data = _temporal_crop_resize(augmented_data)
|
| 103 |
+
|
| 104 |
+
return augmented_data
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def _horizontal_flip(data: np.ndarray) -> np.ndarray:
|
| 108 |
+
"""
|
| 109 |
+
Horizontally flip skeleton by negating x-coordinate and swapping left/right keypoints.
|
| 110 |
+
|
| 111 |
+
Mathematical Formulation:
|
| 112 |
+
x' = -x
|
| 113 |
+
y' = y
|
| 114 |
+
conf' = conf
|
| 115 |
+
For each (left_idx, right_idx) pair: swap keypoints
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
data: Skeleton data (C, T, V, M)
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
flipped_data: Horizontally flipped data (C, T, V, M)
|
| 122 |
+
"""
|
| 123 |
+
flipped_data = data.copy()
|
| 124 |
+
|
| 125 |
+
# Flip x-coordinate (channel 0)
|
| 126 |
+
flipped_data[0] = -flipped_data[0]
|
| 127 |
+
|
| 128 |
+
# Swap left/right keypoint pairs
|
| 129 |
+
for left_idx, right_idx in COCO_LEFT_RIGHT_PAIRS:
|
| 130 |
+
# Swap all channels (x, y, conf) for the keypoint pair
|
| 131 |
+
temp = flipped_data[:, :, left_idx, :].copy()
|
| 132 |
+
flipped_data[:, :, left_idx, :] = flipped_data[:, :, right_idx, :]
|
| 133 |
+
flipped_data[:, :, right_idx, :] = temp
|
| 134 |
+
|
| 135 |
+
return flipped_data
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _add_gaussian_noise(data: np.ndarray, std: float = 0.01) -> np.ndarray:
|
| 139 |
+
"""
|
| 140 |
+
Add Gaussian noise to x,y coordinates (preserves confidence channel).
|
| 141 |
+
|
| 142 |
+
Mathematical Formulation:
|
| 143 |
+
x' = x + N(0, sigma^2)
|
| 144 |
+
y' = y + N(0, sigma^2)
|
| 145 |
+
conf' = conf (unchanged)
|
| 146 |
+
where sigma = 0.01 (default)
|
| 147 |
+
|
| 148 |
+
The noise magnitude is calibrated for normalized coordinates in range [-0.5, 0.5].
|
| 149 |
+
With std=0.01, 99.7% of noise values fall within [-0.03, 0.03] (3-sigma rule).
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
data: Skeleton data (C, T, V, M)
|
| 153 |
+
std: Standard deviation of Gaussian noise (default: 0.01)
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
noisy_data: Data with Gaussian noise added to x,y coordinates
|
| 157 |
+
"""
|
| 158 |
+
C, T, V, M = data.shape
|
| 159 |
+
noisy_data = data.copy()
|
| 160 |
+
|
| 161 |
+
# Generate Gaussian noise for x,y channels only (not confidence)
|
| 162 |
+
noise_shape = (2, T, V, M) # Only x,y channels
|
| 163 |
+
noise = np.random.normal(0, std, noise_shape).astype(data.dtype)
|
| 164 |
+
|
| 165 |
+
# Add noise to x,y channels (0, 1), leave confidence channel (2) unchanged
|
| 166 |
+
noisy_data[:2] += noise
|
| 167 |
+
|
| 168 |
+
return noisy_data
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _temporal_crop_resize(data: np.ndarray, crop_ratio_range: Tuple[float, float] = (0.8, 1.0)) -> np.ndarray:
|
| 172 |
+
"""
|
| 173 |
+
Randomly crop temporal sequence and resize back to original length.
|
| 174 |
+
|
| 175 |
+
This augmentation simulates variable fall speeds by compressing or expanding
|
| 176 |
+
the temporal dimension. A crop ratio of 0.8 means the fall happens 20% faster,
|
| 177 |
+
while 1.0 means no temporal change.
|
| 178 |
+
|
| 179 |
+
Mathematical Formulation:
|
| 180 |
+
T_crop ~ Uniform(crop_min * T, crop_max * T)
|
| 181 |
+
start ~ Uniform(0, T - T_crop)
|
| 182 |
+
cropped = data[:, start:start+T_crop, :, :]
|
| 183 |
+
resized = interpolate(cropped, T) using linear interpolation
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
data: Skeleton data (C, T, V, M)
|
| 187 |
+
crop_ratio_range: (min_ratio, max_ratio) for crop length (default: (0.8, 1.0))
|
| 188 |
+
|
| 189 |
+
Returns:
|
| 190 |
+
resized_data: Temporally augmented data with original shape (C, T, V, M)
|
| 191 |
+
"""
|
| 192 |
+
C, T, V, M = data.shape
|
| 193 |
+
min_ratio, max_ratio = crop_ratio_range
|
| 194 |
+
|
| 195 |
+
# Sample random crop ratio
|
| 196 |
+
crop_ratio = np.random.uniform(min_ratio, max_ratio)
|
| 197 |
+
crop_length = int(T * crop_ratio)
|
| 198 |
+
crop_length = max(1, crop_length) # Ensure at least 1 frame
|
| 199 |
+
|
| 200 |
+
# Sample random start position
|
| 201 |
+
max_start = max(0, T - crop_length)
|
| 202 |
+
start_frame = np.random.randint(0, max_start + 1) if max_start > 0 else 0
|
| 203 |
+
|
| 204 |
+
# Extract cropped window
|
| 205 |
+
cropped = data[:, start_frame:start_frame + crop_length, :, :]
|
| 206 |
+
|
| 207 |
+
# Resize back to original temporal length using linear interpolation
|
| 208 |
+
resized_data = _temporal_interpolate(cropped, T)
|
| 209 |
+
|
| 210 |
+
return resized_data
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def _temporal_interpolate(data: np.ndarray, target_length: int) -> np.ndarray:
|
| 214 |
+
"""
|
| 215 |
+
Interpolate temporal dimension to target length using linear interpolation.
|
| 216 |
+
|
| 217 |
+
This function performs 1D linear interpolation along the temporal axis (axis=1)
|
| 218 |
+
for each channel, keypoint, and person independently.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
data: Skeleton data (C, T, V, M)
|
| 222 |
+
target_length: Target number of frames
|
| 223 |
+
|
| 224 |
+
Returns:
|
| 225 |
+
interpolated_data: Data with temporal dimension resized to target_length
|
| 226 |
+
"""
|
| 227 |
+
C, T_src, V, M = data.shape
|
| 228 |
+
|
| 229 |
+
if T_src == target_length:
|
| 230 |
+
return data
|
| 231 |
+
|
| 232 |
+
# Create target time indices
|
| 233 |
+
src_indices = np.linspace(0, T_src - 1, T_src)
|
| 234 |
+
target_indices = np.linspace(0, T_src - 1, target_length)
|
| 235 |
+
|
| 236 |
+
# Interpolate each channel, keypoint, person combination
|
| 237 |
+
interpolated_data = np.zeros((C, target_length, V, M), dtype=data.dtype)
|
| 238 |
+
|
| 239 |
+
for c in range(C):
|
| 240 |
+
for v in range(V):
|
| 241 |
+
for m in range(M):
|
| 242 |
+
interpolated_data[c, :, v, m] = np.interp(
|
| 243 |
+
target_indices,
|
| 244 |
+
src_indices,
|
| 245 |
+
data[c, :, v, m]
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
return interpolated_data
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def _normalize_by_hip_center(data: np.ndarray) -> np.ndarray:
|
| 252 |
+
"""
|
| 253 |
+
Normalize skeleton by hip center position and skeleton size (ST-GCN standard).
|
| 254 |
+
|
| 255 |
+
This is the recommended normalization method for skeleton-based action recognition,
|
| 256 |
+
following the ST-GCN paper and NTU RGB+D dataset preprocessing.
|
| 257 |
+
|
| 258 |
+
Algorithm:
|
| 259 |
+
----------
|
| 260 |
+
1. Calculate hip center from left_hip (11) and right_hip (12)
|
| 261 |
+
2. If hips have low confidence (<0.3), fallback to shoulder center
|
| 262 |
+
3. Center all keypoints by subtracting hip center
|
| 263 |
+
4. Calculate skeleton size as average shoulder-to-hip distance
|
| 264 |
+
5. Scale all coordinates by skeleton size
|
| 265 |
+
|
| 266 |
+
COCO Keypoints Used:
|
| 267 |
+
- 5: left_shoulder
|
| 268 |
+
- 6: right_shoulder
|
| 269 |
+
- 11: left_hip
|
| 270 |
+
- 12: right_hip
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
data: Skeleton data (C, T, V, M) with C=3 (x, y, conf)
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
normalized_data: (C, T, V, M) centered at hip, scaled by skeleton size
|
| 277 |
+
- x,y channels: relative to hip center, scaled by skeleton size
|
| 278 |
+
- conf channel: unchanged
|
| 279 |
+
|
| 280 |
+
Example:
|
| 281 |
+
>>> data = np.random.rand(3, 60, 17, 1) * [3840, 2160, 1]
|
| 282 |
+
>>> normalized = _normalize_by_hip_center(data)
|
| 283 |
+
>>> # Hip center is now at (0, 0)
|
| 284 |
+
>>> hip_center_x = (normalized[0, :, 11, :] + normalized[0, :, 12, :]) / 2
|
| 285 |
+
>>> np.allclose(hip_center_x, 0.0, atol=1e-6)
|
| 286 |
+
True
|
| 287 |
+
"""
|
| 288 |
+
C, T, V, M = data.shape
|
| 289 |
+
normalized_data = data.copy()
|
| 290 |
+
|
| 291 |
+
# Extract hip keypoints (COCO: 11=left_hip, 12=right_hip)
|
| 292 |
+
left_hip_xy = data[:2, :, 11:12, :] # (2, T, 1, M)
|
| 293 |
+
right_hip_xy = data[:2, :, 12:13, :] # (2, T, 1, M)
|
| 294 |
+
left_hip_conf = data[2:3, :, 11:12, :] # (1, T, 1, M)
|
| 295 |
+
right_hip_conf = data[2:3, :, 12:13, :]# (1, T, 1, M)
|
| 296 |
+
|
| 297 |
+
# Calculate average hip confidence across all frames
|
| 298 |
+
left_hip_conf_mean = np.mean(left_hip_conf)
|
| 299 |
+
right_hip_conf_mean = np.mean(right_hip_conf)
|
| 300 |
+
|
| 301 |
+
# Determine center point (hip or shoulder fallback)
|
| 302 |
+
if left_hip_conf_mean >= 0.3 and right_hip_conf_mean >= 0.3:
|
| 303 |
+
# Normal case: Use hip center
|
| 304 |
+
center_point = (left_hip_xy + right_hip_xy) / 2.0 # (2, T, 1, M)
|
| 305 |
+
|
| 306 |
+
# Calculate skeleton size from shoulder-to-hip distance
|
| 307 |
+
left_shoulder_xy = data[:2, :, 5:6, :] # (2, T, 1, M)
|
| 308 |
+
right_shoulder_xy = data[:2, :, 6:7, :] # (2, T, 1, M)
|
| 309 |
+
|
| 310 |
+
# Left torso distance: ||left_shoulder - left_hip||
|
| 311 |
+
left_torso = left_shoulder_xy - left_hip_xy # (2, T, 1, M)
|
| 312 |
+
left_torso_dist = np.sqrt(np.sum(left_torso ** 2, axis=0)) # (T, 1, M)
|
| 313 |
+
|
| 314 |
+
# Right torso distance: ||right_shoulder - right_hip||
|
| 315 |
+
right_torso = right_shoulder_xy - right_hip_xy # (2, T, 1, M)
|
| 316 |
+
right_torso_dist = np.sqrt(np.sum(right_torso ** 2, axis=0)) # (T, 1, M)
|
| 317 |
+
|
| 318 |
+
# Average skeleton size across frames and left/right
|
| 319 |
+
skeleton_size = np.mean([left_torso_dist, right_torso_dist]) # scalar
|
| 320 |
+
|
| 321 |
+
else:
|
| 322 |
+
# Fallback: Use shoulder center if hips not detected
|
| 323 |
+
left_shoulder_xy = data[:2, :, 5:6, :]
|
| 324 |
+
right_shoulder_xy = data[:2, :, 6:7, :]
|
| 325 |
+
center_point = (left_shoulder_xy + right_shoulder_xy) / 2.0 # (2, T, 1, M)
|
| 326 |
+
|
| 327 |
+
# Use shoulder width as skeleton size estimate
|
| 328 |
+
shoulder_vector = right_shoulder_xy - left_shoulder_xy # (2, T, 1, M)
|
| 329 |
+
shoulder_width = np.sqrt(np.sum(shoulder_vector ** 2, axis=0)) # (T, 1, M)
|
| 330 |
+
skeleton_size = np.mean(shoulder_width) * 2.0 # Approximate torso height
|
| 331 |
+
|
| 332 |
+
# Prevent division by zero
|
| 333 |
+
skeleton_size = max(skeleton_size, 1e-6)
|
| 334 |
+
|
| 335 |
+
# Normalize x,y channels: center and scale
|
| 336 |
+
normalized_data[:2] = (normalized_data[:2] - center_point) / skeleton_size
|
| 337 |
+
|
| 338 |
+
# Confidence channel unchanged
|
| 339 |
+
# normalized_data[2] remains as is
|
| 340 |
+
|
| 341 |
+
return normalized_data
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def _normalize_by_image_center(
|
| 345 |
+
data: np.ndarray,
|
| 346 |
+
img_width: int = 3840,
|
| 347 |
+
img_height: int = 2160
|
| 348 |
+
) -> np.ndarray:
|
| 349 |
+
"""
|
| 350 |
+
Legacy normalization by image center (for comparison only).
|
| 351 |
+
|
| 352 |
+
This method is NOT recommended for ST-GCN training as it:
|
| 353 |
+
- Includes absolute position information
|
| 354 |
+
- Varies with camera angle
|
| 355 |
+
- Does not normalize body size
|
| 356 |
+
|
| 357 |
+
Use this only for comparing with old implementations or specific use cases
|
| 358 |
+
where absolute position in frame matters.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
data: Skeleton data (C, T, V, M)
|
| 362 |
+
img_width: Image width in pixels (default: 3840 for AI Hub 4K)
|
| 363 |
+
img_height: Image height in pixels (default: 2160 for AI Hub 4K)
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
normalized_data: (C, T, V, M) with x,y in [-0.5, 0.5]
|
| 367 |
+
"""
|
| 368 |
+
C, T, V, M = data.shape
|
| 369 |
+
normalized_data = data.copy()
|
| 370 |
+
|
| 371 |
+
# Normalize x-coordinate (channel 0): [0, img_width] -> [-0.5, 0.5]
|
| 372 |
+
normalized_data[0] = (normalized_data[0] / img_width) - 0.5
|
| 373 |
+
|
| 374 |
+
# Normalize y-coordinate (channel 1): [0, img_height] -> [-0.5, 0.5]
|
| 375 |
+
normalized_data[1] = (normalized_data[1] / img_height) - 0.5
|
| 376 |
+
|
| 377 |
+
# Confidence channel (2) remains unchanged in [0, 1]
|
| 378 |
+
|
| 379 |
+
return normalized_data
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def normalize_skeleton(
|
| 383 |
+
data: np.ndarray,
|
| 384 |
+
method: str = 'hip_center',
|
| 385 |
+
img_width: int = 3840,
|
| 386 |
+
img_height: int = 2160
|
| 387 |
+
) -> np.ndarray:
|
| 388 |
+
"""
|
| 389 |
+
Normalize skeleton coordinates using ST-GCN standard method.
|
| 390 |
+
|
| 391 |
+
This normalization removes absolute position information and makes the model
|
| 392 |
+
focus on relative pose patterns, which is critical for fall detection across
|
| 393 |
+
different camera angles (AI Hub 8-camera setup).
|
| 394 |
+
|
| 395 |
+
Methods:
|
| 396 |
+
--------
|
| 397 |
+
1. 'hip_center' (default, ST-GCN standard):
|
| 398 |
+
- Center: Hip center (average of left_hip and right_hip)
|
| 399 |
+
- Scale: Skeleton size (shoulder-to-hip distance)
|
| 400 |
+
- Fallback: Shoulder center if hips not detected
|
| 401 |
+
- Reference: ST-GCN (Yan et al., AAAI 2018), NTU RGB+D normalization
|
| 402 |
+
|
| 403 |
+
2. 'image_center' (legacy, not recommended):
|
| 404 |
+
- Center: Image center
|
| 405 |
+
- Scale: Image dimensions
|
| 406 |
+
- Use only for comparison with old implementations
|
| 407 |
+
|
| 408 |
+
Mathematical Formulations (hip_center):
|
| 409 |
+
----------------------------------------
|
| 410 |
+
Step 1: Calculate hip center
|
| 411 |
+
hip_center = (left_hip + right_hip) / 2 # COCO keypoints 11, 12
|
| 412 |
+
|
| 413 |
+
Step 2: Center all keypoints
|
| 414 |
+
x' = x - hip_center_x
|
| 415 |
+
y' = y - hip_center_y
|
| 416 |
+
|
| 417 |
+
Step 3: Scale by skeleton size (shoulder-to-hip distance)
|
| 418 |
+
skeleton_size = mean(||shoulder - hip||) over left and right
|
| 419 |
+
x'' = x' / skeleton_size
|
| 420 |
+
y'' = y' / skeleton_size
|
| 421 |
+
|
| 422 |
+
Advantages of hip_center normalization:
|
| 423 |
+
- Camera angle invariant (critical for 8-camera AI Hub dataset)
|
| 424 |
+
- Absolute position independent (person can be anywhere in frame)
|
| 425 |
+
- Body size normalized (tall/short people comparable)
|
| 426 |
+
- Matches ST-GCN paper and most skeleton action recognition works
|
| 427 |
+
|
| 428 |
+
Args:
|
| 429 |
+
data: Skeleton data with shape (C, T, V, M) where
|
| 430 |
+
C = 3 (x in pixels, y in pixels, confidence)
|
| 431 |
+
T = number of frames
|
| 432 |
+
V = 17 (COCO keypoints)
|
| 433 |
+
M = 1 (max persons)
|
| 434 |
+
method: Normalization method - 'hip_center' (default) or 'image_center'
|
| 435 |
+
img_width: Image width for image_center method (default: 3840 for AI Hub 4K)
|
| 436 |
+
img_height: Image height for image_center method (default: 2160 for AI Hub 4K)
|
| 437 |
+
|
| 438 |
+
Returns:
|
| 439 |
+
normalized_data: Normalized skeleton data with shape (C, T, V, M)
|
| 440 |
+
For hip_center: relative coordinates centered at hip, scaled by skeleton size
|
| 441 |
+
For image_center: x,y in [-0.5, 0.5], conf in [0, 1]
|
| 442 |
+
|
| 443 |
+
Example:
|
| 444 |
+
>>> # ST-GCN standard normalization
|
| 445 |
+
>>> data = np.random.rand(3, 60, 17, 1) * [3840, 2160, 1]
|
| 446 |
+
>>> normalized = normalize_skeleton(data, method='hip_center')
|
| 447 |
+
>>> # Hip is now at origin (0, 0)
|
| 448 |
+
>>> # Coordinates scaled by skeleton size
|
| 449 |
+
|
| 450 |
+
>>> # Legacy image center normalization
|
| 451 |
+
>>> normalized_legacy = normalize_skeleton(data, method='image_center')
|
| 452 |
+
>>> normalized_legacy[0].min(), normalized_legacy[0].max() # x range
|
| 453 |
+
(-0.5, 0.5)
|
| 454 |
+
"""
|
| 455 |
+
C, T, V, M = data.shape
|
| 456 |
+
assert C == 3, f"Expected 3 channels (x, y, conf), got {C}"
|
| 457 |
+
assert V == 17, f"Expected 17 COCO keypoints, got {V}"
|
| 458 |
+
|
| 459 |
+
if method == 'hip_center':
|
| 460 |
+
return _normalize_by_hip_center(data)
|
| 461 |
+
elif method == 'image_center':
|
| 462 |
+
return _normalize_by_image_center(data, img_width, img_height)
|
| 463 |
+
else:
|
| 464 |
+
raise ValueError(
|
| 465 |
+
f"Unknown normalization method: '{method}'. "
|
| 466 |
+
f"Use 'hip_center' (ST-GCN standard) or 'image_center' (legacy)."
|
| 467 |
+
)
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def denormalize_skeleton(
|
| 471 |
+
data: np.ndarray,
|
| 472 |
+
method: str = 'hip_center',
|
| 473 |
+
hip_center: Optional[np.ndarray] = None,
|
| 474 |
+
skeleton_size: Optional[float] = None,
|
| 475 |
+
img_width: int = 3840,
|
| 476 |
+
img_height: int = 2160
|
| 477 |
+
) -> np.ndarray:
|
| 478 |
+
"""
|
| 479 |
+
Denormalize skeleton coordinates back to original space.
|
| 480 |
+
|
| 481 |
+
NOTE: For hip_center method, denormalization requires storing the original
|
| 482 |
+
hip_center and skeleton_size values during normalization. This function is
|
| 483 |
+
primarily for visualization purposes.
|
| 484 |
+
|
| 485 |
+
For most ST-GCN training workflows, you don't need denormalization since:
|
| 486 |
+
- Training works directly on normalized coordinates
|
| 487 |
+
- Model predictions are classification labels (not coordinates)
|
| 488 |
+
|
| 489 |
+
Methods:
|
| 490 |
+
--------
|
| 491 |
+
1. 'hip_center': Requires hip_center and skeleton_size parameters
|
| 492 |
+
2. 'image_center': Only requires img_width and img_height
|
| 493 |
+
|
| 494 |
+
Args:
|
| 495 |
+
data: Normalized skeleton data (C, T, V, M)
|
| 496 |
+
method: Denormalization method - 'hip_center' or 'image_center'
|
| 497 |
+
hip_center: Original hip center position (2, T, 1, M) - required for hip_center method
|
| 498 |
+
skeleton_size: Original skeleton size (scalar) - required for hip_center method
|
| 499 |
+
img_width: Image width for image_center method (default: 3840)
|
| 500 |
+
img_height: Image height for image_center method (default: 2160)
|
| 501 |
+
|
| 502 |
+
Returns:
|
| 503 |
+
denormalized_data: Skeleton data in original coordinate space
|
| 504 |
+
|
| 505 |
+
Example:
|
| 506 |
+
>>> # Hip center denormalization (requires original values)
|
| 507 |
+
>>> data_original = np.random.rand(3, 60, 17, 1) * [3840, 2160, 1]
|
| 508 |
+
>>> normalized = normalize_skeleton(data_original, method='hip_center')
|
| 509 |
+
>>> # Note: In practice, you need to store hip_center and skeleton_size
|
| 510 |
+
>>> # during normalization for accurate denormalization
|
| 511 |
+
|
| 512 |
+
>>> # Image center denormalization (simpler)
|
| 513 |
+
>>> normalized = normalize_skeleton(data_original, method='image_center')
|
| 514 |
+
>>> denormalized = denormalize_skeleton(normalized, method='image_center')
|
| 515 |
+
>>> np.allclose(data_original[:2], denormalized[:2], atol=1.0) # Within 1 pixel
|
| 516 |
+
True
|
| 517 |
+
"""
|
| 518 |
+
C, T, V, M = data.shape
|
| 519 |
+
assert C == 3, f"Expected 3 channels (x, y, conf), got {C}"
|
| 520 |
+
|
| 521 |
+
if method == 'hip_center':
|
| 522 |
+
if hip_center is None or skeleton_size is None:
|
| 523 |
+
raise ValueError(
|
| 524 |
+
"hip_center denormalization requires 'hip_center' and 'skeleton_size' parameters. "
|
| 525 |
+
"These values must be saved during normalization. "
|
| 526 |
+
"For visualization without original values, consider using method='image_center'."
|
| 527 |
+
)
|
| 528 |
+
return _denormalize_by_hip_center(data, hip_center, skeleton_size)
|
| 529 |
+
|
| 530 |
+
elif method == 'image_center':
|
| 531 |
+
return _denormalize_by_image_center(data, img_width, img_height)
|
| 532 |
+
|
| 533 |
+
else:
|
| 534 |
+
raise ValueError(
|
| 535 |
+
f"Unknown denormalization method: '{method}'. "
|
| 536 |
+
f"Use 'hip_center' or 'image_center'."
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def _denormalize_by_hip_center(
|
| 541 |
+
data: np.ndarray,
|
| 542 |
+
hip_center: np.ndarray,
|
| 543 |
+
skeleton_size: float
|
| 544 |
+
) -> np.ndarray:
|
| 545 |
+
"""
|
| 546 |
+
Reverse hip center normalization.
|
| 547 |
+
|
| 548 |
+
Args:
|
| 549 |
+
data: Normalized skeleton data (C, T, V, M)
|
| 550 |
+
hip_center: Original hip center (2, T, 1, M) or (2,) for constant
|
| 551 |
+
skeleton_size: Original skeleton size (scalar)
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
denormalized_data: (C, T, V, M) in original pixel coordinates
|
| 555 |
+
"""
|
| 556 |
+
C, T, V, M = data.shape
|
| 557 |
+
denormalized_data = data.copy()
|
| 558 |
+
|
| 559 |
+
# Reverse scale and centering: x_original = x_normalized * skeleton_size + hip_center
|
| 560 |
+
denormalized_data[:2] = denormalized_data[:2] * skeleton_size + hip_center
|
| 561 |
+
|
| 562 |
+
# Confidence channel unchanged
|
| 563 |
+
|
| 564 |
+
return denormalized_data
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
def _denormalize_by_image_center(
|
| 568 |
+
data: np.ndarray,
|
| 569 |
+
img_width: int = 3840,
|
| 570 |
+
img_height: int = 2160
|
| 571 |
+
) -> np.ndarray:
|
| 572 |
+
"""
|
| 573 |
+
Reverse image center normalization.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
data: Normalized skeleton data (C, T, V, M) with x,y in [-0.5, 0.5]
|
| 577 |
+
img_width: Image width in pixels (default: 3840)
|
| 578 |
+
img_height: Image height in pixels (default: 2160)
|
| 579 |
+
|
| 580 |
+
Returns:
|
| 581 |
+
denormalized_data: (C, T, V, M) with x,y in pixel coordinates
|
| 582 |
+
"""
|
| 583 |
+
C, T, V, M = data.shape
|
| 584 |
+
denormalized_data = data.copy()
|
| 585 |
+
|
| 586 |
+
# Denormalize x-coordinate: [-0.5, 0.5] -> [0, img_width]
|
| 587 |
+
denormalized_data[0] = (denormalized_data[0] + 0.5) * img_width
|
| 588 |
+
|
| 589 |
+
# Denormalize y-coordinate: [-0.5, 0.5] -> [0, img_height]
|
| 590 |
+
denormalized_data[1] = (denormalized_data[1] + 0.5) * img_height
|
| 591 |
+
|
| 592 |
+
# Confidence channel remains unchanged
|
| 593 |
+
|
| 594 |
+
return denormalized_data
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def test_augmentation():
|
| 598 |
+
"""
|
| 599 |
+
Test augmentation functions and demonstrate their effects.
|
| 600 |
+
|
| 601 |
+
This function creates synthetic skeleton data and applies each augmentation
|
| 602 |
+
to verify correctness and visualize the transformations.
|
| 603 |
+
"""
|
| 604 |
+
print("Skeleton Data Augmentation Test")
|
| 605 |
+
print("=" * 80)
|
| 606 |
+
|
| 607 |
+
# Create synthetic skeleton data (C, T, V, M)
|
| 608 |
+
C, T, V, M = 3, 60, 17, 1
|
| 609 |
+
np.random.seed(42)
|
| 610 |
+
|
| 611 |
+
# Generate synthetic data in pixel coordinates
|
| 612 |
+
data = np.random.rand(C, T, V, M)
|
| 613 |
+
data[0] *= 1920 # x in [0, 1920]
|
| 614 |
+
data[1] *= 1080 # y in [0, 1080]
|
| 615 |
+
data[2] = np.random.uniform(0.5, 1.0, (T, V, M)) # confidence in [0.5, 1.0]
|
| 616 |
+
|
| 617 |
+
print(f"\nOriginal data shape: {data.shape}")
|
| 618 |
+
print(f"Original x range: [{data[0].min():.2f}, {data[0].max():.2f}] pixels")
|
| 619 |
+
print(f"Original y range: [{data[1].min():.2f}, {data[1].max():.2f}] pixels")
|
| 620 |
+
print(f"Original confidence range: [{data[2].min():.3f}, {data[2].max():.3f}]")
|
| 621 |
+
|
| 622 |
+
# Test 1: Normalization
|
| 623 |
+
print("\n" + "-" * 80)
|
| 624 |
+
print("Test 1: Normalization")
|
| 625 |
+
print("-" * 80)
|
| 626 |
+
normalized = normalize_skeleton(data, img_width=1920, img_height=1080)
|
| 627 |
+
print(f"Normalized x range: [{normalized[0].min():.3f}, {normalized[0].max():.3f}]")
|
| 628 |
+
print(f"Normalized y range: [{normalized[1].min():.3f}, {normalized[1].max():.3f}]")
|
| 629 |
+
print(f"Normalized confidence range: [{normalized[2].min():.3f}, {normalized[2].max():.3f}]")
|
| 630 |
+
|
| 631 |
+
# Verify denormalization
|
| 632 |
+
denormalized = denormalize_skeleton(normalized, img_width=1920, img_height=1080)
|
| 633 |
+
reconstruction_error = np.abs(data - denormalized).max()
|
| 634 |
+
print(f"Denormalization reconstruction error: {reconstruction_error:.6f} pixels")
|
| 635 |
+
|
| 636 |
+
# Test 2: Horizontal Flip
|
| 637 |
+
print("\n" + "-" * 80)
|
| 638 |
+
print("Test 2: Horizontal Flip")
|
| 639 |
+
print("-" * 80)
|
| 640 |
+
np.random.seed(42)
|
| 641 |
+
flipped = augment_skeleton(normalized, prob=1.0) # Force all augmentations
|
| 642 |
+
print(f"Original x (frame 0, keypoint 0): {normalized[0, 0, 0, 0]:.3f}")
|
| 643 |
+
print(f"After augmentation x: {flipped[0, 0, 0, 0]:.3f}")
|
| 644 |
+
print(f"X-coordinate sign flipped: {np.sign(normalized[0].mean()) != np.sign(flipped[0].mean())}")
|
| 645 |
+
|
| 646 |
+
# Test 3: Check left/right keypoint swapping
|
| 647 |
+
print("\n" + "-" * 80)
|
| 648 |
+
print("Test 3: Keypoint Pair Swapping (Horizontal Flip)")
|
| 649 |
+
print("-" * 80)
|
| 650 |
+
# Create data with distinctive values for left/right pairs
|
| 651 |
+
test_data = np.zeros((3, 60, 17, 1))
|
| 652 |
+
test_data[0, :, 5, 0] = 100 # left_shoulder x = 100
|
| 653 |
+
test_data[0, :, 6, 0] = -100 # right_shoulder x = -100
|
| 654 |
+
flipped_test = _horizontal_flip(test_data)
|
| 655 |
+
print(f"Original left_shoulder (idx 5) x: {test_data[0, 0, 5, 0]:.1f}")
|
| 656 |
+
print(f"Original right_shoulder (idx 6) x: {test_data[0, 0, 6, 0]:.1f}")
|
| 657 |
+
print(f"Flipped left_shoulder (idx 5) x: {flipped_test[0, 0, 5, 0]:.1f}")
|
| 658 |
+
print(f"Flipped right_shoulder (idx 6) x: {flipped_test[0, 0, 6, 0]:.1f}")
|
| 659 |
+
print(f"Swap successful: {flipped_test[0, 0, 5, 0] == 100 and flipped_test[0, 0, 6, 0] == -100}")
|
| 660 |
+
|
| 661 |
+
# Test 4: Gaussian Noise
|
| 662 |
+
print("\n" + "-" * 80)
|
| 663 |
+
print("Test 4: Gaussian Noise")
|
| 664 |
+
print("-" * 80)
|
| 665 |
+
np.random.seed(42)
|
| 666 |
+
noisy = _add_gaussian_noise(normalized, std=0.01)
|
| 667 |
+
noise_magnitude = np.abs(noisy[:2] - normalized[:2]).max()
|
| 668 |
+
confidence_unchanged = np.allclose(noisy[2], normalized[2])
|
| 669 |
+
print(f"Max noise magnitude (x,y): {noise_magnitude:.4f}")
|
| 670 |
+
print(f"Confidence channel unchanged: {confidence_unchanged}")
|
| 671 |
+
|
| 672 |
+
# Test 5: Temporal Crop and Resize
|
| 673 |
+
print("\n" + "-" * 80)
|
| 674 |
+
print("Test 5: Temporal Crop and Resize")
|
| 675 |
+
print("-" * 80)
|
| 676 |
+
np.random.seed(42)
|
| 677 |
+
cropped = _temporal_crop_resize(normalized, crop_ratio_range=(0.8, 1.0))
|
| 678 |
+
print(f"Original temporal length: {normalized.shape[1]}")
|
| 679 |
+
print(f"Cropped temporal length: {cropped.shape[1]}")
|
| 680 |
+
print(f"Shape preserved: {cropped.shape == normalized.shape}")
|
| 681 |
+
|
| 682 |
+
# Test 6: Full Augmentation Pipeline
|
| 683 |
+
print("\n" + "-" * 80)
|
| 684 |
+
print("Test 6: Full Augmentation Pipeline")
|
| 685 |
+
print("-" * 80)
|
| 686 |
+
np.random.seed(42)
|
| 687 |
+
augmented = augment_skeleton(normalized, prob=0.5)
|
| 688 |
+
print(f"Augmented shape: {augmented.shape}")
|
| 689 |
+
print(f"Augmented x range: [{augmented[0].min():.3f}, {augmented[0].max():.3f}]")
|
| 690 |
+
print(f"Augmented y range: [{augmented[1].min():.3f}, {augmented[1].max():.3f}]")
|
| 691 |
+
print(f"Augmented confidence range: [{augmented[2].min():.3f}, {augmented[2].max():.3f}]")
|
| 692 |
+
|
| 693 |
+
# Test 7: Augmentation Statistics (Run 100 times)
|
| 694 |
+
print("\n" + "-" * 80)
|
| 695 |
+
print("Test 7: Augmentation Statistics (100 runs with prob=0.5)")
|
| 696 |
+
print("-" * 80)
|
| 697 |
+
np.random.seed(42)
|
| 698 |
+
augmentation_counts = {"flip": 0, "noise": 0, "crop": 0}
|
| 699 |
+
num_runs = 100
|
| 700 |
+
|
| 701 |
+
for _ in range(num_runs):
|
| 702 |
+
original_copy = normalized.copy()
|
| 703 |
+
augmented = augment_skeleton(original_copy, prob=0.5)
|
| 704 |
+
|
| 705 |
+
# Detect which augmentations were applied (heuristics)
|
| 706 |
+
x_sign_changed = np.sign(augmented[0].mean()) != np.sign(normalized[0].mean())
|
| 707 |
+
noise_added = not np.allclose(augmented[:2], normalized[:2], atol=1e-4)
|
| 708 |
+
# Crop detection is harder, skip for now
|
| 709 |
+
|
| 710 |
+
if x_sign_changed:
|
| 711 |
+
augmentation_counts["flip"] += 1
|
| 712 |
+
if noise_added and not x_sign_changed:
|
| 713 |
+
augmentation_counts["noise"] += 1
|
| 714 |
+
|
| 715 |
+
print(f"Horizontal flip applied: {augmentation_counts['flip']}/{num_runs} times")
|
| 716 |
+
print(f"Gaussian noise applied: {augmentation_counts['noise']}/{num_runs} times")
|
| 717 |
+
print(f"Expected frequency (prob=0.5): ~50 times per augmentation")
|
| 718 |
+
|
| 719 |
+
print("\n" + "=" * 80)
|
| 720 |
+
print("All tests completed successfully")
|
| 721 |
+
print("=" * 80)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
if __name__ == "__main__":
|
| 725 |
+
test_augmentation()
|
models/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Models package for HF Spaces deployment
|
models/pose_estimator.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
YOLOv11-Pose 래퍼 클래스
|
| 3 |
+
|
| 4 |
+
실시간 pose estimation을 위한 YOLOv11-Pose 모델 래퍼입니다.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import logging
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from ultralytics import YOLO
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PoseEstimator:
|
| 16 |
+
"""YOLOv11-Pose 기반 포즈 추정기"""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
model_path: str = "yolo11m-pose.pt",
|
| 21 |
+
conf_threshold: float = 0.5,
|
| 22 |
+
imgsz: int = 640,
|
| 23 |
+
device: str = "cuda:0",
|
| 24 |
+
logger: Optional[logging.Logger] = None
|
| 25 |
+
):
|
| 26 |
+
"""
|
| 27 |
+
Args:
|
| 28 |
+
model_path: YOLOv11-Pose 모델 경로
|
| 29 |
+
conf_threshold: 감지 신뢰도 임계값
|
| 30 |
+
imgsz: 입력 이미지 크기
|
| 31 |
+
device: 디바이스 (cuda:0, cpu 등)
|
| 32 |
+
logger: 로거 인스턴스
|
| 33 |
+
"""
|
| 34 |
+
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 35 |
+
self.conf_threshold = conf_threshold
|
| 36 |
+
self.imgsz = imgsz
|
| 37 |
+
self.logger = logger or logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
# 모델 로드
|
| 40 |
+
self.logger.info(f"[Stage 1] YOLOv11-Pose 로드 중: {model_path}")
|
| 41 |
+
self.model = YOLO(model_path)
|
| 42 |
+
self.model.to(self.device)
|
| 43 |
+
self.logger.info(f" - Confidence threshold: {conf_threshold}")
|
| 44 |
+
self.logger.info(f" - Image size: {imgsz}")
|
| 45 |
+
self.logger.info(f" - Device: {self.device}")
|
| 46 |
+
|
| 47 |
+
def extract(self, frame: np.ndarray, debug: bool = False) -> Optional[np.ndarray]:
|
| 48 |
+
"""
|
| 49 |
+
프레임에서 pose keypoints 추출
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
frame: OpenCV 이미지 (H, W, 3)
|
| 53 |
+
debug: 디버그 로그 출력 여부
|
| 54 |
+
|
| 55 |
+
Returns:
|
| 56 |
+
keypoints: (17, 3) numpy array 또는 None (사람이 감지되지 않은 경우)
|
| 57 |
+
각 keypoint는 (x, y, confidence) 형태
|
| 58 |
+
"""
|
| 59 |
+
results = self.model.predict(
|
| 60 |
+
frame,
|
| 61 |
+
imgsz=self.imgsz,
|
| 62 |
+
conf=self.conf_threshold,
|
| 63 |
+
verbose=False
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if results and len(results) > 0 and results[0].keypoints is not None:
|
| 67 |
+
keypoints_data = results[0].keypoints.data.cpu().numpy()
|
| 68 |
+
|
| 69 |
+
if len(keypoints_data) > 0:
|
| 70 |
+
# 가장 신뢰도 높은 사람 선택
|
| 71 |
+
if results[0].boxes is not None:
|
| 72 |
+
confidences = results[0].boxes.conf.cpu().numpy()
|
| 73 |
+
best_idx = np.argmax(confidences)
|
| 74 |
+
keypoints = keypoints_data[best_idx] # (17, 3)
|
| 75 |
+
else:
|
| 76 |
+
keypoints = keypoints_data[0]
|
| 77 |
+
|
| 78 |
+
if debug:
|
| 79 |
+
avg_conf = keypoints[:, 2].mean()
|
| 80 |
+
self.logger.debug(f" Pose detected: avg_conf={avg_conf:.3f}")
|
| 81 |
+
|
| 82 |
+
return keypoints
|
| 83 |
+
|
| 84 |
+
if debug:
|
| 85 |
+
self.logger.debug(" No pose detected")
|
| 86 |
+
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
def extract_batch(
|
| 90 |
+
self, frames: list[np.ndarray] | np.ndarray, debug: bool = False
|
| 91 |
+
) -> list[Optional[np.ndarray]]:
|
| 92 |
+
"""
|
| 93 |
+
여러 프레임에서 배치로 pose keypoints 추출 (GPU 활용 극대화)
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
frames: OpenCV 이미지 리스트 [(H, W, 3), ...] 또는 numpy 배열 (N, H, W, C)
|
| 97 |
+
debug: 디버그 로그 출력 여부
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
keypoints_list: [(17, 3) numpy array or None, ...] 각 프레임별 keypoints
|
| 101 |
+
"""
|
| 102 |
+
# 빈 입력 체크 (리스트와 numpy 배열 모두 지원)
|
| 103 |
+
if isinstance(frames, np.ndarray):
|
| 104 |
+
if frames.size == 0:
|
| 105 |
+
return []
|
| 106 |
+
# numpy 배열을 리스트로 변환
|
| 107 |
+
frames = list(frames)
|
| 108 |
+
elif not frames:
|
| 109 |
+
return []
|
| 110 |
+
|
| 111 |
+
# YOLO 배치 추론
|
| 112 |
+
results = self.model.predict(
|
| 113 |
+
frames,
|
| 114 |
+
imgsz=self.imgsz,
|
| 115 |
+
conf=self.conf_threshold,
|
| 116 |
+
verbose=False
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
keypoints_list = []
|
| 120 |
+
for i, result in enumerate(results):
|
| 121 |
+
if result.keypoints is not None:
|
| 122 |
+
keypoints_data = result.keypoints.data.cpu().numpy()
|
| 123 |
+
|
| 124 |
+
if len(keypoints_data) > 0:
|
| 125 |
+
# 가장 신뢰도 높은 사람 선택
|
| 126 |
+
if result.boxes is not None:
|
| 127 |
+
confidences = result.boxes.conf.cpu().numpy()
|
| 128 |
+
best_idx = np.argmax(confidences)
|
| 129 |
+
keypoints = keypoints_data[best_idx] # (17, 3)
|
| 130 |
+
else:
|
| 131 |
+
keypoints = keypoints_data[0]
|
| 132 |
+
|
| 133 |
+
if debug:
|
| 134 |
+
avg_conf = keypoints[:, 2].mean()
|
| 135 |
+
self.logger.debug(
|
| 136 |
+
f" Batch[{i}] Pose detected: avg_conf={avg_conf:.3f}"
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
keypoints_list.append(keypoints)
|
| 140 |
+
continue
|
| 141 |
+
|
| 142 |
+
if debug:
|
| 143 |
+
self.logger.debug(f" Batch[{i}] No pose detected")
|
| 144 |
+
keypoints_list.append(None)
|
| 145 |
+
|
| 146 |
+
return keypoints_list
|
| 147 |
+
|
| 148 |
+
def get_empty_keypoints(self) -> np.ndarray:
|
| 149 |
+
"""빈 keypoints 배열 반환 (사람이 감지되지 않은 경우 사용)"""
|
| 150 |
+
return np.zeros((17, 3), dtype=np.float32)
|
models/stgcn_classifier.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ST-GCN 낙상 분류기 래퍼 클래스
|
| 3 |
+
|
| 4 |
+
Spatial-Temporal Graph Convolutional Network을 이용한 낙상 분류기입니다.
|
| 5 |
+
|
| 6 |
+
Note: HF Spaces 배포용으로 import 경로가 수정되었습니다.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
# HF Spaces 배포용 상대 import
|
| 16 |
+
from augmentation import normalize_skeleton
|
| 17 |
+
from stgcn.model import STGCN
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class STGCNClassifier:
|
| 21 |
+
"""ST-GCN 기반 낙상 분류기"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
checkpoint_path: str = "runs/stgcn_binary_exp2_fixed_graph/best_acc.pth",
|
| 26 |
+
fall_threshold: float = 0.7,
|
| 27 |
+
device: str = "cuda:0",
|
| 28 |
+
in_channels: int = 3,
|
| 29 |
+
num_classes: int = 2,
|
| 30 |
+
dropout: float = 0.5,
|
| 31 |
+
logger: Optional[logging.Logger] = None
|
| 32 |
+
):
|
| 33 |
+
"""
|
| 34 |
+
Args:
|
| 35 |
+
checkpoint_path: ST-GCN 체크포인트 경로
|
| 36 |
+
fall_threshold: 낙상 판정 신뢰도 임계값
|
| 37 |
+
device: 디바이스 (cuda:0, cpu 등)
|
| 38 |
+
in_channels: 입력 채널 수 (x, y, conf)
|
| 39 |
+
num_classes: 출력 클래스 수 (Fall, Non-Fall)
|
| 40 |
+
dropout: 드롭아웃 비율
|
| 41 |
+
logger: 로거 인스턴스
|
| 42 |
+
"""
|
| 43 |
+
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 44 |
+
self.fall_threshold = fall_threshold
|
| 45 |
+
self.logger = logger or logging.getLogger(__name__)
|
| 46 |
+
|
| 47 |
+
self.logger.info(f"[Stage 2] ST-GCN 로드 중: {checkpoint_path}")
|
| 48 |
+
|
| 49 |
+
# 모델 초기화
|
| 50 |
+
self.model = STGCN(
|
| 51 |
+
in_channels=in_channels,
|
| 52 |
+
num_classes=num_classes,
|
| 53 |
+
graph_cfg={},
|
| 54 |
+
edge_importance_weighting=True,
|
| 55 |
+
dropout=dropout
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# 체크포인트 로드
|
| 59 |
+
checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
| 60 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 61 |
+
self.model = self.model.to(self.device)
|
| 62 |
+
self.model.eval()
|
| 63 |
+
|
| 64 |
+
# 체크포인트 정보 로깅
|
| 65 |
+
epoch = checkpoint.get('epoch')
|
| 66 |
+
if epoch is not None:
|
| 67 |
+
self.logger.info(f" - Checkpoint epoch: {epoch}")
|
| 68 |
+
|
| 69 |
+
metrics = checkpoint.get('metrics')
|
| 70 |
+
if isinstance(metrics, dict):
|
| 71 |
+
acc = metrics.get('accuracy')
|
| 72 |
+
f1 = metrics.get('f1')
|
| 73 |
+
if isinstance(acc, (int, float)):
|
| 74 |
+
self.logger.info(f" - Accuracy: {acc:.4f}")
|
| 75 |
+
if isinstance(f1, (int, float)):
|
| 76 |
+
self.logger.info(f" - F1 Score: {f1:.4f}")
|
| 77 |
+
|
| 78 |
+
self.logger.info(f" - Fall threshold: {fall_threshold}")
|
| 79 |
+
self.logger.info(f" - Device: {self.device}")
|
| 80 |
+
|
| 81 |
+
def predict(
|
| 82 |
+
self,
|
| 83 |
+
window: np.ndarray,
|
| 84 |
+
normalize: bool = True,
|
| 85 |
+
debug: bool = False
|
| 86 |
+
) -> Tuple[int, float]:
|
| 87 |
+
"""
|
| 88 |
+
ST-GCN으로 낙상 예측
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
window: (C, T, V, M) ST-GCN 입력 (C=3, T=60, V=17, M=1)
|
| 92 |
+
normalize: hip center 정규화 적용 여부
|
| 93 |
+
debug: 디버그 로그 출력 여부
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
prediction: 0 (Non-Fall) or 1 (Fall)
|
| 97 |
+
confidence: 예측 신뢰도 (0.0-1.0)
|
| 98 |
+
"""
|
| 99 |
+
# Normalize skeleton (hip center + skeleton size scaling)
|
| 100 |
+
if normalize:
|
| 101 |
+
window_input = normalize_skeleton(window, method='hip_center')
|
| 102 |
+
else:
|
| 103 |
+
window_input = window
|
| 104 |
+
|
| 105 |
+
# ST-GCN inference
|
| 106 |
+
window_tensor = torch.from_numpy(window_input).unsqueeze(0).to(self.device) # (1, C, T, V, M)
|
| 107 |
+
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
outputs = self.model(window_tensor)
|
| 110 |
+
probs = torch.softmax(outputs, dim=1)
|
| 111 |
+
pred = torch.argmax(outputs, dim=1)
|
| 112 |
+
|
| 113 |
+
prediction = pred.item()
|
| 114 |
+
confidence = probs[0, prediction].item()
|
| 115 |
+
|
| 116 |
+
if debug:
|
| 117 |
+
self.logger.debug(f" ST-GCN prediction: {prediction} (conf={confidence:.3f})")
|
| 118 |
+
|
| 119 |
+
return prediction, confidence
|
| 120 |
+
|
| 121 |
+
def predict_batch(
|
| 122 |
+
self,
|
| 123 |
+
windows: list[np.ndarray],
|
| 124 |
+
normalize: bool = True,
|
| 125 |
+
debug: bool = False
|
| 126 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 127 |
+
"""
|
| 128 |
+
ST-GCN 배치 낙상 예측 (GPU 활용 극대화)
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
windows: [(C, T, V, M), ...] ST-GCN 입력 윈도우 리스트
|
| 132 |
+
normalize: hip center 정규화 적용 여부
|
| 133 |
+
debug: 디버그 로그 출력 여부
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
predictions: (N,) numpy array of 0 (Non-Fall) or 1 (Fall)
|
| 137 |
+
confidences: (N,) numpy array of predicted class confidence (0.0-1.0)
|
| 138 |
+
fall_probs: (N,) numpy array of Fall class probability (0.0-1.0)
|
| 139 |
+
"""
|
| 140 |
+
if not windows:
|
| 141 |
+
return np.array([]), np.array([]), np.array([])
|
| 142 |
+
|
| 143 |
+
# 정규화 및 배치 텐서 준비
|
| 144 |
+
batch_list = []
|
| 145 |
+
for window in windows:
|
| 146 |
+
if normalize:
|
| 147 |
+
window_input = normalize_skeleton(window, method='hip_center')
|
| 148 |
+
else:
|
| 149 |
+
window_input = window
|
| 150 |
+
batch_list.append(torch.from_numpy(window_input).float())
|
| 151 |
+
|
| 152 |
+
# 배치 텐서 생성 (N, C, T, V, M)
|
| 153 |
+
batch_tensor = torch.stack(batch_list).to(self.device)
|
| 154 |
+
|
| 155 |
+
with torch.no_grad():
|
| 156 |
+
outputs = self.model(batch_tensor)
|
| 157 |
+
probs = torch.softmax(outputs, dim=1)
|
| 158 |
+
preds = torch.argmax(outputs, dim=1)
|
| 159 |
+
|
| 160 |
+
predictions = preds.cpu().numpy()
|
| 161 |
+
# 각 예측에 대해 해당 클래스의 확률을 신뢰도로 사용
|
| 162 |
+
confidences = probs[torch.arange(len(preds)), preds].cpu().numpy()
|
| 163 |
+
# Fall 클래스(class 1)의 확률 - 그래프 표시용
|
| 164 |
+
fall_probs = probs[:, 1].cpu().numpy()
|
| 165 |
+
|
| 166 |
+
if debug:
|
| 167 |
+
for i, (pred, conf, fall_p) in enumerate(zip(predictions, confidences, fall_probs)):
|
| 168 |
+
self.logger.debug(f" Batch[{i}] ST-GCN: pred={pred}, conf={conf:.3f}, fall_prob={fall_p:.3f}")
|
| 169 |
+
|
| 170 |
+
return predictions, confidences, fall_probs
|
| 171 |
+
|
| 172 |
+
def is_fall(self, prediction: int, confidence: float) -> bool:
|
| 173 |
+
"""
|
| 174 |
+
낙상 여부 판정
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
prediction: 모델 예측 (0 or 1)
|
| 178 |
+
confidence: 예측 신뢰도
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
True if fall detected with sufficient confidence
|
| 182 |
+
"""
|
| 183 |
+
return prediction == 1 and confidence >= self.fall_threshold
|
stgcn/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# ST-GCN package for HF Spaces deployment
|
stgcn/graph.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
COCO Skeleton Graph Definition for ST-GCN
|
| 3 |
+
|
| 4 |
+
This module defines the skeleton graph structure for COCO 17-keypoint format
|
| 5 |
+
used by YOLOv11-Pose. The graph represents spatial relationships between joints
|
| 6 |
+
as an adjacency matrix for Spatial-Temporal Graph Convolutional Networks.
|
| 7 |
+
|
| 8 |
+
COCO 17 Keypoints:
|
| 9 |
+
0: nose, 1: left_eye, 2: right_eye, 3: left_ear, 4: right_ear
|
| 10 |
+
5: left_shoulder, 6: right_shoulder, 7: left_elbow, 8: right_elbow
|
| 11 |
+
9: left_wrist, 10: right_wrist, 11: left_hip, 12: right_hip
|
| 12 |
+
13: left_knee, 14: right_knee, 15: left_ankle, 16: right_ankle
|
| 13 |
+
|
| 14 |
+
References:
|
| 15 |
+
- ST-GCN Paper: https://arxiv.org/abs/1801.07455
|
| 16 |
+
- COCO Dataset: https://cocodataset.org/#keypoints-2020
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Graph:
|
| 23 |
+
"""COCO skeleton graph for ST-GCN."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, labeling_mode='spatial'):
|
| 26 |
+
"""
|
| 27 |
+
Initialize COCO skeleton graph.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
labeling_mode: Partitioning strategy for skeleton graph
|
| 31 |
+
- 'spatial': Partition based on spatial distance from center
|
| 32 |
+
- 'uniform': All edges treated equally (baseline)
|
| 33 |
+
"""
|
| 34 |
+
self.num_nodes = 17 # COCO keypoints
|
| 35 |
+
self.labeling_mode = labeling_mode
|
| 36 |
+
|
| 37 |
+
# Define skeleton connectivity (parent-child relationships)
|
| 38 |
+
self.edges = self._get_edges()
|
| 39 |
+
|
| 40 |
+
# Create adjacency matrix
|
| 41 |
+
self.A = self._create_adjacency_matrix()
|
| 42 |
+
|
| 43 |
+
# Get partitioning strategy
|
| 44 |
+
self.A_with_partitions = self._get_partitioned_adjacency()
|
| 45 |
+
|
| 46 |
+
def _get_edges(self):
|
| 47 |
+
"""
|
| 48 |
+
Define COCO skeleton edges (connections between keypoints).
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
List of tuples representing connected joints
|
| 52 |
+
"""
|
| 53 |
+
# COCO skeleton structure (17 keypoints)
|
| 54 |
+
edges = [
|
| 55 |
+
# Head connections
|
| 56 |
+
(0, 1), (0, 2), # nose to eyes
|
| 57 |
+
(1, 3), (2, 4), # eyes to ears
|
| 58 |
+
|
| 59 |
+
# Torso connections
|
| 60 |
+
(5, 6), # shoulders
|
| 61 |
+
(5, 11), (6, 12), # shoulders to hips
|
| 62 |
+
(11, 12), # hips
|
| 63 |
+
|
| 64 |
+
# Left arm
|
| 65 |
+
(5, 7), (7, 9), # shoulder -> elbow -> wrist
|
| 66 |
+
|
| 67 |
+
# Right arm
|
| 68 |
+
(6, 8), (8, 10), # shoulder -> elbow -> wrist
|
| 69 |
+
|
| 70 |
+
# Left leg
|
| 71 |
+
(11, 13), (13, 15), # hip -> knee -> ankle
|
| 72 |
+
|
| 73 |
+
# Right leg
|
| 74 |
+
(12, 14), (14, 16), # hip -> knee -> ankle
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
return edges
|
| 78 |
+
|
| 79 |
+
def _create_adjacency_matrix(self):
|
| 80 |
+
"""
|
| 81 |
+
Create adjacency matrix from skeleton edges.
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
A: (V, V) adjacency matrix where V=17 (number of keypoints)
|
| 85 |
+
"""
|
| 86 |
+
A = np.zeros((self.num_nodes, self.num_nodes))
|
| 87 |
+
|
| 88 |
+
# Add edges (bidirectional connections)
|
| 89 |
+
for i, j in self.edges:
|
| 90 |
+
A[i, j] = 1
|
| 91 |
+
A[j, i] = 1
|
| 92 |
+
|
| 93 |
+
# Add self-connections
|
| 94 |
+
A += np.eye(self.num_nodes)
|
| 95 |
+
|
| 96 |
+
return A
|
| 97 |
+
|
| 98 |
+
def _get_partitioned_adjacency(self):
|
| 99 |
+
"""
|
| 100 |
+
Partition adjacency matrix based on labeling strategy.
|
| 101 |
+
|
| 102 |
+
For spatial labeling, partitions are:
|
| 103 |
+
- Partition 0: Self-connections (centripetal group)
|
| 104 |
+
- Partition 1: Joints closer to skeleton center (centripetal group)
|
| 105 |
+
- Partition 2: Joints farther from skeleton center (centrifugal group)
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
A_partitioned: (num_partitions, V, V) stacked adjacency matrices
|
| 109 |
+
"""
|
| 110 |
+
if self.labeling_mode == 'uniform':
|
| 111 |
+
# Uniform labeling: all edges treated equally
|
| 112 |
+
return self.A[np.newaxis, :, :]
|
| 113 |
+
|
| 114 |
+
elif self.labeling_mode == 'spatial':
|
| 115 |
+
# Spatial labeling: partition based on distance from center
|
| 116 |
+
# Center joint is defined as the midpoint between shoulders (joints 5, 6)
|
| 117 |
+
center_joints = [5, 6] # Left and right shoulders
|
| 118 |
+
|
| 119 |
+
# Initialize partition matrices
|
| 120 |
+
A_partitions = []
|
| 121 |
+
|
| 122 |
+
# Partition 0: Self-connections
|
| 123 |
+
A_self = np.eye(self.num_nodes)
|
| 124 |
+
A_partitions.append(A_self)
|
| 125 |
+
|
| 126 |
+
# Partition 1: Centripetal (moving toward center)
|
| 127 |
+
# Partition 2: Centrifugal (moving away from center)
|
| 128 |
+
A_centripetal = np.zeros((self.num_nodes, self.num_nodes))
|
| 129 |
+
A_centrifugal = np.zeros((self.num_nodes, self.num_nodes))
|
| 130 |
+
|
| 131 |
+
# Compute distances from center for each joint
|
| 132 |
+
distances = self._compute_center_distances(center_joints)
|
| 133 |
+
|
| 134 |
+
# Classify edges based on distance change (both directions)
|
| 135 |
+
for i, j in self.edges:
|
| 136 |
+
if distances[j] < distances[i]:
|
| 137 |
+
# Moving toward center (j is closer than i)
|
| 138 |
+
A_centripetal[i, j] = 1
|
| 139 |
+
# Reverse direction: moving away from center
|
| 140 |
+
A_centrifugal[j, i] = 1
|
| 141 |
+
elif distances[j] > distances[i]:
|
| 142 |
+
# Moving away from center (j is farther than i)
|
| 143 |
+
A_centrifugal[i, j] = 1
|
| 144 |
+
# Reverse direction: moving toward center
|
| 145 |
+
A_centripetal[j, i] = 1
|
| 146 |
+
else:
|
| 147 |
+
# Same distance: treat as centripetal for both directions
|
| 148 |
+
A_centripetal[i, j] = 1
|
| 149 |
+
A_centripetal[j, i] = 1
|
| 150 |
+
|
| 151 |
+
A_partitions.append(A_centripetal)
|
| 152 |
+
A_partitions.append(A_centrifugal)
|
| 153 |
+
|
| 154 |
+
# Stack partitions: (3, V, V)
|
| 155 |
+
A_partitioned = np.stack(A_partitions, axis=0)
|
| 156 |
+
|
| 157 |
+
return A_partitioned
|
| 158 |
+
|
| 159 |
+
else:
|
| 160 |
+
raise ValueError(f"Unknown labeling mode: {self.labeling_mode}")
|
| 161 |
+
|
| 162 |
+
def _compute_center_distances(self, center_joints):
|
| 163 |
+
"""
|
| 164 |
+
Compute graph distance from center joints to all other joints.
|
| 165 |
+
|
| 166 |
+
Uses BFS to compute shortest path distance in graph.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
center_joints: List of joint indices considered as center
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
distances: (V,) array of distances from center
|
| 173 |
+
"""
|
| 174 |
+
from collections import deque
|
| 175 |
+
|
| 176 |
+
distances = np.full(self.num_nodes, np.inf)
|
| 177 |
+
queue = deque()
|
| 178 |
+
|
| 179 |
+
# Initialize center joints with distance 0
|
| 180 |
+
for joint in center_joints:
|
| 181 |
+
distances[joint] = 0
|
| 182 |
+
queue.append(joint)
|
| 183 |
+
|
| 184 |
+
# BFS to compute distances
|
| 185 |
+
while queue:
|
| 186 |
+
current = queue.popleft()
|
| 187 |
+
current_dist = distances[current]
|
| 188 |
+
|
| 189 |
+
# Check all neighbors
|
| 190 |
+
for neighbor in range(self.num_nodes):
|
| 191 |
+
if self.A[current, neighbor] > 0 and neighbor != current:
|
| 192 |
+
if distances[neighbor] > current_dist + 1:
|
| 193 |
+
distances[neighbor] = current_dist + 1
|
| 194 |
+
queue.append(neighbor)
|
| 195 |
+
|
| 196 |
+
return distances
|
| 197 |
+
|
| 198 |
+
def get_adjacency_matrix(self, normalize=True):
|
| 199 |
+
"""
|
| 200 |
+
Get normalized adjacency matrix for ST-GCN.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
normalize: Whether to apply symmetric normalization (D^-0.5 * A * D^-0.5)
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
A_normalized: Normalized adjacency matrix
|
| 207 |
+
"""
|
| 208 |
+
if self.labeling_mode == 'spatial':
|
| 209 |
+
# Return partitioned adjacency matrices
|
| 210 |
+
A = self.A_with_partitions
|
| 211 |
+
|
| 212 |
+
if normalize:
|
| 213 |
+
# Normalize each partition separately
|
| 214 |
+
A_normalized = []
|
| 215 |
+
for partition in A:
|
| 216 |
+
A_norm = self._normalize_adjacency(partition)
|
| 217 |
+
A_normalized.append(A_norm)
|
| 218 |
+
return np.stack(A_normalized, axis=0)
|
| 219 |
+
else:
|
| 220 |
+
return A
|
| 221 |
+
|
| 222 |
+
else:
|
| 223 |
+
# Return single adjacency matrix
|
| 224 |
+
A = self.A[np.newaxis, :, :]
|
| 225 |
+
|
| 226 |
+
if normalize:
|
| 227 |
+
A_norm = self._normalize_adjacency(A[0])
|
| 228 |
+
return A_norm[np.newaxis, :, :]
|
| 229 |
+
else:
|
| 230 |
+
return A
|
| 231 |
+
|
| 232 |
+
def _normalize_adjacency(self, A):
|
| 233 |
+
"""
|
| 234 |
+
Apply symmetric normalization: D^-0.5 * A * D^-0.5
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
A: (V, V) adjacency matrix
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
A_normalized: (V, V) normalized adjacency matrix
|
| 241 |
+
"""
|
| 242 |
+
# Compute degree matrix
|
| 243 |
+
D = np.sum(A, axis=1)
|
| 244 |
+
|
| 245 |
+
# Avoid division by zero
|
| 246 |
+
D[D == 0] = 1
|
| 247 |
+
|
| 248 |
+
# Compute D^-0.5
|
| 249 |
+
D_inv_sqrt = np.power(D, -0.5)
|
| 250 |
+
|
| 251 |
+
# Apply normalization: D^-0.5 * A * D^-0.5
|
| 252 |
+
A_normalized = A * D_inv_sqrt[:, np.newaxis] * D_inv_sqrt[np.newaxis, :]
|
| 253 |
+
|
| 254 |
+
return A_normalized
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def get_coco_graph(labeling_mode='spatial'):
|
| 258 |
+
"""
|
| 259 |
+
Convenience function to get COCO skeleton graph.
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
labeling_mode: Partitioning strategy ('spatial' or 'uniform')
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
Graph object with COCO skeleton structure
|
| 266 |
+
"""
|
| 267 |
+
return Graph(labeling_mode=labeling_mode)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
if __name__ == '__main__':
|
| 271 |
+
# Test graph construction
|
| 272 |
+
print("Testing COCO Skeleton Graph...")
|
| 273 |
+
|
| 274 |
+
# Test uniform labeling
|
| 275 |
+
graph_uniform = Graph(labeling_mode='uniform')
|
| 276 |
+
print(f"\nUniform labeling:")
|
| 277 |
+
print(f" Adjacency shape: {graph_uniform.A.shape}")
|
| 278 |
+
print(f" Partitions shape: {graph_uniform.A_with_partitions.shape}")
|
| 279 |
+
print(f" Number of edges: {len(graph_uniform.edges)}")
|
| 280 |
+
|
| 281 |
+
# Test spatial labeling
|
| 282 |
+
graph_spatial = Graph(labeling_mode='spatial')
|
| 283 |
+
print(f"\nSpatial labeling:")
|
| 284 |
+
print(f" Adjacency shape: {graph_spatial.A.shape}")
|
| 285 |
+
print(f" Partitions shape: {graph_spatial.A_with_partitions.shape}")
|
| 286 |
+
|
| 287 |
+
# Get normalized adjacency
|
| 288 |
+
A_norm = graph_spatial.get_adjacency_matrix(normalize=True)
|
| 289 |
+
print(f"\nNormalized adjacency shape: {A_norm.shape}")
|
| 290 |
+
|
| 291 |
+
print("\nCOCO skeleton graph construction successful!")
|
stgcn/model.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ST-GCN Model for Fall Detection
|
| 3 |
+
|
| 4 |
+
Spatial-Temporal Graph Convolutional Networks for skeleton-based action recognition.
|
| 5 |
+
Adapted for binary fall detection (Fall vs Non-Fall) and multi-class fall type classification.
|
| 6 |
+
|
| 7 |
+
References:
|
| 8 |
+
- ST-GCN Paper: https://arxiv.org/abs/1801.07455
|
| 9 |
+
- Official Implementation: https://github.com/yysijie/st-gcn
|
| 10 |
+
- Fall Detection: Keskes & Noumeir (2021)
|
| 11 |
+
|
| 12 |
+
Input Shape: (N, C, T, V, M)
|
| 13 |
+
- N: Batch size
|
| 14 |
+
- C: Number of channels (3: x, y, confidence)
|
| 15 |
+
- T: Temporal dimension (number of frames)
|
| 16 |
+
- V: Number of vertices (17 COCO keypoints)
|
| 17 |
+
- M: Number of persons (1 for single-person scenarios)
|
| 18 |
+
|
| 19 |
+
Output: Class logits for Fall/Non-Fall (binary) or BY/FY/SY/N (multi-class)
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import torch.nn.functional as F
|
| 25 |
+
|
| 26 |
+
from .graph import Graph
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class STGCNLayer(nn.Module):
|
| 30 |
+
"""
|
| 31 |
+
Spatial-Temporal Graph Convolutional Layer.
|
| 32 |
+
|
| 33 |
+
Combines spatial graph convolution and temporal convolution.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
in_channels,
|
| 39 |
+
out_channels,
|
| 40 |
+
kernel_size,
|
| 41 |
+
stride=1,
|
| 42 |
+
dropout=0.5,
|
| 43 |
+
residual=True
|
| 44 |
+
):
|
| 45 |
+
"""
|
| 46 |
+
Initialize ST-GCN layer.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
in_channels: Number of input channels
|
| 50 |
+
out_channels: Number of output channels
|
| 51 |
+
kernel_size: Tuple (temporal_kernel_size, spatial_kernel_size)
|
| 52 |
+
stride: Temporal stride for downsampling
|
| 53 |
+
dropout: Dropout probability
|
| 54 |
+
residual: Whether to use residual connection
|
| 55 |
+
"""
|
| 56 |
+
super(STGCNLayer, self).__init__()
|
| 57 |
+
|
| 58 |
+
assert len(kernel_size) == 2, "Kernel size must be (temporal, spatial)"
|
| 59 |
+
assert kernel_size[0] % 2 == 1, "Temporal kernel size must be odd"
|
| 60 |
+
|
| 61 |
+
padding = ((kernel_size[0] - 1) // 2, 0) # Temporal padding only
|
| 62 |
+
|
| 63 |
+
# Spatial graph convolution
|
| 64 |
+
self.gcn = SpatialGraphConv(
|
| 65 |
+
in_channels,
|
| 66 |
+
out_channels,
|
| 67 |
+
kernel_size[1]
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Temporal convolution
|
| 71 |
+
self.tcn = nn.Sequential(
|
| 72 |
+
nn.BatchNorm2d(out_channels),
|
| 73 |
+
nn.ReLU(inplace=True),
|
| 74 |
+
nn.Conv2d(
|
| 75 |
+
out_channels,
|
| 76 |
+
out_channels,
|
| 77 |
+
(kernel_size[0], 1),
|
| 78 |
+
(stride, 1),
|
| 79 |
+
padding,
|
| 80 |
+
),
|
| 81 |
+
nn.BatchNorm2d(out_channels),
|
| 82 |
+
nn.Dropout(dropout, inplace=True),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# Residual connection
|
| 86 |
+
if not residual:
|
| 87 |
+
self.residual = lambda x: 0
|
| 88 |
+
elif (in_channels == out_channels) and (stride == 1):
|
| 89 |
+
self.residual = lambda x: x
|
| 90 |
+
else:
|
| 91 |
+
self.residual = nn.Sequential(
|
| 92 |
+
nn.Conv2d(
|
| 93 |
+
in_channels,
|
| 94 |
+
out_channels,
|
| 95 |
+
kernel_size=1,
|
| 96 |
+
stride=(stride, 1)
|
| 97 |
+
),
|
| 98 |
+
nn.BatchNorm2d(out_channels),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.relu = nn.ReLU(inplace=True)
|
| 102 |
+
|
| 103 |
+
def forward(self, x, A):
|
| 104 |
+
"""
|
| 105 |
+
Forward pass.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
x: Input tensor (N, C, T, V)
|
| 109 |
+
A: Adjacency matrix (K, V, V) where K is number of partitions
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
Output tensor (N, C', T', V)
|
| 113 |
+
"""
|
| 114 |
+
res = self.residual(x)
|
| 115 |
+
x = self.gcn(x, A)
|
| 116 |
+
x = self.tcn(x) + res
|
| 117 |
+
|
| 118 |
+
return self.relu(x)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class SpatialGraphConv(nn.Module):
|
| 122 |
+
"""
|
| 123 |
+
Spatial graph convolutional layer.
|
| 124 |
+
|
| 125 |
+
Applies graph convolution on skeleton graph using adjacency matrix.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
|
| 129 |
+
"""
|
| 130 |
+
Initialize spatial graph convolution.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
in_channels: Number of input channels
|
| 134 |
+
out_channels: Number of output channels
|
| 135 |
+
kernel_size: Number of adjacency matrix partitions (1 or 3)
|
| 136 |
+
bias: Whether to include bias term
|
| 137 |
+
"""
|
| 138 |
+
super(SpatialGraphConv, self).__init__()
|
| 139 |
+
|
| 140 |
+
self.kernel_size = kernel_size
|
| 141 |
+
|
| 142 |
+
# Convolutional weights for each partition
|
| 143 |
+
self.conv = nn.Conv2d(
|
| 144 |
+
in_channels,
|
| 145 |
+
out_channels * kernel_size,
|
| 146 |
+
kernel_size=1,
|
| 147 |
+
bias=bias
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
def forward(self, x, A):
|
| 151 |
+
"""
|
| 152 |
+
Forward pass.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
x: Input tensor (N, C, T, V)
|
| 156 |
+
A: Adjacency matrix (K, V, V)
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Output tensor (N, C', T, V)
|
| 160 |
+
"""
|
| 161 |
+
assert A.size(0) == self.kernel_size, \
|
| 162 |
+
f"Adjacency matrix size {A.size(0)} != kernel size {self.kernel_size}"
|
| 163 |
+
|
| 164 |
+
# Apply convolution
|
| 165 |
+
x = self.conv(x) # (N, C'*K, T, V)
|
| 166 |
+
|
| 167 |
+
# Split channels for each partition
|
| 168 |
+
n, kc, t, v = x.size()
|
| 169 |
+
x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v) # (N, K, C', T, V)
|
| 170 |
+
|
| 171 |
+
# Apply graph convolution with each partition
|
| 172 |
+
# A: (K, V, V)
|
| 173 |
+
# x: (N, K, C', T, V)
|
| 174 |
+
x = torch.einsum('nkctv,kvw->nctw', x, A) # (N, C', T, V)
|
| 175 |
+
|
| 176 |
+
return x.contiguous()
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class STGCN(nn.Module):
|
| 180 |
+
"""
|
| 181 |
+
ST-GCN model for fall detection.
|
| 182 |
+
|
| 183 |
+
Architecture:
|
| 184 |
+
- Input: (N, 3, 60, 17, 1) - batch, channels, frames, joints, persons
|
| 185 |
+
- ST-GCN layers: Extract spatial-temporal features
|
| 186 |
+
- Global pooling: Aggregate features across time and space
|
| 187 |
+
- FC layers: Classification (binary or multi-class)
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
num_classes=2,
|
| 193 |
+
in_channels=3,
|
| 194 |
+
edge_importance_weighting=True,
|
| 195 |
+
graph_cfg=None,
|
| 196 |
+
dropout=0.5,
|
| 197 |
+
**kwargs
|
| 198 |
+
):
|
| 199 |
+
"""
|
| 200 |
+
Initialize ST-GCN model.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
num_classes: Number of output classes (2 for binary, 4 for multi-class)
|
| 204 |
+
in_channels: Number of input channels (3: x, y, confidence)
|
| 205 |
+
edge_importance_weighting: Whether to learn edge importance weights
|
| 206 |
+
graph_cfg: Graph configuration (default: spatial labeling)
|
| 207 |
+
dropout: Dropout probability
|
| 208 |
+
"""
|
| 209 |
+
super(STGCN, self).__init__()
|
| 210 |
+
|
| 211 |
+
# Load graph
|
| 212 |
+
if graph_cfg is None:
|
| 213 |
+
graph_cfg = {'labeling_mode': 'spatial'}
|
| 214 |
+
|
| 215 |
+
self.graph = Graph(**graph_cfg)
|
| 216 |
+
|
| 217 |
+
# Get adjacency matrix (K, V, V) where K=3 for spatial labeling
|
| 218 |
+
A = torch.tensor(
|
| 219 |
+
self.graph.get_adjacency_matrix(normalize=True),
|
| 220 |
+
dtype=torch.float32,
|
| 221 |
+
requires_grad=False
|
| 222 |
+
)
|
| 223 |
+
self.register_buffer('A', A)
|
| 224 |
+
|
| 225 |
+
# Number of adjacency matrix partitions
|
| 226 |
+
spatial_kernel_size = A.size(0) # 3 for spatial labeling
|
| 227 |
+
|
| 228 |
+
# Temporal kernel size (odd numbers for symmetric padding)
|
| 229 |
+
temporal_kernel_size = 9
|
| 230 |
+
|
| 231 |
+
# Build ST-GCN layers
|
| 232 |
+
kernel_size = (temporal_kernel_size, spatial_kernel_size)
|
| 233 |
+
|
| 234 |
+
# Layer configurations: (in_channels, out_channels, stride)
|
| 235 |
+
self.st_gcn_networks = nn.ModuleList((
|
| 236 |
+
STGCNLayer(in_channels, 64, kernel_size, 1, dropout, residual=False),
|
| 237 |
+
STGCNLayer(64, 64, kernel_size, 1, dropout),
|
| 238 |
+
STGCNLayer(64, 64, kernel_size, 1, dropout),
|
| 239 |
+
STGCNLayer(64, 64, kernel_size, 1, dropout),
|
| 240 |
+
STGCNLayer(64, 128, kernel_size, 2, dropout),
|
| 241 |
+
STGCNLayer(128, 128, kernel_size, 1, dropout),
|
| 242 |
+
STGCNLayer(128, 128, kernel_size, 1, dropout),
|
| 243 |
+
STGCNLayer(128, 256, kernel_size, 2, dropout),
|
| 244 |
+
STGCNLayer(256, 256, kernel_size, 1, dropout),
|
| 245 |
+
STGCNLayer(256, 256, kernel_size, 1, dropout),
|
| 246 |
+
))
|
| 247 |
+
|
| 248 |
+
# Edge importance weighting
|
| 249 |
+
if edge_importance_weighting:
|
| 250 |
+
self.edge_importance = nn.ParameterList([
|
| 251 |
+
nn.Parameter(torch.ones(self.A.size()))
|
| 252 |
+
for _ in self.st_gcn_networks
|
| 253 |
+
])
|
| 254 |
+
else:
|
| 255 |
+
self.edge_importance = [1] * len(self.st_gcn_networks)
|
| 256 |
+
|
| 257 |
+
# Fully connected layer for classification
|
| 258 |
+
self.fcn = nn.Conv2d(256, num_classes, kernel_size=1)
|
| 259 |
+
|
| 260 |
+
def forward(self, x):
|
| 261 |
+
"""
|
| 262 |
+
Forward pass.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
x: Input tensor (N, C, T, V, M)
|
| 266 |
+
- N: Batch size
|
| 267 |
+
- C: Number of channels (3)
|
| 268 |
+
- T: Number of frames (60)
|
| 269 |
+
- V: Number of joints (17)
|
| 270 |
+
- M: Number of persons (1)
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
Output logits (N, num_classes)
|
| 274 |
+
"""
|
| 275 |
+
# Reshape input: (N, C, T, V, M) -> (N*M, C, T, V)
|
| 276 |
+
N, C, T, V, M = x.size()
|
| 277 |
+
x = x.permute(0, 4, 1, 2, 3).contiguous() # (N, M, C, T, V)
|
| 278 |
+
x = x.view(N * M, C, T, V) # (N*M, C, T, V)
|
| 279 |
+
|
| 280 |
+
# Forward through ST-GCN layers
|
| 281 |
+
for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
|
| 282 |
+
x = gcn(x, self.A * importance)
|
| 283 |
+
|
| 284 |
+
# Global pooling: (N*M, C, T, V) -> (N*M, C)
|
| 285 |
+
x = F.avg_pool2d(x, x.size()[2:]) # (N*M, C, 1, 1)
|
| 286 |
+
x = x.view(N, M, -1, 1, 1).mean(dim=1) # Average across persons: (N, C, 1, 1)
|
| 287 |
+
|
| 288 |
+
# Classification
|
| 289 |
+
x = self.fcn(x) # (N, num_classes, 1, 1)
|
| 290 |
+
x = x.view(x.size(0), -1) # (N, num_classes)
|
| 291 |
+
|
| 292 |
+
return x
|
| 293 |
+
|
| 294 |
+
def extract_features(self, x):
|
| 295 |
+
"""
|
| 296 |
+
Extract features before classification layer.
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
x: Input tensor (N, C, T, V, M)
|
| 300 |
+
|
| 301 |
+
Returns:
|
| 302 |
+
Feature tensor (N, 256)
|
| 303 |
+
"""
|
| 304 |
+
# Reshape input
|
| 305 |
+
N, C, T, V, M = x.size()
|
| 306 |
+
x = x.permute(0, 4, 1, 2, 3).contiguous()
|
| 307 |
+
x = x.view(N * M, C, T, V)
|
| 308 |
+
|
| 309 |
+
# Forward through ST-GCN layers
|
| 310 |
+
for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
|
| 311 |
+
x = gcn(x, self.A * importance)
|
| 312 |
+
|
| 313 |
+
# Global pooling
|
| 314 |
+
x = F.avg_pool2d(x, x.size()[2:])
|
| 315 |
+
x = x.view(N, M, -1).mean(dim=1) # (N, 256)
|
| 316 |
+
|
| 317 |
+
return x
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
def stgcn_binary(pretrained=False, **kwargs):
|
| 321 |
+
"""
|
| 322 |
+
ST-GCN for binary fall detection (Fall vs Non-Fall).
|
| 323 |
+
|
| 324 |
+
Args:
|
| 325 |
+
pretrained: Whether to load pretrained weights (not implemented)
|
| 326 |
+
**kwargs: Additional model arguments
|
| 327 |
+
|
| 328 |
+
Returns:
|
| 329 |
+
ST-GCN model
|
| 330 |
+
"""
|
| 331 |
+
model = STGCN(num_classes=2, **kwargs)
|
| 332 |
+
|
| 333 |
+
if pretrained:
|
| 334 |
+
raise NotImplementedError("Pretrained weights not available")
|
| 335 |
+
|
| 336 |
+
return model
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def stgcn_multiclass(pretrained=False, **kwargs):
|
| 340 |
+
"""
|
| 341 |
+
ST-GCN for multi-class fall detection (BY/FY/SY/N).
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
pretrained: Whether to load pretrained weights (not implemented)
|
| 345 |
+
**kwargs: Additional model arguments
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
ST-GCN model
|
| 349 |
+
"""
|
| 350 |
+
model = STGCN(num_classes=4, **kwargs)
|
| 351 |
+
|
| 352 |
+
if pretrained:
|
| 353 |
+
raise NotImplementedError("Pretrained weights not available")
|
| 354 |
+
|
| 355 |
+
return model
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
if __name__ == '__main__':
|
| 359 |
+
# Test model construction
|
| 360 |
+
print("Testing ST-GCN Model...")
|
| 361 |
+
|
| 362 |
+
# Binary classification
|
| 363 |
+
model_binary = stgcn_binary()
|
| 364 |
+
print(f"\nBinary ST-GCN:")
|
| 365 |
+
print(f" Parameters: {sum(p.numel() for p in model_binary.parameters()):,}")
|
| 366 |
+
print(f" Trainable: {sum(p.numel() for p in model_binary.parameters() if p.requires_grad):,}")
|
| 367 |
+
|
| 368 |
+
# Multi-class classification
|
| 369 |
+
model_multiclass = stgcn_multiclass()
|
| 370 |
+
print(f"\nMulti-class ST-GCN:")
|
| 371 |
+
print(f" Parameters: {sum(p.numel() for p in model_multiclass.parameters()):,}")
|
| 372 |
+
|
| 373 |
+
# Test forward pass
|
| 374 |
+
batch_size = 4
|
| 375 |
+
input_tensor = torch.randn(batch_size, 3, 60, 17, 1)
|
| 376 |
+
print(f"\nInput shape: {input_tensor.shape}")
|
| 377 |
+
|
| 378 |
+
# Binary output
|
| 379 |
+
output_binary = model_binary(input_tensor)
|
| 380 |
+
print(f"Binary output shape: {output_binary.shape}")
|
| 381 |
+
print(f"Binary output: {output_binary}")
|
| 382 |
+
|
| 383 |
+
# Multi-class output
|
| 384 |
+
output_multiclass = model_multiclass(input_tensor)
|
| 385 |
+
print(f"Multi-class output shape: {output_multiclass.shape}")
|
| 386 |
+
|
| 387 |
+
# Feature extraction
|
| 388 |
+
features = model_binary.extract_features(input_tensor)
|
| 389 |
+
print(f"Feature shape: {features.shape}")
|
| 390 |
+
|
| 391 |
+
print("\nST-GCN model construction successful!")
|
visualization.py
ADDED
|
@@ -0,0 +1,973 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Real-time Fall Detection Visualization Module
|
| 3 |
+
|
| 4 |
+
이 모듈은 실시간 낙상 감지 파이프라인의 시각화 기능을 제공합니다.
|
| 5 |
+
COCO 17 keypoints 스켈레톤 렌더링, 예측 결과 오버레이, 성능 메트릭 표시 등을 포함합니다.
|
| 6 |
+
|
| 7 |
+
주요 기능:
|
| 8 |
+
- COCO 17 keypoints 스켈레톤 렌더링
|
| 9 |
+
- Bounding box 렌더링
|
| 10 |
+
- Fall/Non-Fall 라벨 + 신뢰도 표시
|
| 11 |
+
- FPS/Latency 실시간 표시
|
| 12 |
+
- 색상 코딩 (Fall: 빨강, Non-Fall: 초록)
|
| 13 |
+
|
| 14 |
+
최적화 (Issue #56):
|
| 15 |
+
- NumPy 벡터화로 cv2.circle()/cv2.line() 루프 대체
|
| 16 |
+
- morphological dilation으로 keypoint 원 그리기 (30배 속도 향상)
|
| 17 |
+
- cv2.polylines()로 skeleton 선 일괄 그리기
|
| 18 |
+
- 주요 keypoint만 표시 옵션 (--viz-keypoints major)
|
| 19 |
+
- 출력 해상도 조절 옵션 (--viz-scale 0.5)
|
| 20 |
+
|
| 21 |
+
Reference:
|
| 22 |
+
- COCO Keypoints: https://cocodataset.org/#keypoints-2017
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import cv2
|
| 26 |
+
import numpy as np
|
| 27 |
+
from typing import Tuple, Optional, List, Literal
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# COCO 17 keypoints 인덱스
|
| 31 |
+
COCO_KEYPOINT_NAMES = [
|
| 32 |
+
'nose', # 0
|
| 33 |
+
'left_eye', # 1
|
| 34 |
+
'right_eye', # 2
|
| 35 |
+
'left_ear', # 3
|
| 36 |
+
'right_ear', # 4
|
| 37 |
+
'left_shoulder', # 5
|
| 38 |
+
'right_shoulder', # 6
|
| 39 |
+
'left_elbow', # 7
|
| 40 |
+
'right_elbow', # 8
|
| 41 |
+
'left_wrist', # 9
|
| 42 |
+
'right_wrist', # 10
|
| 43 |
+
'left_hip', # 11
|
| 44 |
+
'right_hip', # 12
|
| 45 |
+
'left_knee', # 13
|
| 46 |
+
'right_knee', # 14
|
| 47 |
+
'left_ankle', # 15
|
| 48 |
+
'right_ankle', # 16
|
| 49 |
+
]
|
| 50 |
+
|
| 51 |
+
# COCO 스켈레톤 연결 정의 (뼈대 구조)
|
| 52 |
+
COCO_SKELETON = [
|
| 53 |
+
# 얼굴
|
| 54 |
+
(0, 1), # nose -> left_eye
|
| 55 |
+
(0, 2), # nose -> right_eye
|
| 56 |
+
(1, 3), # left_eye -> left_ear
|
| 57 |
+
(2, 4), # right_eye -> right_ear
|
| 58 |
+
|
| 59 |
+
# 상체
|
| 60 |
+
(0, 5), # nose -> left_shoulder
|
| 61 |
+
(0, 6), # nose -> right_shoulder
|
| 62 |
+
(5, 6), # left_shoulder <-> right_shoulder
|
| 63 |
+
|
| 64 |
+
# 왼팔
|
| 65 |
+
(5, 7), # left_shoulder -> left_elbow
|
| 66 |
+
(7, 9), # left_elbow -> left_wrist
|
| 67 |
+
|
| 68 |
+
# 오른팔
|
| 69 |
+
(6, 8), # right_shoulder -> right_elbow
|
| 70 |
+
(8, 10), # right_elbow -> right_wrist
|
| 71 |
+
|
| 72 |
+
# 몸통
|
| 73 |
+
(5, 11), # left_shoulder -> left_hip
|
| 74 |
+
(6, 12), # right_shoulder -> right_hip
|
| 75 |
+
(11, 12), # left_hip <-> right_hip
|
| 76 |
+
|
| 77 |
+
# 왼다리
|
| 78 |
+
(11, 13), # left_hip -> left_knee
|
| 79 |
+
(13, 15), # left_knee -> left_ankle
|
| 80 |
+
|
| 81 |
+
# 오른다리
|
| 82 |
+
(12, 14), # right_hip -> right_knee
|
| 83 |
+
(14, 16), # right_knee -> right_ankle
|
| 84 |
+
]
|
| 85 |
+
|
| 86 |
+
# 신체 부위별 색상 정의 (BGR 포맷)
|
| 87 |
+
BODY_PART_COLORS = {
|
| 88 |
+
'face': (0, 255, 255), # 노란색
|
| 89 |
+
'left_arm': (255, 0, 180), # 분홍색
|
| 90 |
+
'right_arm': (0, 165, 255), # 오렌지색
|
| 91 |
+
'torso': (255, 150, 0), # 파란색
|
| 92 |
+
'left_leg': (0, 0, 255), # 빨간색
|
| 93 |
+
'right_leg': (180, 0, 255), # 보라색
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# 각 스켈레톤 연결에 대한 신체 부위 매핑
|
| 97 |
+
SKELETON_PART_MAPPING = [
|
| 98 |
+
'face', # (0, 1) nose -> left_eye
|
| 99 |
+
'face', # (0, 2) nose -> right_eye
|
| 100 |
+
'face', # (1, 3) left_eye -> left_ear
|
| 101 |
+
'face', # (2, 4) right_eye -> right_ear
|
| 102 |
+
'face', # (0, 5) nose -> left_shoulder
|
| 103 |
+
'face', # (0, 6) nose -> right_shoulder
|
| 104 |
+
'torso', # (5, 6) left_shoulder <-> right_shoulder
|
| 105 |
+
'left_arm', # (5, 7) left_shoulder -> left_elbow
|
| 106 |
+
'left_arm', # (7, 9) left_elbow -> left_wrist
|
| 107 |
+
'right_arm', # (6, 8) right_shoulder -> right_elbow
|
| 108 |
+
'right_arm', # (8, 10) right_elbow -> right_wrist
|
| 109 |
+
'torso', # (5, 11) left_shoulder -> left_hip
|
| 110 |
+
'torso', # (6, 12) right_shoulder -> right_hip
|
| 111 |
+
'torso', # (11, 12) left_hip <-> right_hip
|
| 112 |
+
'left_leg', # (11, 13) left_hip -> left_knee
|
| 113 |
+
'left_leg', # (13, 15) left_knee -> left_ankle
|
| 114 |
+
'right_leg', # (12, 14) right_hip -> right_knee
|
| 115 |
+
'right_leg', # (14, 16) right_knee -> right_ankle
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
# 예측 결과 색상 정의
|
| 119 |
+
PREDICTION_COLORS = {
|
| 120 |
+
'Fall': (0, 0, 255), # 빨강
|
| 121 |
+
'Non-Fall': (0, 255, 0), # 초록
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
# 주요 keypoint 인덱스 (9개: 코, 어깨, 엉덩이, 무릎, 발목)
|
| 125 |
+
# 낙상 감지에 중요한 신체 부위만 선택
|
| 126 |
+
MAJOR_KEYPOINT_INDICES = [
|
| 127 |
+
0, # nose - 머리 위치
|
| 128 |
+
5, # left_shoulder
|
| 129 |
+
6, # right_shoulder
|
| 130 |
+
11, # left_hip
|
| 131 |
+
12, # right_hip
|
| 132 |
+
13, # left_knee
|
| 133 |
+
14, # right_knee
|
| 134 |
+
15, # left_ankle
|
| 135 |
+
16, # right_ankle
|
| 136 |
+
]
|
| 137 |
+
|
| 138 |
+
# 주요 keypoint용 skeleton 연결 (8개 연결)
|
| 139 |
+
MAJOR_SKELETON = [
|
| 140 |
+
(5, 6), # left_shoulder <-> right_shoulder
|
| 141 |
+
(5, 11), # left_shoulder -> left_hip
|
| 142 |
+
(6, 12), # right_shoulder -> right_hip
|
| 143 |
+
(11, 12), # left_hip <-> right_hip
|
| 144 |
+
(11, 13), # left_hip -> left_knee
|
| 145 |
+
(12, 14), # right_hip -> right_knee
|
| 146 |
+
(13, 15), # left_knee -> left_ankle
|
| 147 |
+
(14, 16), # right_knee -> right_ankle
|
| 148 |
+
]
|
| 149 |
+
|
| 150 |
+
# 주요 skeleton 신체 부위 매핑
|
| 151 |
+
MAJOR_SKELETON_PART_MAPPING = [
|
| 152 |
+
'torso', # (5, 6)
|
| 153 |
+
'torso', # (5, 11)
|
| 154 |
+
'torso', # (6, 12)
|
| 155 |
+
'torso', # (11, 12)
|
| 156 |
+
'left_leg', # (11, 13)
|
| 157 |
+
'right_leg', # (12, 14)
|
| 158 |
+
'left_leg', # (13, 15)
|
| 159 |
+
'right_leg', # (14, 16)
|
| 160 |
+
]
|
| 161 |
+
|
| 162 |
+
# Morphological dilation용 커널 캐시 (동일 크기 재사용)
|
| 163 |
+
_KERNEL_CACHE = {}
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def draw_skeleton(
|
| 167 |
+
frame: np.ndarray,
|
| 168 |
+
keypoints: np.ndarray,
|
| 169 |
+
color: Tuple[int, int, int] = (0, 255, 0),
|
| 170 |
+
thickness: int = 2,
|
| 171 |
+
conf_threshold: float = 0.5,
|
| 172 |
+
keypoint_radius: int = 4,
|
| 173 |
+
use_body_part_colors: bool = True
|
| 174 |
+
) -> np.ndarray:
|
| 175 |
+
"""
|
| 176 |
+
COCO 17 keypoints 스켈레톤 렌더링
|
| 177 |
+
|
| 178 |
+
Args:
|
| 179 |
+
frame: OpenCV 이미지 (H, W, 3) BGR 포맷
|
| 180 |
+
keypoints: (17, 3) numpy array - (x, y, conf)
|
| 181 |
+
color: BGR 색상 (use_body_part_colors=False일 때 사용)
|
| 182 |
+
thickness: 선 두께
|
| 183 |
+
conf_threshold: 최소 신뢰도 임계값 (이 값 이하는 그리지 않음)
|
| 184 |
+
keypoint_radius: 키포인트 원의 반지름
|
| 185 |
+
use_body_part_colors: True면 신체 부위별 색상 사용, False면 단일 색상 사용
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
frame: 스켈레톤이 렌더링된 이미지
|
| 189 |
+
"""
|
| 190 |
+
if keypoints.shape != (17, 3):
|
| 191 |
+
raise ValueError(f"Expected keypoints shape (17, 3), got {keypoints.shape}")
|
| 192 |
+
|
| 193 |
+
frame = frame.copy()
|
| 194 |
+
|
| 195 |
+
# 1. 스켈레톤 연결선 그리기
|
| 196 |
+
for i, (start_idx, end_idx) in enumerate(COCO_SKELETON):
|
| 197 |
+
x1, y1, conf1 = keypoints[start_idx]
|
| 198 |
+
x2, y2, conf2 = keypoints[end_idx]
|
| 199 |
+
|
| 200 |
+
# 두 키포인트 모두 신뢰도 임계값을 넘어야 선을 그림
|
| 201 |
+
if conf1 > conf_threshold and conf2 > conf_threshold:
|
| 202 |
+
# 신체 부위별 색상 또는 단일 색상 선택
|
| 203 |
+
if use_body_part_colors:
|
| 204 |
+
part_name = SKELETON_PART_MAPPING[i]
|
| 205 |
+
line_color = BODY_PART_COLORS[part_name]
|
| 206 |
+
else:
|
| 207 |
+
line_color = color
|
| 208 |
+
|
| 209 |
+
# 선 그리기
|
| 210 |
+
pt1 = (int(x1), int(y1))
|
| 211 |
+
pt2 = (int(x2), int(y2))
|
| 212 |
+
cv2.line(frame, pt1, pt2, line_color, thickness, cv2.LINE_AA)
|
| 213 |
+
|
| 214 |
+
# 2. 키포인트 원 그리기 (선 위에 그려서 더 눈에 띄게)
|
| 215 |
+
for i, (x, y, conf) in enumerate(keypoints):
|
| 216 |
+
if conf > conf_threshold:
|
| 217 |
+
center = (int(x), int(y))
|
| 218 |
+
|
| 219 |
+
# 외곽 흰색 테두리
|
| 220 |
+
cv2.circle(frame, center, keypoint_radius + 2, (255, 255, 255), -1, cv2.LINE_AA)
|
| 221 |
+
|
| 222 |
+
# 내부 색상 원 (밝은 하늘색)
|
| 223 |
+
cv2.circle(frame, center, keypoint_radius, (255, 200, 0), -1, cv2.LINE_AA)
|
| 224 |
+
|
| 225 |
+
return frame
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def _get_ellipse_kernel(radius: int) -> np.ndarray:
|
| 229 |
+
"""
|
| 230 |
+
캐시된 ellipse 커널 반환 (morphological dilation용)
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
radius: 커널 반지름
|
| 234 |
+
|
| 235 |
+
Returns:
|
| 236 |
+
ellipse 커널
|
| 237 |
+
"""
|
| 238 |
+
if radius not in _KERNEL_CACHE:
|
| 239 |
+
kernel_size = radius * 2 + 1
|
| 240 |
+
_KERNEL_CACHE[radius] = cv2.getStructuringElement(
|
| 241 |
+
cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)
|
| 242 |
+
)
|
| 243 |
+
return _KERNEL_CACHE[radius]
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def draw_skeleton_vectorized(
|
| 247 |
+
frame: np.ndarray,
|
| 248 |
+
keypoints: np.ndarray,
|
| 249 |
+
conf_threshold: float = 0.5,
|
| 250 |
+
keypoint_radius: int = 4,
|
| 251 |
+
thickness: int = 2,
|
| 252 |
+
keypoint_mode: Literal['all', 'major'] = 'all',
|
| 253 |
+
use_body_part_colors: bool = True,
|
| 254 |
+
keypoint_color: Tuple[int, int, int] = (255, 200, 0),
|
| 255 |
+
border_color: Tuple[int, int, int] = (255, 255, 255)
|
| 256 |
+
) -> np.ndarray:
|
| 257 |
+
"""
|
| 258 |
+
최적화된 skeleton 렌더링
|
| 259 |
+
|
| 260 |
+
최적화 전략:
|
| 261 |
+
- cv2.polylines()로 skeleton 선 일괄 처리 (색상별 그룹화)
|
| 262 |
+
- 주요 keypoint만 표시 옵션으로 그리기 횟수 감소 (17개 -> 9개)
|
| 263 |
+
- Anti-aliasing 비활성화 옵션 (cv2.LINE_AA -> cv2.LINE_8)
|
| 264 |
+
|
| 265 |
+
Note: morphological dilation은 4K 해상도에서 전체 이미지 마스크 생성으로
|
| 266 |
+
오히려 느려지므로, keypoint 원은 기존 cv2.circle() 유지
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
frame: OpenCV 이미지 (H, W, 3) BGR 포맷
|
| 270 |
+
keypoints: (17, 3) numpy array - (x, y, conf)
|
| 271 |
+
conf_threshold: 최소 신뢰도 임계값 (이 값 이하는 그리지 않음)
|
| 272 |
+
keypoint_radius: 키포인트 원의 반지름
|
| 273 |
+
thickness: skeleton 선 두께
|
| 274 |
+
keypoint_mode: 'all'=전체 17개, 'major'=주요 9개만 표시
|
| 275 |
+
use_body_part_colors: True면 신체 부위별 색상 사용
|
| 276 |
+
keypoint_color: keypoint 원 색상 (BGR)
|
| 277 |
+
border_color: keypoint 테두리 색상 (BGR)
|
| 278 |
+
|
| 279 |
+
Returns:
|
| 280 |
+
frame: 스켈레톤이 렌더링된 이미지
|
| 281 |
+
"""
|
| 282 |
+
if keypoints.shape != (17, 3):
|
| 283 |
+
raise ValueError(f"Expected keypoints shape (17, 3), got {keypoints.shape}")
|
| 284 |
+
|
| 285 |
+
result = frame.copy()
|
| 286 |
+
|
| 287 |
+
# keypoint 모드에 따른 인덱스/skeleton 선택
|
| 288 |
+
if keypoint_mode == 'major':
|
| 289 |
+
kpt_indices = MAJOR_KEYPOINT_INDICES
|
| 290 |
+
skeleton = MAJOR_SKELETON
|
| 291 |
+
skeleton_parts = MAJOR_SKELETON_PART_MAPPING
|
| 292 |
+
else:
|
| 293 |
+
kpt_indices = list(range(17))
|
| 294 |
+
skeleton = COCO_SKELETON
|
| 295 |
+
skeleton_parts = SKELETON_PART_MAPPING
|
| 296 |
+
|
| 297 |
+
# 유효한 keypoints 필터링 (confidence > threshold)
|
| 298 |
+
valid_mask = keypoints[:, 2] > conf_threshold
|
| 299 |
+
if keypoint_mode == 'major':
|
| 300 |
+
# 주요 keypoint 인덱스만 고려
|
| 301 |
+
major_mask = np.zeros(17, dtype=bool)
|
| 302 |
+
major_mask[kpt_indices] = True
|
| 303 |
+
valid_mask = valid_mask & major_mask
|
| 304 |
+
|
| 305 |
+
valid_indices = np.where(valid_mask)[0]
|
| 306 |
+
|
| 307 |
+
if len(valid_indices) == 0:
|
| 308 |
+
return result
|
| 309 |
+
|
| 310 |
+
# 1. Skeleton 선 그리기 (cv2.polylines 사용 - 배치 처리)
|
| 311 |
+
if use_body_part_colors:
|
| 312 |
+
# 색상별로 선 그룹화
|
| 313 |
+
color_groups = {}
|
| 314 |
+
for i, (start_idx, end_idx) in enumerate(skeleton):
|
| 315 |
+
if valid_mask[start_idx] and valid_mask[end_idx]:
|
| 316 |
+
part_name = skeleton_parts[i]
|
| 317 |
+
color = BODY_PART_COLORS[part_name]
|
| 318 |
+
if color not in color_groups:
|
| 319 |
+
color_groups[color] = []
|
| 320 |
+
pt1 = (int(keypoints[start_idx, 0]), int(keypoints[start_idx, 1]))
|
| 321 |
+
pt2 = (int(keypoints[end_idx, 0]), int(keypoints[end_idx, 1]))
|
| 322 |
+
color_groups[color].append(np.array([pt1, pt2], dtype=np.int32))
|
| 323 |
+
|
| 324 |
+
# 색상별로 일괄 그리기
|
| 325 |
+
for color, lines in color_groups.items():
|
| 326 |
+
if lines:
|
| 327 |
+
cv2.polylines(result, lines, isClosed=False, color=color,
|
| 328 |
+
thickness=thickness, lineType=cv2.LINE_AA)
|
| 329 |
+
else:
|
| 330 |
+
# 단일 색상으로 모든 선 그리기
|
| 331 |
+
lines = []
|
| 332 |
+
for start_idx, end_idx in skeleton:
|
| 333 |
+
if valid_mask[start_idx] and valid_mask[end_idx]:
|
| 334 |
+
pt1 = (int(keypoints[start_idx, 0]), int(keypoints[start_idx, 1]))
|
| 335 |
+
pt2 = (int(keypoints[end_idx, 0]), int(keypoints[end_idx, 1]))
|
| 336 |
+
lines.append(np.array([pt1, pt2], dtype=np.int32))
|
| 337 |
+
|
| 338 |
+
if lines:
|
| 339 |
+
cv2.polylines(result, lines, isClosed=False, color=(255, 255, 255),
|
| 340 |
+
thickness=thickness, lineType=cv2.LINE_AA)
|
| 341 |
+
|
| 342 |
+
# 2. Keypoint 원 그리기 (cv2.circle 사용 - 개수가 적어 루프가 효율적)
|
| 343 |
+
for idx in valid_indices:
|
| 344 |
+
x, y = int(keypoints[idx, 0]), int(keypoints[idx, 1])
|
| 345 |
+
center = (x, y)
|
| 346 |
+
|
| 347 |
+
# 외곽 테두리
|
| 348 |
+
cv2.circle(result, center, keypoint_radius + 2, border_color, -1, cv2.LINE_AA)
|
| 349 |
+
|
| 350 |
+
# 내부 색상 원
|
| 351 |
+
cv2.circle(result, center, keypoint_radius, keypoint_color, -1, cv2.LINE_AA)
|
| 352 |
+
|
| 353 |
+
return result
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def draw_prediction(
|
| 357 |
+
frame: np.ndarray,
|
| 358 |
+
prediction: str,
|
| 359 |
+
confidence: float,
|
| 360 |
+
bbox: Optional[Tuple[int, int, int, int]] = None,
|
| 361 |
+
fps: Optional[float] = None,
|
| 362 |
+
latency: Optional[float] = None,
|
| 363 |
+
position: str = 'top-left'
|
| 364 |
+
) -> np.ndarray:
|
| 365 |
+
"""
|
| 366 |
+
예측 결과 오버레이 렌더링
|
| 367 |
+
|
| 368 |
+
Args:
|
| 369 |
+
frame: OpenCV 이미지
|
| 370 |
+
prediction: 'Fall' 또는 'Non-Fall'
|
| 371 |
+
confidence: 신뢰도 (0.0-1.0)
|
| 372 |
+
bbox: (x1, y1, x2, y2) 바운딩 박스 (선택)
|
| 373 |
+
fps: FPS 값 (선택)
|
| 374 |
+
latency: Latency (ms) (선택)
|
| 375 |
+
position: 텍스트 위치 ('top-left', 'top-right', 'bottom-left', 'bottom-right')
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
frame: 렌더링된 이미지
|
| 379 |
+
"""
|
| 380 |
+
frame = frame.copy()
|
| 381 |
+
h, w = frame.shape[:2]
|
| 382 |
+
|
| 383 |
+
# 1. 바운딩 박스 그리기 (있을 경우)
|
| 384 |
+
if bbox is not None:
|
| 385 |
+
x1, y1, x2, y2 = bbox
|
| 386 |
+
pred_color = PREDICTION_COLORS.get(prediction, (255, 255, 255))
|
| 387 |
+
|
| 388 |
+
# 박스 두께는 Fall일 때 더 두껍게
|
| 389 |
+
box_thickness = 4 if prediction == 'Fall' else 2
|
| 390 |
+
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), pred_color, box_thickness)
|
| 391 |
+
|
| 392 |
+
# 2. 예측 라벨 + 신뢰도 텍스트 준비
|
| 393 |
+
if confidence is not None:
|
| 394 |
+
pred_text = f"{prediction}: {confidence:.2%}"
|
| 395 |
+
else:
|
| 396 |
+
pred_text = f"{prediction}"
|
| 397 |
+
pred_color = PREDICTION_COLORS.get(prediction, (255, 255, 255))
|
| 398 |
+
|
| 399 |
+
# 3. FPS/Latency 텍스트 준비 (있을 경우)
|
| 400 |
+
info_texts = []
|
| 401 |
+
if fps is not None:
|
| 402 |
+
info_texts.append(f"FPS: {fps:.1f}")
|
| 403 |
+
if latency is not None:
|
| 404 |
+
info_texts.append(f"Latency: {latency:.1f}ms")
|
| 405 |
+
|
| 406 |
+
# 4. 텍스트 위치 계산
|
| 407 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 408 |
+
font_scale = 0.8
|
| 409 |
+
font_thickness = 2
|
| 410 |
+
padding = 10
|
| 411 |
+
line_height = 35
|
| 412 |
+
|
| 413 |
+
# 예측 텍스트 크기
|
| 414 |
+
(pred_w, pred_h), _ = cv2.getTextSize(pred_text, font, font_scale, font_thickness)
|
| 415 |
+
|
| 416 |
+
# 위치별 좌표 계산
|
| 417 |
+
if position == 'top-left':
|
| 418 |
+
pred_x, pred_y = padding, padding + pred_h
|
| 419 |
+
elif position == 'top-right':
|
| 420 |
+
pred_x, pred_y = w - pred_w - padding, padding + pred_h
|
| 421 |
+
elif position == 'bottom-left':
|
| 422 |
+
pred_x, pred_y = padding, h - padding - (len(info_texts) * line_height) - 10
|
| 423 |
+
elif position == 'bottom-right':
|
| 424 |
+
pred_x, pred_y = w - pred_w - padding, h - padding - (len(info_texts) * line_height) - 10
|
| 425 |
+
else:
|
| 426 |
+
raise ValueError(f"Unknown position: {position}")
|
| 427 |
+
|
| 428 |
+
# 5. 배경 박스 그리기 (가독성 향상)
|
| 429 |
+
bg_x1 = pred_x - 5
|
| 430 |
+
bg_y1 = pred_y - pred_h - 5
|
| 431 |
+
bg_x2 = pred_x + pred_w + 5
|
| 432 |
+
bg_y2 = pred_y + 5
|
| 433 |
+
|
| 434 |
+
# 반투명 검은 배경
|
| 435 |
+
overlay = frame.copy()
|
| 436 |
+
cv2.rectangle(overlay, (bg_x1, bg_y1), (bg_x2, bg_y2), (0, 0, 0), -1)
|
| 437 |
+
cv2.addWeighted(overlay, 0.6, frame, 0.4, 0, frame)
|
| 438 |
+
|
| 439 |
+
# 6. 예측 텍스트 그리기
|
| 440 |
+
cv2.putText(frame, pred_text, (pred_x, pred_y), font, font_scale, pred_color, font_thickness, cv2.LINE_AA)
|
| 441 |
+
|
| 442 |
+
# 7. FPS/Latency 정보 그리기 (있을 경우)
|
| 443 |
+
if info_texts:
|
| 444 |
+
info_y = pred_y + line_height
|
| 445 |
+
for info_text in info_texts:
|
| 446 |
+
(info_w, info_h), _ = cv2.getTextSize(info_text, font, font_scale, font_thickness)
|
| 447 |
+
|
| 448 |
+
# 배경 박스
|
| 449 |
+
bg_x1 = pred_x - 5
|
| 450 |
+
bg_y1 = info_y - info_h - 5
|
| 451 |
+
bg_x2 = pred_x + info_w + 5
|
| 452 |
+
bg_y2 = info_y + 5
|
| 453 |
+
|
| 454 |
+
overlay = frame.copy()
|
| 455 |
+
cv2.rectangle(overlay, (bg_x1, bg_y1), (bg_x2, bg_y2), (0, 0, 0), -1)
|
| 456 |
+
cv2.addWeighted(overlay, 0.6, frame, 0.4, 0, frame)
|
| 457 |
+
|
| 458 |
+
# 텍스트 (흰색)
|
| 459 |
+
cv2.putText(frame, info_text, (pred_x, info_y), font, font_scale, (255, 255, 255), font_thickness, cv2.LINE_AA)
|
| 460 |
+
info_y += line_height
|
| 461 |
+
|
| 462 |
+
return frame
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
def create_info_panel(
|
| 466 |
+
frame_width: int,
|
| 467 |
+
frame_height: int,
|
| 468 |
+
fps: float,
|
| 469 |
+
latency: float,
|
| 470 |
+
prediction: str,
|
| 471 |
+
confidence: float,
|
| 472 |
+
panel_height: int = 80,
|
| 473 |
+
position: str = 'top'
|
| 474 |
+
) -> np.ndarray:
|
| 475 |
+
"""
|
| 476 |
+
정보 패널 생성 (상단 또는 하단 오버레이)
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
frame_width: 프레임 너비
|
| 480 |
+
frame_height: 프레임 높이
|
| 481 |
+
fps: FPS 값
|
| 482 |
+
latency: Latency (ms)
|
| 483 |
+
prediction: 'Fall' 또는 'Non-Fall'
|
| 484 |
+
confidence: 신뢰도 (0.0-1.0)
|
| 485 |
+
panel_height: 패널 높이
|
| 486 |
+
position: 패널 위치 ('top' 또는 'bottom')
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
panel: 정보 패널 이미지 (panel_height, frame_width, 3)
|
| 490 |
+
"""
|
| 491 |
+
# 패널 생성 (검은 배경)
|
| 492 |
+
panel = np.zeros((panel_height, frame_width, 3), dtype=np.uint8)
|
| 493 |
+
|
| 494 |
+
# 예측 결과 색상
|
| 495 |
+
pred_color = PREDICTION_COLORS.get(prediction, (255, 255, 255))
|
| 496 |
+
|
| 497 |
+
# 폰트 설정
|
| 498 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 499 |
+
font_scale = 0.7
|
| 500 |
+
font_thickness = 2
|
| 501 |
+
|
| 502 |
+
# 텍스트 준비
|
| 503 |
+
pred_text = f"{prediction}: {confidence:.1%}" if confidence is not None else f"{prediction}"
|
| 504 |
+
texts = [
|
| 505 |
+
(f"FPS: {fps:.1f}", (255, 255, 255)),
|
| 506 |
+
(f"Latency: {latency:.1f}ms", (255, 255, 255)),
|
| 507 |
+
(pred_text, pred_color),
|
| 508 |
+
]
|
| 509 |
+
|
| 510 |
+
# 텍스트 균등 배치
|
| 511 |
+
section_width = frame_width // len(texts)
|
| 512 |
+
y_pos = panel_height // 2 + 10
|
| 513 |
+
|
| 514 |
+
for i, (text, color) in enumerate(texts):
|
| 515 |
+
# 텍스트 크기 계산
|
| 516 |
+
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, font_thickness)
|
| 517 |
+
|
| 518 |
+
# 중앙 정렬
|
| 519 |
+
x_pos = (i * section_width) + (section_width - text_w) // 2
|
| 520 |
+
|
| 521 |
+
# 텍스트 그리기
|
| 522 |
+
cv2.putText(panel, text, (x_pos, y_pos), font, font_scale, color, font_thickness, cv2.LINE_AA)
|
| 523 |
+
|
| 524 |
+
# 구분선 그리기
|
| 525 |
+
for i in range(1, len(texts)):
|
| 526 |
+
x_pos = i * section_width
|
| 527 |
+
cv2.line(panel, (x_pos, 10), (x_pos, panel_height - 10), (80, 80, 80), 1)
|
| 528 |
+
|
| 529 |
+
return panel
|
| 530 |
+
|
| 531 |
+
|
| 532 |
+
def add_info_panel_to_frame(
|
| 533 |
+
frame: np.ndarray,
|
| 534 |
+
fps: float,
|
| 535 |
+
latency: float,
|
| 536 |
+
prediction: str,
|
| 537 |
+
confidence: float,
|
| 538 |
+
panel_height: int = 80,
|
| 539 |
+
position: str = 'top'
|
| 540 |
+
) -> np.ndarray:
|
| 541 |
+
"""
|
| 542 |
+
프레임에 정보 패널 추가
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
frame: 원본 프레임
|
| 546 |
+
fps: FPS 값
|
| 547 |
+
latency: Latency (ms)
|
| 548 |
+
prediction: 'Fall' 또는 'Non-Fall'
|
| 549 |
+
confidence: 신뢰도
|
| 550 |
+
panel_height: 패널 높이
|
| 551 |
+
position: 패널 위치 ('top' 또는 'bottom')
|
| 552 |
+
|
| 553 |
+
Returns:
|
| 554 |
+
result: 패널이 추가된 프레임
|
| 555 |
+
"""
|
| 556 |
+
h, w = frame.shape[:2]
|
| 557 |
+
|
| 558 |
+
# 정보 패널 생성
|
| 559 |
+
panel = create_info_panel(w, h, fps, latency, prediction, confidence, panel_height, position)
|
| 560 |
+
|
| 561 |
+
# 패널 위치에 따라 결합
|
| 562 |
+
if position == 'top':
|
| 563 |
+
result = np.vstack([panel, frame])
|
| 564 |
+
elif position == 'bottom':
|
| 565 |
+
result = np.vstack([frame, panel])
|
| 566 |
+
else:
|
| 567 |
+
raise ValueError(f"Unknown position: {position}. Use 'top' or 'bottom'.")
|
| 568 |
+
|
| 569 |
+
return result
|
| 570 |
+
|
| 571 |
+
|
| 572 |
+
def draw_fall_alert_overlay(
|
| 573 |
+
frame: np.ndarray,
|
| 574 |
+
alert_text: str = "FALL DETECTED!",
|
| 575 |
+
flash: bool = True
|
| 576 |
+
) -> np.ndarray:
|
| 577 |
+
"""
|
| 578 |
+
낙상 경보 오버레이 그리기 (전체 화면 플래시 효과)
|
| 579 |
+
|
| 580 |
+
Args:
|
| 581 |
+
frame: 원본 프레임
|
| 582 |
+
alert_text: 경보 텍스트
|
| 583 |
+
flash: True면 화면 전체에 빨간 반투명 오버레이 추가
|
| 584 |
+
|
| 585 |
+
Returns:
|
| 586 |
+
result: 경보 오버레이가 추가된 프레임
|
| 587 |
+
"""
|
| 588 |
+
frame = frame.copy()
|
| 589 |
+
h, w = frame.shape[:2]
|
| 590 |
+
|
| 591 |
+
# 1. 플래시 효과 (빨간 반투명 오버레이)
|
| 592 |
+
if flash:
|
| 593 |
+
overlay = frame.copy()
|
| 594 |
+
cv2.rectangle(overlay, (0, 0), (w, h), (0, 0, 255), -1)
|
| 595 |
+
cv2.addWeighted(overlay, 0.3, frame, 0.7, 0, frame)
|
| 596 |
+
|
| 597 |
+
# 2. 중앙에 큰 경고 텍스트
|
| 598 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 599 |
+
font_scale = 2.5
|
| 600 |
+
font_thickness = 8 # 두꺼운 굵기로 볼드체 효과
|
| 601 |
+
|
| 602 |
+
(text_w, text_h), _ = cv2.getTextSize(alert_text, font, font_scale, font_thickness)
|
| 603 |
+
text_x = (w - text_w) // 2
|
| 604 |
+
text_y = (h + text_h) // 2
|
| 605 |
+
|
| 606 |
+
# 텍스트 그림자 (검은색)
|
| 607 |
+
cv2.putText(frame, alert_text, (text_x + 3, text_y + 3), font, font_scale, (0, 0, 0), font_thickness + 2, cv2.LINE_AA)
|
| 608 |
+
|
| 609 |
+
# 텍스트 본문 (흰색)
|
| 610 |
+
cv2.putText(frame, alert_text, (text_x, text_y), font, font_scale, (255, 255, 255), font_thickness, cv2.LINE_AA)
|
| 611 |
+
|
| 612 |
+
return frame
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def visualize_fall_simple(
|
| 616 |
+
frame: np.ndarray,
|
| 617 |
+
keypoints: Optional[np.ndarray] = None,
|
| 618 |
+
show_fall_text: bool = False,
|
| 619 |
+
keypoint_mode: Literal['all', 'major'] = 'all',
|
| 620 |
+
output_scale: float = 1.0
|
| 621 |
+
) -> np.ndarray:
|
| 622 |
+
"""
|
| 623 |
+
간소화된 낙상 감지 시각화 (Pose skeleton + FALL DETECTED 텍스트만)
|
| 624 |
+
|
| 625 |
+
표시 항목:
|
| 626 |
+
- Pose skeleton (신체 부위별 색상)
|
| 627 |
+
- FALL DETECTED 텍스트 (show_fall_text=True일 때)
|
| 628 |
+
|
| 629 |
+
제거된 항목:
|
| 630 |
+
- FPS/Latency 정보
|
| 631 |
+
- 정보 패널
|
| 632 |
+
- 빨간 플래시 오버레이
|
| 633 |
+
- 신뢰도 표시
|
| 634 |
+
|
| 635 |
+
Args:
|
| 636 |
+
frame: 원본 프레임
|
| 637 |
+
keypoints: (17, 3) pose keypoints (선택)
|
| 638 |
+
show_fall_text: True면 FALL DETECTED 텍스트 표시
|
| 639 |
+
keypoint_mode: 'all'=전체 17개, 'major'=주요 9개만 표시
|
| 640 |
+
output_scale: 출력 해상도 스케일 (0.5=50%, 1.0=100%)
|
| 641 |
+
|
| 642 |
+
Returns:
|
| 643 |
+
result: 시각화된 프레임
|
| 644 |
+
"""
|
| 645 |
+
# 1. 해상도 조절 (output_scale < 1.0인 경우)
|
| 646 |
+
original_h, original_w = frame.shape[:2]
|
| 647 |
+
if output_scale < 1.0:
|
| 648 |
+
new_w = int(original_w * output_scale)
|
| 649 |
+
new_h = int(original_h * output_scale)
|
| 650 |
+
result = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
| 651 |
+
|
| 652 |
+
# keypoints 좌표도 스케일 조정
|
| 653 |
+
if keypoints is not None:
|
| 654 |
+
keypoints = keypoints.copy()
|
| 655 |
+
keypoints[:, 0] *= output_scale # x 좌표
|
| 656 |
+
keypoints[:, 1] *= output_scale # y 좌표
|
| 657 |
+
else:
|
| 658 |
+
result = frame.copy()
|
| 659 |
+
|
| 660 |
+
# 2. 스켈레톤 그리기
|
| 661 |
+
if keypoints is not None:
|
| 662 |
+
result = draw_skeleton_vectorized(
|
| 663 |
+
result, keypoints,
|
| 664 |
+
keypoint_mode=keypoint_mode,
|
| 665 |
+
use_body_part_colors=True
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
# 3. FALL DETECTED 텍스트 표시 (플래시 없이)
|
| 669 |
+
if show_fall_text:
|
| 670 |
+
h, w = result.shape[:2]
|
| 671 |
+
alert_text = "FALL DETECTED"
|
| 672 |
+
|
| 673 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 674 |
+
font_scale = 2.0
|
| 675 |
+
font_thickness = 6
|
| 676 |
+
|
| 677 |
+
(text_w, text_h), _ = cv2.getTextSize(alert_text, font, font_scale, font_thickness)
|
| 678 |
+
text_x = (w - text_w) // 2
|
| 679 |
+
text_y = 80 # 화면 상단
|
| 680 |
+
|
| 681 |
+
# 텍스트 배경 (반투명 검은색)
|
| 682 |
+
bg_padding = 15
|
| 683 |
+
overlay = result.copy()
|
| 684 |
+
cv2.rectangle(
|
| 685 |
+
overlay,
|
| 686 |
+
(text_x - bg_padding, text_y - text_h - bg_padding),
|
| 687 |
+
(text_x + text_w + bg_padding, text_y + bg_padding),
|
| 688 |
+
(0, 0, 0),
|
| 689 |
+
-1
|
| 690 |
+
)
|
| 691 |
+
cv2.addWeighted(overlay, 0.6, result, 0.4, 0, result)
|
| 692 |
+
|
| 693 |
+
# 텍스트 그림자 (검은색)
|
| 694 |
+
cv2.putText(result, alert_text, (text_x + 2, text_y + 2),
|
| 695 |
+
font, font_scale, (0, 0, 0), font_thickness + 2, cv2.LINE_AA)
|
| 696 |
+
|
| 697 |
+
# 텍스트 본문 (빨간색)
|
| 698 |
+
cv2.putText(result, alert_text, (text_x, text_y),
|
| 699 |
+
font, font_scale, (0, 0, 255), font_thickness, cv2.LINE_AA)
|
| 700 |
+
|
| 701 |
+
return result
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
def visualize_fall_detection(
|
| 705 |
+
frame: np.ndarray,
|
| 706 |
+
keypoints: Optional[np.ndarray] = None,
|
| 707 |
+
prediction: str = 'Non-Fall',
|
| 708 |
+
confidence: float = 0.0,
|
| 709 |
+
bbox: Optional[Tuple[int, int, int, int]] = None,
|
| 710 |
+
fps: Optional[float] = None,
|
| 711 |
+
latency: Optional[float] = None,
|
| 712 |
+
show_skeleton: bool = True,
|
| 713 |
+
show_info_panel: bool = True,
|
| 714 |
+
show_alert: bool = False,
|
| 715 |
+
use_optimized: bool = True,
|
| 716 |
+
keypoint_mode: Literal['all', 'major'] = 'all',
|
| 717 |
+
output_scale: float = 1.0
|
| 718 |
+
) -> np.ndarray:
|
| 719 |
+
"""
|
| 720 |
+
낙상 감지 결과 종합 시각화 (All-in-one 함수)
|
| 721 |
+
|
| 722 |
+
Args:
|
| 723 |
+
frame: 원본 프레임
|
| 724 |
+
keypoints: (17, 3) pose keypoints (선택)
|
| 725 |
+
prediction: 'Fall' 또는 'Non-Fall'
|
| 726 |
+
confidence: 신뢰도
|
| 727 |
+
bbox: 바운딩 박스 (선택)
|
| 728 |
+
fps: FPS 값 (선택)
|
| 729 |
+
latency: Latency (ms) (선택)
|
| 730 |
+
show_skeleton: True면 스켈레톤 그리기
|
| 731 |
+
show_info_panel: True면 상단에 정보 패널 추가
|
| 732 |
+
show_alert: True면 낙상 경보 오버레이 추가 (prediction='Fall'일 때만)
|
| 733 |
+
use_optimized: True면 벡터화된 그리기 함수 사용 (30배 빠름)
|
| 734 |
+
keypoint_mode: 'all'=전체 17개, 'major'=주요 9개만 표시
|
| 735 |
+
output_scale: 출력 해상도 스케일 (0.5=50%, 1.0=100%)
|
| 736 |
+
|
| 737 |
+
Returns:
|
| 738 |
+
result: 시각화된 프레임
|
| 739 |
+
"""
|
| 740 |
+
# 1. 해상도 조절 (output_scale < 1.0인 경우)
|
| 741 |
+
original_h, original_w = frame.shape[:2]
|
| 742 |
+
if output_scale < 1.0:
|
| 743 |
+
new_w = int(original_w * output_scale)
|
| 744 |
+
new_h = int(original_h * output_scale)
|
| 745 |
+
result = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
| 746 |
+
|
| 747 |
+
# keypoints 좌표도 스케일 조정
|
| 748 |
+
if keypoints is not None:
|
| 749 |
+
keypoints = keypoints.copy()
|
| 750 |
+
keypoints[:, 0] *= output_scale # x 좌표
|
| 751 |
+
keypoints[:, 1] *= output_scale # y 좌표
|
| 752 |
+
|
| 753 |
+
# bbox 좌표도 스케일 조정
|
| 754 |
+
if bbox is not None:
|
| 755 |
+
bbox = tuple(int(v * output_scale) for v in bbox)
|
| 756 |
+
else:
|
| 757 |
+
result = frame.copy()
|
| 758 |
+
|
| 759 |
+
# 2. 스켈레톤 그리기
|
| 760 |
+
if show_skeleton and keypoints is not None:
|
| 761 |
+
if use_optimized:
|
| 762 |
+
result = draw_skeleton_vectorized(
|
| 763 |
+
result, keypoints,
|
| 764 |
+
keypoint_mode=keypoint_mode,
|
| 765 |
+
use_body_part_colors=True
|
| 766 |
+
)
|
| 767 |
+
else:
|
| 768 |
+
result = draw_skeleton(result, keypoints, use_body_part_colors=True)
|
| 769 |
+
|
| 770 |
+
# 3. 예측 결과 오버레이
|
| 771 |
+
if fps is not None or latency is not None:
|
| 772 |
+
result = draw_prediction(result, prediction, confidence, bbox, fps, latency, position='top-left')
|
| 773 |
+
|
| 774 |
+
# 4. 낙상 경보 오버레이 (Fall이고 show_alert=True일 때만)
|
| 775 |
+
if show_alert and prediction == 'Fall':
|
| 776 |
+
result = draw_fall_alert_overlay(result, alert_text="FALL DETECTED!", flash=True)
|
| 777 |
+
|
| 778 |
+
# 5. 정보 패널 추가 (선택)
|
| 779 |
+
if show_info_panel and fps is not None and latency is not None:
|
| 780 |
+
result = add_info_panel_to_frame(result, fps, latency, prediction, confidence, position='bottom')
|
| 781 |
+
|
| 782 |
+
return result
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
if __name__ == '__main__':
|
| 786 |
+
import time
|
| 787 |
+
import argparse
|
| 788 |
+
|
| 789 |
+
parser = argparse.ArgumentParser(description='Visualization module test and benchmark')
|
| 790 |
+
parser.add_argument('--benchmark', action='store_true', help='Run performance benchmark')
|
| 791 |
+
parser.add_argument('--resolution', type=str, default='640x480',
|
| 792 |
+
help='Test resolution (default: 640x480, options: 640x480, 1920x1080, 3840x2160)')
|
| 793 |
+
parser.add_argument('--iterations', type=int, default=100, help='Benchmark iterations')
|
| 794 |
+
args = parser.parse_args()
|
| 795 |
+
|
| 796 |
+
# 해상도 파싱
|
| 797 |
+
res_map = {
|
| 798 |
+
'640x480': (480, 640),
|
| 799 |
+
'1920x1080': (1080, 1920),
|
| 800 |
+
'3840x2160': (2160, 3840),
|
| 801 |
+
'4k': (2160, 3840),
|
| 802 |
+
'fhd': (1080, 1920),
|
| 803 |
+
'vga': (480, 640),
|
| 804 |
+
}
|
| 805 |
+
h, w = res_map.get(args.resolution.lower(), (480, 640))
|
| 806 |
+
|
| 807 |
+
print(f"Testing visualization module at {w}x{h}...")
|
| 808 |
+
|
| 809 |
+
# 1. 더미 프레임 생성
|
| 810 |
+
frame = np.zeros((h, w, 3), dtype=np.uint8)
|
| 811 |
+
frame[:, :] = (50, 50, 50)
|
| 812 |
+
|
| 813 |
+
# 2. 더미 키포인트 생성 (해상도에 맞게 스케일)
|
| 814 |
+
scale_x = w / 640
|
| 815 |
+
scale_y = h / 480
|
| 816 |
+
keypoints = np.array([
|
| 817 |
+
[320, 100, 0.9], # 0: nose
|
| 818 |
+
[310, 90, 0.9], # 1: left_eye
|
| 819 |
+
[330, 90, 0.9], # 2: right_eye
|
| 820 |
+
[300, 90, 0.8], # 3: left_ear
|
| 821 |
+
[340, 90, 0.8], # 4: right_ear
|
| 822 |
+
[300, 150, 0.95], # 5: left_shoulder
|
| 823 |
+
[340, 150, 0.95], # 6: right_shoulder
|
| 824 |
+
[280, 200, 0.9], # 7: left_elbow
|
| 825 |
+
[360, 200, 0.9], # 8: right_elbow
|
| 826 |
+
[270, 250, 0.85], # 9: left_wrist
|
| 827 |
+
[370, 250, 0.85], # 10: right_wrist
|
| 828 |
+
[300, 250, 0.95], # 11: left_hip
|
| 829 |
+
[340, 250, 0.95], # 12: right_hip
|
| 830 |
+
[300, 350, 0.9], # 13: left_knee
|
| 831 |
+
[340, 350, 0.9], # 14: right_knee
|
| 832 |
+
[300, 450, 0.85], # 15: left_ankle
|
| 833 |
+
[340, 450, 0.85], # 16: right_ankle
|
| 834 |
+
], dtype=np.float32)
|
| 835 |
+
keypoints[:, 0] *= scale_x
|
| 836 |
+
keypoints[:, 1] *= scale_y
|
| 837 |
+
|
| 838 |
+
if args.benchmark:
|
| 839 |
+
print("\n" + "=" * 70)
|
| 840 |
+
print("BENCHMARK: Visualization Performance Comparison")
|
| 841 |
+
print("=" * 70)
|
| 842 |
+
print(f"Resolution: {w}x{h}")
|
| 843 |
+
print(f"Iterations: {args.iterations}")
|
| 844 |
+
print("=" * 70)
|
| 845 |
+
|
| 846 |
+
# 기존 draw_skeleton 벤치마크
|
| 847 |
+
print("\n[1] draw_skeleton (original - cv2.circle/line loops)")
|
| 848 |
+
times_original = []
|
| 849 |
+
for _ in range(args.iterations):
|
| 850 |
+
start = time.perf_counter()
|
| 851 |
+
_ = draw_skeleton(frame.copy(), keypoints, use_body_part_colors=True)
|
| 852 |
+
times_original.append((time.perf_counter() - start) * 1000)
|
| 853 |
+
avg_original = np.mean(times_original)
|
| 854 |
+
std_original = np.std(times_original)
|
| 855 |
+
print(f" Average: {avg_original:.2f}ms (+/- {std_original:.2f}ms)")
|
| 856 |
+
|
| 857 |
+
# 벡터화 draw_skeleton_vectorized 벤치마크 (all keypoints)
|
| 858 |
+
print("\n[2] draw_skeleton_vectorized (optimized - all keypoints)")
|
| 859 |
+
times_vectorized = []
|
| 860 |
+
for _ in range(args.iterations):
|
| 861 |
+
start = time.perf_counter()
|
| 862 |
+
_ = draw_skeleton_vectorized(frame.copy(), keypoints, keypoint_mode='all')
|
| 863 |
+
times_vectorized.append((time.perf_counter() - start) * 1000)
|
| 864 |
+
avg_vectorized = np.mean(times_vectorized)
|
| 865 |
+
std_vectorized = np.std(times_vectorized)
|
| 866 |
+
speedup_all = avg_original / avg_vectorized
|
| 867 |
+
print(f" Average: {avg_vectorized:.2f}ms (+/- {std_vectorized:.2f}ms)")
|
| 868 |
+
print(f" Speedup: {speedup_all:.1f}x faster")
|
| 869 |
+
|
| 870 |
+
# 벡터화 draw_skeleton_vectorized 벤치마크 (major keypoints)
|
| 871 |
+
print("\n[3] draw_skeleton_vectorized (optimized - major keypoints only)")
|
| 872 |
+
times_major = []
|
| 873 |
+
for _ in range(args.iterations):
|
| 874 |
+
start = time.perf_counter()
|
| 875 |
+
_ = draw_skeleton_vectorized(frame.copy(), keypoints, keypoint_mode='major')
|
| 876 |
+
times_major.append((time.perf_counter() - start) * 1000)
|
| 877 |
+
avg_major = np.mean(times_major)
|
| 878 |
+
std_major = np.std(times_major)
|
| 879 |
+
speedup_major = avg_original / avg_major
|
| 880 |
+
print(f" Average: {avg_major:.2f}ms (+/- {std_major:.2f}ms)")
|
| 881 |
+
print(f" Speedup: {speedup_major:.1f}x faster")
|
| 882 |
+
|
| 883 |
+
# 해상도 스케일 + 벡터화 벤치마크
|
| 884 |
+
if w > 640:
|
| 885 |
+
print("\n[4] draw_skeleton_vectorized + 50% scale")
|
| 886 |
+
times_scaled = []
|
| 887 |
+
for _ in range(args.iterations):
|
| 888 |
+
start = time.perf_counter()
|
| 889 |
+
result = visualize_fall_detection(
|
| 890 |
+
frame.copy(), keypoints,
|
| 891 |
+
prediction='Fall', confidence=0.9,
|
| 892 |
+
fps=30.0, latency=50.0,
|
| 893 |
+
use_optimized=True,
|
| 894 |
+
keypoint_mode='all',
|
| 895 |
+
output_scale=0.5
|
| 896 |
+
)
|
| 897 |
+
times_scaled.append((time.perf_counter() - start) * 1000)
|
| 898 |
+
avg_scaled = np.mean(times_scaled)
|
| 899 |
+
std_scaled = np.std(times_scaled)
|
| 900 |
+
print(f" Average: {avg_scaled:.2f}ms (+/- {std_scaled:.2f}ms)")
|
| 901 |
+
print(f" Output size: {result.shape[1]}x{result.shape[0]}")
|
| 902 |
+
|
| 903 |
+
print("\n" + "=" * 70)
|
| 904 |
+
print("SUMMARY")
|
| 905 |
+
print("=" * 70)
|
| 906 |
+
print(f"Original: {avg_original:.2f}ms")
|
| 907 |
+
print(f"Optimized: {avg_vectorized:.2f}ms ({speedup_all:.1f}x faster)")
|
| 908 |
+
print(f"Major only: {avg_major:.2f}ms ({speedup_major:.1f}x faster)")
|
| 909 |
+
target_met = avg_vectorized < 10.0
|
| 910 |
+
print(f"\nTarget (<10ms): {'MET' if target_met else 'NOT MET'}")
|
| 911 |
+
print("=" * 70)
|
| 912 |
+
|
| 913 |
+
else:
|
| 914 |
+
# 기본 기능 테스트
|
| 915 |
+
print("\n1. Testing draw_skeleton (original)...")
|
| 916 |
+
result = draw_skeleton(frame.copy(), keypoints, use_body_part_colors=True)
|
| 917 |
+
print(f" Output shape: {result.shape}")
|
| 918 |
+
|
| 919 |
+
print("\n2. Testing draw_skeleton_vectorized (optimized)...")
|
| 920 |
+
result = draw_skeleton_vectorized(frame.copy(), keypoints, keypoint_mode='all')
|
| 921 |
+
print(f" Output shape: {result.shape}")
|
| 922 |
+
|
| 923 |
+
print("\n3. Testing draw_skeleton_vectorized (major only)...")
|
| 924 |
+
result = draw_skeleton_vectorized(frame.copy(), keypoints, keypoint_mode='major')
|
| 925 |
+
print(f" Output shape: {result.shape}")
|
| 926 |
+
|
| 927 |
+
print("\n4. Testing draw_prediction...")
|
| 928 |
+
result = draw_prediction(
|
| 929 |
+
frame.copy(),
|
| 930 |
+
prediction='Non-Fall',
|
| 931 |
+
confidence=0.95,
|
| 932 |
+
bbox=(int(270*scale_x), int(90*scale_y), int(370*scale_x), int(450*scale_y)),
|
| 933 |
+
fps=30.0,
|
| 934 |
+
latency=50.0
|
| 935 |
+
)
|
| 936 |
+
print(f" Output shape: {result.shape}")
|
| 937 |
+
|
| 938 |
+
print("\n5. Testing create_info_panel...")
|
| 939 |
+
panel = create_info_panel(w, h, fps=30.0, latency=50.0, prediction='Non-Fall', confidence=0.95)
|
| 940 |
+
print(f" Panel shape: {panel.shape}")
|
| 941 |
+
|
| 942 |
+
print("\n6. Testing visualize_fall_detection (optimized=True)...")
|
| 943 |
+
result = visualize_fall_detection(
|
| 944 |
+
frame=frame,
|
| 945 |
+
keypoints=keypoints,
|
| 946 |
+
prediction='Fall',
|
| 947 |
+
confidence=0.87,
|
| 948 |
+
fps=30.0,
|
| 949 |
+
latency=50.0,
|
| 950 |
+
show_skeleton=True,
|
| 951 |
+
show_info_panel=True,
|
| 952 |
+
show_alert=True,
|
| 953 |
+
use_optimized=True,
|
| 954 |
+
keypoint_mode='all'
|
| 955 |
+
)
|
| 956 |
+
print(f" Output shape: {result.shape}")
|
| 957 |
+
|
| 958 |
+
print("\n7. Testing visualize_fall_detection (output_scale=0.5)...")
|
| 959 |
+
result = visualize_fall_detection(
|
| 960 |
+
frame=frame,
|
| 961 |
+
keypoints=keypoints,
|
| 962 |
+
prediction='Non-Fall',
|
| 963 |
+
confidence=0.95,
|
| 964 |
+
fps=30.0,
|
| 965 |
+
latency=50.0,
|
| 966 |
+
show_skeleton=True,
|
| 967 |
+
show_info_panel=True,
|
| 968 |
+
use_optimized=True,
|
| 969 |
+
output_scale=0.5
|
| 970 |
+
)
|
| 971 |
+
print(f" Output shape: {result.shape}")
|
| 972 |
+
|
| 973 |
+
print("\nAll tests passed!")
|