Wearable_TimeSeries_Health_Monitor / wearable_anomaly_detector.py
kokemn's picture
Duplicate from oscarzhang/Wearable_TimeSeries_Health_Monitor
bae9e74
"""
Wearable健康异常检测模型 - 标准化封装
提供简单的API接口,用于实时异常检测
"""
import torch
import numpy as np
import json
import pickle
from pathlib import Path
from typing import Dict, List, Optional, Union
from datetime import datetime
import pandas as pd
# 添加项目根目录到路径
import sys
sys.path.insert(0, str(Path(__file__).parent.parent))
from models.phased_lstm_tft import PhasedLSTM_TFT, PhasedLSTM_TFT_WithEnhancedAnomalyDetection
from feature_calculator import FeatureCalculator
class WearableAnomalyDetector:
"""
Wearable健康异常检测器
使用示例:
detector = WearableAnomalyDetector(model_dir="checkpoints/phase2/exp_factor_balanced")
result = detector.predict(data_points)
"""
def __init__(
self,
model_dir: Union[str, Path],
device: Optional[str] = None,
threshold: Optional[float] = None
):
"""
初始化异常检测器
参数:
model_dir: 模型目录路径(包含best_model.pt和配置文件)
device: 设备('cuda'或'cpu'),如果为None则自动选择
threshold: 异常阈值,如果为None则从配置中读取
"""
self.model_dir = Path(model_dir)
self.device = torch.device(device or ('cuda' if torch.cuda.is_available() else 'cpu'))
# 加载配置
self.config = self._load_config()
# 确定阈值
if threshold is not None:
self.threshold = float(threshold)
else:
config_threshold = self.config.get('threshold')
if config_threshold is not None:
self.threshold = float(config_threshold)
else:
self.threshold = 0.53 # 默认阈值
print(f" ⚠️ 未找到阈值配置,使用默认值: {self.threshold:.4f}")
# 加载模型
self.model = self._load_model()
self.model.eval()
# 加载归一化参数(维持向后兼容)
self.norm_params = self._load_norm_params()
# 配置驱动特征计算
self.feature_calculator = FeatureCalculator(
config_path=self.config.get('feature_config_path'),
norm_params_path=Path(__file__).parent / 'processed_data' / 'stage3' / 'norm_params.json',
static_features_path=Path(__file__).parent / 'processed_data' / 'stage2' / 'static_features.csv',
storage_dir=Path(self.config.get('storage_dir', Path(__file__).parent / 'data_storage'))
)
self.features = self.feature_calculator.get_enabled_feature_names()
self.static_feature_names = [cfg["name"] for cfg in self.feature_calculator.static_feature_defs]
self.known_future_dim = max(len(self.feature_calculator.known_future_defs), 1)
self.factor_metadata = {
'enabled': self.feature_calculator.factor_enabled,
'factor_names': self.feature_calculator.factor_names,
'factor_dim': self.feature_calculator.factor_dim
}
print(f"✅ 模型加载成功")
print(f" - 设备: {self.device}")
print(f" - 阈值: {self.threshold:.4f}")
print(f" - 特征数: {len(self.features)}")
def _load_config(self) -> Dict:
"""加载模型配置"""
config_file = self.model_dir / 'config.json'
if config_file.exists():
with open(config_file, 'r') as f:
config = json.load(f)
return config
# 尝试从summary.json读取
summary_file = self.model_dir / 'summary.json'
if summary_file.exists():
with open(summary_file, 'r') as f:
summary = json.load(f)
config = {
'threshold': summary.get('best_threshold'),
'features': [], # 需要从其他地方获取
}
return config
# 如果都没有,返回空配置(使用默认值)
print(f" ⚠️ 未找到配置文件,使用默认配置")
return {}
def _load_model(self):
"""加载模型"""
# 加载Phase1模型
phase1_model_path = self.model_dir.parent.parent / 'phase1' / 'best_model.pt'
if not phase1_model_path.exists():
raise FileNotFoundError(f"Phase1模型不存在: {phase1_model_path}")
checkpoint_phase1 = torch.load(phase1_model_path, map_location=self.device, weights_only=False)
phase1_config = checkpoint_phase1['config']
base_model = PhasedLSTM_TFT(phase1_config)
base_model.load_state_dict(checkpoint_phase1['model_state_dict'])
base_model = base_model.to(self.device)
# 加载factor_config
factor_config = self._load_factor_config()
# 创建Phase2模型
model = PhasedLSTM_TFT_WithEnhancedAnomalyDetection(
base_model,
num_anomaly_types=4,
use_enhanced_head=True,
use_multi_source_heads=False,
use_domain_adversarial=False,
factor_config=factor_config
)
model = model.to(self.device)
# 加载Phase2权重
phase2_model_path = self.model_dir / 'best_model.pt'
if not phase2_model_path.exists():
raise FileNotFoundError(f"Phase2模型不存在: {phase2_model_path}")
checkpoint_phase2 = torch.load(phase2_model_path, map_location=self.device, weights_only=False)
model.load_state_dict(checkpoint_phase2['model_state_dict'])
return model
def _load_factor_config(self) -> Optional[Dict]:
"""加载因子特征配置"""
# 方法1: 从config.json读取(如果已加载)
if hasattr(self, 'factor_metadata') and self.factor_metadata:
if self.factor_metadata.get('enabled'):
return {
'num_factors': len(self.factor_metadata.get('factor_names', [])),
'factor_dim': self.factor_metadata.get('factor_dim', 0),
'factor_names': self.factor_metadata.get('factor_names', []),
'min_weight': 0.2,
'dropout': 0.1,
}
# 方法2: 从窗口信息文件读取
window_info_file = Path(__file__).parent / 'processed_data' / 'stage3' / 'window_info_multi_scale.json'
if window_info_file.exists():
with open(window_info_file, 'r') as f:
window_info = json.load(f)
factor_metadata = window_info.get('factor_features', {})
if factor_metadata and factor_metadata.get('enabled'):
return {
'num_factors': len(factor_metadata.get('factor_names', [])),
'factor_dim': factor_metadata.get('factor_dim', 0),
'factor_names': factor_metadata.get('factor_names', []),
'min_weight': 0.2,
'dropout': 0.1,
}
return None
def _load_norm_params(self) -> Optional[Dict]:
"""加载归一化参数"""
norm_file = Path(__file__).parent / 'processed_data' / 'stage3' / 'norm_params.json'
if norm_file.exists():
with open(norm_file, 'r') as f:
return json.load(f)
return None
def predict(
self,
data_points: List[Dict],
return_score: bool = True,
return_details: bool = False
) -> Dict:
"""
预测异常
参数:
data_points: 数据点列表,每个数据点是一个字典,包含:
- timestamp: 时间戳(datetime或字符串)
- features: 特征字典,包含所有需要的特征值
- static_features: 静态特征字典(可选)
return_score: 是否返回异常分数
return_details: 是否返回详细信息
返回:
{
'is_anomaly': bool, # 是否异常
'anomaly_score': float, # 异常分数(0-1)
'threshold': float, # 使用的阈值
'details': dict (可选) # 详细信息
}
"""
user_id = data_points[0].get('deviceId') or data_points[0].get('user_id')
window = self.feature_calculator.build_window(data_points, user_id=user_id)
# 转换为模型输入格式
model_input = self._prepare_model_input(window)
# 模型预测
with torch.no_grad():
# 模型forward方法接受位置参数,需要按顺序传递
outputs = self.model(
model_input['x'],
model_input['delta_t'],
model_input['static_features'],
model_input['known_future_features'],
mask=model_input.get('mask'),
return_contrastive_features=model_input.get('return_contrastive_features', False),
source=None,
return_domain_features=False,
factor_features=model_input.get('factor_features')
)
anomaly_score = outputs['anomaly_score'].cpu().item()
# 判断是否异常
is_anomaly = anomaly_score >= self.threshold
result = {
'is_anomaly': bool(is_anomaly),
'threshold': float(self.threshold),
}
if return_score:
result['anomaly_score'] = float(anomaly_score)
if return_details:
result['details'] = {
'window_size': len(data_points),
'model_output': float(anomaly_score),
'prediction_confidence': abs(anomaly_score - self.threshold),
}
return result
def _prepare_model_input(self, window: Dict) -> Dict:
"""准备模型输入"""
input_features_list = []
for feat in self.features:
values = window['input_features'].get(feat, [0.0] * 12)
input_features_list.append(values)
# 转换为tensor
input_features = torch.tensor(
np.stack(input_features_list, axis=1),
dtype=torch.float32
).unsqueeze(0).to(self.device) # [1, 12, num_features]
delta_t = torch.tensor(
window['input_delta_t'],
dtype=torch.float32
).unsqueeze(-1).unsqueeze(0).to(self.device) # [1, 12, 1]
# 静态特征
static_feature_values = []
static_keys = self.static_feature_names or sorted(window['static_features'].keys())
for key in static_keys:
value = window['static_features'].get(key, 0.0)
static_feature_values.append(float(value))
if len(static_feature_values) == 0:
static_feature_values = [0.0]
static_features = torch.tensor(
static_feature_values,
dtype=torch.float32
).unsqueeze(0).to(self.device) # [1, num_static]
# 已知未来特征
pred_len = len(window.get('target_timestamp', []))
if pred_len == 0:
pred_len = 6 # 默认预测长度
known_future = torch.zeros(1, pred_len, self.known_future_dim, dtype=torch.float32).to(self.device)
if 'known_future_features' in window:
kf = window['known_future_features']
for idx, cfg in enumerate(self.feature_calculator.known_future_defs):
name = cfg['name']
if name in kf:
series = kf[name][:pred_len]
if name == 'hour_of_day':
values = torch.tensor([float(h) / 23.0 for h in series], dtype=torch.float32)
elif name == 'day_of_week':
values = torch.tensor([float(d) / 6.0 for d in series], dtype=torch.float32)
else:
values = torch.tensor([float(v) for v in series], dtype=torch.float32)
known_future[0, :len(series), idx] = values
# 输入mask(假设所有数据都有效)
input_mask = torch.ones(1, 12, len(self.features), dtype=torch.float32).to(self.device)
# 因子特征
factor_features = None
if window.get('factor_features'):
factor_names = self.factor_metadata.get('factor_names', [])
factor_dim = self.factor_metadata.get('factor_dim', 4)
factor_vectors = []
for name in factor_names:
vec = window['factor_features'].get(name, [0.0] * factor_dim)
factor_vectors.append(vec[:factor_dim])
if factor_vectors:
factor_features = torch.tensor(
factor_vectors,
dtype=torch.float32
).unsqueeze(0).to(self.device) # [1, num_factors, factor_dim]
return {
'x': input_features,
'delta_t': delta_t,
'static_features': static_features,
'known_future_features': known_future,
'mask': input_mask,
'factor_features': factor_features,
'return_contrastive_features': False,
'source': None,
'return_domain_features': False,
}
def batch_predict(
self,
windows: List[List[Dict]],
return_scores: bool = True
) -> List[Dict]:
"""
批量预测
参数:
windows: 窗口列表,每个窗口是一个数据点列表
return_scores: 是否返回异常分数
返回:
预测结果列表
"""
results = []
for window_data in windows:
result = self.predict(window_data, return_score=return_scores)
results.append(result)
return results
def update_threshold(self, threshold: float):
"""更新异常阈值"""
self.threshold = threshold
print(f"✅ 阈值已更新为: {threshold:.4f}")
def load_detector(model_dir: Union[str, Path], **kwargs) -> WearableAnomalyDetector:
"""
便捷函数:加载异常检测器
参数:
model_dir: 模型目录路径
**kwargs: 其他参数(device, threshold等)
返回:
WearableAnomalyDetector实例
"""
return WearableAnomalyDetector(model_dir, **kwargs)
if __name__ == '__main__':
# 使用示例
print("=" * 80)
print("Wearable健康异常检测器 - 使用示例")
print("=" * 80)
# 加载模型
model_dir = Path(__file__).parent / 'checkpoints' / 'phase2' / 'exp_factor_balanced'
detector = load_detector(model_dir)
# 模拟数据点(实际使用时应该从实时数据流获取)
print("\n模拟数据点...")
data_points = []
base_time = datetime.now()
# 使用一个真实的deviceId(如果静态特征表存在)
# 或者提供一个完整的静态特征示例
example_device_id = None
static_dict = detector.feature_calculator.static_features_dict
if static_dict:
example_device_id = list(static_dict.keys())[0]
print(f" 使用示例用户ID: {example_device_id}")
for i in range(12):
data_point = {
'timestamp': base_time.replace(minute=i*5),
'deviceId': example_device_id, # 提供deviceId以便加载完整静态特征
'features': {
'hr': 70.0 + np.random.randn() * 5,
'hrv_rmssd': 30.0 + np.random.randn() * 3,
# ... 其他特征(简化示例,实际需要所有36个特征)
},
'static_features': {
# 可以只提供部分特征,系统会自动从静态特征表补充
# 或者不提供,完全从静态特征表加载
}
}
data_points.append(data_point)
# 预测
result = detector.predict(data_points, return_score=True, return_details=True)
print(f"\n预测结果:")
print(f" - 是否异常: {result['is_anomaly']}")
print(f" - 异常分数: {result['anomaly_score']:.4f}")
print(f" - 阈值: {result['threshold']:.4f}")
if 'details' in result:
print(f" - 详细信息: {result['details']}")