Spaces:
Runtime error
Runtime error
File size: 8,634 Bytes
2df1044 23bb099 2df1044 23bb099 2df1044 23bb099 2df1044 23bb099 2df1044 23bb099 2df1044 23bb099 2df1044 23bb099 2df1044 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 | #!/usr/bin/env python3
"""
simulate_patchad_case_pipeline.py
用现有 demo 数据演示两条平台接入路径:
1. 平台自带 PatchTrAD/PatchAD:直接调用 build_case。
2. 平台依赖官方 precheck:先请求预筛接口,再补传数据生成 case。
脚本会输出:
- 预筛得分与状态
- build_case 返回的 case 结构摘录
- 生成的 Markdown 片段
- 验证结果
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List
from utils.case_builder import CaseBuilder
from utils.patchad_filter import PatchADFilter, PrecheckServer
DATA_DIR = Path(__file__).parent / "data_storage" / "users"
DEFAULT_FILE = DATA_DIR / "demo_pattern.jsonl"
def load_jsonl(path: Path) -> List[Dict[str, Any]]:
records = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
records.append(json.loads(line))
return records
def split_windows(records: List[Dict[str, Any]], days: int = 3) -> List[Dict[str, Any]]:
by_date: Dict[str, List[Dict[str, Any]]] = {}
for rec in records:
ts = rec.get("timestamp", "")
date = ts[:10]
by_date.setdefault(date, []).append(rec)
windows = []
sorted_dates = sorted(by_date.keys())
for date in sorted_dates:
day_records = by_date[date][:12]
if len(day_records) >= 12:
windows.append({"date": date, "window": day_records})
if len(windows) >= days:
break
if windows and 0 < len(windows) < days:
print(f"⚠️ 历史窗口不足 {days} 天,自动复制最后一天填充示例。")
last = windows[-1]
for idx in range(days - len(windows)):
clone = {
"date": f"{last['date']}+copy{idx+1}",
"window": last["window"],
}
windows.append(clone)
return windows
def sample_profile() -> Dict[str, Any]:
return {
"age_group": "30-35岁",
"estimated_age": 33,
"sex": "男性",
"exercise": "每周5次以上",
"coffee": "偶尔",
"smoking": "不吸烟",
"drinking": "经常饮酒",
"MEQ": 62.0,
"MEQ_type": "晨型",
}
def save_demo_case(save_dir: Path, sample_name: str, mode: str, bundle: Dict[str, Any], precheck: Dict[str, Any]) -> None:
save_dir.mkdir(parents=True, exist_ok=True)
file_path = save_dir / f"{sample_name}_{mode}.json"
payload = {
"mode": mode,
"sample": sample_name,
"precheck": precheck,
"case_bundle": bundle,
}
with open(file_path, "w", encoding="utf-8") as f:
json.dump(payload, f, ensure_ascii=False, indent=2)
manifest_path = save_dir / "manifest.json"
manifest = []
if manifest_path.exists():
with open(manifest_path, "r", encoding="utf-8") as f:
manifest = json.load(f)
manifest = [item for item in manifest if not (item.get("mode") == mode and item.get("sample") == sample_name)]
title = f"{sample_name} · {'模式A:平台自带' if mode == 'platform' else '模式B:官方预筛'}"
manifest.append({"sample": sample_name, "mode": mode, "title": title, "file": file_path.name})
with open(manifest_path, "w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=False, indent=2)
def run_mode_platform(records: List[Dict[str, Any]], sample_name: str, save_dir: Path | None) -> Dict[str, Any] | None:
print("\n=== 模式A:平台自带 PatchTrAD → 直接调用 build_case ===")
windows = split_windows(records, days=3)
if len(windows) < 2:
raise RuntimeError("样例数据不足以生成历史窗口")
current_window = windows[-1]["window"]
patchad = PatchADFilter(score_threshold=0.18)
score = patchad.score_window(current_window)
status = "abnormal" if score >= patchad.threshold else "normal"
print(f"预筛得分: {score:.4f} (threshold={patchad.threshold}) -> {status}")
if status != "abnormal":
print("窗口未触发异常,示例终止。")
return None
payload = {
"event_id": "platform-demo-evt",
"user_id": current_window[0].get("deviceId", "demo_user"),
"window_data": current_window,
"user_profile": sample_profile(),
"history_windows": windows,
"metadata": {"detector": "platform_patchad", "score": score},
}
builder = CaseBuilder()
result = builder.build_case(payload)
print(f"验证结果: {result['validation']}")
print(f"Case ID: {result['case']['case_id']}")
print(f"LLM 输入预览:\n{result['llm_input'][:400]}...\n")
if save_dir:
precheck_info = {"status": status, "score": score, "threshold": patchad.threshold}
save_demo_case(save_dir, sample_name, "platform", result, precheck_info)
print(f"✅ 已保存示例到 {save_dir}/{sample_name}_platform.json")
return result
def run_mode_official(records: List[Dict[str, Any]], sample_name: str, save_dir: Path | None) -> Dict[str, Any] | None:
print("\n=== 模式B:官方 precheck + build_case ===")
print("步骤1:平台传入当前窗口到 precheck 接口")
current_window = records[-12:] if len(records) >= 12 else records
user_id = current_window[0].get("deviceId", "demo_user") if current_window else "demo_user"
server = PrecheckServer(score_threshold=0.18)
pre_result = server.precheck(user_id, current_window)
print(f"预筛接口返回: {pre_result}")
if pre_result["status"] != "abnormal":
print("预筛未触发异常,示例终止。")
return None
print(f"\n步骤2:平台收到 event_id={pre_result['event_id']},平台内部合成历史窗口数据")
# 模拟平台自己合成历史窗口(从历史数据平台获取)
windows = split_windows(records, days=3)
if len(windows) < 2:
raise RuntimeError("样例数据不足以生成历史窗口")
print(f"平台合成历史窗口: {len(windows)} 天")
print(f"\n步骤3:平台调用 build_case,传入 event_id 和合成的历史窗口(不传 window_data,系统从缓存获取)")
payload = {
"event_id": pre_result["event_id"], # 关键:只传 event_id,不传 window_data
"user_id": user_id,
"user_profile": sample_profile(),
"history_windows": windows, # 平台自己合成的历史窗口
"metadata": {
"detector": "official_patchad",
"patchad_score": pre_result["score"],
"threshold": pre_result["threshold"],
},
}
builder = CaseBuilder(precheck_server=server) # 传入 PrecheckServer 实例
result = builder.build_case(payload)
print(f"验证结果: {result['validation']}")
print(f"Case ID: {result['case']['case_id']}")
print(f"LLM 输入预览:\n{result['llm_input'][:400]}...\n")
if save_dir:
precheck_info = pre_result
save_demo_case(save_dir, sample_name, "official", result, precheck_info)
print(f"✅ 已保存示例到 {save_dir}/{sample_name}_official.json")
return result
def run_negative_test(records: List[Dict[str, Any]]) -> None:
print("\n=== 校验示例:缺少 history_windows ===")
builder = CaseBuilder()
try:
builder.build_case(
{
"event_id": "invalid",
"user_id": "demo_user",
"window_data": records[:12],
"user_profile": sample_profile(),
"history_windows": [],
}
)
except ValueError as exc:
print(f"校验失败(预期):{exc}")
def main() -> None:
parser = argparse.ArgumentParser(description="模拟 PatchTrAD → build_case 两种接入模式")
parser.add_argument("--data-file", type=Path, default=DEFAULT_FILE, help="样例 JSONL 文件")
parser.add_argument("--mode", choices=["platform", "official", "all"], default="all")
parser.add_argument("--save-dir", type=Path, default=None, help="如指定则保存 case 示例到该目录")
parser.add_argument("--sample-name", type=str, default=None, help="示例名称(默认取文件名)")
args = parser.parse_args()
records = load_jsonl(args.data_file)
sample_name = args.sample_name or args.data_file.stem
if args.mode in ("platform", "all"):
run_mode_platform(records, sample_name, args.save_dir)
if args.mode in ("official", "all"):
run_mode_official(records, sample_name, args.save_dir)
run_negative_test(records)
if __name__ == "__main__":
main()
|