| | from __future__ import annotations |
| |
|
| | import json |
| | import statistics |
| | from dataclasses import dataclass |
| | from datetime import datetime, timezone |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Tuple |
| | from uuid import uuid4 |
| |
|
| | from utils.patchad_filter import PatchADFilter |
| |
|
| | |
| | try: |
| | import sys |
| | base_dir = Path(__file__).parent.parent |
| | if str(base_dir) not in sys.path: |
| | sys.path.insert(0, str(base_dir)) |
| | from wearable_anomaly_detector import WearableAnomalyDetector |
| | except ImportError: |
| | WearableAnomalyDetector = None |
| |
|
| |
|
| | @dataclass |
| | class ValidationResult: |
| | passed: bool |
| | errors: List[str] |
| | warnings: List[str] |
| |
|
| | def to_dict(self) -> Dict[str, Any]: |
| | return { |
| | "passed": self.passed, |
| | "errors": self.errors, |
| | "warnings": self.warnings, |
| | } |
| |
|
| |
|
| | class CaseBuilder: |
| | """ |
| | 根据规范化 payload 生成 case JSON + Markdown 输入描述。 |
| | |
| | 支持配置主时序模型确认(自动/手工模式): |
| | - enabled=true, mode=auto: 自动调用主时序模型确认风险 |
| | - enabled=false: 只用 PatchAD 分数,不调用主时序模型 |
| | |
| | 支持两种模式: |
| | - 模式A:平台有PatchAD,直接传入所有数据(window_data + history_windows) |
| | - 模式B:平台没有PatchAD,通过 event_id 从 PrecheckServer 缓存获取 window_data |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | config_path: Optional[Path] = None, |
| | detector: Optional[Any] = None, |
| | precheck_server: Optional[Any] = None |
| | ) -> None: |
| | """ |
| | 初始化 CaseBuilder |
| | |
| | 参数: |
| | config_path: 配置文件路径,如果为None则使用默认配置 |
| | detector: 可选的主时序模型检测器实例(如果外部已加载) |
| | precheck_server: 可选的 PrecheckServer 实例,用于模式B从缓存获取 window_data |
| | """ |
| | self.patchad = PatchADFilter() |
| | self.detector = detector |
| | self.precheck_server = precheck_server |
| | self.config = self._load_config(config_path) |
| | |
| | |
| | if self.config.get("main_model_confirmation", {}).get("enabled", False): |
| | if self.detector is None and WearableAnomalyDetector is not None: |
| | model_dir = self.config["main_model_confirmation"].get("model_dir") |
| | if model_dir: |
| | |
| | base_dir = Path(__file__).parent.parent |
| | model_path = base_dir / model_dir if not Path(model_dir).is_absolute() else Path(model_dir) |
| | device = self.config["main_model_confirmation"].get("device", "cpu") |
| | threshold = self.config["main_model_confirmation"].get("threshold") |
| | |
| | try: |
| | self.detector = WearableAnomalyDetector( |
| | model_dir=model_path, |
| | device=device, |
| | threshold=threshold |
| | ) |
| | print(f"✅ 已自动加载主时序模型: {model_path}") |
| | except Exception as e: |
| | print(f"⚠️ 自动加载主时序模型失败: {e},将只使用 PatchAD 分数") |
| | self.detector = None |
| | |
| | def _load_config(self, config_path: Optional[Path]) -> Dict[str, Any]: |
| | """加载配置文件""" |
| | if config_path is None: |
| | config_path = Path(__file__).parent.parent / "configs" / "case_builder_config.json" |
| | |
| | if config_path.exists(): |
| | try: |
| | with open(config_path, 'r', encoding='utf-8') as f: |
| | return json.load(f) |
| | except Exception as e: |
| | print(f"⚠️ 加载 CaseBuilder 配置失败: {e},使用默认配置") |
| | |
| | |
| | return { |
| | "main_model_confirmation": { |
| | "enabled": False, |
| | "mode": "auto", |
| | "model_dir": "checkpoints/phase2/exp_factor_balanced", |
| | "device": "cpu", |
| | "threshold": None, |
| | "min_duration_days": 3 |
| | } |
| | } |
| |
|
| | def validate_payload(self, payload: Dict[str, Any]) -> ValidationResult: |
| | errors: List[str] = [] |
| | warnings: List[str] = [] |
| |
|
| | if not payload.get("user_id"): |
| | errors.append("缺少 user_id") |
| |
|
| | |
| | event_id = payload.get("event_id") |
| | window = payload.get("window_data") |
| | |
| | if event_id and self.precheck_server and not window: |
| | |
| | if self.precheck_server.has_pending(event_id): |
| | |
| | pending = self.precheck_server.pop_event(event_id) |
| | window = pending.get("window_data") |
| | payload["window_data"] = window |
| | if pending.get("user_id") and not payload.get("user_id"): |
| | payload["user_id"] = pending.get("user_id") |
| | |
| | if not payload.get("metadata", {}).get("patchad_score"): |
| | metadata = payload.get("metadata", {}) |
| | if "patchad_score" not in metadata: |
| | metadata["patchad_score"] = pending.get("metadata", {}).get("patchad_score", 0.0) |
| | if "threshold" not in metadata: |
| | metadata["threshold"] = pending.get("metadata", {}).get("threshold", 0.0) |
| | payload["metadata"] = metadata |
| | |
| | |
| | if not window or len(window) < 12: |
| | errors.append("window_data 不足 12 条(模式A需直接传入,模式B需通过 event_id 从缓存获取)") |
| |
|
| | profile = payload.get("user_profile") |
| | if not profile: |
| | errors.append("缺少 user_profile") |
| |
|
| | history = payload.get("history_windows", []) |
| | if not history: |
| | errors.append("缺少 history_windows(平台需自己合成历史窗口数据)") |
| |
|
| | metadata = payload.get("metadata", {}) |
| | if not metadata.get("detector"): |
| | warnings.append("metadata.detector 未提供,将标记为 unknown") |
| |
|
| | return ValidationResult(passed=not errors, errors=errors, warnings=warnings) |
| |
|
| | @staticmethod |
| | def _infer_date(window: List[Dict[str, Any]]) -> str: |
| | for point in window: |
| | ts = point.get("timestamp") |
| | if ts: |
| | return ts[:10] |
| | return datetime.now().strftime("%Y-%m-%d") |
| |
|
| | def _normalize_history(self, history_windows: List[Any]) -> List[Dict[str, Any]]: |
| | normalized = [] |
| | for entry in history_windows: |
| | if isinstance(entry, dict) and "window" in entry: |
| | window = entry["window"] |
| | date = entry.get("date") or self._infer_date(window) |
| | else: |
| | window = entry |
| | date = self._infer_date(window) |
| | normalized.append({"date": date, "window": window}) |
| | return normalized |
| |
|
| | @staticmethod |
| | def _avg_feature(window: List[Dict[str, Any]], feature_name: str) -> float: |
| | values = [] |
| | for point in window: |
| | features = point.get("features", {}) |
| | value = features.get(feature_name) |
| | if isinstance(value, (int, float)): |
| | values.append(float(value)) |
| | return statistics.fmean(values) if values else 0.0 |
| |
|
| | @staticmethod |
| | def _extract_steps(window: List[Dict[str, Any]]) -> float: |
| | steps = [] |
| | for point in window: |
| | features = point.get("features", {}) |
| | for key in ("steps", "step_count", "average_steps"): |
| | if key in features and isinstance(features[key], (int, float)): |
| | steps.append(float(features[key])) |
| | break |
| | return sum(steps) |
| |
|
| | def _build_daily_results(self, normalized_history: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| | daily_results = [] |
| | for item in normalized_history: |
| | window = item["window"] |
| | score = self.patchad.score_window(window) |
| | daily_results.append( |
| | { |
| | "date": item["date"], |
| | "hrv_rmssd": round(self._avg_feature(window, "hrv_rmssd"), 2), |
| | "hr": round(self._avg_feature(window, "hr"), 2), |
| | "anomaly_score": round(score, 4), |
| | } |
| | ) |
| | return daily_results |
| |
|
| | def _compute_baseline(self, normalized_history: List[Dict[str, Any]]) -> Dict[str, Any]: |
| | hrv_values = [] |
| | for item in normalized_history: |
| | for point in item["window"]: |
| | val = point.get("features", {}).get("baseline_hrv_mean") |
| | if isinstance(val, (int, float)): |
| | hrv_values.append(float(val)) |
| | personal_mean = statistics.fmean(hrv_values) if hrv_values else 0.0 |
| | personal_std = statistics.pstdev(hrv_values) if len(hrv_values) > 1 else 0.0 |
| | return { |
| | "personal_mean": round(personal_mean, 2), |
| | "personal_std": round(personal_std, 2), |
| | "group_mean": round(personal_mean * 0.9, 2) if personal_mean else 0.0, |
| | "record_count": len(hrv_values), |
| | "baseline_type": "personal", |
| | "baseline_reliability": "high" if len(hrv_values) >= 30 else "medium", |
| | } |
| |
|
| | @staticmethod |
| | def _related_indicators(normalized_history: List[Dict[str, Any]]) -> Dict[str, Any]: |
| | steps = [CaseBuilder._extract_steps(item["window"]) for item in normalized_history] |
| | avg_steps = statistics.fmean(steps) if steps else 0.0 |
| | activity_level = "中等" if avg_steps >= 3000 else ("低" if avg_steps < 500 else "未知") |
| |
|
| | return { |
| | "activity_level": { |
| | "level": activity_level, |
| | "avg_steps": round(avg_steps, 1), |
| | "trend": "increasing" if len(steps) > 1 and steps[-1] > steps[0] else "stable", |
| | }, |
| | "sleep_quality": {"available": False, "quality": "数据不可用"}, |
| | "stress_indicators": {"level": "unknown"}, |
| | } |
| |
|
| | def _anomaly_pattern(self, daily_results: List[Dict[str, Any]], metadata: Dict[str, Any]) -> Dict[str, Any]: |
| | scores = [item["anomaly_score"] for item in daily_results] |
| | if not scores: |
| | scores = [metadata.get("patchad_score", 0.0)] |
| | trend = "worsening" |
| | if len(scores) >= 2: |
| | trend = "improving" if scores[-1] < scores[0] else ("stable" if abs(scores[-1] - scores[0]) < 0.02 else "worsening") |
| | return { |
| | "type": metadata.get("anomaly_type", "continuous_anomaly"), |
| | "duration_days": len(daily_results), |
| | "trend": trend, |
| | "min_score": min(scores), |
| | "max_score": max(scores), |
| | "avg_score": round(statistics.fmean(scores), 4), |
| | "threshold": metadata.get("threshold", self.patchad.threshold), |
| | } |
| |
|
| | def _format_llm_input(self, case: Dict[str, Any]) -> str: |
| | lines = [ |
| | "# 健康异常检测分析报告", |
| | "", |
| | "## 异常概览", |
| | f"**异常类型**:{case['anomaly_pattern']['type']} ", |
| | f"**持续时长**:{case['anomaly_pattern']['duration_days']}天 ", |
| | f"**异常趋势**:{case['anomaly_pattern']['trend']} ", |
| | "", |
| | "## 异常评分分析", |
| | f"- **异常分数范围**:{case['anomaly_pattern']['min_score']:.4f} - {case['anomaly_pattern']['max_score']:.4f}", |
| | f"- **平均异常分数**:{case['anomaly_pattern']['avg_score']:.4f}", |
| | f"- **检测阈值**:{case['anomaly_pattern']['threshold']:.4f}", |
| | "", |
| | "## 当前生理状态评估", |
| | "| 指标 | 当前值 | 基线值 | 偏离基线 |", |
| | "|------|--------|--------|----------|", |
| | ] |
| | baseline = case["baseline_info"] |
| | current_hrv = case["daily_results"][-1]["hrv_rmssd"] if case["daily_results"] else 0.0 |
| | deviation_pct = 0.0 |
| | if baseline["personal_mean"]: |
| | deviation_pct = (current_hrv - baseline["personal_mean"]) / baseline["personal_mean"] * 100 |
| | lines.append(f"| HRV RMSSD | {current_hrv:.2f} ms | {baseline['personal_mean']:.2f} ms | {deviation_pct:+.1f}% |") |
| | lines.append("") |
| |
|
| | lines.append("## 历史趋势分析") |
| | lines.append("| 日期 | HRV (ms) | 心率 (bpm) | 异常分数 |") |
| | lines.append("|------|----------|------------|----------|") |
| | for record in case["daily_results"]: |
| | lines.append( |
| | f"| {record['date']} | {record['hrv_rmssd']:.2f} | {record['hr']:.2f} | {record['anomaly_score']:.4f} |" |
| | ) |
| |
|
| | lines.append("") |
| | lines.append("## 相关健康指标分析") |
| | rel = case["related_indicators"]["activity_level"] |
| | lines.append(f"- 活动水平:{rel['level']}(平均步数={rel['avg_steps']},趋势={rel['trend']})") |
| | lines.append("- 睡眠质量:数据不可用") |
| | lines.append("- 压力指标:暂无显著异常") |
| |
|
| | lines.append("") |
| | lines.append("## 用户背景信息") |
| | profile = case["user_profile"] |
| | lines.append(f"- 年龄:{profile.get('estimated_age')}岁({profile.get('age_group')})") |
| | lines.append(f"- 性别:{profile.get('sex')}") |
| | lines.append(f"- 运动习惯:{profile.get('exercise')}") |
| | lines.append(f"- 咖啡:{profile.get('coffee')} / 饮酒:{profile.get('drinking')}") |
| | lines.append(f"- 生物节律:MEQ={profile.get('MEQ')} ({profile.get('MEQ_type')})") |
| |
|
| | return "\n".join(lines) |
| |
|
| | def build_case(self, payload: Dict[str, Any]) -> Dict[str, Any]: |
| | validation = self.validate_payload(payload) |
| | if not validation.passed: |
| | raise ValueError("; ".join(validation.errors)) |
| |
|
| | normalized_history = self._normalize_history(payload["history_windows"]) |
| | daily_results = self._build_daily_results(normalized_history) |
| | baseline_info = self._compute_baseline(normalized_history) |
| | related_indicators = self._related_indicators(normalized_history) |
| | anomaly_pattern = self._anomaly_pattern(daily_results, payload.get("metadata", {})) |
| |
|
| | case = { |
| | "case_id": payload.get("event_id") or f"case_{uuid4().hex[:8]}", |
| | "user_id": payload["user_id"], |
| | "generated_at": datetime.now(timezone.utc).isoformat(), |
| | "anomaly_pattern": anomaly_pattern, |
| | "baseline_info": baseline_info, |
| | "related_indicators": related_indicators, |
| | "daily_results": daily_results, |
| | "user_profile": payload["user_profile"], |
| | "metadata": payload.get("metadata", {}), |
| | } |
| |
|
| | llm_input = self._format_llm_input(case) |
| |
|
| | |
| | risk_confirmed = None |
| | main_model_result = None |
| | |
| | if self.config.get("main_model_confirmation", {}).get("enabled", False): |
| | if self.detector: |
| | try: |
| | |
| | |
| | daily_data_points = [] |
| | for item in normalized_history: |
| | daily_data_points.append(item["window"]) |
| | |
| | if daily_data_points: |
| | pattern_result = self.detector.detect_pattern( |
| | daily_data_points, |
| | days=len(normalized_history), |
| | min_duration_days=self.config["main_model_confirmation"].get("min_duration_days", 3) |
| | ) |
| | |
| | |
| | risk_confirmed = pattern_result.get('anomaly_pattern', {}).get('has_pattern', False) |
| | main_model_result = pattern_result |
| | except Exception as e: |
| | print(f"⚠️ 主时序模型确认失败: {e},将只使用 PatchAD 分数") |
| | risk_confirmed = None |
| | main_model_result = None |
| |
|
| | return { |
| | "case": case, |
| | "llm_input": llm_input, |
| | "validation": validation.to_dict(), |
| | "risk_confirmed": risk_confirmed, |
| | "main_model_result": main_model_result, |
| | } |
| |
|
| |
|
| | def save_case_to_file(case_bundle: Dict[str, Any], output_dir: Path) -> Path: |
| | output_dir.mkdir(parents=True, exist_ok=True) |
| | case_id = case_bundle["case"]["case_id"] |
| | case_path = output_dir / f"{case_id}.json" |
| | with open(case_path, "w", encoding="utf-8") as f: |
| | json.dump(case_bundle, f, ensure_ascii=False, indent=2) |
| | return case_path |
| |
|
| |
|