""" ST-GCN 낙상 분류기 래퍼 클래스 Spatial-Temporal Graph Convolutional Network을 이용한 낙상 분류기입니다. Note: HF Spaces 배포용으로 import 경로가 수정되었습니다. """ import logging from typing import Optional, Tuple import numpy as np import torch # HF Spaces 배포용 상대 import from augmentation import normalize_skeleton from stgcn.model import STGCN class STGCNClassifier: """ST-GCN 기반 낙상 분류기""" def __init__( self, checkpoint_path: str = "runs/stgcn_binary_exp2_fixed_graph/best_acc.pth", fall_threshold: float = 0.7, device: str = "cuda:0", in_channels: int = 3, num_classes: int = 2, dropout: float = 0.5, logger: Optional[logging.Logger] = None ): """ Args: checkpoint_path: ST-GCN 체크포인트 경로 fall_threshold: 낙상 판정 신뢰도 임계값 device: 디바이스 (cuda:0, cpu 등) in_channels: 입력 채널 수 (x, y, conf) num_classes: 출력 클래스 수 (Fall, Non-Fall) dropout: 드롭아웃 비율 logger: 로거 인스턴스 """ self.device = torch.device(device if torch.cuda.is_available() else "cpu") self.fall_threshold = fall_threshold self.logger = logger or logging.getLogger(__name__) self.logger.info(f"[Stage 2] ST-GCN 로드 중: {checkpoint_path}") # 모델 초기화 self.model = STGCN( in_channels=in_channels, num_classes=num_classes, graph_cfg={}, edge_importance_weighting=True, dropout=dropout ) # 체크포인트 로드 checkpoint = torch.load(checkpoint_path, map_location=self.device) self.model.load_state_dict(checkpoint['model_state_dict']) self.model = self.model.to(self.device) self.model.eval() # 체크포인트 정보 로깅 epoch = checkpoint.get('epoch') if epoch is not None: self.logger.info(f" - Checkpoint epoch: {epoch}") metrics = checkpoint.get('metrics') if isinstance(metrics, dict): acc = metrics.get('accuracy') f1 = metrics.get('f1') if isinstance(acc, (int, float)): self.logger.info(f" - Accuracy: {acc:.4f}") if isinstance(f1, (int, float)): self.logger.info(f" - F1 Score: {f1:.4f}") self.logger.info(f" - Fall threshold: {fall_threshold}") self.logger.info(f" - Device: {self.device}") def predict( self, window: np.ndarray, normalize: bool = True, debug: bool = False ) -> Tuple[int, float]: """ ST-GCN으로 낙상 예측 Args: window: (C, T, V, M) ST-GCN 입력 (C=3, T=60, V=17, M=1) normalize: hip center 정규화 적용 여부 debug: 디버그 로그 출력 여부 Returns: prediction: 0 (Non-Fall) or 1 (Fall) confidence: 예측 신뢰도 (0.0-1.0) """ # Normalize skeleton (hip center + skeleton size scaling) if normalize: window_input = normalize_skeleton(window, method='hip_center') else: window_input = window # ST-GCN inference window_tensor = torch.from_numpy(window_input).float().unsqueeze(0).to(self.device) # (1, C, T, V, M) with torch.no_grad(): outputs = self.model(window_tensor) probs = torch.softmax(outputs, dim=1) pred = torch.argmax(outputs, dim=1) prediction = pred.item() confidence = probs[0, prediction].item() if debug: self.logger.debug(f" ST-GCN prediction: {prediction} (conf={confidence:.3f})") return prediction, confidence def predict_batch( self, windows: list[np.ndarray], batch_size: int = 32, normalize: bool = True, debug: bool = False ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ ST-GCN 배치 낙상 예측 (GPU 활용 극대화) Args: windows: [(C, T, V, M), ...] ST-GCN 입력 윈도우 리스트 batch_size: GPU 배치 크기 (기본값: 32, OOM 방지용) normalize: hip center 정규화 적용 여부 debug: 디버그 로그 출력 여부 Returns: predictions: (N,) numpy array of 0 (Non-Fall) or 1 (Fall) confidences: (N,) numpy array of predicted class confidence (0.0-1.0) fall_probs: (N,) numpy array of Fall class probability (0.0-1.0) """ if not windows: return np.array([]), np.array([]), np.array([]) all_predictions = [] all_confidences = [] all_fall_probs = [] for chunk_start in range(0, len(windows), batch_size): chunk_windows = windows[chunk_start:chunk_start + batch_size] batch_list = [] for window in chunk_windows: if normalize: window_input = normalize_skeleton(window, method='hip_center') else: window_input = window batch_list.append(torch.from_numpy(window_input).float()) batch_tensor = torch.stack(batch_list).to(self.device) with torch.no_grad(): outputs = self.model(batch_tensor) probs = torch.softmax(outputs, dim=1) preds = torch.argmax(outputs, dim=1) predictions = preds.cpu().numpy() confidences = probs[torch.arange(len(preds)), preds].cpu().numpy() fall_probs = probs[:, 1].cpu().numpy() all_predictions.append(predictions) all_confidences.append(confidences) all_fall_probs.append(fall_probs) if debug: for i, (pred, conf, fall_p) in enumerate(zip(predictions, confidences, fall_probs)): global_idx = chunk_start + i self.logger.debug(f" Batch[{global_idx}] ST-GCN: pred={pred}, conf={conf:.3f}, fall_prob={fall_p:.3f}") return ( np.concatenate(all_predictions), np.concatenate(all_confidences), np.concatenate(all_fall_probs) ) def is_fall(self, prediction: int, confidence: float) -> bool: """ 낙상 여부 판정 Args: prediction: 모델 예측 (0 or 1) confidence: 예측 신뢰도 Returns: True if fall detected with sufficient confidence """ return prediction == 1 and confidence >= self.fall_threshold