fall-detection-demo / models /stgcn_classifier.py
YoungjaeDev
fix: predict_batch์— batch_size ์ฒญํ‚น ์ถ”๊ฐ€ - GPU OOM ๋ฐฉ์ง€
dc83241
"""
ST-GCN ๋‚™์ƒ ๋ถ„๋ฅ˜๊ธฐ ๋ž˜ํผ ํด๋ž˜์Šค
Spatial-Temporal Graph Convolutional Network์„ ์ด์šฉํ•œ ๋‚™์ƒ ๋ถ„๋ฅ˜๊ธฐ์ž…๋‹ˆ๋‹ค.
Note: HF Spaces ๋ฐฐํฌ์šฉ์œผ๋กœ import ๊ฒฝ๋กœ๊ฐ€ ์ˆ˜์ •๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
"""
import logging
from typing import Optional, Tuple
import numpy as np
import torch
# HF Spaces ๋ฐฐํฌ์šฉ ์ƒ๋Œ€ import
from augmentation import normalize_skeleton
from stgcn.model import STGCN
class STGCNClassifier:
"""ST-GCN ๊ธฐ๋ฐ˜ ๋‚™์ƒ ๋ถ„๋ฅ˜๊ธฐ"""
def __init__(
self,
checkpoint_path: str = "runs/stgcn_binary_exp2_fixed_graph/best_acc.pth",
fall_threshold: float = 0.7,
device: str = "cuda:0",
in_channels: int = 3,
num_classes: int = 2,
dropout: float = 0.5,
logger: Optional[logging.Logger] = None
):
"""
Args:
checkpoint_path: ST-GCN ์ฒดํฌํฌ์ธํŠธ ๊ฒฝ๋กœ
fall_threshold: ๋‚™์ƒ ํŒ์ • ์‹ ๋ขฐ๋„ ์ž„๊ณ„๊ฐ’
device: ๋””๋ฐ”์ด์Šค (cuda:0, cpu ๋“ฑ)
in_channels: ์ž…๋ ฅ ์ฑ„๋„ ์ˆ˜ (x, y, conf)
num_classes: ์ถœ๋ ฅ ํด๋ž˜์Šค ์ˆ˜ (Fall, Non-Fall)
dropout: ๋“œ๋กญ์•„์›ƒ ๋น„์œจ
logger: ๋กœ๊ฑฐ ์ธ์Šคํ„ด์Šค
"""
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
self.fall_threshold = fall_threshold
self.logger = logger or logging.getLogger(__name__)
self.logger.info(f"[Stage 2] ST-GCN ๋กœ๋“œ ์ค‘: {checkpoint_path}")
# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
self.model = STGCN(
in_channels=in_channels,
num_classes=num_classes,
graph_cfg={},
edge_importance_weighting=True,
dropout=dropout
)
# ์ฒดํฌํฌ์ธํŠธ ๋กœ๋“œ
checkpoint = torch.load(checkpoint_path, map_location=self.device)
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model = self.model.to(self.device)
self.model.eval()
# ์ฒดํฌํฌ์ธํŠธ ์ •๋ณด ๋กœ๊น…
epoch = checkpoint.get('epoch')
if epoch is not None:
self.logger.info(f" - Checkpoint epoch: {epoch}")
metrics = checkpoint.get('metrics')
if isinstance(metrics, dict):
acc = metrics.get('accuracy')
f1 = metrics.get('f1')
if isinstance(acc, (int, float)):
self.logger.info(f" - Accuracy: {acc:.4f}")
if isinstance(f1, (int, float)):
self.logger.info(f" - F1 Score: {f1:.4f}")
self.logger.info(f" - Fall threshold: {fall_threshold}")
self.logger.info(f" - Device: {self.device}")
def predict(
self,
window: np.ndarray,
normalize: bool = True,
debug: bool = False
) -> Tuple[int, float]:
"""
ST-GCN์œผ๋กœ ๋‚™์ƒ ์˜ˆ์ธก
Args:
window: (C, T, V, M) ST-GCN ์ž…๋ ฅ (C=3, T=60, V=17, M=1)
normalize: hip center ์ •๊ทœํ™” ์ ์šฉ ์—ฌ๋ถ€
debug: ๋””๋ฒ„๊ทธ ๋กœ๊ทธ ์ถœ๋ ฅ ์—ฌ๋ถ€
Returns:
prediction: 0 (Non-Fall) or 1 (Fall)
confidence: ์˜ˆ์ธก ์‹ ๋ขฐ๋„ (0.0-1.0)
"""
# Normalize skeleton (hip center + skeleton size scaling)
if normalize:
window_input = normalize_skeleton(window, method='hip_center')
else:
window_input = window
# ST-GCN inference
window_tensor = torch.from_numpy(window_input).float().unsqueeze(0).to(self.device) # (1, C, T, V, M)
with torch.no_grad():
outputs = self.model(window_tensor)
probs = torch.softmax(outputs, dim=1)
pred = torch.argmax(outputs, dim=1)
prediction = pred.item()
confidence = probs[0, prediction].item()
if debug:
self.logger.debug(f" ST-GCN prediction: {prediction} (conf={confidence:.3f})")
return prediction, confidence
def predict_batch(
self,
windows: list[np.ndarray],
batch_size: int = 32,
normalize: bool = True,
debug: bool = False
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
ST-GCN ๋ฐฐ์น˜ ๋‚™์ƒ ์˜ˆ์ธก (GPU ํ™œ์šฉ ๊ทน๋Œ€ํ™”)
Args:
windows: [(C, T, V, M), ...] ST-GCN ์ž…๋ ฅ ์œˆ๋„์šฐ ๋ฆฌ์ŠคํŠธ
batch_size: GPU ๋ฐฐ์น˜ ํฌ๊ธฐ (๊ธฐ๋ณธ๊ฐ’: 32, OOM ๋ฐฉ์ง€์šฉ)
normalize: hip center ์ •๊ทœํ™” ์ ์šฉ ์—ฌ๋ถ€
debug: ๋””๋ฒ„๊ทธ ๋กœ๊ทธ ์ถœ๋ ฅ ์—ฌ๋ถ€
Returns:
predictions: (N,) numpy array of 0 (Non-Fall) or 1 (Fall)
confidences: (N,) numpy array of predicted class confidence (0.0-1.0)
fall_probs: (N,) numpy array of Fall class probability (0.0-1.0)
"""
if not windows:
return np.array([]), np.array([]), np.array([])
all_predictions = []
all_confidences = []
all_fall_probs = []
for chunk_start in range(0, len(windows), batch_size):
chunk_windows = windows[chunk_start:chunk_start + batch_size]
batch_list = []
for window in chunk_windows:
if normalize:
window_input = normalize_skeleton(window, method='hip_center')
else:
window_input = window
batch_list.append(torch.from_numpy(window_input).float())
batch_tensor = torch.stack(batch_list).to(self.device)
with torch.no_grad():
outputs = self.model(batch_tensor)
probs = torch.softmax(outputs, dim=1)
preds = torch.argmax(outputs, dim=1)
predictions = preds.cpu().numpy()
confidences = probs[torch.arange(len(preds)), preds].cpu().numpy()
fall_probs = probs[:, 1].cpu().numpy()
all_predictions.append(predictions)
all_confidences.append(confidences)
all_fall_probs.append(fall_probs)
if debug:
for i, (pred, conf, fall_p) in enumerate(zip(predictions, confidences, fall_probs)):
global_idx = chunk_start + i
self.logger.debug(f" Batch[{global_idx}] ST-GCN: pred={pred}, conf={conf:.3f}, fall_prob={fall_p:.3f}")
return (
np.concatenate(all_predictions),
np.concatenate(all_confidences),
np.concatenate(all_fall_probs)
)
def is_fall(self, prediction: int, confidence: float) -> bool:
"""
๋‚™์ƒ ์—ฌ๋ถ€ ํŒ์ •
Args:
prediction: ๋ชจ๋ธ ์˜ˆ์ธก (0 or 1)
confidence: ์˜ˆ์ธก ์‹ ๋ขฐ๋„
Returns:
True if fall detected with sufficient confidence
"""
return prediction == 1 and confidence >= self.fall_threshold