Wearable_TimeSeries_Health_Monitor / simulate_patchad_case_pipeline.py
oscarzhang's picture
Upload folder using huggingface_hub
23bb099 verified
#!/usr/bin/env python3
"""
simulate_patchad_case_pipeline.py
用现有 demo 数据演示两条平台接入路径:
1. 平台自带 PatchTrAD/PatchAD:直接调用 build_case。
2. 平台依赖官方 precheck:先请求预筛接口,再补传数据生成 case。
脚本会输出:
- 预筛得分与状态
- build_case 返回的 case 结构摘录
- 生成的 Markdown 片段
- 验证结果
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List
from utils.case_builder import CaseBuilder
from utils.patchad_filter import PatchADFilter, PrecheckServer
DATA_DIR = Path(__file__).parent / "data_storage" / "users"
DEFAULT_FILE = DATA_DIR / "demo_pattern.jsonl"
def load_jsonl(path: Path) -> List[Dict[str, Any]]:
records = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
records.append(json.loads(line))
return records
def split_windows(records: List[Dict[str, Any]], days: int = 3) -> List[Dict[str, Any]]:
by_date: Dict[str, List[Dict[str, Any]]] = {}
for rec in records:
ts = rec.get("timestamp", "")
date = ts[:10]
by_date.setdefault(date, []).append(rec)
windows = []
sorted_dates = sorted(by_date.keys())
for date in sorted_dates:
day_records = by_date[date][:12]
if len(day_records) >= 12:
windows.append({"date": date, "window": day_records})
if len(windows) >= days:
break
if windows and 0 < len(windows) < days:
print(f"⚠️ 历史窗口不足 {days} 天,自动复制最后一天填充示例。")
last = windows[-1]
for idx in range(days - len(windows)):
clone = {
"date": f"{last['date']}+copy{idx+1}",
"window": last["window"],
}
windows.append(clone)
return windows
def sample_profile() -> Dict[str, Any]:
return {
"age_group": "30-35岁",
"estimated_age": 33,
"sex": "男性",
"exercise": "每周5次以上",
"coffee": "偶尔",
"smoking": "不吸烟",
"drinking": "经常饮酒",
"MEQ": 62.0,
"MEQ_type": "晨型",
}
def save_demo_case(save_dir: Path, sample_name: str, mode: str, bundle: Dict[str, Any], precheck: Dict[str, Any]) -> None:
save_dir.mkdir(parents=True, exist_ok=True)
file_path = save_dir / f"{sample_name}_{mode}.json"
payload = {
"mode": mode,
"sample": sample_name,
"precheck": precheck,
"case_bundle": bundle,
}
with open(file_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
manifest_path = save_dir / "manifest.json"
manifest = []
if manifest_path.exists():
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json.load(f)
manifest = [item for item in manifest if not (item.get("mode") == mode and item.get("sample") == sample_name)]
title = f"{sample_name} · {'模式A:平台自带' if mode == 'platform' else '模式B:官方预筛'}"
manifest.append({"sample": sample_name, "mode": mode, "title": title, "file": file_path.name})
with open(manifest_path, "w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=False, indent=2)
def run_mode_platform(records: List[Dict[str, Any]], sample_name: str, save_dir: Path | None) -> Dict[str, Any] | None:
print("\n=== 模式A:平台自带 PatchTrAD → 直接调用 build_case ===")
windows = split_windows(records, days=3)
if len(windows) < 2:
raise RuntimeError("样例数据不足以生成历史窗口")
current_window = windows[-1]["window"]
patchad = PatchADFilter(score_threshold=0.18)
score = patchad.score_window(current_window)
status = "abnormal" if score >= patchad.threshold else "normal"
print(f"预筛得分: {score:.4f} (threshold={patchad.threshold}) -> {status}")
if status != "abnormal":
print("窗口未触发异常,示例终止。")
return None
payload = {
"event_id": "platform-demo-evt",
"user_id": current_window[0].get("deviceId", "demo_user"),
"window_data": current_window,
"user_profile": sample_profile(),
"history_windows": windows,
"metadata": {"detector": "platform_patchad", "score": score},
}
builder = CaseBuilder()
result = builder.build_case(payload)
print(f"验证结果: {result['validation']}")
print(f"Case ID: {result['case']['case_id']}")
print(f"LLM 输入预览:\n{result['llm_input'][:400]}...\n")
if save_dir:
precheck_info = {"status": status, "score": score, "threshold": patchad.threshold}
save_demo_case(save_dir, sample_name, "platform", result, precheck_info)
print(f"✅ 已保存示例到 {save_dir}/{sample_name}_platform.json")
return result
def run_mode_official(records: List[Dict[str, Any]], sample_name: str, save_dir: Path | None) -> Dict[str, Any] | None:
print("\n=== 模式B:官方 precheck + build_case ===")
print("步骤1:平台传入当前窗口到 precheck 接口")
current_window = records[-12:] if len(records) >= 12 else records
user_id = current_window[0].get("deviceId", "demo_user") if current_window else "demo_user"
server = PrecheckServer(score_threshold=0.18)
pre_result = server.precheck(user_id, current_window)
print(f"预筛接口返回: {pre_result}")
if pre_result["status"] != "abnormal":
print("预筛未触发异常,示例终止。")
return None
print(f"\n步骤2:平台收到 event_id={pre_result['event_id']},平台内部合成历史窗口数据")
# 模拟平台自己合成历史窗口(从历史数据平台获取)
windows = split_windows(records, days=3)
if len(windows) < 2:
raise RuntimeError("样例数据不足以生成历史窗口")
print(f"平台合成历史窗口: {len(windows)} 天")
print(f"\n步骤3:平台调用 build_case,传入 event_id 和合成的历史窗口(不传 window_data,系统从缓存获取)")
payload = {
"event_id": pre_result["event_id"], # 关键:只传 event_id,不传 window_data
"user_id": user_id,
"user_profile": sample_profile(),
"history_windows": windows, # 平台自己合成的历史窗口
"metadata": {
"detector": "official_patchad",
"patchad_score": pre_result["score"],
"threshold": pre_result["threshold"],
},
}
builder = CaseBuilder(precheck_server=server) # 传入 PrecheckServer 实例
result = builder.build_case(payload)
print(f"验证结果: {result['validation']}")
print(f"Case ID: {result['case']['case_id']}")
print(f"LLM 输入预览:\n{result['llm_input'][:400]}...\n")
if save_dir:
precheck_info = pre_result
save_demo_case(save_dir, sample_name, "official", result, precheck_info)
print(f"✅ 已保存示例到 {save_dir}/{sample_name}_official.json")
return result
def run_negative_test(records: List[Dict[str, Any]]) -> None:
print("\n=== 校验示例:缺少 history_windows ===")
builder = CaseBuilder()
try:
builder.build_case(
{
"event_id": "invalid",
"user_id": "demo_user",
"window_data": records[:12],
"user_profile": sample_profile(),
"history_windows": [],
}
)
except ValueError as exc:
print(f"校验失败(预期):{exc}")
def main() -> None:
parser = argparse.ArgumentParser(description="模拟 PatchTrAD → build_case 两种接入模式")
parser.add_argument("--data-file", type=Path, default=DEFAULT_FILE, help="样例 JSONL 文件")
parser.add_argument("--mode", choices=["platform", "official", "all"], default="all")
parser.add_argument("--save-dir", type=Path, default=None, help="如指定则保存 case 示例到该目录")
parser.add_argument("--sample-name", type=str, default=None, help="示例名称(默认取文件名)")
args = parser.parse_args()
records = load_jsonl(args.data_file)
sample_name = args.sample_name or args.data_file.stem
if args.mode in ("platform", "all"):
run_mode_platform(records, sample_name, args.save_dir)
if args.mode in ("official", "all"):
run_mode_official(records, sample_name, args.save_dir)
run_negative_test(records)
if __name__ == "__main__":
main()