abpt / scripts /run_qwen_future_influence_probe.py
Search
auto: sync run_qwen_future_influence_probe.py
2d32784
from __future__ import annotations
import argparse
from dataclasses import replace
from datetime import datetime, timezone
import json
from pathlib import Path
import sys
from typing import Any
import torch
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from src.data.qwen_probe_cases import make_qwen_probe_cases
from src.model.config import TOY_CONFIG
from src.model.qwen_anchor_overlay import QwenAnchorOverlay
from src.model.future_span_hints import (
build_future_hint_candidates,
compute_span_anchor_overlap,
extract_high_influence_spans,
safe_decode_token,
)
def collect_case_result(
overlay: QwenAnchorOverlay,
case_name: str,
case_family: str,
case_description: str,
case_prompt: str,
expected_mode: str,
max_length: int,
future_window: int,
top_k: int,
span_threshold: float,
top_spans: int,
) -> dict[str, Any]:
out, batch = overlay.analyze_texts_with_future_influence(
[case_prompt],
max_length=max_length,
future_window=future_window,
)
diag = out["anchor_diagnostics"]
influence = out["future_influence"]
aux_diag = out["auxiliary_proposal_diagnostics"]
aux_revision_diag = out["auxiliary_revision_diagnostics"]
scores = influence["scores"][0]
input_ids = batch["input_ids"][0]
valid_len = int(batch["attention_mask"][0].sum().item()) if "attention_mask" in batch else int(input_ids.numel())
trimmed_scores = scores[:valid_len]
trimmed_ids = input_ids[:valid_len]
k = min(top_k, valid_len)
top_values, top_indices = torch.topk(trimmed_scores, k=k)
top_tokens = [
{
"position": int(pos.item()),
"token_id": int(trimmed_ids[pos].item()),
"token_text": safe_decode_token(overlay.tokenizer, int(trimmed_ids[pos].item())),
"score": float(val.item()),
}
for val, pos in zip(top_values, top_indices)
]
active_anchor_spans = [
{
"start": max(0, min(int(anchor.start_idx), valid_len - 1)),
"end": max(0, min(int(anchor.end_idx), valid_len - 1)),
}
for anchor in out["active_anchors"][0]
if valid_len > 0
]
anchor_positions = sorted({span["end"] for span in active_anchor_spans})
anchor_scores = [float(trimmed_scores[pos].item()) for pos in anchor_positions]
future_spans = extract_high_influence_spans(
scores=trimmed_scores,
input_ids=trimmed_ids,
tokenizer=overlay.tokenizer,
min_score=span_threshold,
top_spans=top_spans,
)
future_hint_candidates = build_future_hint_candidates(future_spans, active_anchor_spans)
overlap = compute_span_anchor_overlap(future_spans, active_anchor_spans)
auxiliary_proposals = [
{
"proposal_type": item["proposal_type"],
"proposal_score": float(item["proposal_score"]),
"proposal_span": (
int(item["proposal_span"][0]),
int(item["proposal_span"][1]),
),
"proposal_root_token": item["proposal_root_token"],
"proposal_text": item["proposal_text"],
}
for item in out["auxiliary_proposal_batches"][0]
]
return {
"name": case_name,
"family": case_family,
"description": case_description,
"expected_mode": expected_mode,
"tokens": valid_len,
"num_active": int(diag["num_active"]),
"mean_contradiction_pressure": float(diag["mean_contradiction_pressure"]),
"mean_viability": float(diag["mean_viability"]),
"future_loss": float(influence["loss"]),
"future_window": int(influence["target_window"]),
"mean_future_influence": float(trimmed_scores.mean().item()),
"max_future_influence": float(trimmed_scores.max().item()),
"anchor_position_mean_future_influence": (
sum(anchor_scores) / len(anchor_scores) if anchor_scores else 0.0
),
"anchor_positions": anchor_positions,
"active_anchor_spans": active_anchor_spans,
"future_spans": future_spans,
"future_hint_candidates": future_hint_candidates,
"auxiliary_proposals": auxiliary_proposals,
"auxiliary_proposal_count": int(len(auxiliary_proposals)),
"auxiliary_mean_proposal_score": (
sum(float(item["proposal_score"]) for item in auxiliary_proposals)
/ max(len(auxiliary_proposals), 1)
if auxiliary_proposals
else 0.0
),
"auxiliary_batch_mean_score": float(aux_diag["mean_proposal_score"]),
"auxiliary_revision_matched_anchor_count": int(aux_revision_diag["matched_anchor_count"]),
"auxiliary_revision_mean_alt_prob": float(aux_revision_diag["mean_alt_prob"]),
"auxiliary_revision_mean_matched_proposal_score": float(aux_revision_diag["mean_matched_proposal_score"]),
"auxiliary_revision_base_revise_count": int(aux_revision_diag["base_revise_count"]),
"auxiliary_revision_revise_count": int(aux_revision_diag["auxiliary_revise_count"]),
"auxiliary_revision_revise_gain": int(aux_revision_diag["auxiliary_revise_gain"]),
"auxiliary_revision_base_retire_count": int(aux_revision_diag["base_retire_count"]),
"auxiliary_revision_retire_count": int(aux_revision_diag["auxiliary_retire_count"]),
"auxiliary_revision_retire_delta": int(aux_revision_diag["auxiliary_retire_delta"]),
**overlap,
"top_future_tokens": top_tokens,
}
def summarize_results(results: list[dict[str, Any]]) -> dict[str, Any]:
summary: dict[str, Any] = {
"case_count": len(results),
"stable_count": sum(1 for item in results if item["expected_mode"] == "stable"),
"conflict_count": sum(1 for item in results if item["expected_mode"] == "conflict"),
}
for mode in ("stable", "conflict"):
subset = [item for item in results if item["expected_mode"] == mode]
if not subset:
continue
summary[f"{mode}_mean_future_influence"] = sum(item["mean_future_influence"] for item in subset) / len(subset)
summary[f"{mode}_mean_anchor_future_influence"] = (
sum(item["anchor_position_mean_future_influence"] for item in subset) / len(subset)
)
summary[f"{mode}_mean_future_span_overlap"] = (
sum(item.get("future_span_overlap_ratio", 0.0) for item in subset) / len(subset)
)
summary[f"{mode}_mean_anchor_span_overlap"] = (
sum(item.get("anchor_span_overlap_ratio", 0.0) for item in subset) / len(subset)
)
summary[f"{mode}_mean_auxiliary_proposal_count"] = (
sum(item.get("auxiliary_proposal_count", 0) for item in subset) / len(subset)
)
summary[f"{mode}_mean_auxiliary_proposal_score"] = (
sum(item.get("auxiliary_mean_proposal_score", 0.0) for item in subset) / len(subset)
)
summary[f"{mode}_mean_auxiliary_revision_matched_anchor_count"] = (
sum(item.get("auxiliary_revision_matched_anchor_count", 0) for item in subset) / len(subset)
)
summary[f"{mode}_mean_auxiliary_revision_alt_prob"] = (
sum(item.get("auxiliary_revision_mean_alt_prob", 0.0) for item in subset) / len(subset)
)
summary[f"{mode}_mean_auxiliary_revision_revise_gain"] = (
sum(item.get("auxiliary_revision_revise_gain", 0) for item in subset) / len(subset)
)
summary[f"{mode}_mean_auxiliary_revision_retire_delta"] = (
sum(item.get("auxiliary_revision_retire_delta", 0) for item in subset) / len(subset)
)
summary[f"{mode}_mean_future_loss"] = sum(item["future_loss"] for item in subset) / len(subset)
if "stable_mean_future_influence" in summary and "conflict_mean_future_influence" in summary:
summary["future_influence_gap_conflict_minus_stable"] = (
summary["conflict_mean_future_influence"] - summary["stable_mean_future_influence"]
)
if "stable_mean_anchor_future_influence" in summary and "conflict_mean_anchor_future_influence" in summary:
summary["anchor_future_influence_gap_conflict_minus_stable"] = (
summary["conflict_mean_anchor_future_influence"] - summary["stable_mean_anchor_future_influence"]
)
if "stable_mean_future_span_overlap" in summary and "conflict_mean_future_span_overlap" in summary:
summary["future_span_overlap_gap_conflict_minus_stable"] = (
summary["conflict_mean_future_span_overlap"] - summary["stable_mean_future_span_overlap"]
)
if "stable_mean_auxiliary_proposal_count" in summary and "conflict_mean_auxiliary_proposal_count" in summary:
summary["auxiliary_proposal_count_gap_conflict_minus_stable"] = (
summary["conflict_mean_auxiliary_proposal_count"] - summary["stable_mean_auxiliary_proposal_count"]
)
if "stable_mean_auxiliary_proposal_score" in summary and "conflict_mean_auxiliary_proposal_score" in summary:
summary["auxiliary_proposal_score_gap_conflict_minus_stable"] = (
summary["conflict_mean_auxiliary_proposal_score"] - summary["stable_mean_auxiliary_proposal_score"]
)
if (
"stable_mean_auxiliary_revision_matched_anchor_count" in summary
and "conflict_mean_auxiliary_revision_matched_anchor_count" in summary
):
summary["auxiliary_revision_match_gap_conflict_minus_stable"] = (
summary["conflict_mean_auxiliary_revision_matched_anchor_count"]
- summary["stable_mean_auxiliary_revision_matched_anchor_count"]
)
if (
"stable_mean_auxiliary_revision_revise_gain" in summary
and "conflict_mean_auxiliary_revision_revise_gain" in summary
):
summary["auxiliary_revision_revise_gain_conflict_minus_stable"] = (
summary["conflict_mean_auxiliary_revision_revise_gain"]
- summary["stable_mean_auxiliary_revision_revise_gain"]
)
if (
"stable_mean_auxiliary_revision_retire_delta" in summary
and "conflict_mean_auxiliary_revision_retire_delta" in summary
):
summary["auxiliary_revision_retire_delta_conflict_minus_stable"] = (
summary["conflict_mean_auxiliary_revision_retire_delta"]
- summary["stable_mean_auxiliary_revision_retire_delta"]
)
return summary
def build_markdown_report(
model_name: str,
device: str,
max_length: int,
future_window: int,
span_threshold: float,
top_spans: int,
seed: int,
results: list[dict[str, Any]],
summary: dict[str, Any],
) -> str:
lines = [
"# Qwen Future Influence Probe",
"",
f"Date: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')}",
f"Model: `{model_name}`",
f"Device: `{device}`",
f"Max length: `{max_length}`",
f"Future window: `{future_window}`",
f"Span threshold: `{span_threshold:.2f}`",
f"Top spans per case: `{top_spans}`",
f"Seed: `{seed}`",
"",
"## Summary",
"",
f"- Cases: `{summary['case_count']}`",
f"- Stable cases: `{summary['stable_count']}`",
f"- Conflict cases: `{summary['conflict_count']}`",
]
if "future_influence_gap_conflict_minus_stable" in summary:
lines.append(
f"- Conflict minus stable mean future influence gap: `{summary['future_influence_gap_conflict_minus_stable']:.4f}`"
)
if "anchor_future_influence_gap_conflict_minus_stable" in summary:
lines.append(
f"- Conflict minus stable active-anchor future influence gap: `{summary['anchor_future_influence_gap_conflict_minus_stable']:.4f}`"
)
if "future_span_overlap_gap_conflict_minus_stable" in summary:
lines.append(
f"- Conflict minus stable future-span overlap gap: `{summary['future_span_overlap_gap_conflict_minus_stable']:.4f}`"
)
if "auxiliary_proposal_count_gap_conflict_minus_stable" in summary:
lines.append(
f"- Conflict minus stable auxiliary proposal-count gap: `{summary['auxiliary_proposal_count_gap_conflict_minus_stable']:.4f}`"
)
if "auxiliary_proposal_score_gap_conflict_minus_stable" in summary:
lines.append(
f"- Conflict minus stable auxiliary proposal-score gap: `{summary['auxiliary_proposal_score_gap_conflict_minus_stable']:.4f}`"
)
if "auxiliary_revision_match_gap_conflict_minus_stable" in summary:
lines.append(
f"- Conflict minus stable auxiliary revision-match gap: `{summary['auxiliary_revision_match_gap_conflict_minus_stable']:.4f}`"
)
if "auxiliary_revision_revise_gain_conflict_minus_stable" in summary:
lines.append(
f"- Conflict minus stable auxiliary revise-gain gap: `{summary['auxiliary_revision_revise_gain_conflict_minus_stable']:.4f}`"
)
if "auxiliary_revision_retire_delta_conflict_minus_stable" in summary:
lines.append(
f"- Conflict minus stable auxiliary retire-delta gap: `{summary['auxiliary_revision_retire_delta_conflict_minus_stable']:.4f}`"
)
lines.extend(
[
"",
"## Case table",
"",
"| Family | Case | Expected | Tokens | Active | Aux proposals | Aux matches | Aux revise gain | Mean future influence | Anchor-position mean | Span overlap | Max influence | Future loss |",
"|---|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|",
]
)
for item in results:
lines.append(
"| {family} | {name} | {expected_mode} | {tokens} | {num_active} | {auxiliary_proposal_count} | {auxiliary_revision_matched_anchor_count} | {auxiliary_revision_revise_gain:+d} | {mean_future_influence:.4f} | "
"{anchor_position_mean_future_influence:.4f} | {future_span_overlap_ratio:.4f} | {max_future_influence:.4f} | {future_loss:.4f} |".format(**item)
)
lines.extend(["", "## Top future-influence tokens", ""])
for item in results:
lines.append(f"### {item['name']}")
for token in item["top_future_tokens"]:
lines.append(
f"- pos `{token['position']}` | token `{token['token_text']}` | id `{token['token_id']}` | score `{token['score']:.4f}`"
)
lines.append("")
lines.extend(["## High future-influence spans", ""])
for item in results:
lines.append(f"### {item['name']}")
if not item["future_spans"]:
lines.append("- no spans crossed the configured threshold")
for span in item["future_spans"]:
lines.append(
f"- span `{span['start']}-{span['end']}` | mean `{span['mean_score']:.4f}` | max `{span['max_score']:.4f}` | text `{span['text']}`"
)
if item["active_anchor_spans"]:
anchor_text = ", ".join(
f"{span['start']}-{span['end']}" for span in item["active_anchor_spans"]
)
lines.append(f"- active anchor spans: `{anchor_text}`")
else:
lines.append("- active anchor spans: none")
lines.append(
f"- future-span overlap ratio: `{item['future_span_overlap_ratio']:.4f}` | anchor-span overlap ratio: `{item['anchor_span_overlap_ratio']:.4f}`"
)
if item["future_hint_candidates"]:
lines.append("- proposal-like future hint spans:")
for span in item["future_hint_candidates"]:
lines.append(
f" - `{span['start']}-{span['end']}` | mean `{span['mean_score']:.4f}` | text `{span['text']}`"
)
if item["auxiliary_proposals"]:
lines.append("- auxiliary proposals:")
for proposal in item["auxiliary_proposals"]:
lines.append(
f" - `{proposal['proposal_span'][0]}-{proposal['proposal_span'][1]}` | score `{proposal['proposal_score']:.4f}` | text `{proposal['proposal_text']}`"
)
lines.append(
f"- auxiliary revision: matches `{item['auxiliary_revision_matched_anchor_count']}`, "
f"mean alt prob `{item['auxiliary_revision_mean_alt_prob']:.4f}`, "
f"revise gain `{item['auxiliary_revision_revise_gain']:+d}`, "
f"retire delta `{item['auxiliary_revision_retire_delta']:+d}`"
)
lines.append("")
lines.extend(
[
"## Interpretation",
"",
"- This report is an experimental midpoint between delta-hidden heuristics and full leave-one-out KL.",
"- Scores are based on gradient influence of token positions on a future autoregressive loss window.",
"- High-scoring positions are candidates for semantically important context even when local hidden-state jumps are ambiguous.",
"- Grouped high-influence spans help test whether future-attribution concentrates on the same regions as current active anchors or highlights missed context spans.",
]
)
return "\n".join(lines) + "\n"
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run future-gradient influence diagnostics on top of Qwen hidden states.")
parser.add_argument("--model", "--model-name", dest="model", type=str, default="Qwen/Qwen2.5-1.5B")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--max_length", "--max-length", dest="max_length", type=int, default=192)
parser.add_argument("--future_window", "--future-window", dest="future_window", type=int, default=16)
parser.add_argument("--top_k", "--top-k", dest="top_k", type=int, default=5)
parser.add_argument("--span_threshold", "--span-threshold", dest="span_threshold", type=float, default=0.75)
parser.add_argument("--top_spans", "--top-spans", dest="top_spans", type=int, default=5)
parser.add_argument("--case_filter", "--case-filter", dest="case_filter", type=str, default="")
parser.add_argument("--seed", type=int, default=7)
parser.add_argument(
"--output_json",
"--output-json",
dest="output_json",
type=Path,
default=ROOT / "archive" / "qwen_future_influence_probe.json",
)
parser.add_argument(
"--output_md",
"--output-md",
dest="output_md",
type=Path,
default=ROOT / "docs" / "research" / "qwen_future_influence_probe.md",
)
args, _ = parser.parse_known_args(argv)
return args
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
cfg = replace(
TOY_CONFIG,
anchor_threshold=0.10,
anchor_revision_threshold=0.35,
anchor_contradiction_threshold=0.20,
anchor_dead_end_threshold=0.50,
)
overlay = QwenAnchorOverlay.from_pretrained(
model_name=args.model,
cfg=cfg,
device=args.device,
torch_dtype=torch.float16 if "cuda" in args.device else None,
)
overlay.eval()
print("=== Qwen Future Influence Probe ===")
print(f"model={args.model}")
print(f"device={args.device}")
print()
results: list[dict[str, Any]] = []
case_filters = [part.strip() for part in args.case_filter.split(",") if part.strip()]
cases = make_qwen_probe_cases()
if case_filters:
cases = [
case
for case in cases
if any(
needle.lower() in case.name.lower() or needle.lower() in case.family.lower()
for needle in case_filters
)
]
for case in cases:
result = collect_case_result(
overlay=overlay,
case_name=case.name,
case_family=case.family,
case_description=case.description,
case_prompt=case.prompt,
expected_mode=case.expected_mode,
max_length=args.max_length,
future_window=args.future_window,
top_k=args.top_k,
span_threshold=args.span_threshold,
top_spans=args.top_spans,
)
results.append(result)
print(f"--- {case.name} ---")
print(f"family={case.family}")
print(f"expected_mode={case.expected_mode}")
print(f"mean_future_influence={result['mean_future_influence']:.4f}")
print(f"anchor_position_mean_future_influence={result['anchor_position_mean_future_influence']:.4f}")
print(f"future_span_overlap_ratio={result['future_span_overlap_ratio']:.4f}")
print(f"max_future_influence={result['max_future_influence']:.4f}")
print(f"future_loss={result['future_loss']:.4f}")
print()
summary = summarize_results(results)
payload = {
"generated_at": datetime.now(timezone.utc).isoformat(),
"model": args.model,
"device": args.device,
"max_length": args.max_length,
"future_window": args.future_window,
"seed": args.seed,
"results": results,
"summary": summary,
}
report = build_markdown_report(
model_name=args.model,
device=args.device,
max_length=args.max_length,
future_window=args.future_window,
span_threshold=args.span_threshold,
top_spans=args.top_spans,
seed=args.seed,
results=results,
summary=summary,
)
args.output_json.parent.mkdir(parents=True, exist_ok=True)
args.output_md.parent.mkdir(parents=True, exist_ok=True)
args.output_json.write_text(json.dumps(payload, indent=2), encoding="utf-8")
args.output_md.write_text(report, encoding="utf-8")
print(f"saved_json={args.output_json}")
print(f"saved_md={args.output_md}")
if __name__ == "__main__":
main()