Spaces:
Sleeping
Sleeping
| """ | |
| ST-GCN ๋์ ๋ถ๋ฅ๊ธฐ ๋ํผ ํด๋์ค | |
| Spatial-Temporal Graph Convolutional Network์ ์ด์ฉํ ๋์ ๋ถ๋ฅ๊ธฐ์ ๋๋ค. | |
| Note: HF Spaces ๋ฐฐํฌ์ฉ์ผ๋ก import ๊ฒฝ๋ก๊ฐ ์์ ๋์์ต๋๋ค. | |
| """ | |
| import logging | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| # HF Spaces ๋ฐฐํฌ์ฉ ์๋ import | |
| from augmentation import normalize_skeleton | |
| from stgcn.model import STGCN | |
| class STGCNClassifier: | |
| """ST-GCN ๊ธฐ๋ฐ ๋์ ๋ถ๋ฅ๊ธฐ""" | |
| def __init__( | |
| self, | |
| checkpoint_path: str = "runs/stgcn_binary_exp2_fixed_graph/best_acc.pth", | |
| fall_threshold: float = 0.7, | |
| device: str = "cuda:0", | |
| in_channels: int = 3, | |
| num_classes: int = 2, | |
| dropout: float = 0.5, | |
| logger: Optional[logging.Logger] = None | |
| ): | |
| """ | |
| Args: | |
| checkpoint_path: ST-GCN ์ฒดํฌํฌ์ธํธ ๊ฒฝ๋ก | |
| fall_threshold: ๋์ ํ์ ์ ๋ขฐ๋ ์๊ณ๊ฐ | |
| device: ๋๋ฐ์ด์ค (cuda:0, cpu ๋ฑ) | |
| in_channels: ์ ๋ ฅ ์ฑ๋ ์ (x, y, conf) | |
| num_classes: ์ถ๋ ฅ ํด๋์ค ์ (Fall, Non-Fall) | |
| dropout: ๋๋กญ์์ ๋น์จ | |
| logger: ๋ก๊ฑฐ ์ธ์คํด์ค | |
| """ | |
| self.device = torch.device(device if torch.cuda.is_available() else "cpu") | |
| self.fall_threshold = fall_threshold | |
| self.logger = logger or logging.getLogger(__name__) | |
| self.logger.info(f"[Stage 2] ST-GCN ๋ก๋ ์ค: {checkpoint_path}") | |
| # ๋ชจ๋ธ ์ด๊ธฐํ | |
| self.model = STGCN( | |
| in_channels=in_channels, | |
| num_classes=num_classes, | |
| graph_cfg={}, | |
| edge_importance_weighting=True, | |
| dropout=dropout | |
| ) | |
| # ์ฒดํฌํฌ์ธํธ ๋ก๋ | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| self.model = self.model.to(self.device) | |
| self.model.eval() | |
| # ์ฒดํฌํฌ์ธํธ ์ ๋ณด ๋ก๊น | |
| epoch = checkpoint.get('epoch') | |
| if epoch is not None: | |
| self.logger.info(f" - Checkpoint epoch: {epoch}") | |
| metrics = checkpoint.get('metrics') | |
| if isinstance(metrics, dict): | |
| acc = metrics.get('accuracy') | |
| f1 = metrics.get('f1') | |
| if isinstance(acc, (int, float)): | |
| self.logger.info(f" - Accuracy: {acc:.4f}") | |
| if isinstance(f1, (int, float)): | |
| self.logger.info(f" - F1 Score: {f1:.4f}") | |
| self.logger.info(f" - Fall threshold: {fall_threshold}") | |
| self.logger.info(f" - Device: {self.device}") | |
| def predict( | |
| self, | |
| window: np.ndarray, | |
| normalize: bool = True, | |
| debug: bool = False | |
| ) -> Tuple[int, float]: | |
| """ | |
| ST-GCN์ผ๋ก ๋์ ์์ธก | |
| Args: | |
| window: (C, T, V, M) ST-GCN ์ ๋ ฅ (C=3, T=60, V=17, M=1) | |
| normalize: hip center ์ ๊ทํ ์ ์ฉ ์ฌ๋ถ | |
| debug: ๋๋ฒ๊ทธ ๋ก๊ทธ ์ถ๋ ฅ ์ฌ๋ถ | |
| Returns: | |
| prediction: 0 (Non-Fall) or 1 (Fall) | |
| confidence: ์์ธก ์ ๋ขฐ๋ (0.0-1.0) | |
| """ | |
| # Normalize skeleton (hip center + skeleton size scaling) | |
| if normalize: | |
| window_input = normalize_skeleton(window, method='hip_center') | |
| else: | |
| window_input = window | |
| # ST-GCN inference | |
| window_tensor = torch.from_numpy(window_input).float().unsqueeze(0).to(self.device) # (1, C, T, V, M) | |
| with torch.no_grad(): | |
| outputs = self.model(window_tensor) | |
| probs = torch.softmax(outputs, dim=1) | |
| pred = torch.argmax(outputs, dim=1) | |
| prediction = pred.item() | |
| confidence = probs[0, prediction].item() | |
| if debug: | |
| self.logger.debug(f" ST-GCN prediction: {prediction} (conf={confidence:.3f})") | |
| return prediction, confidence | |
| def predict_batch( | |
| self, | |
| windows: list[np.ndarray], | |
| batch_size: int = 32, | |
| normalize: bool = True, | |
| debug: bool = False | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| """ | |
| ST-GCN ๋ฐฐ์น ๋์ ์์ธก (GPU ํ์ฉ ๊ทน๋ํ) | |
| Args: | |
| windows: [(C, T, V, M), ...] ST-GCN ์ ๋ ฅ ์๋์ฐ ๋ฆฌ์คํธ | |
| batch_size: GPU ๋ฐฐ์น ํฌ๊ธฐ (๊ธฐ๋ณธ๊ฐ: 32, OOM ๋ฐฉ์ง์ฉ) | |
| normalize: hip center ์ ๊ทํ ์ ์ฉ ์ฌ๋ถ | |
| debug: ๋๋ฒ๊ทธ ๋ก๊ทธ ์ถ๋ ฅ ์ฌ๋ถ | |
| Returns: | |
| predictions: (N,) numpy array of 0 (Non-Fall) or 1 (Fall) | |
| confidences: (N,) numpy array of predicted class confidence (0.0-1.0) | |
| fall_probs: (N,) numpy array of Fall class probability (0.0-1.0) | |
| """ | |
| if not windows: | |
| return np.array([]), np.array([]), np.array([]) | |
| all_predictions = [] | |
| all_confidences = [] | |
| all_fall_probs = [] | |
| for chunk_start in range(0, len(windows), batch_size): | |
| chunk_windows = windows[chunk_start:chunk_start + batch_size] | |
| batch_list = [] | |
| for window in chunk_windows: | |
| if normalize: | |
| window_input = normalize_skeleton(window, method='hip_center') | |
| else: | |
| window_input = window | |
| batch_list.append(torch.from_numpy(window_input).float()) | |
| batch_tensor = torch.stack(batch_list).to(self.device) | |
| with torch.no_grad(): | |
| outputs = self.model(batch_tensor) | |
| probs = torch.softmax(outputs, dim=1) | |
| preds = torch.argmax(outputs, dim=1) | |
| predictions = preds.cpu().numpy() | |
| confidences = probs[torch.arange(len(preds)), preds].cpu().numpy() | |
| fall_probs = probs[:, 1].cpu().numpy() | |
| all_predictions.append(predictions) | |
| all_confidences.append(confidences) | |
| all_fall_probs.append(fall_probs) | |
| if debug: | |
| for i, (pred, conf, fall_p) in enumerate(zip(predictions, confidences, fall_probs)): | |
| global_idx = chunk_start + i | |
| self.logger.debug(f" Batch[{global_idx}] ST-GCN: pred={pred}, conf={conf:.3f}, fall_prob={fall_p:.3f}") | |
| return ( | |
| np.concatenate(all_predictions), | |
| np.concatenate(all_confidences), | |
| np.concatenate(all_fall_probs) | |
| ) | |
| def is_fall(self, prediction: int, confidence: float) -> bool: | |
| """ | |
| ๋์ ์ฌ๋ถ ํ์ | |
| Args: | |
| prediction: ๋ชจ๋ธ ์์ธก (0 or 1) | |
| confidence: ์์ธก ์ ๋ขฐ๋ | |
| Returns: | |
| True if fall detected with sufficient confidence | |
| """ | |
| return prediction == 1 and confidence >= self.fall_threshold | |