| """ |
| Wearable健康异常检测模型 - 标准化封装 |
| 提供简单的API接口,用于实时异常检测 |
| """ |
|
|
| import torch |
| import numpy as np |
| import json |
| import pickle |
| from pathlib import Path |
| from typing import Dict, List, Optional, Union |
| from datetime import datetime |
| import pandas as pd |
|
|
| |
| import sys |
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| from models.phased_lstm_tft import PhasedLSTM_TFT, PhasedLSTM_TFT_WithEnhancedAnomalyDetection |
| from feature_calculator import FeatureCalculator |
|
|
|
|
| class WearableAnomalyDetector: |
| """ |
| Wearable健康异常检测器 |
| |
| 使用示例: |
| detector = WearableAnomalyDetector(model_dir="checkpoints/phase2/exp_factor_balanced") |
| result = detector.predict(data_points) |
| """ |
| |
| def __init__( |
| self, |
| model_dir: Union[str, Path], |
| device: Optional[str] = None, |
| threshold: Optional[float] = None |
| ): |
| """ |
| 初始化异常检测器 |
| |
| 参数: |
| model_dir: 模型目录路径(包含best_model.pt和配置文件) |
| device: 设备('cuda'或'cpu'),如果为None则自动选择 |
| threshold: 异常阈值,如果为None则从配置中读取 |
| """ |
| self.model_dir = Path(model_dir) |
| self.device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu')) |
| |
| |
| self.config = self._load_config() |
| |
| |
| if threshold is not None: |
| self.threshold = float(threshold) |
| else: |
| config_threshold = self.config.get('threshold') |
| if config_threshold is not None: |
| self.threshold = float(config_threshold) |
| else: |
| self.threshold = 0.53 |
| print(f" ⚠️ 未找到阈值配置,使用默认值: {self.threshold:.4f}") |
| |
| |
| self.model = self._load_model() |
| self.model.eval() |
| |
| |
| self.norm_params = self._load_norm_params() |
| |
| |
| self.feature_calculator = FeatureCalculator( |
| config_path=self.config.get('feature_config_path'), |
| norm_params_path=Path(__file__).parent / 'processed_data' / 'stage3' / 'norm_params.json', |
| static_features_path=Path(__file__).parent / 'processed_data' / 'stage2' / 'static_features.csv', |
| storage_dir=Path(self.config.get('storage_dir', Path(__file__).parent / 'data_storage')) |
| ) |
| self.features = self.feature_calculator.get_enabled_feature_names() |
| self.static_feature_names = [cfg["name"] for cfg in self.feature_calculator.static_feature_defs] |
| self.known_future_dim = max(len(self.feature_calculator.known_future_defs), 1) |
| self.factor_metadata = { |
| 'enabled': self.feature_calculator.factor_enabled, |
| 'factor_names': self.feature_calculator.factor_names, |
| 'factor_dim': self.feature_calculator.factor_dim |
| } |
| |
| print(f"✅ 模型加载成功") |
| print(f" - 设备: {self.device}") |
| print(f" - 阈值: {self.threshold:.4f}") |
| print(f" - 特征数: {len(self.features)}") |
| |
| def _load_config(self) -> Dict: |
| """加载模型配置""" |
| config_file = self.model_dir / 'config.json' |
| if config_file.exists(): |
| with open(config_file, 'r') as f: |
| config = json.load(f) |
| return config |
| |
| |
| summary_file = self.model_dir / 'summary.json' |
| if summary_file.exists(): |
| with open(summary_file, 'r') as f: |
| summary = json.load(f) |
| config = { |
| 'threshold': summary.get('best_threshold'), |
| 'features': [], |
| } |
| return config |
| |
| |
| print(f" ⚠️ 未找到配置文件,使用默认配置") |
| return {} |
| |
| def _load_model(self): |
| """加载模型""" |
| |
| phase1_model_path = self.model_dir.parent.parent / 'phase1' / 'best_model.pt' |
| if not phase1_model_path.exists(): |
| raise FileNotFoundError(f"Phase1模型不存在: {phase1_model_path}") |
| |
| checkpoint_phase1 = torch.load(phase1_model_path, map_location=self.device, weights_only=False) |
| phase1_config = checkpoint_phase1['config'] |
| |
| base_model = PhasedLSTM_TFT(phase1_config) |
| base_model.load_state_dict(checkpoint_phase1['model_state_dict']) |
| base_model = base_model.to(self.device) |
| |
| |
| factor_config = self._load_factor_config() |
| |
| |
| model = PhasedLSTM_TFT_WithEnhancedAnomalyDetection( |
| base_model, |
| num_anomaly_types=4, |
| use_enhanced_head=True, |
| use_multi_source_heads=False, |
| use_domain_adversarial=False, |
| factor_config=factor_config |
| ) |
| model = model.to(self.device) |
| |
| |
| phase2_model_path = self.model_dir / 'best_model.pt' |
| if not phase2_model_path.exists(): |
| raise FileNotFoundError(f"Phase2模型不存在: {phase2_model_path}") |
| |
| checkpoint_phase2 = torch.load(phase2_model_path, map_location=self.device, weights_only=False) |
| model.load_state_dict(checkpoint_phase2['model_state_dict']) |
| |
| return model |
| |
| def _load_factor_config(self) -> Optional[Dict]: |
| """加载因子特征配置""" |
| |
| if hasattr(self, 'factor_metadata') and self.factor_metadata: |
| if self.factor_metadata.get('enabled'): |
| return { |
| 'num_factors': len(self.factor_metadata.get('factor_names', [])), |
| 'factor_dim': self.factor_metadata.get('factor_dim', 0), |
| 'factor_names': self.factor_metadata.get('factor_names', []), |
| 'min_weight': 0.2, |
| 'dropout': 0.1, |
| } |
| |
| |
| window_info_file = Path(__file__).parent / 'processed_data' / 'stage3' / 'window_info_multi_scale.json' |
| if window_info_file.exists(): |
| with open(window_info_file, 'r') as f: |
| window_info = json.load(f) |
| factor_metadata = window_info.get('factor_features', {}) |
| if factor_metadata and factor_metadata.get('enabled'): |
| return { |
| 'num_factors': len(factor_metadata.get('factor_names', [])), |
| 'factor_dim': factor_metadata.get('factor_dim', 0), |
| 'factor_names': factor_metadata.get('factor_names', []), |
| 'min_weight': 0.2, |
| 'dropout': 0.1, |
| } |
| return None |
| |
| def _load_norm_params(self) -> Optional[Dict]: |
| """加载归一化参数""" |
| norm_file = Path(__file__).parent / 'processed_data' / 'stage3' / 'norm_params.json' |
| if norm_file.exists(): |
| with open(norm_file, 'r') as f: |
| return json.load(f) |
| return None |
| |
| def predict( |
| self, |
| data_points: List[Dict], |
| return_score: bool = True, |
| return_details: bool = False |
| ) -> Dict: |
| """ |
| 预测异常 |
| |
| 参数: |
| data_points: 数据点列表,每个数据点是一个字典,包含: |
| - timestamp: 时间戳(datetime或字符串) |
| - features: 特征字典,包含所有需要的特征值 |
| - static_features: 静态特征字典(可选) |
| return_score: 是否返回异常分数 |
| return_details: 是否返回详细信息 |
| |
| 返回: |
| { |
| 'is_anomaly': bool, # 是否异常 |
| 'anomaly_score': float, # 异常分数(0-1) |
| 'threshold': float, # 使用的阈值 |
| 'details': dict (可选) # 详细信息 |
| } |
| """ |
| user_id = data_points[0].get('deviceId') or data_points[0].get('user_id') |
| window = self.feature_calculator.build_window(data_points, user_id=user_id) |
| |
| |
| model_input = self._prepare_model_input(window) |
| |
| |
| with torch.no_grad(): |
| |
| outputs = self.model( |
| model_input['x'], |
| model_input['delta_t'], |
| model_input['static_features'], |
| model_input['known_future_features'], |
| mask=model_input.get('mask'), |
| return_contrastive_features=model_input.get('return_contrastive_features', False), |
| source=None, |
| return_domain_features=False, |
| factor_features=model_input.get('factor_features') |
| ) |
| anomaly_score = outputs['anomaly_score'].cpu().item() |
| |
| |
| is_anomaly = anomaly_score >= self.threshold |
| |
| result = { |
| 'is_anomaly': bool(is_anomaly), |
| 'threshold': float(self.threshold), |
| } |
| |
| if return_score: |
| result['anomaly_score'] = float(anomaly_score) |
| |
| if return_details: |
| result['details'] = { |
| 'window_size': len(data_points), |
| 'model_output': float(anomaly_score), |
| 'prediction_confidence': abs(anomaly_score - self.threshold), |
| } |
| |
| return result |
| |
| def _prepare_model_input(self, window: Dict) -> Dict: |
| """准备模型输入""" |
| input_features_list = [] |
| for feat in self.features: |
| values = window['input_features'].get(feat, [0.0] * 12) |
| input_features_list.append(values) |
| |
| |
| input_features = torch.tensor( |
| np.stack(input_features_list, axis=1), |
| dtype=torch.float32 |
| ).unsqueeze(0).to(self.device) |
| |
| delta_t = torch.tensor( |
| window['input_delta_t'], |
| dtype=torch.float32 |
| ).unsqueeze(-1).unsqueeze(0).to(self.device) |
| |
| |
| static_feature_values = [] |
| static_keys = self.static_feature_names or sorted(window['static_features'].keys()) |
| for key in static_keys: |
| value = window['static_features'].get(key, 0.0) |
| static_feature_values.append(float(value)) |
| |
| if len(static_feature_values) == 0: |
| static_feature_values = [0.0] |
| |
| static_features = torch.tensor( |
| static_feature_values, |
| dtype=torch.float32 |
| ).unsqueeze(0).to(self.device) |
| |
| |
| pred_len = len(window.get('target_timestamp', [])) |
| if pred_len == 0: |
| pred_len = 6 |
| |
| known_future = torch.zeros(1, pred_len, self.known_future_dim, dtype=torch.float32).to(self.device) |
| if 'known_future_features' in window: |
| kf = window['known_future_features'] |
| for idx, cfg in enumerate(self.feature_calculator.known_future_defs): |
| name = cfg['name'] |
| if name in kf: |
| series = kf[name][:pred_len] |
| if name == 'hour_of_day': |
| values = torch.tensor([float(h) / 23.0 for h in series], dtype=torch.float32) |
| elif name == 'day_of_week': |
| values = torch.tensor([float(d) / 6.0 for d in series], dtype=torch.float32) |
| else: |
| values = torch.tensor([float(v) for v in series], dtype=torch.float32) |
| known_future[0, :len(series), idx] = values |
| |
| |
| input_mask = torch.ones(1, 12, len(self.features), dtype=torch.float32).to(self.device) |
| |
| |
| factor_features = None |
| if window.get('factor_features'): |
| factor_names = self.factor_metadata.get('factor_names', []) |
| factor_dim = self.factor_metadata.get('factor_dim', 4) |
| factor_vectors = [] |
| for name in factor_names: |
| vec = window['factor_features'].get(name, [0.0] * factor_dim) |
| factor_vectors.append(vec[:factor_dim]) |
| if factor_vectors: |
| factor_features = torch.tensor( |
| factor_vectors, |
| dtype=torch.float32 |
| ).unsqueeze(0).to(self.device) |
| |
| return { |
| 'x': input_features, |
| 'delta_t': delta_t, |
| 'static_features': static_features, |
| 'known_future_features': known_future, |
| 'mask': input_mask, |
| 'factor_features': factor_features, |
| 'return_contrastive_features': False, |
| 'source': None, |
| 'return_domain_features': False, |
| } |
| |
| def batch_predict( |
| self, |
| windows: List[List[Dict]], |
| return_scores: bool = True |
| ) -> List[Dict]: |
| """ |
| 批量预测 |
| |
| 参数: |
| windows: 窗口列表,每个窗口是一个数据点列表 |
| return_scores: 是否返回异常分数 |
| |
| 返回: |
| 预测结果列表 |
| """ |
| results = [] |
| for window_data in windows: |
| result = self.predict(window_data, return_score=return_scores) |
| results.append(result) |
| return results |
| |
| def update_threshold(self, threshold: float): |
| """更新异常阈值""" |
| self.threshold = threshold |
| print(f"✅ 阈值已更新为: {threshold:.4f}") |
|
|
|
|
| def load_detector(model_dir: Union[str, Path], **kwargs) -> WearableAnomalyDetector: |
| """ |
| 便捷函数:加载异常检测器 |
| |
| 参数: |
| model_dir: 模型目录路径 |
| **kwargs: 其他参数(device, threshold等) |
| |
| 返回: |
| WearableAnomalyDetector实例 |
| """ |
| return WearableAnomalyDetector(model_dir, **kwargs) |
|
|
|
|
| if __name__ == '__main__': |
| |
| print("=" * 80) |
| print("Wearable健康异常检测器 - 使用示例") |
| print("=" * 80) |
| |
| |
| model_dir = Path(__file__).parent / 'checkpoints' / 'phase2' / 'exp_factor_balanced' |
| detector = load_detector(model_dir) |
| |
| |
| print("\n模拟数据点...") |
| data_points = [] |
| base_time = datetime.now() |
| |
| |
| |
| example_device_id = None |
| static_dict = detector.feature_calculator.static_features_dict |
| if static_dict: |
| example_device_id = list(static_dict.keys())[0] |
| print(f" 使用示例用户ID: {example_device_id}") |
| |
| for i in range(12): |
| data_point = { |
| 'timestamp': base_time.replace(minute=i*5), |
| 'deviceId': example_device_id, |
| 'features': { |
| 'hr': 70.0 + np.random.randn() * 5, |
| 'hrv_rmssd': 30.0 + np.random.randn() * 3, |
| |
| }, |
| 'static_features': { |
| |
| |
| } |
| } |
| data_points.append(data_point) |
| |
| |
| result = detector.predict(data_points, return_score=True, return_details=True) |
| |
| print(f"\n预测结果:") |
| print(f" - 是否异常: {result['is_anomaly']}") |
| print(f" - 异常分数: {result['anomaly_score']:.4f}") |
| print(f" - 阈值: {result['threshold']:.4f}") |
| if 'details' in result: |
| print(f" - 详细信息: {result['details']}") |
|
|
|
|