Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import argparse | |
| from dataclasses import replace | |
| from datetime import datetime, timezone | |
| import json | |
| from pathlib import Path | |
| import sys | |
| from typing import Any | |
| import torch | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from src.data.qwen_probe_cases import make_qwen_probe_cases | |
| from src.model.config import TOY_CONFIG | |
| from src.model.qwen_anchor_overlay import QwenAnchorOverlay | |
| from src.model.future_span_hints import ( | |
| build_future_hint_candidates, | |
| compute_span_anchor_overlap, | |
| extract_high_influence_spans, | |
| safe_decode_token, | |
| ) | |
| def collect_case_result( | |
| overlay: QwenAnchorOverlay, | |
| case_name: str, | |
| case_family: str, | |
| case_description: str, | |
| case_prompt: str, | |
| expected_mode: str, | |
| max_length: int, | |
| future_window: int, | |
| top_k: int, | |
| span_threshold: float, | |
| top_spans: int, | |
| ) -> dict[str, Any]: | |
| out, batch = overlay.analyze_texts_with_future_influence( | |
| [case_prompt], | |
| max_length=max_length, | |
| future_window=future_window, | |
| ) | |
| diag = out["anchor_diagnostics"] | |
| influence = out["future_influence"] | |
| aux_diag = out["auxiliary_proposal_diagnostics"] | |
| aux_revision_diag = out["auxiliary_revision_diagnostics"] | |
| scores = influence["scores"][0] | |
| input_ids = batch["input_ids"][0] | |
| valid_len = int(batch["attention_mask"][0].sum().item()) if "attention_mask" in batch else int(input_ids.numel()) | |
| trimmed_scores = scores[:valid_len] | |
| trimmed_ids = input_ids[:valid_len] | |
| k = min(top_k, valid_len) | |
| top_values, top_indices = torch.topk(trimmed_scores, k=k) | |
| top_tokens = [ | |
| { | |
| "position": int(pos.item()), | |
| "token_id": int(trimmed_ids[pos].item()), | |
| "token_text": safe_decode_token(overlay.tokenizer, int(trimmed_ids[pos].item())), | |
| "score": float(val.item()), | |
| } | |
| for val, pos in zip(top_values, top_indices) | |
| ] | |
| active_anchor_spans = [ | |
| { | |
| "start": max(0, min(int(anchor.start_idx), valid_len - 1)), | |
| "end": max(0, min(int(anchor.end_idx), valid_len - 1)), | |
| } | |
| for anchor in out["active_anchors"][0] | |
| if valid_len > 0 | |
| ] | |
| anchor_positions = sorted({span["end"] for span in active_anchor_spans}) | |
| anchor_scores = [float(trimmed_scores[pos].item()) for pos in anchor_positions] | |
| future_spans = extract_high_influence_spans( | |
| scores=trimmed_scores, | |
| input_ids=trimmed_ids, | |
| tokenizer=overlay.tokenizer, | |
| min_score=span_threshold, | |
| top_spans=top_spans, | |
| ) | |
| future_hint_candidates = build_future_hint_candidates(future_spans, active_anchor_spans) | |
| overlap = compute_span_anchor_overlap(future_spans, active_anchor_spans) | |
| auxiliary_proposals = [ | |
| { | |
| "proposal_type": item["proposal_type"], | |
| "proposal_score": float(item["proposal_score"]), | |
| "proposal_span": ( | |
| int(item["proposal_span"][0]), | |
| int(item["proposal_span"][1]), | |
| ), | |
| "proposal_root_token": item["proposal_root_token"], | |
| "proposal_text": item["proposal_text"], | |
| } | |
| for item in out["auxiliary_proposal_batches"][0] | |
| ] | |
| return { | |
| "name": case_name, | |
| "family": case_family, | |
| "description": case_description, | |
| "expected_mode": expected_mode, | |
| "tokens": valid_len, | |
| "num_active": int(diag["num_active"]), | |
| "mean_contradiction_pressure": float(diag["mean_contradiction_pressure"]), | |
| "mean_viability": float(diag["mean_viability"]), | |
| "future_loss": float(influence["loss"]), | |
| "future_window": int(influence["target_window"]), | |
| "mean_future_influence": float(trimmed_scores.mean().item()), | |
| "max_future_influence": float(trimmed_scores.max().item()), | |
| "anchor_position_mean_future_influence": ( | |
| sum(anchor_scores) / len(anchor_scores) if anchor_scores else 0.0 | |
| ), | |
| "anchor_positions": anchor_positions, | |
| "active_anchor_spans": active_anchor_spans, | |
| "future_spans": future_spans, | |
| "future_hint_candidates": future_hint_candidates, | |
| "auxiliary_proposals": auxiliary_proposals, | |
| "auxiliary_proposal_count": int(len(auxiliary_proposals)), | |
| "auxiliary_mean_proposal_score": ( | |
| sum(float(item["proposal_score"]) for item in auxiliary_proposals) | |
| / max(len(auxiliary_proposals), 1) | |
| if auxiliary_proposals | |
| else 0.0 | |
| ), | |
| "auxiliary_batch_mean_score": float(aux_diag["mean_proposal_score"]), | |
| "auxiliary_revision_matched_anchor_count": int(aux_revision_diag["matched_anchor_count"]), | |
| "auxiliary_revision_mean_alt_prob": float(aux_revision_diag["mean_alt_prob"]), | |
| "auxiliary_revision_mean_matched_proposal_score": float(aux_revision_diag["mean_matched_proposal_score"]), | |
| "auxiliary_revision_base_revise_count": int(aux_revision_diag["base_revise_count"]), | |
| "auxiliary_revision_revise_count": int(aux_revision_diag["auxiliary_revise_count"]), | |
| "auxiliary_revision_revise_gain": int(aux_revision_diag["auxiliary_revise_gain"]), | |
| "auxiliary_revision_base_retire_count": int(aux_revision_diag["base_retire_count"]), | |
| "auxiliary_revision_retire_count": int(aux_revision_diag["auxiliary_retire_count"]), | |
| "auxiliary_revision_retire_delta": int(aux_revision_diag["auxiliary_retire_delta"]), | |
| **overlap, | |
| "top_future_tokens": top_tokens, | |
| } | |
| def summarize_results(results: list[dict[str, Any]]) -> dict[str, Any]: | |
| summary: dict[str, Any] = { | |
| "case_count": len(results), | |
| "stable_count": sum(1 for item in results if item["expected_mode"] == "stable"), | |
| "conflict_count": sum(1 for item in results if item["expected_mode"] == "conflict"), | |
| } | |
| for mode in ("stable", "conflict"): | |
| subset = [item for item in results if item["expected_mode"] == mode] | |
| if not subset: | |
| continue | |
| summary[f"{mode}_mean_future_influence"] = sum(item["mean_future_influence"] for item in subset) / len(subset) | |
| summary[f"{mode}_mean_anchor_future_influence"] = ( | |
| sum(item["anchor_position_mean_future_influence"] for item in subset) / len(subset) | |
| ) | |
| summary[f"{mode}_mean_future_span_overlap"] = ( | |
| sum(item.get("future_span_overlap_ratio", 0.0) for item in subset) / len(subset) | |
| ) | |
| summary[f"{mode}_mean_anchor_span_overlap"] = ( | |
| sum(item.get("anchor_span_overlap_ratio", 0.0) for item in subset) / len(subset) | |
| ) | |
| summary[f"{mode}_mean_auxiliary_proposal_count"] = ( | |
| sum(item.get("auxiliary_proposal_count", 0) for item in subset) / len(subset) | |
| ) | |
| summary[f"{mode}_mean_auxiliary_proposal_score"] = ( | |
| sum(item.get("auxiliary_mean_proposal_score", 0.0) for item in subset) / len(subset) | |
| ) | |
| summary[f"{mode}_mean_auxiliary_revision_matched_anchor_count"] = ( | |
| sum(item.get("auxiliary_revision_matched_anchor_count", 0) for item in subset) / len(subset) | |
| ) | |
| summary[f"{mode}_mean_auxiliary_revision_alt_prob"] = ( | |
| sum(item.get("auxiliary_revision_mean_alt_prob", 0.0) for item in subset) / len(subset) | |
| ) | |
| summary[f"{mode}_mean_auxiliary_revision_revise_gain"] = ( | |
| sum(item.get("auxiliary_revision_revise_gain", 0) for item in subset) / len(subset) | |
| ) | |
| summary[f"{mode}_mean_auxiliary_revision_retire_delta"] = ( | |
| sum(item.get("auxiliary_revision_retire_delta", 0) for item in subset) / len(subset) | |
| ) | |
| summary[f"{mode}_mean_future_loss"] = sum(item["future_loss"] for item in subset) / len(subset) | |
| if "stable_mean_future_influence" in summary and "conflict_mean_future_influence" in summary: | |
| summary["future_influence_gap_conflict_minus_stable"] = ( | |
| summary["conflict_mean_future_influence"] - summary["stable_mean_future_influence"] | |
| ) | |
| if "stable_mean_anchor_future_influence" in summary and "conflict_mean_anchor_future_influence" in summary: | |
| summary["anchor_future_influence_gap_conflict_minus_stable"] = ( | |
| summary["conflict_mean_anchor_future_influence"] - summary["stable_mean_anchor_future_influence"] | |
| ) | |
| if "stable_mean_future_span_overlap" in summary and "conflict_mean_future_span_overlap" in summary: | |
| summary["future_span_overlap_gap_conflict_minus_stable"] = ( | |
| summary["conflict_mean_future_span_overlap"] - summary["stable_mean_future_span_overlap"] | |
| ) | |
| if "stable_mean_auxiliary_proposal_count" in summary and "conflict_mean_auxiliary_proposal_count" in summary: | |
| summary["auxiliary_proposal_count_gap_conflict_minus_stable"] = ( | |
| summary["conflict_mean_auxiliary_proposal_count"] - summary["stable_mean_auxiliary_proposal_count"] | |
| ) | |
| if "stable_mean_auxiliary_proposal_score" in summary and "conflict_mean_auxiliary_proposal_score" in summary: | |
| summary["auxiliary_proposal_score_gap_conflict_minus_stable"] = ( | |
| summary["conflict_mean_auxiliary_proposal_score"] - summary["stable_mean_auxiliary_proposal_score"] | |
| ) | |
| if ( | |
| "stable_mean_auxiliary_revision_matched_anchor_count" in summary | |
| and "conflict_mean_auxiliary_revision_matched_anchor_count" in summary | |
| ): | |
| summary["auxiliary_revision_match_gap_conflict_minus_stable"] = ( | |
| summary["conflict_mean_auxiliary_revision_matched_anchor_count"] | |
| - summary["stable_mean_auxiliary_revision_matched_anchor_count"] | |
| ) | |
| if ( | |
| "stable_mean_auxiliary_revision_revise_gain" in summary | |
| and "conflict_mean_auxiliary_revision_revise_gain" in summary | |
| ): | |
| summary["auxiliary_revision_revise_gain_conflict_minus_stable"] = ( | |
| summary["conflict_mean_auxiliary_revision_revise_gain"] | |
| - summary["stable_mean_auxiliary_revision_revise_gain"] | |
| ) | |
| if ( | |
| "stable_mean_auxiliary_revision_retire_delta" in summary | |
| and "conflict_mean_auxiliary_revision_retire_delta" in summary | |
| ): | |
| summary["auxiliary_revision_retire_delta_conflict_minus_stable"] = ( | |
| summary["conflict_mean_auxiliary_revision_retire_delta"] | |
| - summary["stable_mean_auxiliary_revision_retire_delta"] | |
| ) | |
| return summary | |
| def build_markdown_report( | |
| model_name: str, | |
| device: str, | |
| max_length: int, | |
| future_window: int, | |
| span_threshold: float, | |
| top_spans: int, | |
| seed: int, | |
| results: list[dict[str, Any]], | |
| summary: dict[str, Any], | |
| ) -> str: | |
| lines = [ | |
| "# Qwen Future Influence Probe", | |
| "", | |
| f"Date: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M UTC')}", | |
| f"Model: `{model_name}`", | |
| f"Device: `{device}`", | |
| f"Max length: `{max_length}`", | |
| f"Future window: `{future_window}`", | |
| f"Span threshold: `{span_threshold:.2f}`", | |
| f"Top spans per case: `{top_spans}`", | |
| f"Seed: `{seed}`", | |
| "", | |
| "## Summary", | |
| "", | |
| f"- Cases: `{summary['case_count']}`", | |
| f"- Stable cases: `{summary['stable_count']}`", | |
| f"- Conflict cases: `{summary['conflict_count']}`", | |
| ] | |
| if "future_influence_gap_conflict_minus_stable" in summary: | |
| lines.append( | |
| f"- Conflict minus stable mean future influence gap: `{summary['future_influence_gap_conflict_minus_stable']:.4f}`" | |
| ) | |
| if "anchor_future_influence_gap_conflict_minus_stable" in summary: | |
| lines.append( | |
| f"- Conflict minus stable active-anchor future influence gap: `{summary['anchor_future_influence_gap_conflict_minus_stable']:.4f}`" | |
| ) | |
| if "future_span_overlap_gap_conflict_minus_stable" in summary: | |
| lines.append( | |
| f"- Conflict minus stable future-span overlap gap: `{summary['future_span_overlap_gap_conflict_minus_stable']:.4f}`" | |
| ) | |
| if "auxiliary_proposal_count_gap_conflict_minus_stable" in summary: | |
| lines.append( | |
| f"- Conflict minus stable auxiliary proposal-count gap: `{summary['auxiliary_proposal_count_gap_conflict_minus_stable']:.4f}`" | |
| ) | |
| if "auxiliary_proposal_score_gap_conflict_minus_stable" in summary: | |
| lines.append( | |
| f"- Conflict minus stable auxiliary proposal-score gap: `{summary['auxiliary_proposal_score_gap_conflict_minus_stable']:.4f}`" | |
| ) | |
| if "auxiliary_revision_match_gap_conflict_minus_stable" in summary: | |
| lines.append( | |
| f"- Conflict minus stable auxiliary revision-match gap: `{summary['auxiliary_revision_match_gap_conflict_minus_stable']:.4f}`" | |
| ) | |
| if "auxiliary_revision_revise_gain_conflict_minus_stable" in summary: | |
| lines.append( | |
| f"- Conflict minus stable auxiliary revise-gain gap: `{summary['auxiliary_revision_revise_gain_conflict_minus_stable']:.4f}`" | |
| ) | |
| if "auxiliary_revision_retire_delta_conflict_minus_stable" in summary: | |
| lines.append( | |
| f"- Conflict minus stable auxiliary retire-delta gap: `{summary['auxiliary_revision_retire_delta_conflict_minus_stable']:.4f}`" | |
| ) | |
| lines.extend( | |
| [ | |
| "", | |
| "## Case table", | |
| "", | |
| "| Family | Case | Expected | Tokens | Active | Aux proposals | Aux matches | Aux revise gain | Mean future influence | Anchor-position mean | Span overlap | Max influence | Future loss |", | |
| "|---|---|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|", | |
| ] | |
| ) | |
| for item in results: | |
| lines.append( | |
| "| {family} | {name} | {expected_mode} | {tokens} | {num_active} | {auxiliary_proposal_count} | {auxiliary_revision_matched_anchor_count} | {auxiliary_revision_revise_gain:+d} | {mean_future_influence:.4f} | " | |
| "{anchor_position_mean_future_influence:.4f} | {future_span_overlap_ratio:.4f} | {max_future_influence:.4f} | {future_loss:.4f} |".format(**item) | |
| ) | |
| lines.extend(["", "## Top future-influence tokens", ""]) | |
| for item in results: | |
| lines.append(f"### {item['name']}") | |
| for token in item["top_future_tokens"]: | |
| lines.append( | |
| f"- pos `{token['position']}` | token `{token['token_text']}` | id `{token['token_id']}` | score `{token['score']:.4f}`" | |
| ) | |
| lines.append("") | |
| lines.extend(["## High future-influence spans", ""]) | |
| for item in results: | |
| lines.append(f"### {item['name']}") | |
| if not item["future_spans"]: | |
| lines.append("- no spans crossed the configured threshold") | |
| for span in item["future_spans"]: | |
| lines.append( | |
| f"- span `{span['start']}-{span['end']}` | mean `{span['mean_score']:.4f}` | max `{span['max_score']:.4f}` | text `{span['text']}`" | |
| ) | |
| if item["active_anchor_spans"]: | |
| anchor_text = ", ".join( | |
| f"{span['start']}-{span['end']}" for span in item["active_anchor_spans"] | |
| ) | |
| lines.append(f"- active anchor spans: `{anchor_text}`") | |
| else: | |
| lines.append("- active anchor spans: none") | |
| lines.append( | |
| f"- future-span overlap ratio: `{item['future_span_overlap_ratio']:.4f}` | anchor-span overlap ratio: `{item['anchor_span_overlap_ratio']:.4f}`" | |
| ) | |
| if item["future_hint_candidates"]: | |
| lines.append("- proposal-like future hint spans:") | |
| for span in item["future_hint_candidates"]: | |
| lines.append( | |
| f" - `{span['start']}-{span['end']}` | mean `{span['mean_score']:.4f}` | text `{span['text']}`" | |
| ) | |
| if item["auxiliary_proposals"]: | |
| lines.append("- auxiliary proposals:") | |
| for proposal in item["auxiliary_proposals"]: | |
| lines.append( | |
| f" - `{proposal['proposal_span'][0]}-{proposal['proposal_span'][1]}` | score `{proposal['proposal_score']:.4f}` | text `{proposal['proposal_text']}`" | |
| ) | |
| lines.append( | |
| f"- auxiliary revision: matches `{item['auxiliary_revision_matched_anchor_count']}`, " | |
| f"mean alt prob `{item['auxiliary_revision_mean_alt_prob']:.4f}`, " | |
| f"revise gain `{item['auxiliary_revision_revise_gain']:+d}`, " | |
| f"retire delta `{item['auxiliary_revision_retire_delta']:+d}`" | |
| ) | |
| lines.append("") | |
| lines.extend( | |
| [ | |
| "## Interpretation", | |
| "", | |
| "- This report is an experimental midpoint between delta-hidden heuristics and full leave-one-out KL.", | |
| "- Scores are based on gradient influence of token positions on a future autoregressive loss window.", | |
| "- High-scoring positions are candidates for semantically important context even when local hidden-state jumps are ambiguous.", | |
| "- Grouped high-influence spans help test whether future-attribution concentrates on the same regions as current active anchors or highlights missed context spans.", | |
| ] | |
| ) | |
| return "\n".join(lines) + "\n" | |
| def parse_args(argv: list[str] | None = None) -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Run future-gradient influence diagnostics on top of Qwen hidden states.") | |
| parser.add_argument("--model", "--model-name", dest="model", type=str, default="Qwen/Qwen2.5-1.5B") | |
| parser.add_argument("--device", type=str, default="cpu") | |
| parser.add_argument("--max_length", "--max-length", dest="max_length", type=int, default=192) | |
| parser.add_argument("--future_window", "--future-window", dest="future_window", type=int, default=16) | |
| parser.add_argument("--top_k", "--top-k", dest="top_k", type=int, default=5) | |
| parser.add_argument("--span_threshold", "--span-threshold", dest="span_threshold", type=float, default=0.75) | |
| parser.add_argument("--top_spans", "--top-spans", dest="top_spans", type=int, default=5) | |
| parser.add_argument("--case_filter", "--case-filter", dest="case_filter", type=str, default="") | |
| parser.add_argument("--seed", type=int, default=7) | |
| parser.add_argument( | |
| "--output_json", | |
| "--output-json", | |
| dest="output_json", | |
| type=Path, | |
| default=ROOT / "archive" / "qwen_future_influence_probe.json", | |
| ) | |
| parser.add_argument( | |
| "--output_md", | |
| "--output-md", | |
| dest="output_md", | |
| type=Path, | |
| default=ROOT / "docs" / "research" / "qwen_future_influence_probe.md", | |
| ) | |
| args, _ = parser.parse_known_args(argv) | |
| return args | |
| def main() -> None: | |
| args = parse_args() | |
| torch.manual_seed(args.seed) | |
| cfg = replace( | |
| TOY_CONFIG, | |
| anchor_threshold=0.10, | |
| anchor_revision_threshold=0.35, | |
| anchor_contradiction_threshold=0.20, | |
| anchor_dead_end_threshold=0.50, | |
| ) | |
| overlay = QwenAnchorOverlay.from_pretrained( | |
| model_name=args.model, | |
| cfg=cfg, | |
| device=args.device, | |
| torch_dtype=torch.float16 if "cuda" in args.device else None, | |
| ) | |
| overlay.eval() | |
| print("=== Qwen Future Influence Probe ===") | |
| print(f"model={args.model}") | |
| print(f"device={args.device}") | |
| print() | |
| results: list[dict[str, Any]] = [] | |
| case_filters = [part.strip() for part in args.case_filter.split(",") if part.strip()] | |
| cases = make_qwen_probe_cases() | |
| if case_filters: | |
| cases = [ | |
| case | |
| for case in cases | |
| if any( | |
| needle.lower() in case.name.lower() or needle.lower() in case.family.lower() | |
| for needle in case_filters | |
| ) | |
| ] | |
| for case in cases: | |
| result = collect_case_result( | |
| overlay=overlay, | |
| case_name=case.name, | |
| case_family=case.family, | |
| case_description=case.description, | |
| case_prompt=case.prompt, | |
| expected_mode=case.expected_mode, | |
| max_length=args.max_length, | |
| future_window=args.future_window, | |
| top_k=args.top_k, | |
| span_threshold=args.span_threshold, | |
| top_spans=args.top_spans, | |
| ) | |
| results.append(result) | |
| print(f"--- {case.name} ---") | |
| print(f"family={case.family}") | |
| print(f"expected_mode={case.expected_mode}") | |
| print(f"mean_future_influence={result['mean_future_influence']:.4f}") | |
| print(f"anchor_position_mean_future_influence={result['anchor_position_mean_future_influence']:.4f}") | |
| print(f"future_span_overlap_ratio={result['future_span_overlap_ratio']:.4f}") | |
| print(f"max_future_influence={result['max_future_influence']:.4f}") | |
| print(f"future_loss={result['future_loss']:.4f}") | |
| print() | |
| summary = summarize_results(results) | |
| payload = { | |
| "generated_at": datetime.now(timezone.utc).isoformat(), | |
| "model": args.model, | |
| "device": args.device, | |
| "max_length": args.max_length, | |
| "future_window": args.future_window, | |
| "seed": args.seed, | |
| "results": results, | |
| "summary": summary, | |
| } | |
| report = build_markdown_report( | |
| model_name=args.model, | |
| device=args.device, | |
| max_length=args.max_length, | |
| future_window=args.future_window, | |
| span_threshold=args.span_threshold, | |
| top_spans=args.top_spans, | |
| seed=args.seed, | |
| results=results, | |
| summary=summary, | |
| ) | |
| args.output_json.parent.mkdir(parents=True, exist_ok=True) | |
| args.output_md.parent.mkdir(parents=True, exist_ok=True) | |
| args.output_json.write_text(json.dumps(payload, indent=2), encoding="utf-8") | |
| args.output_md.write_text(report, encoding="utf-8") | |
| print(f"saved_json={args.output_json}") | |
| print(f"saved_md={args.output_md}") | |
| if __name__ == "__main__": | |
| main() | |