DevaFlow-space / analysis /semantic_drift.py
bhsinghgrid's picture
Upgrade UI: model selection + tasks 1-5 + analysis modules
29e5bf8 verified
# """
# analysis/semantic_drift.py
# ===========================
# Task 2: Semantic drift metric — how much does the intermediate generation
# diverge from the final output as we walk through diffusion steps T → 0?
#
# Metric: CER between x0_estimate at each step vs the final x0 at t=0.
#
# A well-trained model should show:
# - High drift at t=T-1 (near-random initial estimate)
# - Rapid decrease in drift around t=T//2 (model finds the right structure)
# - Near-zero drift at t=10 (output is stable, only fine corrections remain)
#
# If drift stays high until t=5 then suddenly collapses → model is doing all
# its work in the last few steps → consider reducing T.
#
# Also measures:
# - Token stability: fraction of positions that don't change between steps
# - Lock-in time: first step where each position "commits" to its final token
#
# No retraining required. Uses generate_cached() with intermediate snapshots.
# """
#
# import torch
# import torch.nn.functional as F
# import numpy as np
# from typing import List, Dict, Optional, Tuple
#
#
# def compute_cer_between(pred: str, ref: str) -> float:
# """CER between two strings."""
# if not ref:
# return 1.0 if pred else 0.0
#
# def edit_distance(s1, s2):
# m, n = len(s1), len(s2)
# dp = list(range(n + 1))
# for i in range(1, m + 1):
# prev, dp[0] = dp[0], i
# for j in range(1, n + 1):
# temp = dp[j]
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
# prev = temp
# return dp[n]
#
# return edit_distance(pred, ref) / len(ref)
#
#
# @torch.no_grad()
# def capture_intermediate_outputs(
# model,
# src: torch.Tensor,
# tgt_tokenizer,
# capture_every: int = 5,
# temperature: float = 0.8,
# top_k: int = 40,
# ) -> Tuple[Dict[int, str], str]:
# """
# Run generation while recording the decoded x0_estimate at every
# `capture_every` diffusion steps.
#
# Args:
# model : SanskritModel (D3PMCrossAttention)
# src : [1, src_len] IAST token ids (single sample)
# tgt_tokenizer : SanskritTargetTokenizer for decoding intermediate outputs
# capture_every : record every N steps
# temperature : sampling temperature
# top_k : top-k filter
#
# Returns:
# step_outputs : dict mapping t_val → decoded Devanagari string at that step
# final_output : decoded string at t=0 (final result)
# """
# if src.dim() == 1:
# src = src.unsqueeze(0)
#
# inner = model.model
# T = inner.scheduler.num_timesteps
# device = src.device
#
# # Encode source once (KV cache)
# memory, src_pad_mask = inner.encode_source(src)
#
# B = src.shape[0]
# tgt_len = inner.max_seq_len
# mask_id = inner.mask_token_id
#
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
# hint = None
#
# step_outputs: Dict[int, str] = {}
# inner.eval()
#
# for t_val in range(T - 1, -1, -1):
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
# is_last = (t_val == 0)
#
# logits, _ = inner.forward_cached(
# memory, src_pad_mask, x0_est, t,
# x0_hint=hint, inference_mode=True,
# )
#
# logits = logits / max(temperature, 1e-8)
# if top_k > 0:
# V = logits.shape[-1]
# if top_k < V:
# topk_vals, _ = torch.topk(logits, top_k, dim=-1)
# threshold = topk_vals[..., -1].unsqueeze(-1)
# logits = logits.masked_fill(logits < threshold, float('-inf'))
#
# probs = F.softmax(logits, dim=-1)
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
# hint = x0_est
#
# # Capture at this step
# if (T - 1 - t_val) % capture_every == 0 or is_last:
# ids = [x for x in x0_est[0].tolist() if x > 4]
# text = tgt_tokenizer.decode(ids).strip()
# step_outputs[t_val] = text
#
# final_output = step_outputs.get(0, "")
# return step_outputs, final_output
#
#
# def _sample(probs):
# B, L, V = probs.shape
# flat = probs.view(B * L, V).clamp(min=1e-9)
# flat = flat / flat.sum(dim=-1, keepdim=True)
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
#
#
# def compute_drift(
# step_outputs: Dict[int, str],
# final_output: str,
# ) -> Dict[str, object]:
# """
# Compute drift metrics comparing each intermediate output to the final.
#
# Returns dict with:
# t_vals : list of captured timesteps (T-1 → 0)
# cer_to_final: CER between each step's output and the final output
# 0.0 = identical to final, 1.0 = completely different
# lock_in_t : first t_val where CER drops and stays below 0.1
# (step at which output "commits" to final form)
# """
# t_vals = sorted(step_outputs.keys(), reverse=True) # T-1 → 0
# cer_to_final = []
#
# for t_val in t_vals:
# cer = compute_cer_between(step_outputs[t_val], final_output)
# cer_to_final.append(cer)
#
# # Find lock-in: first step where CER stays below threshold for rest of run
# threshold = 0.1
# lock_in_t = 0 # default: never locked in early
# for i, (t_val, cer) in enumerate(zip(t_vals, cer_to_final)):
# if all(c <= threshold for c in cer_to_final[i:]):
# lock_in_t = t_val
# break
#
# return {
# "t_vals": t_vals,
# "cer_to_final": cer_to_final,
# "lock_in_t": lock_in_t,
# "final_output": final_output,
# }
#
#
# def compute_token_stability(
# step_outputs: Dict[int, str],
# final_output: str,
# tgt_tokenizer,
# ) -> Dict[str, object]:
# """
# Token-level stability: for each position, at which diffusion step
# does it first match its final token and stay matched?
#
# Returns:
# position_lock_times: list of t_val at which each position locks in
# mean_lock_t : average lock-in timestep across positions
# """
# T = max(step_outputs.keys())
# t_vals = sorted(step_outputs.keys(), reverse=True) # T-1 → 0
#
# # Encode all intermediate outputs and the final
# def encode(text):
# return tgt_tokenizer.encode(text)
#
# final_ids = encode(final_output)
# L = len(final_ids)
#
# # Build matrix: [n_steps, L]
# step_ids = []
# for t_val in t_vals:
# step_ids.append(encode(step_outputs.get(t_val, "")))
#
# # Pad all to same length
# max_len = max(len(s) for s in step_ids)
# step_ids = [s + [1] * (max_len - len(s)) for s in step_ids] # 1=PAD
# final_ids_padded = final_ids + [1] * (max_len - len(final_ids))
#
# step_arr = np.array(step_ids) # [n_steps, L]
# final_arr = np.array(final_ids_padded) # [L]
#
# # For each position: find first step index where it matches final
# # and stays matched for all subsequent steps
# position_lock_steps = []
# for pos in range(min(L, max_len)):
# col = step_arr[:, pos] # [n_steps]
# fin = final_arr[pos]
# locked_at = len(t_vals) - 1 # default: never locks early
# for i in range(len(t_vals)):
# if all(col[i:] == fin):
# locked_at = i
# break
# position_lock_steps.append(t_vals[locked_at] if locked_at < len(t_vals) else 0)
#
# return {
# "position_lock_times": position_lock_steps,
# "mean_lock_t": float(np.mean(position_lock_steps)),
# "std_lock_t": float(np.std(position_lock_steps)),
# }
#
#
# def plot_drift_curve(
# drift_result: Dict,
# src_text: str = "",
# save_path: Optional[str] = None,
# ):
# """
# Plot CER-to-final vs diffusion step.
# Shows where the model "commits" to the final output.
# """
# try:
# import matplotlib.pyplot as plt
# except ImportError:
# print("pip install matplotlib.")
# return
#
# t_vals = drift_result["t_vals"]
# cers = drift_result["cer_to_final"]
# lock_t = drift_result["lock_in_t"]
#
# fig, ax = plt.subplots(figsize=(12, 4))
# ax.plot(range(len(t_vals)), cers, linewidth=1.8, color='coral', label='CER to final')
# ax.fill_between(range(len(t_vals)), cers, alpha=0.15, color='coral')
#
# # Mark lock-in point
# if lock_t in t_vals:
# lock_idx = t_vals.index(lock_t)
# ax.axvline(lock_idx, color='steelblue', linestyle='--', linewidth=1.2,
# label=f"Lock-in at t={lock_t}")
#
# ax.axhline(0.1, color='gray', linestyle=':', linewidth=1, alpha=0.7)
#
# n = len(t_vals)
# tick_positions = list(range(0, n, max(1, n // 10)))
# ax.set_xticks(tick_positions)
# ax.set_xticklabels([str(t_vals[i]) for i in tick_positions], fontsize=8)
# ax.set_xlabel("Diffusion step t (T-1 → 0)", fontsize=11)
# ax.set_ylabel("CER vs final output", fontsize=11)
# ax.set_ylim(0, 1.05)
# ax.set_xlim(0, n - 1)
# ax.legend(fontsize=10)
#
# title = f"Semantic drift"
# if src_text:
# title += f" | src: {src_text[:50]}"
# ax.set_title(title, fontsize=11)
# plt.tight_layout()
#
# if save_path:
# import os
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
# plt.savefig(save_path, dpi=150, bbox_inches='tight')
# print(f"Saved: {save_path}")
# else:
# plt.show()
# plt.close()
# ============================================================
# TASK 2: Source–Paraphrase Semantic Alignment Trajectory
# ============================================================
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Tuple
from collections import defaultdict
# Optional (install if needed)
# pip install bert-score scikit-learn
from bert_score import score as bertscore
from sklearn.feature_extraction.text import TfidfVectorizer
# ============================================================
# ------------------ ATTENTION HOOK --------------------------
# ============================================================
def register_attention_hooks(model):
"""
Registers forward hooks to capture cross-attention weights
from each decoder block.
Assumes each block has attribute `.cross_attn.attn_weights`
"""
inner = model.model
attention_maps = []
def hook_fn(module, input, output):
if hasattr(module, "attn_weights"):
attention_maps.append(module.attn_weights.detach().cpu())
hooks = []
for block in inner.decoder_blocks:
if hasattr(block, "cross_attn"):
h = block.cross_attn.register_forward_hook(hook_fn)
hooks.append(h)
return hooks, attention_maps
# ============================================================
# ------------------ CAPTURE TRAJECTORY ----------------------
# ============================================================
@torch.no_grad()
def capture_alignment_trajectory(
model,
src_tensor: torch.Tensor,
src_text: str,
tgt_tokenizer,
steps_to_capture: List[int] = None,
):
"""
Capture:
- intermediate outputs
- cross-attention maps
- BERTScore vs source
Returns:
dict with outputs, attention, drift
"""
inner = model.model
device = src_tensor.device
T = inner.scheduler.num_timesteps
if steps_to_capture is None:
steps_to_capture = list(range(T - 1, -1, -5)) + [0]
# Register hooks
hooks, attn_storage = register_attention_hooks(model)
memory, src_pad_mask = inner.encode_source(src_tensor)
B = src_tensor.shape[0]
tgt_len = inner.max_seq_len
mask_id = inner.mask_token_id
x0_est = torch.full((B, tgt_len), mask_id, device=device)
hint = None
outputs = {}
attention_per_step = {}
for t_val in range(T - 1, -1, -1):
t = torch.full((B,), t_val, device=device)
logits, _ = inner.forward_cached(
memory, src_pad_mask, x0_est, t,
x0_hint=hint, inference_mode=True
)
probs = F.softmax(logits, dim=-1)
x0_est = torch.argmax(probs, dim=-1)
hint = x0_est
if t_val in steps_to_capture:
ids = [x for x in x0_est[0].tolist() if x > 4]
text = tgt_tokenizer.decode(ids)
outputs[t_val] = text
# Collect attention maps (last layer only for simplicity)
if len(attn_storage) > 0:
attention_per_step[t_val] = attn_storage[-1].numpy()
# Remove hooks
for h in hooks:
h.remove()
# Compute BERTScore trajectory
bert_scores = compute_bert_alignment(src_text, outputs)
return {
"outputs": outputs,
"attention": attention_per_step,
"bert_scores": bert_scores,
}
# ============================================================
# ------------------ BERTScore -------------------------------
# ============================================================
def compute_bert_alignment(src_text: str, outputs: Dict[int, str]):
"""
Compute BERTScore between source and each intermediate output
"""
scores = {}
for t, text in outputs.items():
P, R, F1 = bertscore([text], [src_text], lang="hi", verbose=False)
scores[t] = float(F1.mean())
return scores
# ============================================================
# ------------------ SEMANTIC DRIFT --------------------------
# ============================================================
def compute_semantic_drift(bert_scores: Dict[int, float]):
"""
Drift = drop from best alignment
"""
max_score = max(bert_scores.values())
drift = {t: max_score - s for t, s in bert_scores.items()}
return drift
# ============================================================
# ------------------ ATTENTION STABILITY ---------------------
# ============================================================
def compute_attention_stability(attention_maps: Dict[int, np.ndarray]):
"""
Measures if tokens attend consistently across steps.
"""
steps = sorted(attention_maps.keys(), reverse=True)
stability_scores = []
for i in range(len(steps) - 1):
A = attention_maps[steps[i]]
B = attention_maps[steps[i+1]]
diff = np.abs(A - B).mean()
stability_scores.append(diff)
return np.mean(stability_scores)
# ============================================================
# ------------------ TF-IDF vs STABILITY ---------------------
# ============================================================
def compute_tfidf_attention_correlation(
src_texts: List[str],
attention_maps_list: List[Dict[int, np.ndarray]]
):
"""
Correlate TF-IDF importance with attention stability
"""
vectorizer = TfidfVectorizer()
tfidf = vectorizer.fit_transform(src_texts).toarray()
word_importance = tfidf.mean(axis=0)
stability = []
for attn_maps in attention_maps_list:
stability.append(compute_attention_stability(attn_maps))
corr = np.corrcoef(word_importance[:len(stability)], stability)[0, 1]
return corr
# ============================================================
# ------------------ HEATMAP VISUALIZATION -------------------
# ============================================================
def plot_attention_heatmap(attn: np.ndarray, title="Attention"):
"""
Plot cross-attention heatmap
attn: [tgt_len, src_len]
"""
plt.figure(figsize=(6,5))
plt.imshow(attn, aspect='auto', cmap='viridis')
plt.colorbar()
plt.title(title)
plt.xlabel("Source tokens")
plt.ylabel("Target tokens")
plt.show()
def visualize_trajectory(attention_maps: Dict[int, np.ndarray]):
"""
Show attention evolution over time
"""
steps = sorted(attention_maps.keys(), reverse=True)
for t in steps[:5]: # show 5 steps
plot_attention_heatmap(attention_maps[t], title=f"Step t={t}")
# ============================================================
# ------------------ LOCKED vs FLEXIBLE ----------------------
# ============================================================
def analyze_token_behavior(attention_maps: Dict[int, np.ndarray]):
"""
Detect whether tokens are locked or flexible
"""
steps = sorted(attention_maps.keys(), reverse=True)
first = attention_maps[steps[0]]
last = attention_maps[steps[-1]]
diff = np.abs(first - last).mean(axis=1)
locked = np.where(diff < 0.05)[0]
flexible = np.where(diff >= 0.05)[0]
return {
"locked_tokens": locked.tolist(),
"flexible_tokens": flexible.tolist()
}
# ============================================================
# ------------------ MASTER FUNCTION -------------------------
# ============================================================
def run_task2_analysis(
model,
src_tensor,
src_text,
tgt_tokenizer
):
result = capture_alignment_trajectory(
model, src_tensor, src_text, tgt_tokenizer
)
drift = compute_semantic_drift(result["bert_scores"])
stability = compute_attention_stability(result["attention"])
behavior = analyze_token_behavior(result["attention"])
print("\nBERTScore trajectory:")
print(result["bert_scores"])
print("\nSemantic drift:")
print(drift)
print(f"\nAttention stability: {stability:.4f}")
print("\nToken behavior:")
print(behavior)
visualize_trajectory(result["attention"])
return {
"trajectory": result,
"drift": drift,
"stability": stability,
"behavior": behavior
}