TokenTrace / scripts /test_ablation_massive.py
cccmmd
init: TokenTrace - LLM interpretability toolbox
7985065
Raw
History Blame Contribute Delete
12.8 kB
#!/usr/bin/env python3
"""
大规模测试消融归因:测试几十个不同场景的 case,观察是否出现"全部为 0"的情况。
"""
import json
import sys
import os
import urllib.request
import urllib.error
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
API_BASE = "http://127.0.0.1:5001"
# ---------------------------------------------------------------------------
# Test cases: diverse scenarios
# ---------------------------------------------------------------------------
test_cases = [
# --- 基础事实 ---
("中国的首都是", "北京"),
("法国的首都是", "巴黎"),
("地球绕着什么转", "太阳"),
("1+1=", "2"),
("水的化学式是", "H2O"),
("太阳从什么方向升起", "东"),
# --- 单 token 目标 ---
("The capital of France is", " Paris"),
("The color of the sky is", " blue"),
("Apple's founder was", " Steve"),
("Python is a programming", " language"),
# --- 否定 / 对比 ---
("猫不属于", "鱼类"),
("企鹅虽然不会飞,但属于", "鸟类"),
("鲸鱼生活在水中,但它不是鱼,而是", "哺乳动物"),
# --- 逻辑推理 ---
("如果今天是周一,那么明天是", "周二"),
("一年有十二个月,一个月大约有30", "天"),
("人喝水是为了", "解渴"),
("汽车没有油就", "无法"),
# --- 中文长句子 ---
("人工智能是计算机科学的一个分支,它企图了解智能的实质,并生产出一种新的能以人类智能相似的方式做出反应的智能机器。该领域的研究包括", "机器人"),
("机器学习是人工智能的子领域,它专注于让计算机通过经验自动改善性能。监督学习是最常见的机器学习范式之一,它使用", "标注"),
("深度学习是机器学习的一个分支,它使用多层神经网络来学习数据的表示。卷积神经网络特别适合处理", "图像"),
# --- 英文长句子 ---
("Artificial intelligence is a branch of computer science that aims to create machines that can perform tasks that typically require human intelligence. One of the most popular applications is", " machine"),
("Natural language processing is a subfield of AI that focuses on the interaction between computers and humans through natural language. A key challenge is", " understanding"),
("The transformer architecture introduced in 'Attention is All You Need' has become the foundation of most modern language models. It relies primarily on", " attention"),
# --- 重复结构 / 模板 ---
("I think therefore I", " am"),
("To be or not to", " be"),
("Hello world is the first program every", " programmer"),
# --- 数字 / 数学 ---
("The square root of 9 is", " 3"),
("2 multiplied by 6 equals", " 12"),
("The year after 2023 is", " 2024"),
# --- 知识问答 ---
("爱因斯坦提出了", "相对论"),
("牛顿发现了", "万有引力"),
("莎士比亚是著名的", "戏剧"),
("莫扎特是一位", "作曲"),
# --- 短文本 / 单字目标 ---
("天对地,雨对风,大陆对长", "空"),
("床前明月", "光"),
# --- ambiguous / 开放 ---
("明天会不会下雨取决于", "天气"),
("要成功就必须付出", "努力"),
# --- top-1(不指定 target)的自动模式 ---
("我今天中午吃了", None), # top-1
("The best thing about summer is", None), # top-1
("在我看来,最重要的品质是", None), # top-1
]
def test_via_api(context, target_prediction):
"""通过 HTTP API 测试"""
payload = {
"context": context,
"model": "base",
"source_page": "attribution",
}
if target_prediction is not None:
payload["target_prediction"] = target_prediction
req = urllib.request.Request(
f"{API_BASE}/api/ablation-attribute",
data=json.dumps(payload).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=120) as resp:
data = json.loads(resp.read().decode("utf-8"))
return data
except urllib.error.HTTPError as e:
body = e.read().decode("utf-8")
return {"error": f"HTTP {e.code}: {body}", "success": False}
except Exception as e:
return {"error": str(e), "success": False}
def analyze_results(results):
"""分析测试结果,找出全为 0 的情况"""
zero_score_cases = []
non_zero_cases = [] # has at least one non-zero score
error_cases = []
for i, (case, r) in enumerate(results):
context, target = case
if "error" in r:
error_cases.append((i, context, target, r["error"]))
continue
if "token_attribution" not in r:
error_cases.append((i, context, target, "No token_attribution in response: " + json.dumps(r)[:200]))
continue
scores = [e["score"] for e in r["token_attribution"]]
non_zero_scores = [s for s in scores if abs(s) > 1e-10]
info = {
"case_idx": i,
"context": repr((context[:50] + "...") if len(context) > 50 else context),
"target": target,
"target_token": r.get("target_token", "?"),
"target_prob": r.get("target_prob"),
"n_tokens": len(r["token_attribution"]),
"n_non_zero": len(non_zero_scores),
"max_score": max(scores) if scores else 0,
"min_score": min(scores) if scores else 0,
"all_scores": scores,
"all_deltas": [e.get("delta_logit", "N/A") for e in r["token_attribution"]],
}
if len(non_zero_scores) == 0:
zero_score_cases.append(info)
else:
non_zero_cases.append(info)
return {
"zero_score_cases": zero_score_cases,
"non_zero_cases": non_zero_cases,
"error_cases": error_cases,
"total": len(results),
"zero_count": len(zero_score_cases),
"non_zero_count": len(non_zero_cases),
"error_count": len(error_cases),
}
def print_analysis(stats):
print("\n" + "=" * 80)
print("📊 消融归因大规模测试报告")
print("=" * 80)
print(f"总测试用例: {stats['total']}")
print(f" ✅ 有非零分: {stats['non_zero_count']}")
print(f" ❌ 全为 0: {stats['zero_count']}")
print(f" ⚠️ 错误: {stats['error_count']}")
print()
if stats['error_cases']:
print("--- 错误用例 ---")
for idx, ctx, tgt, err in stats['error_cases']:
print(f" [{idx}] context={repr(str(ctx)[:40])}, target={tgt}")
print(f" error: {err}")
print()
if stats['zero_score_cases']:
print("--- 🚨 全为 0 的用例 ---")
for info in stats['zero_score_cases']:
print(f" [{info['case_idx']}] context={info['context']}")
print(f" target={info['target']}, target_token={info['target_token']!r}")
print(f" target_prob={info['target_prob']}, n_tokens={info['n_tokens']}")
print(f" scores: {info['all_scores']}")
# Check delta_logits
deltas = [d for d in info['all_deltas'] if d != "N/A"]
if deltas:
non_zero_deltas = [d for d in deltas if abs(d) > 1e-10]
print(f" delta_logits: {info['all_deltas'][:10]}...")
if non_zero_deltas:
print(f" ⚠️ scores 全 0,但 delta_logits 有非零值! count={len(non_zero_deltas)}")
else:
print(f" delta_logits 也全 0")
print()
if stats['non_zero_cases']:
print("--- ✅ 有非零分的用例(前 8 个)---")
for info in stats['non_zero_cases'][:8]:
non_zero_ratio = f"{info['n_non_zero']}/{info['n_tokens']}"
print(f" [{info['case_idx']}] context={info['context']}")
print(f" target={info['target']}, non_zero_ratio={non_zero_ratio}")
print(f" max_score={info['max_score']:.6f}, min_score={info['min_score']:.4f}, target_prob={info['target_prob']}")
print(f" scores={info['all_scores']}")
print()
def main():
print(f"🚀 启动消融归因大规模测试 ({len(test_cases)} 个用例)")
print(f"目标服务器: {API_BASE}")
print()
# 先检查服务器是否可达
try:
with urllib.request.urlopen(f"{API_BASE}/", timeout=5) as _:
print("✅ 服务器可达")
except Exception as e:
print(f"⚠️ 服务器不可达 ({e}),继续尝试...")
print()
results = []
for i, (context, target) in enumerate(test_cases):
short_ctx = (context[:40] + "...") if len(context) > 40 else context
print(f"[{i+1}/{len(test_cases)}] testing: context={repr(short_ctx)}, target={target}")
try:
r = test_via_api(context, target)
results.append(((context, target), r))
# 检查是否有 score
if "token_attribution" in r:
scores = [e["score"] for e in r["token_attribution"]]
non_zero = sum(1 for s in scores if abs(s) > 1e-10)
total = len(scores)
if non_zero == 0:
print(f" ❌ 全为 0! target_prob={r.get('target_prob')}, target_token={r.get('target_token')!r}")
else:
max_s = max(scores)
print(f" ✅ {non_zero}/{total} 非零, max_score={max_s:.6f}")
elif "error" in r:
print(f" ⚠️ error: {r['error'][:80]}")
except Exception as e:
print(f" ❌ 异常: {e}")
results.append(((context, target), {"error": str(e)}))
stats = analyze_results(results)
print_analysis(stats)
# 深入分析全零根因
if stats['zero_count'] > 0:
first_zero = stats['zero_score_cases'][0]
ctx = test_cases[first_zero['case_idx']][0]
tgt = test_cases[first_zero['case_idx']][1]
print("=" * 80)
print("🔍 对第一个全零 case 做深度诊断")
print("=" * 80)
diagnose_ablation(ctx, tgt)
def diagnose_ablation(context, target_prediction=None):
"""对单个 case 做详细诊断"""
payload = {
"context": context,
"model": "base",
"source_page": "attribution",
}
if target_prediction is not None:
payload["target_prediction"] = target_prediction
req = urllib.request.Request(
f"{API_BASE}/api/ablation-attribute",
data=json.dumps(payload).encode("utf-8"),
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=120) as resp:
data = json.loads(resp.read().decode("utf-8"))
except Exception as e:
print(f" API 调用失败: {e}")
return
if "token_attribution" in data:
print(f" token_attribution 条目数: {len(data['token_attribution'])}")
print(f" target_prob: {data['target_prob']}")
print(f" target_token: {data['target_token']!r}")
print(f" debug_info: {data.get('debug_info')}")
scores = [e["score"] for e in data["token_attribution"]]
deltas = [e.get("delta_logit", "N/A") for e in data["token_attribution"]]
raws = [e["raw"] for e in data["token_attribution"]]
print(f" scores: {scores}")
print(f" delta_logits: {deltas}")
print(f" raws: {raws}")
# 如果 scores 全 0 但 delta_logit 非全 0,说明 prob 变化在 round_to_sig_figs 之下
if all(abs(s) < 1e-10 for s in scores):
numeric_deltas = [d for d in deltas if isinstance(d, (int, float))]
non_zero_deltas = [d for d in numeric_deltas if abs(d) > 1e-10]
if non_zero_deltas:
print(f"\n 🔑 关键发现: scores 全部 round 到 0,但 delta_logits 有非零值!")
print(f" 说明模型对遮挡确实有反应,但概率变化太小 (< 5e-8)")
print(f" 非零 deltas 示例: {non_zero_deltas[:10]}")
else:
print(f"\n 🔑 delta_logits 也是全 0 (或 {len(numeric_deltas)}/{len(deltas)} 为数值型),说明遮挡完全不影响预测")
print(f"\n 📝 debug_info topk_tokens: {data.get('debug_info', {}).get('topk_tokens', [])}")
print(f" 📝 debug_info topk_probs: {data.get('debug_info', {}).get('topk_probs', [])}")
if __name__ == "__main__":
main()