stufs2 / shadow_analysis.py
Echo9Zulu's picture
Upload shadow_analysis.py
560de10 verified
"""
Anticipation Analysis - Temporal View
======================================
True planning events only: tokens that appeared in top-k (ranks 2-20)
but were NOT chosen, then later BECAME the chosen token.
NEW: Plots anticipation events across sequence position to reveal
how planning behavior evolves through generation.
"""
import json
import math
from collections import defaultdict
from pathlib import Path
from typing import Optional, List, Dict, Any, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize
from matplotlib.cm import ScalarMappable
try:
# Optional dependency: improves label placement in dense plots
from adjustText import adjust_text
except ImportError:
adjust_text = None
# =============================================================================
# OUTPUT DIRECTORY
# =============================================================================
# =============================================================================
# CORE ANALYSIS PARAMETERS
# =============================================================================
# Top-k candidate analysis
TOP_K = 20 # Number of top candidates to consider at each position (rank 1-20)
# Temporal hierarchy / persistence settings
HIERARCHY_MIN_EPISODES_FOR_CLUSTER = 5 # Minimum (token, chosen_position) episodes to include a token in clustering
HIERARCHY_N_CLUSTERS = 50 # Number of clusters for token hierarchy analysis
HIERARCHY_TOP_EPISODES_PER_COMPLETION = 5000 # How many longest-horizon episodes to visualize per completion
HIERARCHY_MAX_COMPLETIONS_TO_PLOT = 10 # Limit plots when analyzing many completions
# =============================================================================
# DATA LOADING & FILTERING
# =============================================================================
OUTPUT_DIR = Path("anticipation_analysis_output_gemma1b_basemodel_poems")
INPUT_FILE = r"T:\AI\gemma\logprobs\gemma3_1b_basemodel_poems.jsonl"
MAX_FUTURE_SEARCH = 8192 # Max tokens to search backward for anticipation
MIN_SEQ_LEN = 100 # Minimum sequence length to analyze
MIN_DISTANCE = 5 # Minimum distance between anticipation and selection
NUM_COMPLETIONS = 3000 # Number of completions to process
MAX_LABELS_PER_PANEL = 80 # Max token labels per visualization panel
# Minimum anticipation events required to include a completion in analysis
MIN_EVENTS_PER_COMPLETION = 5
# Topic-anchor controls
KEEP_ONLY_TOKENS_CHOSEN_ONCE = False # True => restrict to tokens chosen exactly once in completion
EXCLUDE_PROMPT_TOKENS = False # True => remove tokens that appear in the prompt (anchor-y)
# Context window for source/target context extraction
CTX_WINDOW = 5
# =============================================================================
# DIAGNOSTIC THRESHOLDS
# =============================================================================
# Horizon threshold for "long-range" planning diagnostic output
LONG_HORIZON_THRESHOLD = 200
EVENT_SEQLEN_N_BINS = 10
def load_data(filepath: str, max_records: Optional[int] = None) -> List[Dict[str, Any]]:
"""Load JSONL records from file."""
records = []
with open(filepath, 'r', encoding='utf-8') as f:
for i, line in enumerate(f):
if max_records and i >= max_records:
break
if line.strip():
records.append(json.loads(line))
return records
def get_true_token(step_data: Dict[str, Any]) -> str:
"""Get the highest-probability token from step data."""
candidates = step_data['top_candidates']
return max(candidates, key=candidates.get)
class EventSeqLenAnalysis:
"""
Analyzes the ratio of anticipation events to sequence length.
Creates 10 dynamically-sized bins based on the range of ratios observed,
outputs each bin to its own markdown file, and creates a summary file
with user prompts organized by bin.
Output structure:
event_seq_len/
bin_01_0.00_to_0.15.md
bin_02_0.15_to_0.30.md
...
bin_10_1.35_to_1.50.md
prompts_by_bin.md
"""
def __init__(self, n_bins: int = EVENT_SEQLEN_N_BINS):
self.n_bins = n_bins
self.completions: List[Dict[str, Any]] = []
def add_completion(
self,
completion_idx: int,
n_events: int,
seq_len: int,
original_messages: List[Dict],
completion_text: str,
prompt: str = "",
response: str = "",
):
"""Add a completion's data for later binning analysis."""
ratio = n_events / seq_len if seq_len > 0 else 0.0
self.completions.append({
'completion_idx': completion_idx,
'n_events': n_events,
'seq_len': seq_len,
'ratio': ratio,
'original_messages': original_messages,
'completion_text': completion_text,
'prompt': prompt,
'response': response,
})
def _compute_bin_edges(self) -> np.ndarray:
"""Compute bin edges dynamically based on observed ratio range."""
if not self.completions:
return np.linspace(0, 1, self.n_bins + 1)
ratios = [c['ratio'] for c in self.completions]
min_ratio = min(ratios)
max_ratio = max(ratios)
# Add small padding to ensure all values fall within bins
padding = (max_ratio - min_ratio) * 0.001 if max_ratio > min_ratio else 0.1
return np.linspace(min_ratio - padding, max_ratio + padding, self.n_bins + 1)
def _assign_bins(self, bin_edges: np.ndarray) -> List[int]:
"""Assign each completion to a bin based on its ratio."""
bins = []
for c in self.completions:
ratio = c['ratio']
# Find which bin this ratio falls into
bin_idx = np.searchsorted(bin_edges[1:], ratio, side='right')
bin_idx = min(bin_idx, self.n_bins - 1) # Clamp to valid range
bins.append(bin_idx)
return bins
def _get_user_message(self, original_messages: List[Dict]) -> str:
"""Extract the user message from original_messages."""
for msg in original_messages:
if msg.get('role', '').lower() == 'user':
return msg.get('content', '')
return ""
def write_analysis(self, output_dir: Path):
"""
Write all analysis files to the event_seq_len subfolder.
Creates:
- One markdown file per bin with full completion details
- prompts_by_bin.md with user messages organized by bin
"""
if not self.completions:
print("No completions to analyze for event/seqlen binning.")
return
# Create subfolder
event_dir = output_dir / "event_seq_len"
event_dir.mkdir(exist_ok=True)
# Compute bins
bin_edges = self._compute_bin_edges()
bin_assignments = self._assign_bins(bin_edges)
# Organize completions by bin
bins_data: Dict[int, List[Dict]] = defaultdict(list)
for comp, bin_idx in zip(self.completions, bin_assignments):
bins_data[bin_idx].append(comp)
# Write individual bin files
for bin_idx in range(self.n_bins):
low = bin_edges[bin_idx]
high = bin_edges[bin_idx + 1]
filename = f"bin_{bin_idx + 1:02d}_{low:.2f}_to_{high:.2f}.md"
filepath = event_dir / filename
bin_completions = bins_data.get(bin_idx, [])
with open(filepath, 'w', encoding='utf-8') as f:
f.write(f"# Event/SeqLen Ratio Bin {bin_idx + 1}\n\n")
f.write(f"**Ratio Range:** {low:.4f} to {high:.4f}\n\n")
f.write(f"**Completions in Bin:** {len(bin_completions)}\n\n")
f.write("---\n\n")
if not bin_completions:
f.write("*No completions in this bin.*\n")
else:
# Sort by ratio within bin
bin_completions_sorted = sorted(bin_completions, key=lambda x: x['ratio'])
for i, comp in enumerate(bin_completions_sorted, 1):
f.write(f"## Completion {comp['completion_idx']} (Sample {i}/{len(bin_completions)})\n\n")
f.write(f"- **Events:** {comp['n_events']}\n")
f.write(f"- **Seq Length:** {comp['seq_len']}\n")
f.write(f"- **Ratio:** {comp['ratio']:.4f}\n\n")
f.write("### Original Messages\n\n")
for msg in comp['original_messages']:
role = msg.get('role', 'unknown')
content = msg.get('content', '')
f.write(f"**{role}:**\n```\n{content}\n```\n\n")
f.write("### Completion Text\n\n")
f.write(f"```\n{comp['completion_text']}\n```\n\n")
f.write("---\n\n")
print(f"Saved: {filepath} ({len(bin_completions)} completions)")
# Write prompts_by_bin.md
prompts_path = event_dir / "prompts_by_bin.md"
with open(prompts_path, 'w', encoding='utf-8') as f:
f.write("# User Prompts by Event/SeqLen Ratio Bin\n\n")
f.write("This file contains all user prompts organized by their event/seqlen ratio bin.\n\n")
# Summary statistics
ratios = [c['ratio'] for c in self.completions]
f.write("## Summary Statistics\n\n")
f.write(f"- **Total Completions:** {len(self.completions)}\n")
f.write(f"- **Min Ratio:** {min(ratios):.4f}\n")
f.write(f"- **Max Ratio:** {max(ratios):.4f}\n")
f.write(f"- **Mean Ratio:** {np.mean(ratios):.4f}\n")
f.write(f"- **Median Ratio:** {np.median(ratios):.4f}\n\n")
f.write("---\n\n")
for bin_idx in range(self.n_bins):
low = bin_edges[bin_idx]
high = bin_edges[bin_idx + 1]
bin_completions = bins_data.get(bin_idx, [])
f.write(f"# Bin {bin_idx + 1}: {low:.4f} to {high:.4f}\n\n")
f.write(f"**Count:** {len(bin_completions)} completions\n\n")
if not bin_completions:
f.write("*No completions in this bin.*\n\n")
else:
# Sort by ratio
bin_completions_sorted = sorted(bin_completions, key=lambda x: x['ratio'])
for comp in bin_completions_sorted:
user_msg = self._get_user_message(comp['original_messages'])
if user_msg:
# Truncate very long prompts for readability
display_msg = user_msg[:500] + "..." if len(user_msg) > 500 else user_msg
f.write(f"- **[Completion {comp['completion_idx']}, ratio={comp['ratio']:.3f}]:** {display_msg}\n\n")
else:
f.write(f"- **[Completion {comp['completion_idx']}, ratio={comp['ratio']:.3f}]:** *(no user message)*\n\n")
f.write("---\n\n")
print(f"Saved: {prompts_path}")
# Print bin distribution summary
print("\n--- Event/SeqLen Ratio Bin Distribution ---")
for bin_idx in range(self.n_bins):
low = bin_edges[bin_idx]
high = bin_edges[bin_idx + 1]
count = len(bins_data.get(bin_idx, []))
bar = "█" * (count // 5) if count > 0 else ""
print(f" Bin {bin_idx + 1:2d} [{low:6.3f} - {high:6.3f}]: {count:4d} {bar}")
class ShadowTokens:
"""
Finds and analyzes shadow tokens: tokens that appear in the top-k candidates
(ranks 2-20) but are not chosen, then later become the chosen token.
This represents the model's "anticipation" of future tokens.
"""
def __init__(
self,
max_future_search: int = MAX_FUTURE_SEARCH,
min_distance: int = MIN_DISTANCE,
min_seq_len: int = MIN_SEQ_LEN,
top_k: int = TOP_K,
ctx_window: int = CTX_WINDOW,
keep_only_tokens_chosen_once: bool = KEEP_ONLY_TOKENS_CHOSEN_ONCE,
exclude_prompt_tokens: bool = EXCLUDE_PROMPT_TOKENS,
):
self.max_future_search = max_future_search
self.min_distance = min_distance
self.min_seq_len = min_seq_len
self.top_k = top_k
self.ctx_window = ctx_window
self.keep_only_tokens_chosen_once = keep_only_tokens_chosen_once
self.exclude_prompt_tokens = exclude_prompt_tokens
def _extract_prompt(self, record: Dict[str, Any]) -> str:
"""Extract prompt text from record."""
prompt = record.get("full_prompt_text", "")
if not prompt:
original_messages = record.get("original_messages", [])
if original_messages:
prompt = "\n\n".join(
f"[{msg.get('role', 'unknown')}]: {msg.get('content', '')}"
for msg in original_messages
)
return prompt or ""
def _build_position_candidates(
self, logprobs_data: List[Dict]
) -> Tuple[List[Dict[str, Tuple[int, float]]], List[float]]:
"""
Build index: for each position, map token -> (rank, logprob),
and store kth logprob as a "top-k cutoff" reference for margins.
"""
position_candidates = []
position_kth_logprob = []
for step in logprobs_data:
candidates = step.get('top_candidates', {}) or {}
cleaned = []
for tok, lp in candidates.items():
try:
lp_f = float(lp)
except (TypeError, ValueError):
continue
cleaned.append((tok, lp_f))
sorted_cands = sorted(cleaned, key=lambda x: x[1], reverse=True)
token_to_rank = {}
for rank, (tok, lp) in enumerate(sorted_cands, start=1):
token_to_rank[tok] = (rank, float(lp))
k = len(sorted_cands)
kth_lp = float(sorted_cands[min(self.top_k, k) - 1][1]) if k > 0 else -np.inf
position_candidates.append(token_to_rank)
position_kth_logprob.append(kth_lp)
return position_candidates, position_kth_logprob
def find_events(
self, record: Dict[str, Any], record_idx: int
) -> Tuple[Optional[List[Dict]], int]:
"""
Find anticipation events:
For each chosen token at position j, look back to find where that token
appeared in the tail (lower-ranked candidates) at earlier positions.
Definition: At position j, token T is chosen. If T appeared in the
candidates (rank > 1) at an earlier position i, that's anticipation.
Returns:
Tuple of (list of event dicts or None, sequence length)
"""
logprobs_data = record.get("logprobs_data", [])
seq_len = len(logprobs_data)
prompt = self._extract_prompt(record)
if seq_len < self.min_seq_len:
return None, seq_len
# Build sequence of chosen tokens
chosen_sequence = [get_true_token(step) for step in logprobs_data]
# Count how often each token is chosen in this completion (topic anchors tend to repeat)
chosen_counts = defaultdict(int)
for tok in chosen_sequence:
chosen_counts[tok] += 1
prompt_text = prompt
# Response is reconstructed from the chosen tokens in logprobs
response = ''.join(chosen_sequence)
position_candidates, position_kth_logprob = self._build_position_candidates(logprobs_data)
results = []
# For each position j where a token is chosen
for j in range(self.min_distance, seq_len):
chosen_token = chosen_sequence[j]
# --- Topic-anchor filters ---
chosen_count = int(chosen_counts.get(chosen_token, 0))
if self.keep_only_tokens_chosen_once and chosen_count != 1:
continue
token_stripped = chosen_token.strip()
token_in_prompt = bool(token_stripped and (token_stripped in prompt_text))
if self.exclude_prompt_tokens and token_in_prompt:
continue
# Look back at earlier positions i < j
search_start = max(0, j - self.max_future_search)
for i in range(search_start, j - self.min_distance + 1):
# Check if the chosen token at j appeared in candidates at i
if chosen_token in position_candidates[i]:
rank_at_i, logprob_at_i = position_candidates[i][chosen_token]
# Enforce tail ranks 2-TOP_K explicitly:
# 1. Rank must be in [2, TOP_K] (not rank-1, not beyond top-k)
# 2. Token must NOT be the chosen token at position i
# This prevents counting rank-1 alternatives and prevents counting
# the chosen token even if it was sampled from a lower rank
chosen_at_i = chosen_sequence[i]
if 2 <= rank_at_i <= self.top_k and chosen_at_i != chosen_token:
prob = math.exp(logprob_at_i)
distance = j - i
kth_lp = position_kth_logprob[i]
margin_to_k = float(logprob_at_i - kth_lp) if np.isfinite(kth_lp) else np.nan
# Get context
ctx_window = int(self.ctx_window)
source_ctx_start = max(0, i - ctx_window)
source_ctx_end = min(seq_len, i + ctx_window + 1)
source_context = ''.join(chosen_sequence[source_ctx_start:source_ctx_end])
target_ctx_start = max(0, j - ctx_window)
target_ctx_end = min(seq_len, j + ctx_window + 1)
target_context = ''.join(chosen_sequence[target_ctx_start:target_ctx_end])
results.append({
'completion_idx': record_idx,
'seq_len': seq_len,
'source_pos': i,
'rel_pos': i / seq_len,
'token': chosen_token,
'token_display': chosen_token.strip()[:15],
'rank_when_anticipated': rank_at_i,
'prob_when_anticipated': prob,
'chosen_at_source': chosen_at_i,
'future_pos': j,
'distance': distance,
'source_context': source_context,
'target_context': target_context,
'prompt': prompt,
'response': response,
'logprob_at_source': float(logprob_at_i),
'kth_logprob_at_source': float(kth_lp) if np.isfinite(kth_lp) else np.nan,
'logprob_margin_to_k': margin_to_k,
'token_chosen_count_in_completion': chosen_count,
'token_in_prompt': token_in_prompt,
})
return results, seq_len
def find_forward_realizations(
self, record: Dict[str, Any], record_idx: int
) -> Tuple[Optional[List[Dict]], int]:
"""
Forward-looking baseline:
At each position i, for each tail candidate token (rank 2-20),
check whether it becomes the chosen token at some future position j.
Record the earliest such j (first realization).
Returns:
Tuple of (list of event dicts or None, sequence length)
"""
logprobs_data = record.get("logprobs_data", [])
seq_len = len(logprobs_data)
if seq_len < self.min_seq_len:
return None, seq_len
prompt = self._extract_prompt(record)
prompt_text = prompt
chosen_sequence = [get_true_token(step) for step in logprobs_data]
# Precompute chosen positions + counts for each token
chosen_positions = defaultdict(list)
chosen_counts = defaultdict(int)
for j, tok in enumerate(chosen_sequence):
chosen_positions[tok].append(j)
chosen_counts[tok] += 1
results = []
for i, step in enumerate(logprobs_data):
candidates = step.get("top_candidates", {}) or {}
cleaned = []
for tok, lp in candidates.items():
try:
lp_f = float(lp)
except (TypeError, ValueError):
continue
cleaned.append((tok, lp_f))
sorted_cands = sorted(cleaned, key=lambda x: x[1], reverse=True)
# kth cutoff for margin diagnostics
k = len(sorted_cands)
kth_lp = float(sorted_cands[min(self.top_k, k) - 1][1]) if k > 0 else -np.inf
# Enforce tail ranks 2-TOP_K explicitly:
# Only consider tokens with rank in [2, TOP_K] that are NOT the chosen token
chosen_at_i = chosen_sequence[i]
tail = []
for rank, (tok, lp) in enumerate(sorted_cands[:self.top_k], start=1):
# Skip rank 1 (would be the chosen token in greedy decoding)
if rank < 2:
continue
# Skip if this token is actually the chosen token (handles non-greedy sampling)
if tok == chosen_at_i:
continue
tail.append((tok, rank, float(lp)))
if not tail:
continue
for tok, rank, lp in tail:
prob = float(np.exp(lp))
margin_to_k = float(lp - kth_lp) if np.isfinite(kth_lp) else np.nan
# --- Optional topic-anchor filters (same toggles as backward analysis) ---
chosen_count = int(chosen_counts.get(tok, 0))
if self.keep_only_tokens_chosen_once and chosen_count != 1:
continue
tok_stripped = str(tok).strip()
token_in_prompt = bool(tok_stripped and (tok_stripped in prompt_text))
if self.exclude_prompt_tokens and token_in_prompt:
continue
# find first chosen position j > i+min_dist
js = chosen_positions.get(tok, [])
j_future = None
for j in js:
if j >= i + self.min_distance:
j_future = j
break
if j_future is None:
continue
results.append({
"completion_idx": record_idx,
"source_pos": i,
"future_pos": int(j_future),
"distance": int(j_future - i),
"token": tok,
"rank_when_anticipated": int(rank),
"prob_when_anticipated": prob,
"logprob_at_source": float(lp),
"kth_logprob_at_source": float(kth_lp) if np.isfinite(kth_lp) else np.nan,
"logprob_margin_to_k": margin_to_k,
"token_chosen_count_in_completion": chosen_count,
"token_in_prompt": token_in_prompt,
})
return results, seq_len
class PlanningEpisodes:
"""
Analyzes planning episodes: aggregating shadow token events into
persistence/horizon metrics that capture how tokens evolve from
first appearance in the tail to eventual selection.
"""
def __init__(
self,
min_episodes_for_cluster: int = HIERARCHY_MIN_EPISODES_FOR_CLUSTER,
n_clusters: int = HIERARCHY_N_CLUSTERS,
top_episodes_per_completion: int = HIERARCHY_TOP_EPISODES_PER_COMPLETION,
max_completions_to_plot: int = HIERARCHY_MAX_COMPLETIONS_TO_PLOT,
long_horizon_threshold: int = LONG_HORIZON_THRESHOLD,
):
self.min_episodes_for_cluster = min_episodes_for_cluster
self.n_clusters = n_clusters
self.top_episodes_per_completion = top_episodes_per_completion
self.max_completions_to_plot = max_completions_to_plot
self.long_horizon_threshold = long_horizon_threshold
def build(self, df: pd.DataFrame) -> pd.DataFrame:
"""
Collapse anticipation events into "commit episodes":
one row per (completion_idx, token, future_pos), aggregating all earlier
source positions where the token appeared as a non-chosen candidate.
This captures "persistence" (how often the same token reappears in the tail
between first appearance and eventual selection).
"""
required = {"completion_idx", "token", "future_pos", "source_pos", "seq_len", "distance", "rank_when_anticipated"}
missing = required - set(df.columns)
if missing:
return pd.DataFrame()
gcols = ["completion_idx", "token", "future_pos"]
df2 = df.copy()
for col in ["completion_idx", "future_pos", "source_pos", "seq_len", "distance", "rank_when_anticipated"]:
df2[col] = pd.to_numeric(df2[col], errors="coerce")
df2 = df2.dropna(subset=["completion_idx", "future_pos", "source_pos", "seq_len", "distance", "rank_when_anticipated"])
def _agg(group: pd.DataFrame) -> pd.Series:
seq_len = float(group["seq_len"].iloc[0])
future_pos = float(group["future_pos"].iloc[0])
first_source = float(group["source_pos"].min())
last_source = float(group["source_pos"].max())
n_mentions = int(len(group))
horizon = float(future_pos - first_source) # == max distance in group
horizon = float(max(horizon, 0.0))
last_gap = float(future_pos - last_source) # how long after the last tail-mention the token is chosen
recency_ratio = float(last_gap / max(horizon, 1.0))
persistence_span = float(max(0.0, last_source - first_source))
density = float(n_mentions / max(horizon, 1.0))
best_rank = float(group["rank_when_anticipated"].min())
mean_rank = float(group["rank_when_anticipated"].mean())
out = {
"completion_idx": int(group["completion_idx"].iloc[0]),
"token": group["token"].iloc[0],
"seq_len": int(seq_len),
"future_pos": int(future_pos),
"future_rel_pos": float(future_pos / seq_len) if seq_len > 0 else np.nan,
"first_source_pos": int(first_source),
"last_source_pos": int(last_source),
"first_rel_pos": float(first_source / seq_len) if seq_len > 0 else np.nan,
"horizon": horizon,
"last_gap": last_gap,
"recency_ratio": recency_ratio,
"persistence_span": persistence_span,
"n_mentions": n_mentions,
"density": density,
"best_rank": best_rank,
"mean_rank": mean_rank,
"mean_prob": float(group["prob_when_anticipated"].mean()) if "prob_when_anticipated" in group.columns else np.nan,
}
return pd.Series(out)
episodes = df2.groupby(gcols, sort=False).apply(_agg).reset_index(drop=True)
episodes = episodes.sort_values(["completion_idx", "horizon"], ascending=[True, False]).reset_index(drop=True)
return episodes
def summarize_hierarchy(self, episodes: pd.DataFrame) -> pd.DataFrame:
"""
Aggregate episode-level persistence/horizon stats to token-level summaries.
"""
if episodes is None or len(episodes) == 0:
return pd.DataFrame()
token_stats = (
episodes.groupby("token", sort=False)
.agg(
n_episodes=("future_pos", "size"),
mean_horizon=("horizon", "mean"),
max_horizon=("horizon", "max"),
median_horizon=("horizon", "median"),
mean_mentions=("n_mentions", "mean"),
max_mentions=("n_mentions", "max"),
mean_density=("density", "mean"),
mean_best_rank=("best_rank", "mean"),
mean_rank=("mean_rank", "mean"),
mean_first_rel_pos=("first_rel_pos", "mean"),
mean_future_rel_pos=("future_rel_pos", "mean"),
)
.reset_index()
)
token_stats = token_stats.sort_values(["max_horizon", "n_episodes"], ascending=[False, False]).reset_index(drop=True)
return token_stats
def cluster_hierarchy(self, token_stats: pd.DataFrame) -> pd.DataFrame:
"""
Cluster tokens by persistence/horizon features to surface "temporal tiers".
Uses hierarchical clustering (Ward) on standardized features.
"""
if token_stats is None or len(token_stats) == 0:
return pd.DataFrame()
out = token_stats.copy()
out["cluster"] = -1
# Filter to tokens with enough episodes (stable estimates)
eligible = out[out["n_episodes"] >= self.min_episodes_for_cluster].copy()
if len(eligible) < max(self.n_clusters, 2):
return out
# Features for clustering: horizon, density, rank
feature_cols = ["mean_horizon", "mean_density", "mean_rank"]
for c in feature_cols:
eligible[c] = pd.to_numeric(eligible[c], errors="coerce")
eligible2 = eligible.dropna(subset=feature_cols).copy()
if len(eligible2) < max(self.n_clusters, 2):
return out
X = eligible2[feature_cols].to_numpy(dtype=float)
mu = np.nanmean(X, axis=0)
sigma = np.nanstd(X, axis=0)
sigma[sigma == 0] = 1.0
Xz = (X - mu) / sigma
try:
from scipy.cluster.hierarchy import fcluster, linkage
except Exception:
return out
Z = linkage(Xz, method="ward")
labels = fcluster(Z, t=self.n_clusters, criterion="maxclust")
eligible2["cluster"] = labels.astype(int)
out = out.merge(eligible2[["token", "cluster"]], on="token", how="left", suffixes=("", "_new"))
out["cluster"] = out["cluster_new"].fillna(out["cluster"]).astype(int)
out = out.drop(columns=["cluster_new"])
return out
def plot_hierarchy(self, token_stats: pd.DataFrame, output_dir: Path):
"""
Plot token-level hierarchy: horizon vs persistence density, colored by cluster.
"""
if token_stats is None or len(token_stats) == 0:
return
dfp = token_stats.copy()
dfp["mean_horizon"] = pd.to_numeric(dfp["mean_horizon"], errors="coerce")
dfp["mean_density"] = pd.to_numeric(dfp["mean_density"], errors="coerce")
dfp["cluster"] = pd.to_numeric(dfp.get("cluster", -1), errors="coerce").fillna(-1).astype(int)
dfp = dfp.dropna(subset=["mean_horizon", "mean_density"])
if len(dfp) == 0:
return
fig, ax = plt.subplots(figsize=(12, 8))
clusters = sorted(dfp["cluster"].unique())
cmap = plt.cm.tab10
for idx, c in enumerate(clusters):
sub = dfp[dfp["cluster"] == c]
color = cmap(idx % 10) if c >= 0 else (0.5, 0.5, 0.5, 0.35)
ax.scatter(
sub["mean_horizon"],
sub["mean_density"],
s=np.clip(sub["n_episodes"].to_numpy(dtype=float) * 2.0, 20, 200),
alpha=0.75 if c >= 0 else 0.25,
label=f"Cluster {c}" if c >= 0 else "Unclustered",
color=color,
edgecolors="none",
)
ax.set_xlabel("Mean planning horizon (max distance per episode)", fontsize=12)
ax.set_ylabel("Persistence density (mentions / horizon)", fontsize=12)
ax.set_title("Token temporal hierarchy: horizon vs persistence density", fontsize=13, fontweight="bold")
ax.grid(True, alpha=0.25)
ax.legend(loc="upper right", fontsize=9, frameon=True)
# Label a few extreme tokens for interpretability
to_label = dfp.nlargest(20, "mean_horizon")
texts = []
for _, row in to_label.iterrows():
tok = str(row["token"]).replace("\n", "\\n")
texts.append(
ax.text(
row["mean_horizon"],
row["mean_density"],
tok.strip()[:18],
fontsize=8,
alpha=0.9,
ha="left",
va="bottom",
)
)
if adjust_text is not None:
try:
adjust_text(texts, ax=ax, arrowprops=dict(arrowstyle="-", color="gray", alpha=0.25), time_lim=3)
except Exception:
pass
plt.tight_layout()
plt.savefig(output_dir / "token_temporal_hierarchy.png", dpi=150, bbox_inches="tight")
plt.close()
def plot_gantt(self, df: pd.DataFrame, output_dir: Path, completion_idx: int, top_n: int = 60):
"""
Visualize "first-seen → chosen" spans for the longest-horizon episodes in a completion.
Each row is one (token, future_pos) episode; points mark repeated candidate mentions.
Outputs to shadow_lifecycle/ subfolder.
"""
if df is None or len(df) == 0:
return
comp_df = df[df["completion_idx"] == completion_idx].copy()
if len(comp_df) == 0:
return
# Create shadow_lifecycle subfolder
lifecycle_dir = output_dir / "shadow_lifecycle"
lifecycle_dir.mkdir(exist_ok=True)
# Build per-episode aggregates and keep the underlying source positions
gcols = ["token", "future_pos"]
comp_df["future_pos"] = pd.to_numeric(comp_df["future_pos"], errors="coerce")
comp_df["source_pos"] = pd.to_numeric(comp_df["source_pos"], errors="coerce")
comp_df["rank_when_anticipated"] = pd.to_numeric(comp_df["rank_when_anticipated"], errors="coerce")
comp_df = comp_df.dropna(subset=["future_pos", "source_pos", "rank_when_anticipated"])
agg = (
comp_df.groupby(gcols, sort=False)
.agg(
seq_len=("seq_len", "first"),
first_source=("source_pos", "min"),
last_source=("source_pos", "max"),
n_mentions=("source_pos", "size"),
best_rank=("rank_when_anticipated", "min"),
mean_rank=("rank_when_anticipated", "mean"),
)
.reset_index()
)
agg["horizon"] = agg["future_pos"] - agg["first_source"]
agg = agg.sort_values("horizon", ascending=False).head(int(top_n)).reset_index(drop=True)
if len(agg) == 0:
return
# Plot
fig_h = max(8, 0.22 * len(agg))
fig, ax = plt.subplots(figsize=(14, fig_h))
norm = Normalize(vmin=2, vmax=20)
cmap = plt.cm.viridis_r
for y, row in enumerate(agg.itertuples(index=False)):
tok = str(row.token)
future = float(row.future_pos)
first = float(row.first_source)
best_rank = float(row.best_rank) if np.isfinite(row.best_rank) else 20.0
color = cmap(norm(best_rank))
# Span line
ax.plot([first, future], [y, y], color=color, alpha=0.7, linewidth=2.2)
# Candidate mentions as x's (more visible than dots)
mentions = comp_df[(comp_df["token"] == tok) & (comp_df["future_pos"] == future)]["source_pos"].to_numpy(dtype=float)
if mentions.size > 0:
ax.scatter(
mentions,
np.full_like(mentions, y),
s=38,
marker="x",
color=color,
alpha=0.85,
linewidths=1.2,
)
# Chosen position marker
ax.scatter([future], [y], s=60, marker="*", color="black", alpha=0.9, zorder=3)
label = tok.replace("\n", "\\n").strip()
label = label if label else "<space>"
ax.text(future + 2, y, label[:28], fontsize=8, va="center", alpha=0.95)
seq_len = int(agg["seq_len"].iloc[0]) if "seq_len" in agg.columns and pd.notna(agg["seq_len"].iloc[0]) else None
ax.set_yticks([])
ax.set_xlabel("Token position in sequence", fontsize=11)
ax.set_title(
f"Completion {completion_idx + 1}: longest-horizon planning episodes (top {len(agg)})\n"
"Line = first time token appears in tail → chosen; dots = repeated candidate mentions; star = chosen position",
fontsize=12,
fontweight="bold",
)
if seq_len is not None:
ax.set_xlim(0, seq_len)
ax.grid(True, alpha=0.2, axis="x")
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
plt.colorbar(sm, ax=ax, label="Best rank during episode (2=high prob)", shrink=0.8)
plt.tight_layout()
plt.savefig(lifecycle_dir / f"planning_episode_hierarchy_completion_{completion_idx + 1}.png", dpi=150, bbox_inches="tight")
plt.close()
def plot_temporal_sequence(df: pd.DataFrame, output_dir: Path, cmap=plt.cm.viridis_r):
"""
Plot anticipation events across token sequence position.
X-axis: source position (where anticipation occurred)
Y-axis: distance to chosen
Color: rank when anticipated
"""
n_completions = df['completion_idx'].nunique()
# --- Panel 1: Individual completions - temporal view (SEPARATE IMAGES) ---
for comp_idx in range(n_completions):
comp_df = df[df['completion_idx'] == comp_idx].copy()
if len(comp_df) == 0:
continue
seq_len = comp_df['seq_len'].iloc[0]
# Create individual figure for this completion
fig, ax = plt.subplots(figsize=(14, 10))
# Plot: x = source position, y = distance, color = rank
scatter = ax.scatter(
comp_df['source_pos'],
comp_df['distance'],
c=comp_df['rank_when_anticipated'],
cmap=cmap,
alpha=0.6,
s=60,
edgecolors='white',
linewidth=0.5,
vmin=2,
vmax=20
)
# Add token labels on each dot
texts = []
for _, row in comp_df.iterrows():
if len(row['token_display']) > 0:
txt = ax.text(
row['source_pos'],
row['distance'],
row['token_display'],
fontsize=7,
alpha=0.9,
fontweight='bold',
ha='left',
va='bottom',
)
texts.append(txt)
# Adjust text positions to avoid overlap
try:
adjust_text(texts, ax=ax,
arrowprops=dict(arrowstyle='-', color='gray', alpha=0.3),
expand_points=(1.3, 1.3), force_text=(0.3, 0.5),
time_lim=5)
except:
pass
# Add diagonal line showing "maximum possible distance" (can't anticipate beyond seq end)
x_line = np.arange(0, seq_len)
y_max_line = seq_len - x_line
ax.plot(x_line, y_max_line, 'k--', alpha=0.2, label='Max possible distance')
ax.set_xlabel('Source Position (token index)', fontsize=12)
ax.set_ylabel('Distance to Chosen (tokens ahead)', fontsize=12)
ax.set_title(f'TEMPORAL VIEW: Completion {comp_idx + 1} (len={seq_len}): {len(comp_df)} anticipation events\n'
'Where in generation does the model "look ahead" and how far?',
fontsize=13, fontweight='bold')
ax.set_xlim(0, seq_len)
ax.grid(True, alpha=0.3)
# Add trend line
if len(comp_df) > 10:
z = np.polyfit(comp_df['source_pos'], comp_df['distance'], 1)
p = np.poly1d(z)
x_trend = np.linspace(0, seq_len, 100)
ax.plot(x_trend, p(x_trend), 'r-', alpha=0.5, linewidth=2,
label=f'Trend: {z[0]:.3f}x + {z[1]:.1f}')
ax.legend(loc='upper right', fontsize=10)
plt.colorbar(scatter, ax=ax, label='Rank (2=high prob)', shrink=0.8)
plt.tight_layout()
filename = f'anticipation_temporal_completion_{comp_idx + 1}.png'
plt.savefig(output_dir / filename, dpi=150, bbox_inches='tight')
print(f"Saved: {output_dir}/{filename}")
plt.close()
# --- Panel 2: Aggregated view with density ---
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
# 2a: All completions overlaid (normalized position)
ax = axes[0, 0]
scatter = ax.scatter(
df['rel_pos'],
df['distance'],
c=df['rank_when_anticipated'],
cmap=cmap,
alpha=0.4,
s=25,
edgecolors='none',
vmin=2,
vmax=20
)
ax.set_xlabel('Relative Position in Sequence (0=start, 1=end)', fontsize=10)
ax.set_ylabel('Distance to Chosen', fontsize=10)
ax.set_title('All Completions: Planning Horizon by Relative Position', fontsize=11)
ax.set_xlim(0, 1)
plt.colorbar(scatter, ax=ax, label='Rank')
ax.grid(True, alpha=0.3)
# 2b: Mean distance by position bins
ax = axes[0, 1]
df['pos_bin'] = pd.cut(df['rel_pos'], bins=20, labels=False)
bin_stats = df.groupby('pos_bin').agg({
'distance': ['mean', 'std', 'count'],
'rank_when_anticipated': 'mean'
}).reset_index()
bin_stats.columns = ['pos_bin', 'mean_dist', 'std_dist', 'count', 'mean_rank']
bin_stats['rel_pos_center'] = (bin_stats['pos_bin'] + 0.5) / 20
ax.errorbar(
bin_stats['rel_pos_center'],
bin_stats['mean_dist'],
yerr=bin_stats['std_dist'] / np.sqrt(bin_stats['count']),
fmt='o-',
capsize=3,
color='steelblue',
markersize=6
)
ax.set_xlabel('Relative Position (binned)', fontsize=10)
ax.set_ylabel('Mean Distance (± SE)', fontsize=10)
ax.set_title('Planning Horizon: Does it change across generation?', fontsize=11)
ax.grid(True, alpha=0.3)
# 2c: Rank by position
ax = axes[1, 0]
ax.scatter(
df['rel_pos'],
df['rank_when_anticipated'],
c=df['distance'],
cmap='plasma',
alpha=0.4,
s=25,
edgecolors='none'
)
ax.set_xlabel('Relative Position', fontsize=10)
ax.set_ylabel('Rank When Anticipated', fontsize=10)
ax.set_title('Anticipation Confidence by Position (color=distance)', fontsize=11)
ax.set_xlim(0, 1)
ax.set_ylim(1.5, 20.5)
ax.grid(True, alpha=0.3)
# Add mean rank trend
ax2 = ax.twinx()
ax2.plot(bin_stats['rel_pos_center'], bin_stats['mean_rank'], 'r-', linewidth=2, label='Mean rank')
ax2.set_ylabel('Mean Rank', color='red', fontsize=10)
ax2.tick_params(axis='y', labelcolor='red')
# 2d: Event density by position
ax = axes[1, 1]
ax.bar(bin_stats['rel_pos_center'], bin_stats['count'], width=0.045,
color='steelblue', alpha=0.7, edgecolor='white')
ax.set_xlabel('Relative Position', fontsize=10)
ax.set_ylabel('Number of Anticipation Events', fontsize=10)
ax.set_title('Where Does Planning Happen Most?', fontsize=11)
ax.set_xlim(0, 1)
ax.grid(True, alpha=0.3, axis='y')
plt.suptitle('AGGREGATED TEMPORAL ANALYSIS\n'
'How planning behavior evolves through generation',
fontsize=13, fontweight='bold')
plt.tight_layout()
plt.savefig(output_dir / 'anticipation_temporal_aggregate.png', dpi=150, bbox_inches='tight')
print(f"Saved: {output_dir}/anticipation_temporal_aggregate.png")
plt.close()
# --- Panel 3: Heatmap view ---
fig, ax = plt.subplots(figsize=(14, 6))
# 2D histogram: position vs distance
h = ax.hist2d(
df['rel_pos'],
df['distance'],
bins=[40, 50],
cmap='YlOrRd',
cmin=1
)
plt.colorbar(h[3], ax=ax, label='Event count')
ax.set_xlabel('Relative Position in Sequence', fontsize=11)
ax.set_ylabel('Distance to Chosen (tokens)', fontsize=11)
ax.set_title('Planning Horizon Heatmap: Position × Distance\n'
'Where are the "hot spots" of long-range anticipation?',
fontsize=12)
plt.tight_layout()
plt.savefig(output_dir / 'anticipation_heatmap.png', dpi=150, bbox_inches='tight')
print(f"Saved: {output_dir}/anticipation_heatmap.png")
plt.close()
return bin_stats
def main(input_path: Optional[str] = None, num_completions: int = NUM_COMPLETIONS):
"""
Main entry point for anticipation analysis.
Args:
input_path: Path to JSONL file with logprobs data. Uses INPUT_FILE if None.
num_completions: Number of completions to process.
"""
if input_path:
input_file = Path(input_path)
else:
input_file = Path(INPUT_FILE)
if not input_file.exists():
print(f"Error: {input_file} not found")
return
output_dir = OUTPUT_DIR
output_dir.mkdir(exist_ok=True)
print(f"Loading completions from {input_file}...")
print("Finding ANTICIPATION events: non-chosen candidates that later become chosen\n")
# Load all records - process until we have enough valid completions
records = load_data(input_file, max_records=None)
# Initialize analyzers
shadow_analyzer = ShadowTokens()
event_seqlen_analyzer = EventSeqLenAnalysis()
all_results = []
completions_processed = 0
for idx, record in enumerate(records):
if completions_processed >= num_completions:
break
results, seq_len = shadow_analyzer.find_events(record, completions_processed)
if results and len(results) > MIN_EVENTS_PER_COMPLETION:
all_results.extend(results)
completions_processed += 1
n_events = len(results)
print(f" Completion {completions_processed}: {n_events} anticipation events (seq_len={seq_len})")
# Collect data for event/seqlen binning analysis
prompt = results[0].get('prompt', '') if results else ''
response = results[0].get('response', '') if results else ''
event_seqlen_analyzer.add_completion(
completion_idx=completions_processed,
n_events=n_events,
seq_len=seq_len,
original_messages=record.get('original_messages', []),
completion_text=record.get('completion_text', record.get('complete_text', '')),
prompt=prompt,
response=response,
)
print(f"\nTotal anticipation events: {len(all_results):,}")
print(f"Completions processed: {completions_processed}")
# Forward baseline (candidate -> eventual realization)
forward_all = []
forward_processed = 0
for idx, record in enumerate(records):
if forward_processed >= num_completions:
break
fwd, seq_len = shadow_analyzer.find_forward_realizations(record, forward_processed)
if fwd and len(fwd) > MIN_EVENTS_PER_COMPLETION:
forward_all.extend(fwd)
forward_processed += 1
if forward_all:
fdf = pd.DataFrame(forward_all)
fdf.to_csv(output_dir / "forward_realizations.csv", index=False)
print(f"\nSaved: {output_dir}/forward_realizations.csv")
print("\nForward baseline: distance stats")
print(fdf["distance"].describe().to_string())
# Write event/seqlen ratio analysis
event_seqlen_analyzer.write_analysis(output_dir)
if not all_results:
print("No anticipation events found!")
return
df = pd.DataFrame(all_results)
# Compute position bins (used for non-plot statistics + CSV output)
df['pos_bin'] = pd.cut(df['rel_pos'], bins=20, labels=False)
bin_stats = df.groupby('pos_bin').agg({
'distance': ['mean', 'std', 'count'],
'rank_when_anticipated': 'mean'
}).reset_index()
bin_stats.columns = ['pos_bin', 'mean_dist', 'std_dist', 'count', 'mean_rank']
bin_stats['rel_pos_center'] = (bin_stats['pos_bin'] + 0.5) / 20
# =========================================================================
# Temporal hierarchy / persistence analysis using PlanningEpisodes class
# =========================================================================
print("\n--- Temporal Hierarchy (Token Persistence) Analysis ---")
planner = PlanningEpisodes()
episodes_df = planner.build(df)
if len(episodes_df) > 0:
long = episodes_df[episodes_df["horizon"] >= planner.long_horizon_threshold].copy()
if len(long) > 0:
print(f"\nGating diagnostic for long-horizon episodes (horizon >= {planner.long_horizon_threshold}):")
print(long[["horizon", "last_gap", "recency_ratio", "n_mentions", "best_rank", "mean_rank"]].describe().to_string())
token_hierarchy = planner.summarize_hierarchy(episodes_df)
token_hierarchy = planner.cluster_hierarchy(token_hierarchy)
if len(episodes_df) > 0:
episodes_df.to_csv(output_dir / "planning_episodes.csv", index=False)
print(f"Saved: {output_dir}/planning_episodes.csv")
else:
print("(No planning episodes computed — missing columns or empty df.)")
if len(token_hierarchy) > 0:
token_hierarchy.to_csv(output_dir / "token_temporal_hierarchy.csv", index=False)
print(f"Saved: {output_dir}/token_temporal_hierarchy.csv")
print("\nTop tokens by max planning horizon (episodes aggregated):")
print(token_hierarchy[["token", "n_episodes", "max_horizon", "mean_horizon", "mean_density", "cluster"]].head(20).to_string(index=False))
planner.plot_hierarchy(token_hierarchy, output_dir)
print(f"Saved: {output_dir}/token_temporal_hierarchy.png")
else:
print("(No token hierarchy summary available.)")
# Per-completion episode visualization (limited to avoid huge batches)
# Outputs to shadow_lifecycle/ subfolder
try:
n_comp = int(df["completion_idx"].nunique())
except Exception:
n_comp = 0
for comp_idx in range(min(n_comp, planner.max_completions_to_plot)):
planner.plot_gantt(
df,
output_dir,
completion_idx=comp_idx,
top_n=planner.top_episodes_per_completion,
)
print(f"Saved: {output_dir}/shadow_lifecycle/planning_episode_hierarchy_completion_{comp_idx + 1}.png")
# =========================================================================
# SAVE PROMPTS AND RESPONSES TO MARKDOWN
# =========================================================================
md_path = output_dir / 'completions_prompts_responses.md'
with open(md_path, 'w', encoding='utf-8') as f:
f.write("# Anticipation Analysis - Prompts and Responses\n\n")
f.write(f"Source file: `{input_file}`\n\n")
f.write("---\n\n")
for comp_idx in range(df['completion_idx'].nunique()):
comp_df = df[df['completion_idx'] == comp_idx]
if len(comp_df) == 0:
continue
prompt = comp_df['prompt'].iloc[0] if 'prompt' in comp_df.columns else "(not available)"
response = comp_df['response'].iloc[0] if 'response' in comp_df.columns else "(not available)"
seq_len = comp_df['seq_len'].iloc[0]
n_events = len(comp_df)
f.write(f"## Completion {comp_idx + 1}\n\n")
f.write(f"- **Sequence length:** {seq_len}\n")
f.write(f"- **Anticipation events:** {n_events}\n")
f.write("### Prompt\n\n")
f.write(f"```\n{prompt}\n```\n\n")
f.write("### Response\n\n")
f.write(f"```\n{response}\n```\n\n")
f.write("---\n\n")
print(f"Saved: {md_path}")
# =========================================================================
# STATISTICS
# =========================================================================
print(f"\n{'='*60}")
print("ANTICIPATION STATISTICS")
print('='*60)
print(f"\nTotal anticipation events: {len(df):,}")
print(f"Unique tokens anticipated: {df['token'].nunique()}")
# By rank
rank_stats = df.groupby('rank_when_anticipated')['distance'].agg(['mean', 'median', 'count'])
print(f"\nDistance by rank when anticipated:")
print(rank_stats.to_string())
# By position bin
print(f"\nDistance by position in sequence:")
pos_stats = df.groupby('pos_bin').agg({
'distance': ['mean', 'median', 'count'],
'rank_when_anticipated': 'mean'
})
print(pos_stats.to_string())
# Correlation analysis
corr_pos_dist = df['rel_pos'].corr(df['distance'])
corr_pos_rank = df['rel_pos'].corr(df['rank_when_anticipated'])
corr_rank_dist = df['rank_when_anticipated'].corr(df['distance'])
print(f"\n--- Correlation Analysis ---")
print(f"Position ↔ Distance: {corr_pos_dist:.3f}")
print(f"Position ↔ Rank: {corr_pos_rank:.3f}")
print(f"Rank ↔ Distance: {corr_rank_dist:.3f}")
# Top anticipated tokens
token_stats = df.groupby('token').agg({
'distance': ['mean', 'max', 'count'],
'rank_when_anticipated': 'mean',
'prob_when_anticipated': 'mean',
'rel_pos': 'mean', # where in sequence these tokens are anticipated
}).reset_index()
token_stats.columns = ['token', 'mean_dist', 'max_dist', 'count', 'mean_rank', 'mean_prob', 'mean_pos']
token_stats = token_stats.sort_values('count', ascending=False)
print(f"\nTop 20 most frequently anticipated tokens:")
print(token_stats.head(20).to_string())
print(f"\nTop 20 tokens by max anticipation distance:")
print(token_stats.nlargest(20, 'max_dist')[['token', 'max_dist', 'count', 'mean_rank', 'mean_pos']].to_string())
# Save data
df.to_csv(output_dir / 'anticipation_events.csv', index=False)
token_stats.to_csv(output_dir / 'anticipated_tokens_summary.csv', index=False)
rank_stats.to_csv(output_dir / 'rank_statistics.csv')
bin_stats.to_csv(output_dir / 'position_bin_statistics.csv', index=False)
print(f"\nAll outputs saved to {output_dir}/")
plt.close('all')
return {
'df': df,
'token_stats': token_stats,
'rank_stats': rank_stats,
'bin_stats': bin_stats,
}
if __name__ == "__main__":
main()