| | |
| | """ |
| | simulate_patchad_case_pipeline.py |
| | |
| | 用现有 demo 数据演示两条平台接入路径: |
| | |
| | 1. 平台自带 PatchTrAD/PatchAD:直接调用 build_case。 |
| | 2. 平台依赖官方 precheck:先请求预筛接口,再补传数据生成 case。 |
| | |
| | 脚本会输出: |
| | - 预筛得分与状态 |
| | - build_case 返回的 case 结构摘录 |
| | - 生成的 Markdown 片段 |
| | - 验证结果 |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import json |
| | from pathlib import Path |
| | from typing import Any, Dict, List |
| |
|
| | from utils.case_builder import CaseBuilder |
| | from utils.patchad_filter import PatchADFilter, PrecheckServer |
| |
|
| | DATA_DIR = Path(__file__).parent / "data_storage" / "users" |
| | DEFAULT_FILE = DATA_DIR / "demo_pattern.jsonl" |
| |
|
| |
|
| | def load_jsonl(path: Path) -> List[Dict[str, Any]]: |
| | records = [] |
| | with open(path, "r", encoding="utf-8") as f: |
| | for line in f: |
| | if line.strip(): |
| | records.append(json.loads(line)) |
| | return records |
| |
|
| |
|
| | def split_windows(records: List[Dict[str, Any]], days: int = 3) -> List[Dict[str, Any]]: |
| | by_date: Dict[str, List[Dict[str, Any]]] = {} |
| | for rec in records: |
| | ts = rec.get("timestamp", "") |
| | date = ts[:10] |
| | by_date.setdefault(date, []).append(rec) |
| | windows = [] |
| | sorted_dates = sorted(by_date.keys()) |
| | for date in sorted_dates: |
| | day_records = by_date[date][:12] |
| | if len(day_records) >= 12: |
| | windows.append({"date": date, "window": day_records}) |
| | if len(windows) >= days: |
| | break |
| |
|
| | if windows and 0 < len(windows) < days: |
| | print(f"⚠️ 历史窗口不足 {days} 天,自动复制最后一天填充示例。") |
| | last = windows[-1] |
| | for idx in range(days - len(windows)): |
| | clone = { |
| | "date": f"{last['date']}+copy{idx+1}", |
| | "window": last["window"], |
| | } |
| | windows.append(clone) |
| | return windows |
| |
|
| |
|
| | def sample_profile() -> Dict[str, Any]: |
| | return { |
| | "age_group": "30-35岁", |
| | "estimated_age": 33, |
| | "sex": "男性", |
| | "exercise": "每周5次以上", |
| | "coffee": "偶尔", |
| | "smoking": "不吸烟", |
| | "drinking": "经常饮酒", |
| | "MEQ": 62.0, |
| | "MEQ_type": "晨型", |
| | } |
| |
|
| |
|
| | def save_demo_case(save_dir: Path, sample_name: str, mode: str, bundle: Dict[str, Any], precheck: Dict[str, Any]) -> None: |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| | file_path = save_dir / f"{sample_name}_{mode}.json" |
| | payload = { |
| | "mode": mode, |
| | "sample": sample_name, |
| | "precheck": precheck, |
| | "case_bundle": bundle, |
| | } |
| | with open(file_path, "w", encoding="utf-8") as f: |
| | json.dump(payload, f, ensure_ascii=False, indent=2) |
| |
|
| | manifest_path = save_dir / "manifest.json" |
| | manifest = [] |
| | if manifest_path.exists(): |
| | with open(manifest_path, "r", encoding="utf-8") as f: |
| | manifest = json.load(f) |
| | manifest = [item for item in manifest if not (item.get("mode") == mode and item.get("sample") == sample_name)] |
| | title = f"{sample_name} · {'模式A:平台自带' if mode == 'platform' else '模式B:官方预筛'}" |
| | manifest.append({"sample": sample_name, "mode": mode, "title": title, "file": file_path.name}) |
| | with open(manifest_path, "w", encoding="utf-8") as f: |
| | json.dump(manifest, f, ensure_ascii=False, indent=2) |
| |
|
| |
|
| | def run_mode_platform(records: List[Dict[str, Any]], sample_name: str, save_dir: Path | None) -> Dict[str, Any] | None: |
| | print("\n=== 模式A:平台自带 PatchTrAD → 直接调用 build_case ===") |
| | windows = split_windows(records, days=3) |
| | if len(windows) < 2: |
| | raise RuntimeError("样例数据不足以生成历史窗口") |
| | current_window = windows[-1]["window"] |
| | patchad = PatchADFilter(score_threshold=0.18) |
| | score = patchad.score_window(current_window) |
| | status = "abnormal" if score >= patchad.threshold else "normal" |
| | print(f"预筛得分: {score:.4f} (threshold={patchad.threshold}) -> {status}") |
| | if status != "abnormal": |
| | print("窗口未触发异常,示例终止。") |
| | return None |
| |
|
| | payload = { |
| | "event_id": "platform-demo-evt", |
| | "user_id": current_window[0].get("deviceId", "demo_user"), |
| | "window_data": current_window, |
| | "user_profile": sample_profile(), |
| | "history_windows": windows, |
| | "metadata": {"detector": "platform_patchad", "score": score}, |
| | } |
| |
|
| | builder = CaseBuilder() |
| | result = builder.build_case(payload) |
| | print(f"验证结果: {result['validation']}") |
| | print(f"Case ID: {result['case']['case_id']}") |
| | print(f"LLM 输入预览:\n{result['llm_input'][:400]}...\n") |
| | if save_dir: |
| | precheck_info = {"status": status, "score": score, "threshold": patchad.threshold} |
| | save_demo_case(save_dir, sample_name, "platform", result, precheck_info) |
| | print(f"✅ 已保存示例到 {save_dir}/{sample_name}_platform.json") |
| | return result |
| |
|
| |
|
| | def run_mode_official(records: List[Dict[str, Any]], sample_name: str, save_dir: Path | None) -> Dict[str, Any] | None: |
| | print("\n=== 模式B:官方 precheck + build_case ===") |
| | print("步骤1:平台传入当前窗口到 precheck 接口") |
| | current_window = records[-12:] if len(records) >= 12 else records |
| | user_id = current_window[0].get("deviceId", "demo_user") if current_window else "demo_user" |
| |
|
| | server = PrecheckServer(score_threshold=0.18) |
| | pre_result = server.precheck(user_id, current_window) |
| | print(f"预筛接口返回: {pre_result}") |
| | if pre_result["status"] != "abnormal": |
| | print("预筛未触发异常,示例终止。") |
| | return None |
| |
|
| | print(f"\n步骤2:平台收到 event_id={pre_result['event_id']},平台内部合成历史窗口数据") |
| | |
| | windows = split_windows(records, days=3) |
| | if len(windows) < 2: |
| | raise RuntimeError("样例数据不足以生成历史窗口") |
| | print(f"平台合成历史窗口: {len(windows)} 天") |
| |
|
| | print(f"\n步骤3:平台调用 build_case,传入 event_id 和合成的历史窗口(不传 window_data,系统从缓存获取)") |
| | payload = { |
| | "event_id": pre_result["event_id"], |
| | "user_id": user_id, |
| | "user_profile": sample_profile(), |
| | "history_windows": windows, |
| | "metadata": { |
| | "detector": "official_patchad", |
| | "patchad_score": pre_result["score"], |
| | "threshold": pre_result["threshold"], |
| | }, |
| | } |
| |
|
| | builder = CaseBuilder(precheck_server=server) |
| | result = builder.build_case(payload) |
| | print(f"验证结果: {result['validation']}") |
| | print(f"Case ID: {result['case']['case_id']}") |
| | print(f"LLM 输入预览:\n{result['llm_input'][:400]}...\n") |
| | if save_dir: |
| | precheck_info = pre_result |
| | save_demo_case(save_dir, sample_name, "official", result, precheck_info) |
| | print(f"✅ 已保存示例到 {save_dir}/{sample_name}_official.json") |
| | return result |
| |
|
| |
|
| | def run_negative_test(records: List[Dict[str, Any]]) -> None: |
| | print("\n=== 校验示例:缺少 history_windows ===") |
| | builder = CaseBuilder() |
| | try: |
| | builder.build_case( |
| | { |
| | "event_id": "invalid", |
| | "user_id": "demo_user", |
| | "window_data": records[:12], |
| | "user_profile": sample_profile(), |
| | "history_windows": [], |
| | } |
| | ) |
| | except ValueError as exc: |
| | print(f"校验失败(预期):{exc}") |
| |
|
| |
|
| | def main() -> None: |
| | parser = argparse.ArgumentParser(description="模拟 PatchTrAD → build_case 两种接入模式") |
| | parser.add_argument("--data-file", type=Path, default=DEFAULT_FILE, help="样例 JSONL 文件") |
| | parser.add_argument("--mode", choices=["platform", "official", "all"], default="all") |
| | parser.add_argument("--save-dir", type=Path, default=None, help="如指定则保存 case 示例到该目录") |
| | parser.add_argument("--sample-name", type=str, default=None, help="示例名称(默认取文件名)") |
| | args = parser.parse_args() |
| |
|
| | records = load_jsonl(args.data_file) |
| | sample_name = args.sample_name or args.data_file.stem |
| | if args.mode in ("platform", "all"): |
| | run_mode_platform(records, sample_name, args.save_dir) |
| | if args.mode in ("official", "all"): |
| | run_mode_official(records, sample_name, args.save_dir) |
| | run_negative_test(records) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|