Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| 第二波测试:聚焦可能导致 score 全为 0 的边界场景。 | |
| """ | |
| import json | |
| import urllib.request | |
| import urllib.error | |
| import math | |
| API_BASE = "http://127.0.0.1:5001" | |
| # --------------------------------------------------------------------------- | |
| # 边界场景测试 | |
| # --------------------------------------------------------------------------- | |
| edge_cases = [ | |
| # --- 模型不"认识"的 target(低概率 target)--- | |
| ("法国的首都是", "华盛顿"), # 错误的预测 | |
| ("苹果是一种", "水果"), # 正确 | |
| ("苹果是一种", "橘子"), # 错误的预测 | |
| ("1+1=", "3"), # 错误答案 | |
| # --- target 是概率极低的 token(模型根本不会预测这个)--- | |
| ("上海是中国的", "首都"), # 上海不是首都 | |
| ("中国的首都是", "上海"), # 错误城市 | |
| ("太阳从西边", "升起"), # 错误方向 | |
| ("企鹅是一种", "哺乳动物"), # 企鹅不是哺乳动物 | |
| # --- 非常短的输入(可能没有"有效"token)--- | |
| ("A", " B"), | |
| ("I", " am"), | |
| ("是", "的"), | |
| ("好", "人"), | |
| # --- 非常长的输入 --- | |
| ("A" * 100, " B"), | |
| ("测试" * 50, "通"), | |
| # --- 中英混合 --- | |
| ("ML is short for 机器", "学习"), | |
| ("AI代表人工", "智能"), | |
| # --- 多字 target(取首 token,可能语义变化)--- | |
| ("中国的首都是", "北京欢迎你"), | |
| ("1+1=2是", "数学常识"), | |
| # --- 特殊标点 --- | |
| ("Hello!!!", " World"), | |
| ("你好???", "?"), | |
| # --- target_prob 本身就近乎 0 的场景(越少见的 token 越危险)--- | |
| ("床前明月", "的"), # 常见但可能不是 top-1 | |
| ("The quick brown fox jumps over the lazy", " cat"), # 经典句子 | |
| ("天地玄黄 宇宙洪荒 日月盈昃 辰宿列张 寒来暑往 秋收冬藏", "闰"), # 千字文 | |
| # --- 纯英文连续高频词 --- | |
| ("I like to eat", " pizza"), | |
| ("She went to the", " store"), | |
| ("They are going to", " school"), | |
| # --- token 拼接可能导致的问题 --- | |
| ("Big", " Apple"), # "Big Apple" = New York 但分开 tokenize 可能不同 | |
| ("New", " York"), | |
| ("San", " Francisco"), | |
| # --- 和一些肯定能正常工作的对比 --- | |
| ("台湾是中国不可分割的一部分", "中国"), # 政治敏感词 | |
| ("天空是", "蓝"), | |
| ("太阳是", "恒星"), | |
| ] | |
| def test_via_api(context, target_prediction): | |
| payload = { | |
| "context": context, | |
| "model": "base", | |
| "source_page": "attribution", | |
| } | |
| if target_prediction is not None: | |
| payload["target_prediction"] = target_prediction | |
| req = urllib.request.Request( | |
| f"{API_BASE}/api/ablation-attribute", | |
| data=json.dumps(payload).encode("utf-8"), | |
| headers={"Content-Type": "application/json"}, | |
| method="POST", | |
| ) | |
| try: | |
| with urllib.request.urlopen(req, timeout=120) as resp: | |
| return json.loads(resp.read().decode("utf-8")) | |
| except urllib.error.HTTPError as e: | |
| body = e.read().decode("utf-8") | |
| return {"error": f"HTTP {e.code}: {body}", "success": False} | |
| except Exception as e: | |
| return {"error": str(e), "success": False} | |
| def main(): | |
| print(f"🚀 消融归因 — 边界场景测试 ({len(edge_cases)} 个用例)") | |
| print("=" * 80) | |
| zero_score_cases = [] | |
| non_zero_cases = [] | |
| borderline_cases = [] | |
| error_cases = [] | |
| for i, (context, target) in enumerate(edge_cases): | |
| short_ctx = (context[:50] + "...") if len(context) > 50 else context | |
| print(f"\n[{i+1}/{len(edge_cases)}] context={repr(short_ctx)}, target={target}") | |
| r = test_via_api(context, target) | |
| if "error" in r: | |
| print(f" ⚠️ error: {r['error'][:80]}") | |
| error_cases.append((i, context, target, r["error"])) | |
| continue | |
| if "token_attribution" not in r: | |
| print(f" ⚠️ no token_attribution: {json.dumps(r)[:100]}") | |
| error_cases.append((i, context, target, "No token_attribution")) | |
| continue | |
| scores = [e["score"] for e in r["token_attribution"]] | |
| target_prob = r.get("target_prob", "?") | |
| target_token = r.get("target_token", "?") | |
| debug_info = r.get("debug_info", {}) | |
| topk_tokens = debug_info.get("topk_tokens", []) | |
| topk_probs = debug_info.get("topk_probs", []) | |
| non_zero = [s for s in scores if abs(s) > 1e-10] | |
| total = len(scores) | |
| # 判断是否在 top-10 中 | |
| in_top10 = target_token in topk_tokens if topk_tokens else "unknown" | |
| info = { | |
| "idx": i, "context": context, "target": target, | |
| "target_token": target_token, "target_prob": target_prob, | |
| "n_tokens": total, "n_non_zero": len(non_zero), | |
| "max_score": max(scores) if scores else 0, | |
| "min_score": min(scores) if scores else 0, | |
| "all_scores": scores, | |
| "in_top10": in_top10, | |
| "topk_tokens": topk_tokens, | |
| "topk_probs": topk_probs, | |
| } | |
| if len(non_zero) == 0: | |
| print(f" ❌ 全为 0! target_prob={target_prob}, target_token={target_token!r}") | |
| print(f" in_top10={in_top10}, top10={topk_tokens[:5]}") | |
| zero_score_cases.append(info) | |
| elif max(scores) < 1e-5 and target_prob is not None and target_prob < 1e-4: | |
| # 虽然非零但非常小,且 target_prob 也很小(边界情况) | |
| print(f" ⚠️ 边界: scores 极小 (max={max(scores):.3e}), target_prob={target_prob}, target_token={target_token!r}") | |
| print(f" in_top10={in_top10}, top10={topk_tokens[:5]}") | |
| borderline_cases.append(info) | |
| else: | |
| print(f" ✅ {len(non_zero)}/{total} 非零, max={max(scores):.6f}, target_prob={target_prob}") | |
| non_zero_cases.append(info) | |
| # --- 报告 --- | |
| print("\n" + "=" * 80) | |
| print("📊 边界场景测试报告") | |
| print("=" * 80) | |
| print(f"总用例: {len(edge_cases)}") | |
| print(f" ✅ 正常非零: {len(non_zero_cases)}") | |
| print(f" ⚠️ 边界(值极小): {len(borderline_cases)}") | |
| print(f" ❌ 全为 0: {len(zero_score_cases)}") | |
| print(f" ❌ 错误: {len(error_cases)}") | |
| if zero_score_cases: | |
| print("\n" + "-" * 60) | |
| print("🔴 全为 0 的用例:") | |
| for info in zero_score_cases: | |
| print(f" [{info['idx']}] ctx={repr(info['context'][:40])}, target={info['target']}") | |
| print(f" target_token={info['target_token']!r}, target_prob={info['target_prob']}") | |
| print(f" in_top10={info['in_top10']}, topk={info['topk_tokens'][:3]}") | |
| print(f" scores={info['all_scores']}") | |
| if borderline_cases: | |
| print("\n" + "-" * 60) | |
| print("🟡 边界用例(scores 极小):") | |
| for info in borderline_cases: | |
| print(f" [{info['idx']}] ctx={repr(info['context'][:40])}, target={info['target']}") | |
| print(f" target_token={info['target_token']!r}, target_prob={info['target_prob']}") | |
| print(f" in_top10={info['in_top10']}, topk={info['topk_tokens'][:5]}") | |
| print(f" scores={info['all_scores'][:6]}") | |
| # --- 总结分析 --- | |
| print("\n" + "=" * 60) | |
| print("📈 相关性分析") | |
| print("=" * 60) | |
| all_bad = zero_score_cases + borderline_cases | |
| if all_bad: | |
| print(f"共 {len(all_bad)} 个有问题的用例:") | |
| not_in_top10 = [c for c in all_bad if c.get("in_top10") == False or c.get("in_top10") == "unknown"] | |
| prob_low = [c for c in all_bad if c.get("target_prob") is not None and c["target_prob"] < 0.01] | |
| print(f" 不在 top-10 中: {len(not_in_top10)}") | |
| print(f" target_prob < 0.01: {len(prob_low)}") | |
| for c in all_bad: | |
| print(f" [{c['idx']}] target='{c['target']}' → token='{c['target_token']!r}' prob={c['target_prob']:.6e} in_top10={c['in_top10']}") | |
| if __name__ == "__main__": | |
| main() |