from __future__ import annotations import argparse from dataclasses import replace from datetime import datetime, timezone import json from pathlib import Path import sys from typing import Any import torch ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from src.data.qwen_probe_cases import make_qwen_probe_cases from src.model.config import TOY_CONFIG from src.model.qwen_anchor_overlay import QwenAnchorOverlay from src.model.future_span_hints import ( build_future_hint_candidates, compute_span_anchor_overlap, extract_high_influence_spans, safe_decode_token, ) def collect_case_result( overlay: QwenAnchorOverlay, case_name: str, case_family: str, case_description: str, case_prompt: str, expected_mode: str, max_length: int, future_window: int, top_k: int, span_threshold: float, top_spans: int, ) -> dict[str, Any]: out, batch = overlay.analyze_texts_with_future_influence( [case_prompt], max_length=max_length, future_window=future_window, ) diag = out["anchor_diagnostics"] influence = out["future_influence"] aux_diag = out["auxiliary_proposal_diagnostics"] aux_revision_diag = out["auxiliary_revision_diagnostics"] scores = influence["scores"][0] input_ids = batch["input_ids"][0] valid_len = int(batch["attention_mask"][0].sum().item()) if "attention_mask" in batch else int(input_ids.numel()) trimmed_scores = scores[:valid_len] trimmed_ids = input_ids[:valid_len] k = min(top_k, valid_len) top_values, top_indices = torch.topk(trimmed_scores, k=k) top_tokens = [ { "position": int(pos.item()), "token_id": int(trimmed_ids[pos].item()), "token_text": safe_decode_token(overlay.tokenizer, int(trimmed_ids[pos].item())), "score": float(val.item()), } for val, pos in zip(top_values, top_indices) ] active_anchor_spans = [ { "start": max(0, min(int(anchor.start_idx), valid_len - 1)), "end": max(0, min(int(anchor.end_idx), valid_len - 1)), } for anchor in out["active_anchors"][0] if valid_len > 0 ] anchor_positions = sorted({span["end"] for span in active_anchor_spans}) anchor_scores = [float(trimmed_scores[pos].item()) for pos in anchor_positions] future_spans = extract_high_influence_spans( scores=trimmed_scores, input_ids=trimmed_ids, tokenizer=overlay.tokenizer, min_score=span_threshold, top_spans=top_spans, ) future_hint_candidates = build_future_hint_candidates(future_spans, active_anchor_spans) overlap = compute_span_anchor_overlap(future_spans, active_anchor_spans) auxiliary_proposals = [ { "proposal_type": item["proposal_type"], "proposal_score": float(item["proposal_score"]), "proposal_span": ( int(item["proposal_span"][0]), int(item["proposal_span"][1]), ), "proposal_root_token": item["proposal_root_token"], "proposal_text": item["proposal_text"], } for item in out["auxiliary_proposal_batches"][0] ] return { "name": case_name, "family": case_family, "description": case_description, "expected_mode": expected_mode, "tokens": valid_len, "num_active": int(diag["num_active"]), "mean_contradiction_pressure": float(diag["mean_contradiction_pressure"]), "mean_viability": float(diag["mean_viability"]), "future_loss": float(influence["loss"]), "future_window": int(influence["target_window"]), "mean_future_influence": float(trimmed_scores.mean().item()), "max_future_influence": float(trimmed_scores.max().item()), "anchor_position_mean_future_influence": ( sum(anchor_scores) / len(anchor_scores) if anchor_scores else 0.0 ), "anchor_positions": anchor_positions, "active_anchor_spans": active_anchor_spans, "future_spans": future_spans, "future_hint_candidates": future_hint_candidates, "auxiliary_proposals": auxiliary_proposals, "auxiliary_proposal_count": int(len(auxiliary_proposals)), "auxiliary_mean_proposal_score": ( sum(float(item["proposal_score"]) for item in auxiliary_proposals) / max(len(auxiliary_proposals), 1) if auxiliary_proposals else 0.0 ), "auxiliary_batch_mean_score": float(aux_diag["mean_proposal_score"]), "auxiliary_revision_matched_anchor_count": int(aux_revision_diag["matched_anchor_count"]), "auxiliary_revision_mean_alt_prob": float(aux_revision_diag["mean_alt_prob"]), "auxiliary_revision_mean_matched_proposal_score": float(aux_revision_diag["mean_matched_proposal_score"]), "auxiliary_revision_base_revise_count": int(aux_revision_diag["base_revise_count"]), "auxiliary_revision_revise_count": int(aux_revision_diag["auxiliary_revise_count"]), "auxiliary_revision_revise_gain": int(aux_revision_diag["auxiliary_revise_gain"]), "auxiliary_revision_base_retire_count": int(aux_revision_diag["base_retire_count"]), "auxiliary_revision_retire_count": int(aux_revision_diag["auxiliary_retire_count"]), "auxiliary_revision_retire_delta": int(aux_revision_diag["auxiliary_retire_delta"]), **overlap, "top_future_tokens": top_tokens, } def summarize_results(results: list[dict[str, Any]]) -> dict[str, Any]: summary: dict[str, Any] = { "case_count": len(results), "stable_count": sum(1 for item in results if item["expected_mode"] == "stable"), "conflict_count": sum(1 for item in results if item["expected_mode"] == "conflict"), } for mode in ("stable", "conflict"): subset = [item for item in results if item["expected_mode"] == mode] if not subset: continue summary[f"{mode}_mean_future_influence"] = sum(item["mean_future_influence"] for item in subset) / len(subset) summary[f"{mode}_mean_anchor_future_influence"] = ( sum(item["anchor_position_mean_future_influence"] for item in subset) / len(subset) ) summary[f"{mode}_mean_future_span_overlap"] = ( sum(item.get("future_span_overlap_ratio", 0.0) for item in subset) / len(subset) ) summary[f"{mode}_mean_anchor_span_overlap"] = ( sum(item.get("anchor_span_overlap_ratio", 0.0) for item in subset) / len(subset) ) summary[f"{mode}_mean_auxiliary_proposal_count"] = ( sum(item.get("auxiliary_proposal_count", 0) for item in subset) / len(subset) ) summary[f"{mode}_mean_auxiliary_proposal_score"] = ( sum(item.get("auxiliary_mean_proposal_score", 0.0) for item in subset) / len(subset) ) summary[f"{mode}_mean_auxiliary_revision_matched_anchor_count"] = ( sum(item.get("auxiliary_revision_matched_anchor_count", 0) for item in subset) / len(subset) ) summary[f"{mode}_mean_auxiliary_revision_alt_prob"] = ( sum(item.get("auxiliary_revision_mean_alt_prob", 0.0) for item in subset) / len(subset) ) summary[f"{mode}_mean_auxiliary_revision_revise_gain"] = ( sum(item.get("auxiliary_revision_revise_gain", 0) for item in subset) / len(subset) ) summary[f"{mode}_mean_auxiliary_revision_retire_delta"] = ( sum(item.get("auxiliary_revision_retire_delta", 0) for item in subset) / len(subset) ) summary[f"{mode}_mean_future_loss"] = sum(item["future_loss"] for item in subset) / len(subset) if "stable_mean_future_influence" in summary and "conflict_mean_future_influence" in summary: summary["future_influence_gap_conflict_minus_stable"] = ( summary["conflict_mean_future_influence"] - summary["stable_mean_future_influence"] ) if "stable_mean_anchor_future_influence" in summary and "conflict_mean_anchor_future_influence" in summary: summary["anchor_future_influence_gap_conflict_minus_stable"] = ( summary["conflict_mean_anchor_future_influence"] - summary["stable_mean_anchor_future_influence"] ) if "stable_mean_future_span_overlap" in summary and "conflict_mean_future_span_overlap" in summary: summary["future_span_overlap_gap_conflict_minus_stable"] = ( summary["conflict_mean_future_span_overlap"] - summary["stable_mean_future_span_overlap"] ) if "stable_mean_auxiliary_proposal_count" in summary and "conflict_mean_auxiliary_proposal_count" in summary: summary["auxiliary_proposal_count_gap_conflict_minus_stable"] = ( summary["conflict_mean_auxiliary_proposal_count"] - summary["stable_mean_auxiliary_proposal_count"] ) if "stable_mean_auxiliary_proposal_score" in summary and "conflict_mean_auxiliary_proposal_score" in summary: summary["auxiliary_proposal_score_gap_conflict_minus_stable"] = ( summary["conflict_mean_auxiliary_proposal_score"] - summary["stable_mean_auxiliary_proposal_score"] ) if ( "stable_mean_auxiliary_revision_matched_anchor_count" in summary and "conflict_mean_auxiliary_revision_matched_anchor_count" in summary ): summary["auxiliary_revision_match_gap_conflict_minus_stable"] = ( summary["conflict_mean_auxiliary_revision_matched_anchor_count"] - summary["stable_mean_auxiliary_revision_matched_anchor_count"] ) if ( "stable_mean_auxiliary_revision_revise_gain" in summary and "conflict_mean_auxiliary_revision_revise_gain" in summary ): summary["auxiliary_revision_revise_gain_conflict_minus_stable"] = ( summary["conflict_mean_auxiliary_revision_revise_gain"] - summary["stable_mean_auxiliary_revision_revise_gain"] ) if ( "stable_mean_auxiliary_revision_retire_delta" in summary and "conflict_mean_auxiliary_revision_retire_delta" in summary ): summary["auxiliary_revision_retire_delta_conflict_minus_stable"] = ( summary["conflict_mean_auxiliary_revision_retire_delta"] - summary["stable_mean_auxiliary_revision_retire_delta"] ) return summary def build_markdown_report( model_name: str, device: str, max_length: int, future_window: int, span_threshold: float, top_spans: int, seed: int, results: list[dict[str, Any]], summary: dict[str, Any], ) -> str: lines = [ "# Qwen Future Influence Probe", "", f"Date: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')}", f"Model: `{model_name}`", f"Device: `{device}`", f"Max length: `{max_length}`", f"Future window: `{future_window}`", f"Span threshold: `{span_threshold:.2f}`", f"Top spans per case: `{top_spans}`", f"Seed: `{seed}`", "", "## Summary", "", f"- Cases: `{summary['case_count']}`", f"- Stable cases: `{summary['stable_count']}`", f"- Conflict cases: `{summary['conflict_count']}`", ] if "future_influence_gap_conflict_minus_stable" in summary: lines.append( f"- Conflict minus stable mean future influence gap: `{summary['future_influence_gap_conflict_minus_stable']:.4f}`" ) if "anchor_future_influence_gap_conflict_minus_stable" in summary: lines.append( f"- Conflict minus stable active-anchor future influence gap: `{summary['anchor_future_influence_gap_conflict_minus_stable']:.4f}`" ) if "future_span_overlap_gap_conflict_minus_stable" in summary: lines.append( f"- Conflict minus stable future-span overlap gap: `{summary['future_span_overlap_gap_conflict_minus_stable']:.4f}`" ) if "auxiliary_proposal_count_gap_conflict_minus_stable" in summary: lines.append( f"- Conflict minus stable auxiliary proposal-count gap: `{summary['auxiliary_proposal_count_gap_conflict_minus_stable']:.4f}`" ) if "auxiliary_proposal_score_gap_conflict_minus_stable" in summary: lines.append( f"- Conflict minus stable auxiliary proposal-score gap: `{summary['auxiliary_proposal_score_gap_conflict_minus_stable']:.4f}`" ) if "auxiliary_revision_match_gap_conflict_minus_stable" in summary: lines.append( f"- Conflict minus stable auxiliary revision-match gap: `{summary['auxiliary_revision_match_gap_conflict_minus_stable']:.4f}`" ) if "auxiliary_revision_revise_gain_conflict_minus_stable" in summary: lines.append( f"- Conflict minus stable auxiliary revise-gain gap: `{summary['auxiliary_revision_revise_gain_conflict_minus_stable']:.4f}`" ) if "auxiliary_revision_retire_delta_conflict_minus_stable" in summary: lines.append( f"- Conflict minus stable auxiliary retire-delta gap: `{summary['auxiliary_revision_retire_delta_conflict_minus_stable']:.4f}`" ) lines.extend( [ "", "## Case table", "", "| Family | Case | Expected | Tokens | Active | Aux proposals | Aux matches | Aux revise gain | Mean future influence | Anchor-position mean | Span overlap | Max influence | Future loss |", "|---|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|", ] ) for item in results: lines.append( "| {family} | {name} | {expected_mode} | {tokens} | {num_active} | {auxiliary_proposal_count} | {auxiliary_revision_matched_anchor_count} | {auxiliary_revision_revise_gain:+d} | {mean_future_influence:.4f} | " "{anchor_position_mean_future_influence:.4f} | {future_span_overlap_ratio:.4f} | {max_future_influence:.4f} | {future_loss:.4f} |".format(**item) ) lines.extend(["", "## Top future-influence tokens", ""]) for item in results: lines.append(f"### {item['name']}") for token in item["top_future_tokens"]: lines.append( f"- pos `{token['position']}` | token `{token['token_text']}` | id `{token['token_id']}` | score `{token['score']:.4f}`" ) lines.append("") lines.extend(["## High future-influence spans", ""]) for item in results: lines.append(f"### {item['name']}") if not item["future_spans"]: lines.append("- no spans crossed the configured threshold") for span in item["future_spans"]: lines.append( f"- span `{span['start']}-{span['end']}` | mean `{span['mean_score']:.4f}` | max `{span['max_score']:.4f}` | text `{span['text']}`" ) if item["active_anchor_spans"]: anchor_text = ", ".join( f"{span['start']}-{span['end']}" for span in item["active_anchor_spans"] ) lines.append(f"- active anchor spans: `{anchor_text}`") else: lines.append("- active anchor spans: none") lines.append( f"- future-span overlap ratio: `{item['future_span_overlap_ratio']:.4f}` | anchor-span overlap ratio: `{item['anchor_span_overlap_ratio']:.4f}`" ) if item["future_hint_candidates"]: lines.append("- proposal-like future hint spans:") for span in item["future_hint_candidates"]: lines.append( f" - `{span['start']}-{span['end']}` | mean `{span['mean_score']:.4f}` | text `{span['text']}`" ) if item["auxiliary_proposals"]: lines.append("- auxiliary proposals:") for proposal in item["auxiliary_proposals"]: lines.append( f" - `{proposal['proposal_span'][0]}-{proposal['proposal_span'][1]}` | score `{proposal['proposal_score']:.4f}` | text `{proposal['proposal_text']}`" ) lines.append( f"- auxiliary revision: matches `{item['auxiliary_revision_matched_anchor_count']}`, " f"mean alt prob `{item['auxiliary_revision_mean_alt_prob']:.4f}`, " f"revise gain `{item['auxiliary_revision_revise_gain']:+d}`, " f"retire delta `{item['auxiliary_revision_retire_delta']:+d}`" ) lines.append("") lines.extend( [ "## Interpretation", "", "- This report is an experimental midpoint between delta-hidden heuristics and full leave-one-out KL.", "- Scores are based on gradient influence of token positions on a future autoregressive loss window.", "- High-scoring positions are candidates for semantically important context even when local hidden-state jumps are ambiguous.", "- Grouped high-influence spans help test whether future-attribution concentrates on the same regions as current active anchors or highlights missed context spans.", ] ) return "\n".join(lines) + "\n" def parse_args(argv: list[str] | None = None) -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run future-gradient influence diagnostics on top of Qwen hidden states.") parser.add_argument("--model", "--model-name", dest="model", type=str, default="Qwen/Qwen2.5-1.5B") parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--max_length", "--max-length", dest="max_length", type=int, default=192) parser.add_argument("--future_window", "--future-window", dest="future_window", type=int, default=16) parser.add_argument("--top_k", "--top-k", dest="top_k", type=int, default=5) parser.add_argument("--span_threshold", "--span-threshold", dest="span_threshold", type=float, default=0.75) parser.add_argument("--top_spans", "--top-spans", dest="top_spans", type=int, default=5) parser.add_argument("--case_filter", "--case-filter", dest="case_filter", type=str, default="") parser.add_argument("--seed", type=int, default=7) parser.add_argument( "--output_json", "--output-json", dest="output_json", type=Path, default=ROOT / "archive" / "qwen_future_influence_probe.json", ) parser.add_argument( "--output_md", "--output-md", dest="output_md", type=Path, default=ROOT / "docs" / "research" / "qwen_future_influence_probe.md", ) args, _ = parser.parse_known_args(argv) return args def main() -> None: args = parse_args() torch.manual_seed(args.seed) cfg = replace( TOY_CONFIG, anchor_threshold=0.10, anchor_revision_threshold=0.35, anchor_contradiction_threshold=0.20, anchor_dead_end_threshold=0.50, ) overlay = QwenAnchorOverlay.from_pretrained( model_name=args.model, cfg=cfg, device=args.device, torch_dtype=torch.float16 if "cuda" in args.device else None, ) overlay.eval() print("=== Qwen Future Influence Probe ===") print(f"model={args.model}") print(f"device={args.device}") print() results: list[dict[str, Any]] = [] case_filters = [part.strip() for part in args.case_filter.split(",") if part.strip()] cases = make_qwen_probe_cases() if case_filters: cases = [ case for case in cases if any( needle.lower() in case.name.lower() or needle.lower() in case.family.lower() for needle in case_filters ) ] for case in cases: result = collect_case_result( overlay=overlay, case_name=case.name, case_family=case.family, case_description=case.description, case_prompt=case.prompt, expected_mode=case.expected_mode, max_length=args.max_length, future_window=args.future_window, top_k=args.top_k, span_threshold=args.span_threshold, top_spans=args.top_spans, ) results.append(result) print(f"--- {case.name} ---") print(f"family={case.family}") print(f"expected_mode={case.expected_mode}") print(f"mean_future_influence={result['mean_future_influence']:.4f}") print(f"anchor_position_mean_future_influence={result['anchor_position_mean_future_influence']:.4f}") print(f"future_span_overlap_ratio={result['future_span_overlap_ratio']:.4f}") print(f"max_future_influence={result['max_future_influence']:.4f}") print(f"future_loss={result['future_loss']:.4f}") print() summary = summarize_results(results) payload = { "generated_at": datetime.now(timezone.utc).isoformat(), "model": args.model, "device": args.device, "max_length": args.max_length, "future_window": args.future_window, "seed": args.seed, "results": results, "summary": summary, } report = build_markdown_report( model_name=args.model, device=args.device, max_length=args.max_length, future_window=args.future_window, span_threshold=args.span_threshold, top_spans=args.top_spans, seed=args.seed, results=results, summary=summary, ) args.output_json.parent.mkdir(parents=True, exist_ok=True) args.output_md.parent.mkdir(parents=True, exist_ok=True) args.output_json.write_text(json.dumps(payload, indent=2), encoding="utf-8") args.output_md.write_text(report, encoding="utf-8") print(f"saved_json={args.output_json}") print(f"saved_md={args.output_md}") if __name__ == "__main__": main()