Spaces:
Sleeping
Sleeping
Upgrade UI: model selection + tasks 1-5 + analysis modules
Browse files- .gitattributes +1 -0
- __pycache__/app.cpython-311.pyc +0 -0
- analysis/__pycache__/run_analysis.cpython-311.pyc +0 -0
- analysis/attention_viz.py +621 -0
- analysis/concept_vectors.py +637 -0
- analysis/kv_cache_benchmark.py +451 -0
- analysis/outputs/task1_kv_cache.txt +23 -0
- analysis/outputs/task2_all_layers_t0.png +0 -0
- analysis/outputs/task2_attn_evolution.png +0 -0
- analysis/outputs/task2_attn_t0.png +0 -0
- analysis/outputs/task2_attn_t127.png +0 -0
- analysis/outputs/task2_examples/example_1_attn_t0.png +0 -0
- analysis/outputs/task2_examples/example_2_attn_t0.png +0 -0
- analysis/outputs/task2_examples/example_3_attn_t0.png +0 -0
- analysis/outputs/task2_examples/example_4_attn_t0.png +0 -0
- analysis/outputs/task2_examples/example_5_attn_t0.png +0 -0
- analysis/outputs/task2_report.txt +100 -0
- analysis/outputs/task2_semantic_drift.png +0 -0
- analysis/outputs/task2_source_alignment.png +0 -0
- analysis/outputs/task3_concept_space.png +3 -0
- analysis/outputs/task3_diversity_direction.npy +3 -0
- analysis/outputs/task3_report.txt +12 -0
- analysis/outputs/task5_quality_classifier.pt +3 -0
- analysis/outputs/task5_quality_data.npz +3 -0
- analysis/outputs_multi/results__d3pm_cross_attention_neg_False/task1/task1_kv_cache.txt +10 -0
- analysis/outputs_multi/results__d3pm_cross_attention_neg_True/task1/task1_kv_cache.txt +10 -0
- analysis/quality_classifier.py +723 -0
- analysis/reports/README.md +19 -0
- analysis/reports/task1_kv_cache_report.md +99 -0
- analysis/reports/task2_attention_drift_report.md +112 -0
- analysis/reports/task3_concept_vectors_report.md +96 -0
- analysis/reports/task4_step_ablation_report.md +89 -0
- analysis/reports/task5_quality_guidance_report.md +101 -0
- analysis/run_analysis.py +466 -0
- analysis/run_tasks_except4_all_models.py +123 -0
- analysis/semantic_drift.py +569 -0
- analysis/step_ablation.py +582 -0
- app.py +487 -175
- data/__init__.py +0 -0
- data/dataset.py +152 -0
- requirements.txt +6 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
analysis/outputs/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
|
__pycache__/app.cpython-311.pyc
ADDED
|
Binary file (29.2 kB). View file
|
|
|
analysis/__pycache__/run_analysis.cpython-311.pyc
ADDED
|
Binary file (32 kB). View file
|
|
|
analysis/attention_viz.py
ADDED
|
@@ -0,0 +1,621 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# analysis/attention_viz.py
|
| 3 |
+
# ==========================
|
| 4 |
+
# Task 2: Attention weight capture and visualization across diffusion steps.
|
| 5 |
+
#
|
| 6 |
+
# How it works (no retraining needed):
|
| 7 |
+
# MultiHeadAttention now has two attributes:
|
| 8 |
+
# - capture_weights: bool — set True to start storing weights
|
| 9 |
+
# - last_attn_weights: Tensor — [B, n_heads, Lq, Lk], updated each forward call
|
| 10 |
+
#
|
| 11 |
+
# AttentionCapture:
|
| 12 |
+
# - Sets capture_weights=True on all cross-attention layers
|
| 13 |
+
# - Hooks into generate_cached() to record weights at every diffusion step
|
| 14 |
+
# - Returns a dict: {t_val: [layer_0_weights, layer_1_weights, ...]}
|
| 15 |
+
#
|
| 16 |
+
# Visualization:
|
| 17 |
+
# - plot_attn_heatmap(): shows src→tgt alignment at a single step
|
| 18 |
+
# - plot_attn_evolution(): shows how one src→tgt pair evolves over T steps
|
| 19 |
+
# - plot_all_layers(): grid of heatmaps per layer at a given step
|
| 20 |
+
#
|
| 21 |
+
# Usage:
|
| 22 |
+
# from analysis.attention_viz import AttentionCapture, plot_attn_heatmap
|
| 23 |
+
#
|
| 24 |
+
# capturer = AttentionCapture(model)
|
| 25 |
+
# weights = capturer.capture(src_ids, src_tokens, tgt_tokens)
|
| 26 |
+
# plot_attn_heatmap(weights, step=0, layer=0, src_tokens=..., tgt_tokens=...)
|
| 27 |
+
# """
|
| 28 |
+
#
|
| 29 |
+
# import torch
|
| 30 |
+
# import numpy as np
|
| 31 |
+
# import os
|
| 32 |
+
# from typing import List, Dict, Optional
|
| 33 |
+
#
|
| 34 |
+
#
|
| 35 |
+
# # ── Attention capture ─────────────────────────────────────────────────
|
| 36 |
+
#
|
| 37 |
+
# class AttentionCapture:
|
| 38 |
+
# """
|
| 39 |
+
# Captures cross-attention weights from all decoder layers at every
|
| 40 |
+
# diffusion step during generate_cached().
|
| 41 |
+
#
|
| 42 |
+
# Works by:
|
| 43 |
+
# 1. Setting capture_weights=True on each DecoderBlock.cross_attn
|
| 44 |
+
# 2. Running generate_cached() (encoder runs once via KV cache)
|
| 45 |
+
# 3. After each denoising step, reading last_attn_weights from each layer
|
| 46 |
+
# 4. Storing as {t_val: list_of_layer_weights}
|
| 47 |
+
#
|
| 48 |
+
# Zero retraining required — uses the flag added to MultiHeadAttention.
|
| 49 |
+
# """
|
| 50 |
+
#
|
| 51 |
+
# def __init__(self, model):
|
| 52 |
+
# """
|
| 53 |
+
# Args:
|
| 54 |
+
# model : SanskritModel wrapper (must be D3PMCrossAttention)
|
| 55 |
+
# """
|
| 56 |
+
# self.model = model
|
| 57 |
+
# self.inner = model.model # D3PMCrossAttention
|
| 58 |
+
# self._cross_attns = []
|
| 59 |
+
#
|
| 60 |
+
# # Collect all cross-attention modules from decoder blocks
|
| 61 |
+
# if hasattr(self.inner, 'decoder_blocks'):
|
| 62 |
+
# for block in self.inner.decoder_blocks:
|
| 63 |
+
# if hasattr(block, 'cross_attn'):
|
| 64 |
+
# self._cross_attns.append(block.cross_attn)
|
| 65 |
+
#
|
| 66 |
+
# if not self._cross_attns:
|
| 67 |
+
# raise ValueError(
|
| 68 |
+
# "No cross-attention layers found. "
|
| 69 |
+
# "AttentionCapture only works with D3PMCrossAttention."
|
| 70 |
+
# )
|
| 71 |
+
#
|
| 72 |
+
# print(f"AttentionCapture: found {len(self._cross_attns)} cross-attention layers.")
|
| 73 |
+
#
|
| 74 |
+
# def _enable(self):
|
| 75 |
+
# """Turn on weight capture for all cross-attention layers."""
|
| 76 |
+
# for ca in self._cross_attns:
|
| 77 |
+
# ca.capture_weights = True
|
| 78 |
+
#
|
| 79 |
+
# def _disable(self):
|
| 80 |
+
# """Turn off weight capture (restores zero overhead)."""
|
| 81 |
+
# for ca in self._cross_attns:
|
| 82 |
+
# ca.capture_weights = False
|
| 83 |
+
# ca.last_attn_weights = None
|
| 84 |
+
#
|
| 85 |
+
# def _read_weights(self) -> List[np.ndarray]:
|
| 86 |
+
# """
|
| 87 |
+
# Read current last_attn_weights from all layers.
|
| 88 |
+
# Returns list of [B, n_heads, Lq, Lk] arrays — one per layer.
|
| 89 |
+
# Averages over heads to produce [B, Lq, Lk].
|
| 90 |
+
# """
|
| 91 |
+
# weights = []
|
| 92 |
+
# for ca in self._cross_attns:
|
| 93 |
+
# if ca.last_attn_weights is not None:
|
| 94 |
+
# # Average over attention heads → [B, Lq, Lk]
|
| 95 |
+
# w = ca.last_attn_weights.float().mean(dim=1)
|
| 96 |
+
# weights.append(w.numpy())
|
| 97 |
+
# return weights
|
| 98 |
+
#
|
| 99 |
+
# @torch.no_grad()
|
| 100 |
+
# def capture(
|
| 101 |
+
# self,
|
| 102 |
+
# src: torch.Tensor,
|
| 103 |
+
# capture_every: int = 10,
|
| 104 |
+
# ) -> Dict[int, List[np.ndarray]]:
|
| 105 |
+
# """
|
| 106 |
+
# Run full generation while capturing attention at every `capture_every` steps.
|
| 107 |
+
#
|
| 108 |
+
# Args:
|
| 109 |
+
# src : [1, src_len] or [B, src_len] IAST token ids
|
| 110 |
+
# capture_every : capture weights every N steps (default 10)
|
| 111 |
+
# Use 1 to capture every step (slow, high memory).
|
| 112 |
+
#
|
| 113 |
+
# Returns:
|
| 114 |
+
# step_weights : dict mapping t_val → list of [B, Lq, Lk] arrays
|
| 115 |
+
# one array per decoder layer
|
| 116 |
+
# keys are t values: T-1, T-1-N, ..., 0
|
| 117 |
+
#
|
| 118 |
+
# Example:
|
| 119 |
+
# weights = capturer.capture(src_ids, capture_every=10)
|
| 120 |
+
# # weights[127] = layer weights at t=127 (heavy noise)
|
| 121 |
+
# # weights[0] = layer weights at t=0 (clean output)
|
| 122 |
+
# """
|
| 123 |
+
# if src.dim() == 1:
|
| 124 |
+
# src = src.unsqueeze(0)
|
| 125 |
+
#
|
| 126 |
+
# inner = self.inner
|
| 127 |
+
# T = inner.scheduler.num_timesteps
|
| 128 |
+
# device = src.device
|
| 129 |
+
#
|
| 130 |
+
# # KV cache: encode source once
|
| 131 |
+
# memory, src_pad_mask = inner.encode_source(src)
|
| 132 |
+
#
|
| 133 |
+
# B = src.shape[0]
|
| 134 |
+
# tgt_len = inner.max_seq_len
|
| 135 |
+
# mask_id = inner.mask_token_id
|
| 136 |
+
#
|
| 137 |
+
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 138 |
+
# hint = None
|
| 139 |
+
#
|
| 140 |
+
# step_weights: Dict[int, List[np.ndarray]] = {}
|
| 141 |
+
#
|
| 142 |
+
# self._enable()
|
| 143 |
+
# try:
|
| 144 |
+
# inner.eval()
|
| 145 |
+
# for t_val in range(T - 1, -1, -1):
|
| 146 |
+
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 147 |
+
# is_last = (t_val == 0)
|
| 148 |
+
#
|
| 149 |
+
# logits, _ = inner.forward_cached(
|
| 150 |
+
# memory, src_pad_mask, x0_est, t,
|
| 151 |
+
# x0_hint=hint, inference_mode=True,
|
| 152 |
+
# )
|
| 153 |
+
#
|
| 154 |
+
# # Capture at this step if scheduled or it's the last step
|
| 155 |
+
# if (T - 1 - t_val) % capture_every == 0 or is_last:
|
| 156 |
+
# step_weights[t_val] = self._read_weights()
|
| 157 |
+
#
|
| 158 |
+
# import torch.nn.functional as F
|
| 159 |
+
# probs = F.softmax(logits / 0.8, dim=-1)
|
| 160 |
+
# x0_est = torch.argmax(probs, dim=-1) if is_last else \
|
| 161 |
+
# _multinomial_sample(probs)
|
| 162 |
+
# hint = x0_est
|
| 163 |
+
#
|
| 164 |
+
# finally:
|
| 165 |
+
# self._disable() # always restore — even if exception raised
|
| 166 |
+
#
|
| 167 |
+
# print(f"Captured attention at {len(step_weights)} steps "
|
| 168 |
+
# f"({len(self._cross_attns)} layers each).")
|
| 169 |
+
# return step_weights
|
| 170 |
+
#
|
| 171 |
+
#
|
| 172 |
+
# def _multinomial_sample(probs: torch.Tensor) -> torch.Tensor:
|
| 173 |
+
# B, L, V = probs.shape
|
| 174 |
+
# flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 175 |
+
# flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 176 |
+
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 177 |
+
#
|
| 178 |
+
#
|
| 179 |
+
# # ── Visualization ─────────────────────────────────────────────────────
|
| 180 |
+
#
|
| 181 |
+
# def plot_attn_heatmap(
|
| 182 |
+
# step_weights: Dict[int, List[np.ndarray]],
|
| 183 |
+
# t_val: int,
|
| 184 |
+
# layer: int,
|
| 185 |
+
# src_tokens: List[str],
|
| 186 |
+
# tgt_tokens: List[str],
|
| 187 |
+
# sample_idx: int = 0,
|
| 188 |
+
# save_path: Optional[str] = None,
|
| 189 |
+
# title: Optional[str] = None,
|
| 190 |
+
# ):
|
| 191 |
+
# """
|
| 192 |
+
# Plot cross-attention heatmap for a single step and layer.
|
| 193 |
+
#
|
| 194 |
+
# X-axis = source (IAST) tokens
|
| 195 |
+
# Y-axis = target (Devanagari) positions
|
| 196 |
+
# Color = attention weight (brighter = stronger attention)
|
| 197 |
+
#
|
| 198 |
+
# Args:
|
| 199 |
+
# step_weights : output of AttentionCapture.capture()
|
| 200 |
+
# t_val : which diffusion step to visualize
|
| 201 |
+
# layer : which decoder layer (0 = first, -1 = last)
|
| 202 |
+
# src_tokens : list of IAST token strings for x-axis labels
|
| 203 |
+
# tgt_tokens : list of Devanagari token strings for y-axis labels
|
| 204 |
+
# sample_idx : which batch item to visualize (default 0)
|
| 205 |
+
# save_path : if given, save figure to this path
|
| 206 |
+
# title : custom plot title
|
| 207 |
+
# """
|
| 208 |
+
# try:
|
| 209 |
+
# import matplotlib.pyplot as plt
|
| 210 |
+
# import matplotlib.ticker as ticker
|
| 211 |
+
# except ImportError:
|
| 212 |
+
# print("pip install matplotlib to use visualization functions.")
|
| 213 |
+
# return
|
| 214 |
+
#
|
| 215 |
+
# if t_val not in step_weights:
|
| 216 |
+
# available = sorted(step_weights.keys())
|
| 217 |
+
# raise ValueError(
|
| 218 |
+
# f"t_val={t_val} not in captured steps. "
|
| 219 |
+
# f"Available: {available[:5]}...{available[-5:]}"
|
| 220 |
+
# )
|
| 221 |
+
#
|
| 222 |
+
# layers = step_weights[t_val]
|
| 223 |
+
# weights = layers[layer][sample_idx] # [Lq, Lk]
|
| 224 |
+
#
|
| 225 |
+
# # Trim to actual token lengths
|
| 226 |
+
# n_src = min(len(src_tokens), weights.shape[1])
|
| 227 |
+
# n_tgt = min(len(tgt_tokens), weights.shape[0])
|
| 228 |
+
# weights = weights[:n_tgt, :n_src]
|
| 229 |
+
#
|
| 230 |
+
# fig, ax = plt.subplots(figsize=(max(8, n_src * 0.4), max(6, n_tgt * 0.35)))
|
| 231 |
+
# im = ax.imshow(weights, aspect='auto', cmap='YlOrRd', interpolation='nearest')
|
| 232 |
+
#
|
| 233 |
+
# ax.set_xticks(range(n_src))
|
| 234 |
+
# ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=9)
|
| 235 |
+
# ax.set_yticks(range(n_tgt))
|
| 236 |
+
# ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=9)
|
| 237 |
+
#
|
| 238 |
+
# ax.set_xlabel("Source (IAST)", fontsize=11)
|
| 239 |
+
# ax.set_ylabel("Target position (Devanagari)", fontsize=11)
|
| 240 |
+
#
|
| 241 |
+
# plot_title = title or f"Cross-Attention | t={t_val} | Layer {layer}"
|
| 242 |
+
# ax.set_title(plot_title, fontsize=12, pad=10)
|
| 243 |
+
#
|
| 244 |
+
# plt.colorbar(im, ax=ax, label="Attention weight")
|
| 245 |
+
# plt.tight_layout()
|
| 246 |
+
#
|
| 247 |
+
# if save_path:
|
| 248 |
+
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 249 |
+
# plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 250 |
+
# print(f"Saved: {save_path}")
|
| 251 |
+
# else:
|
| 252 |
+
# plt.show()
|
| 253 |
+
# plt.close()
|
| 254 |
+
#
|
| 255 |
+
#
|
| 256 |
+
# def plot_attn_evolution(
|
| 257 |
+
# step_weights: Dict[int, List[np.ndarray]],
|
| 258 |
+
# src_token_idx: int,
|
| 259 |
+
# tgt_token_idx: int,
|
| 260 |
+
# layer: int = -1,
|
| 261 |
+
# sample_idx: int = 0,
|
| 262 |
+
# src_token_str: str = "",
|
| 263 |
+
# tgt_token_str: str = "",
|
| 264 |
+
# save_path: Optional[str] = None,
|
| 265 |
+
# ):
|
| 266 |
+
# """
|
| 267 |
+
# Plot how attention between one specific src↔tgt token pair evolves
|
| 268 |
+
# across all captured diffusion steps (T → 0).
|
| 269 |
+
#
|
| 270 |
+
# Reveals whether a token pair is 'locked' (stable from early steps)
|
| 271 |
+
# or 'flexible' (weight fluctuates until final steps).
|
| 272 |
+
#
|
| 273 |
+
# Args:
|
| 274 |
+
# step_weights : output of AttentionCapture.capture()
|
| 275 |
+
# src_token_idx : index of source token to track
|
| 276 |
+
# tgt_token_idx : index of target position to track
|
| 277 |
+
# layer : decoder layer index
|
| 278 |
+
# sample_idx : batch item
|
| 279 |
+
# src_token_str : string label for the source token (for plot title)
|
| 280 |
+
# tgt_token_str : string label for the target token (for plot title)
|
| 281 |
+
# save_path : if given, save figure to this path
|
| 282 |
+
# """
|
| 283 |
+
# try:
|
| 284 |
+
# import matplotlib.pyplot as plt
|
| 285 |
+
# except ImportError:
|
| 286 |
+
# print("pip install matplotlib to use visualization functions.")
|
| 287 |
+
# return
|
| 288 |
+
#
|
| 289 |
+
# t_vals = sorted(step_weights.keys(), reverse=True) # T-1 → 0
|
| 290 |
+
# weights = []
|
| 291 |
+
#
|
| 292 |
+
# for t_val in t_vals:
|
| 293 |
+
# layers = step_weights[t_val]
|
| 294 |
+
# w = layers[layer][sample_idx] # [Lq, Lk]
|
| 295 |
+
# if tgt_token_idx < w.shape[0] and src_token_idx < w.shape[1]:
|
| 296 |
+
# weights.append(w[tgt_token_idx, src_token_idx])
|
| 297 |
+
# else:
|
| 298 |
+
# weights.append(0.0)
|
| 299 |
+
#
|
| 300 |
+
# fig, ax = plt.subplots(figsize=(12, 4))
|
| 301 |
+
# ax.plot(range(len(t_vals)), weights, linewidth=1.5, color='steelblue')
|
| 302 |
+
# ax.fill_between(range(len(t_vals)), weights, alpha=0.2, color='steelblue')
|
| 303 |
+
#
|
| 304 |
+
# # Mark every 10th step on x-axis
|
| 305 |
+
# step_labels = [str(t) if i % max(1, len(t_vals)//10) == 0 else ""
|
| 306 |
+
# for i, t in enumerate(t_vals)]
|
| 307 |
+
# ax.set_xticks(range(len(t_vals)))
|
| 308 |
+
# ax.set_xticklabels(step_labels, fontsize=8)
|
| 309 |
+
# ax.set_xlabel("Diffusion step (T → 0)", fontsize=11)
|
| 310 |
+
# ax.set_ylabel("Attention weight", fontsize=11)
|
| 311 |
+
#
|
| 312 |
+
# pair_str = f"src[{src_token_idx}]={src_token_str!r} → tgt[{tgt_token_idx}]={tgt_token_str!r}"
|
| 313 |
+
# ax.set_title(f"Attention evolution | {pair_str} | Layer {layer}", fontsize=11)
|
| 314 |
+
# ax.set_xlim(0, len(t_vals) - 1)
|
| 315 |
+
# ax.set_ylim(0, None)
|
| 316 |
+
# plt.tight_layout()
|
| 317 |
+
#
|
| 318 |
+
# if save_path:
|
| 319 |
+
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 320 |
+
# plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 321 |
+
# print(f"Saved: {save_path}")
|
| 322 |
+
# else:
|
| 323 |
+
# plt.show()
|
| 324 |
+
# plt.close()
|
| 325 |
+
#
|
| 326 |
+
#
|
| 327 |
+
# def plot_all_layers(
|
| 328 |
+
# step_weights: Dict[int, List[np.ndarray]],
|
| 329 |
+
# t_val: int,
|
| 330 |
+
# src_tokens: List[str],
|
| 331 |
+
# tgt_tokens: List[str],
|
| 332 |
+
# sample_idx: int = 0,
|
| 333 |
+
# save_path: Optional[str] = None,
|
| 334 |
+
# ):
|
| 335 |
+
# """
|
| 336 |
+
# Plot attention heatmaps for ALL decoder layers at a single diffusion step.
|
| 337 |
+
# Shows how different layers specialize their attention patterns.
|
| 338 |
+
# """
|
| 339 |
+
# try:
|
| 340 |
+
# import matplotlib.pyplot as plt
|
| 341 |
+
# except ImportError:
|
| 342 |
+
# print("pip install matplotlib to use visualization functions.")
|
| 343 |
+
# return
|
| 344 |
+
#
|
| 345 |
+
# layers = step_weights[t_val]
|
| 346 |
+
# n_layers = len(layers)
|
| 347 |
+
# n_cols = min(4, n_layers)
|
| 348 |
+
# n_rows = (n_layers + n_cols - 1) // n_cols
|
| 349 |
+
#
|
| 350 |
+
# fig, axes = plt.subplots(n_rows, n_cols,
|
| 351 |
+
# figsize=(n_cols * 5, n_rows * 4))
|
| 352 |
+
# axes = np.array(axes).flatten() if n_layers > 1 else [axes]
|
| 353 |
+
#
|
| 354 |
+
# n_src = min(len(src_tokens), layers[0][sample_idx].shape[1])
|
| 355 |
+
# n_tgt = min(len(tgt_tokens), layers[0][sample_idx].shape[0])
|
| 356 |
+
#
|
| 357 |
+
# for i, (ax, layer_w) in enumerate(zip(axes, layers)):
|
| 358 |
+
# w = layer_w[sample_idx][:n_tgt, :n_src]
|
| 359 |
+
# im = ax.imshow(w, aspect='auto', cmap='YlOrRd', interpolation='nearest',
|
| 360 |
+
# vmin=0, vmax=w.max())
|
| 361 |
+
# ax.set_title(f"Layer {i}", fontsize=10)
|
| 362 |
+
# ax.set_xticks(range(n_src))
|
| 363 |
+
# ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=7)
|
| 364 |
+
# ax.set_yticks(range(n_tgt))
|
| 365 |
+
# ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=7)
|
| 366 |
+
#
|
| 367 |
+
# for ax in axes[n_layers:]:
|
| 368 |
+
# ax.set_visible(False)
|
| 369 |
+
#
|
| 370 |
+
# fig.suptitle(f"All layers at t={t_val}", fontsize=13, y=1.02)
|
| 371 |
+
# plt.tight_layout()
|
| 372 |
+
#
|
| 373 |
+
# if save_path:
|
| 374 |
+
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 375 |
+
# plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 376 |
+
# print(f"Saved: {save_path}")
|
| 377 |
+
# else:
|
| 378 |
+
# plt.show()
|
| 379 |
+
# plt.close()
|
| 380 |
+
"""
|
| 381 |
+
analysis/task2_full.py
|
| 382 |
+
=====================
|
| 383 |
+
|
| 384 |
+
FULL Task 2 implementation:
|
| 385 |
+
✔ Attention trajectory (already yours)
|
| 386 |
+
✔ BERTScore over diffusion steps
|
| 387 |
+
✔ Semantic drift metric
|
| 388 |
+
✔ Locked vs flexible token detection
|
| 389 |
+
✔ TF-IDF vs attention stability correlation
|
| 390 |
+
"""
|
| 391 |
+
|
| 392 |
+
import torch
|
| 393 |
+
import numpy as np
|
| 394 |
+
from typing import Dict, List
|
| 395 |
+
from collections import defaultdict
|
| 396 |
+
|
| 397 |
+
# Optional metrics
|
| 398 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 399 |
+
|
| 400 |
+
try:
|
| 401 |
+
import evaluate
|
| 402 |
+
bertscore = evaluate.load("bertscore")
|
| 403 |
+
USE_BERT = True
|
| 404 |
+
except:
|
| 405 |
+
USE_BERT = False
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# ─────────────────────────────────────────────────────────────
|
| 409 |
+
# 1. ATTENTION CAPTURE (FIXED VERSION)
|
| 410 |
+
# ─────────────────────────────────────────────────────────────
|
| 411 |
+
|
| 412 |
+
class AttentionCapture:
|
| 413 |
+
def __init__(self, model):
|
| 414 |
+
self.model = model
|
| 415 |
+
self.inner = model.model
|
| 416 |
+
self.cross_attns = []
|
| 417 |
+
|
| 418 |
+
for block in self.inner.decoder_blocks:
|
| 419 |
+
if hasattr(block, "cross_attn"):
|
| 420 |
+
self.cross_attns.append(block.cross_attn)
|
| 421 |
+
|
| 422 |
+
def _enable(self):
|
| 423 |
+
for ca in self.cross_attns:
|
| 424 |
+
ca.capture_weights = True
|
| 425 |
+
|
| 426 |
+
def _disable(self):
|
| 427 |
+
for ca in self.cross_attns:
|
| 428 |
+
ca.capture_weights = False
|
| 429 |
+
ca.last_attn_weights = None
|
| 430 |
+
|
| 431 |
+
def _read(self):
|
| 432 |
+
weights = []
|
| 433 |
+
for ca in self.cross_attns:
|
| 434 |
+
if ca.last_attn_weights is not None:
|
| 435 |
+
w = ca.last_attn_weights.mean(dim=1) # avg heads
|
| 436 |
+
weights.append(w.cpu().numpy())
|
| 437 |
+
return weights
|
| 438 |
+
|
| 439 |
+
@torch.no_grad()
|
| 440 |
+
def run(self, src_ids):
|
| 441 |
+
inner = self.inner
|
| 442 |
+
T = inner.scheduler.num_timesteps
|
| 443 |
+
device = src_ids.device
|
| 444 |
+
|
| 445 |
+
memory, mask = inner.encode_source(src_ids)
|
| 446 |
+
|
| 447 |
+
x = torch.full(
|
| 448 |
+
(1, inner.max_seq_len),
|
| 449 |
+
inner.mask_token_id,
|
| 450 |
+
dtype=torch.long,
|
| 451 |
+
device=device
|
| 452 |
+
)
|
| 453 |
+
|
| 454 |
+
hint = None
|
| 455 |
+
step_weights = {}
|
| 456 |
+
step_outputs = {}
|
| 457 |
+
|
| 458 |
+
self._enable()
|
| 459 |
+
|
| 460 |
+
try:
|
| 461 |
+
for t_val in range(T - 1, -1, -1):
|
| 462 |
+
t = torch.tensor([t_val], device=device)
|
| 463 |
+
|
| 464 |
+
logits, _ = inner.forward_cached(
|
| 465 |
+
memory, mask, x, t, x0_hint=hint, inference_mode=True
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
probs = torch.softmax(logits, dim=-1)
|
| 469 |
+
x = torch.argmax(probs, dim=-1)
|
| 470 |
+
|
| 471 |
+
step_weights[t_val] = self._read()
|
| 472 |
+
step_outputs[t_val] = x.clone()
|
| 473 |
+
|
| 474 |
+
hint = x
|
| 475 |
+
|
| 476 |
+
finally:
|
| 477 |
+
self._disable()
|
| 478 |
+
|
| 479 |
+
return step_weights, step_outputs
|
| 480 |
+
|
| 481 |
+
|
| 482 |
+
# ─────────────────────────────────────────────────────────────
|
| 483 |
+
# 2. BERTScore + Semantic Drift
|
| 484 |
+
# ─────────────────────────────────────────────────────────────
|
| 485 |
+
|
| 486 |
+
def compute_trajectory_metrics(
|
| 487 |
+
step_outputs,
|
| 488 |
+
tgt_tokenizer,
|
| 489 |
+
reference_text
|
| 490 |
+
):
|
| 491 |
+
trajectory = []
|
| 492 |
+
|
| 493 |
+
for t, ids in step_outputs.items():
|
| 494 |
+
text = tgt_tokenizer.decode(
|
| 495 |
+
[x for x in ids[0].tolist() if x > 4]
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
if USE_BERT:
|
| 499 |
+
score = bertscore.compute(
|
| 500 |
+
predictions=[text],
|
| 501 |
+
references=[reference_text],
|
| 502 |
+
lang="hi"
|
| 503 |
+
)["f1"][0]
|
| 504 |
+
else:
|
| 505 |
+
score = 0.0
|
| 506 |
+
|
| 507 |
+
drift = 1.0 - score
|
| 508 |
+
|
| 509 |
+
trajectory.append({
|
| 510 |
+
"step": t,
|
| 511 |
+
"text": text,
|
| 512 |
+
"bert": score,
|
| 513 |
+
"drift": drift
|
| 514 |
+
})
|
| 515 |
+
|
| 516 |
+
return sorted(trajectory, key=lambda x: -x["step"])
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
# ─────────────────────────────────────────────────────────────
|
| 520 |
+
# 3. LOCKED vs FLEXIBLE TOKENS
|
| 521 |
+
# ─────────────────────────────────────────────────────────────
|
| 522 |
+
|
| 523 |
+
def analyze_token_stability(step_weights):
|
| 524 |
+
"""
|
| 525 |
+
Measure variance of attention over time
|
| 526 |
+
"""
|
| 527 |
+
token_stability = defaultdict(list)
|
| 528 |
+
|
| 529 |
+
for t, layers in step_weights.items():
|
| 530 |
+
last_layer = layers[-1][0] # [Lq, Lk]
|
| 531 |
+
|
| 532 |
+
# max attention source index per target token
|
| 533 |
+
align = np.argmax(last_layer, axis=1)
|
| 534 |
+
|
| 535 |
+
for tgt_idx, src_idx in enumerate(align):
|
| 536 |
+
token_stability[tgt_idx].append(src_idx)
|
| 537 |
+
|
| 538 |
+
results = {}
|
| 539 |
+
|
| 540 |
+
for tgt_idx, src_seq in token_stability.items():
|
| 541 |
+
changes = sum(
|
| 542 |
+
1 for i in range(1, len(src_seq))
|
| 543 |
+
if src_seq[i] != src_seq[i-1]
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
if changes <= 2:
|
| 547 |
+
results[tgt_idx] = "LOCKED"
|
| 548 |
+
else:
|
| 549 |
+
results[tgt_idx] = "FLEXIBLE"
|
| 550 |
+
|
| 551 |
+
return results
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
# ─────────────────────────────────────────────────────────────
|
| 555 |
+
# 4. TF-IDF vs ATTENTION STABILITY
|
| 556 |
+
# ─────────────────────────���───────────────────────────────────
|
| 557 |
+
|
| 558 |
+
def tfidf_attention_correlation(src_text, step_weights):
|
| 559 |
+
vectorizer = TfidfVectorizer()
|
| 560 |
+
tfidf = vectorizer.fit_transform([src_text]).toarray()[0]
|
| 561 |
+
|
| 562 |
+
# Avg attention over steps
|
| 563 |
+
attn_scores = None
|
| 564 |
+
|
| 565 |
+
for t, layers in step_weights.items():
|
| 566 |
+
w = layers[-1][0] # last layer
|
| 567 |
+
avg = w.mean(axis=0) # per source token
|
| 568 |
+
|
| 569 |
+
if attn_scores is None:
|
| 570 |
+
attn_scores = avg
|
| 571 |
+
else:
|
| 572 |
+
attn_scores += avg
|
| 573 |
+
|
| 574 |
+
attn_scores /= len(step_weights)
|
| 575 |
+
|
| 576 |
+
# Correlation
|
| 577 |
+
min_len = min(len(tfidf), len(attn_scores))
|
| 578 |
+
corr = np.corrcoef(tfidf[:min_len], attn_scores[:min_len])[0, 1]
|
| 579 |
+
|
| 580 |
+
return corr
|
| 581 |
+
|
| 582 |
+
|
| 583 |
+
# ─────────────────────────────────────────────────────────────
|
| 584 |
+
# 5. FULL PIPELINE
|
| 585 |
+
# ─────────────────────────────────────────────────────────────
|
| 586 |
+
|
| 587 |
+
def run_task2_analysis(
|
| 588 |
+
text,
|
| 589 |
+
model,
|
| 590 |
+
src_tokenizer,
|
| 591 |
+
tgt_tokenizer,
|
| 592 |
+
device
|
| 593 |
+
):
|
| 594 |
+
src_ids = torch.tensor(
|
| 595 |
+
[src_tokenizer.encode(text)],
|
| 596 |
+
device=device
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
capturer = AttentionCapture(model)
|
| 600 |
+
|
| 601 |
+
# Step 1: Capture
|
| 602 |
+
step_weights, step_outputs = capturer.run(src_ids)
|
| 603 |
+
|
| 604 |
+
# Step 2: Metrics
|
| 605 |
+
trajectory = compute_trajectory_metrics(
|
| 606 |
+
step_outputs,
|
| 607 |
+
tgt_tokenizer,
|
| 608 |
+
reference_text=text # transliteration task
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
# Step 3: Token stability
|
| 612 |
+
stability = analyze_token_stability(step_weights)
|
| 613 |
+
|
| 614 |
+
# Step 4: TF-IDF correlation
|
| 615 |
+
corr = tfidf_attention_correlation(text, step_weights)
|
| 616 |
+
|
| 617 |
+
return {
|
| 618 |
+
"trajectory": trajectory,
|
| 619 |
+
"token_stability": stability,
|
| 620 |
+
"tfidf_corr": corr
|
| 621 |
+
}
|
analysis/concept_vectors.py
ADDED
|
@@ -0,0 +1,637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# analysis/concept_vectors.py
|
| 3 |
+
# ============================
|
| 4 |
+
# Task 3: Concept Vector Extraction + Controlled Paraphrase Diversity
|
| 5 |
+
#
|
| 6 |
+
# No retraining required. Uses decoder hidden states already computed
|
| 7 |
+
# during generate_cached() — stored in model.model._last_hidden after
|
| 8 |
+
# each forward_cached() call.
|
| 9 |
+
#
|
| 10 |
+
# Steps:
|
| 11 |
+
# 1. Collect hidden states from N examples at a fixed diffusion step
|
| 12 |
+
# 2. Pool sequence dimension → [N, d_model] representation per example
|
| 13 |
+
# 3. PCA → find principal directions in concept space
|
| 14 |
+
# 4. Identify "diversity direction" (PC that best separates short/long outputs)
|
| 15 |
+
# 5. Steer: at inference, shift hidden states along diversity direction
|
| 16 |
+
# before the output head projection
|
| 17 |
+
# 6. Generate at 5 points along the direction, measure output diversity
|
| 18 |
+
#
|
| 19 |
+
# Key insight: the diversity direction is found purely from model outputs
|
| 20 |
+
# (no human annotation needed). We use output length as a proxy:
|
| 21 |
+
# short output → low diversity (model collapsed to simple token)
|
| 22 |
+
# long output → high diversity (model exploring more of the space)
|
| 23 |
+
# """
|
| 24 |
+
#
|
| 25 |
+
# import torch
|
| 26 |
+
# import torch.nn as nn
|
| 27 |
+
# import torch.nn.functional as F
|
| 28 |
+
# import numpy as np
|
| 29 |
+
# from typing import List, Dict, Optional, Tuple
|
| 30 |
+
#
|
| 31 |
+
#
|
| 32 |
+
# # ── Hidden state collection ───────────────────────────────────────────
|
| 33 |
+
#
|
| 34 |
+
# @torch.no_grad()
|
| 35 |
+
# def collect_hidden_states(
|
| 36 |
+
# model,
|
| 37 |
+
# src_list: List[torch.Tensor],
|
| 38 |
+
# t_capture: int = 0,
|
| 39 |
+
# temperature: float = 0.8,
|
| 40 |
+
# top_k: int = 40,
|
| 41 |
+
# max_samples: int = 1000,
|
| 42 |
+
# ) -> Tuple[np.ndarray, List[str]]:
|
| 43 |
+
# """
|
| 44 |
+
# Run generate_cached() on a list of source tensors, collecting the
|
| 45 |
+
# decoder hidden state at timestep t_capture for each sample.
|
| 46 |
+
#
|
| 47 |
+
# Args:
|
| 48 |
+
# model : SanskritModel (D3PMCrossAttention)
|
| 49 |
+
# src_list : list of [1, src_len] tensors, one per sample
|
| 50 |
+
# t_capture : which diffusion step to capture hidden states at
|
| 51 |
+
# 0 = final (clean), T-1 = noisy start
|
| 52 |
+
# temperature: sampling temperature
|
| 53 |
+
# top_k : top-k filter
|
| 54 |
+
# max_samples: cap at this many samples
|
| 55 |
+
#
|
| 56 |
+
# Returns:
|
| 57 |
+
# hidden_matrix : np.ndarray [N, d_model] — pooled hidden states
|
| 58 |
+
# output_texts : list of N decoded output strings (for diversity analysis)
|
| 59 |
+
# """
|
| 60 |
+
# inner = model.model
|
| 61 |
+
# T = inner.scheduler.num_timesteps
|
| 62 |
+
# device = next(inner.parameters()).device
|
| 63 |
+
#
|
| 64 |
+
# hidden_list = []
|
| 65 |
+
# output_list = []
|
| 66 |
+
#
|
| 67 |
+
# n = min(len(src_list), max_samples)
|
| 68 |
+
# print(f"Collecting hidden states from {n} examples at t={t_capture}...")
|
| 69 |
+
#
|
| 70 |
+
# for i, src in enumerate(src_list[:n]):
|
| 71 |
+
# if i % 100 == 0:
|
| 72 |
+
# print(f" {i}/{n}")
|
| 73 |
+
#
|
| 74 |
+
# if src.dim() == 1:
|
| 75 |
+
# src = src.unsqueeze(0)
|
| 76 |
+
# src = src.to(device)
|
| 77 |
+
#
|
| 78 |
+
# B = src.shape[0]
|
| 79 |
+
# tgt_len = inner.max_seq_len
|
| 80 |
+
# mask_id = inner.mask_token_id
|
| 81 |
+
#
|
| 82 |
+
# # KV cache
|
| 83 |
+
# memory, src_pad_mask = inner.encode_source(src)
|
| 84 |
+
#
|
| 85 |
+
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 86 |
+
# hint = None
|
| 87 |
+
# captured_hidden = None
|
| 88 |
+
#
|
| 89 |
+
# for t_val in range(T - 1, -1, -1):
|
| 90 |
+
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 91 |
+
# is_last = (t_val == 0)
|
| 92 |
+
#
|
| 93 |
+
# logits, _ = inner.forward_cached(
|
| 94 |
+
# memory, src_pad_mask, x0_est, t,
|
| 95 |
+
# x0_hint=hint, inference_mode=True,
|
| 96 |
+
# )
|
| 97 |
+
#
|
| 98 |
+
# # Capture hidden state at target step
|
| 99 |
+
# if t_val == t_capture and hasattr(inner, '_last_hidden'):
|
| 100 |
+
# captured_hidden = inner._last_hidden.detach().cpu()
|
| 101 |
+
#
|
| 102 |
+
# logits = logits / max(temperature, 1e-8)
|
| 103 |
+
# if top_k > 0:
|
| 104 |
+
# V = logits.shape[-1]
|
| 105 |
+
# if top_k < V:
|
| 106 |
+
# vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 107 |
+
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 108 |
+
#
|
| 109 |
+
# probs = F.softmax(logits, dim=-1)
|
| 110 |
+
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 111 |
+
# hint = x0_est
|
| 112 |
+
#
|
| 113 |
+
# # Pool hidden state over non-PAD positions → [d_model]
|
| 114 |
+
# if captured_hidden is not None:
|
| 115 |
+
# non_pad = (x0_est[0] > 1).cpu() # [tgt_len] bool
|
| 116 |
+
# if non_pad.sum() > 0:
|
| 117 |
+
# h = captured_hidden[0][non_pad].mean(dim=0) # [d_model]
|
| 118 |
+
# else:
|
| 119 |
+
# h = captured_hidden[0].mean(dim=0)
|
| 120 |
+
# hidden_list.append(h.numpy())
|
| 121 |
+
#
|
| 122 |
+
# # Decode output
|
| 123 |
+
# ids = [x for x in x0_est[0].tolist() if x > 4]
|
| 124 |
+
#
|
| 125 |
+
# print(f"Collected {len(hidden_list)} hidden states.")
|
| 126 |
+
# return np.stack(hidden_list), output_list
|
| 127 |
+
#
|
| 128 |
+
#
|
| 129 |
+
# # ── PCA on hidden states ─────────────────────────────���────────────────
|
| 130 |
+
#
|
| 131 |
+
# def fit_pca(
|
| 132 |
+
# hidden_matrix: np.ndarray,
|
| 133 |
+
# n_components: int = 50,
|
| 134 |
+
# ) -> object:
|
| 135 |
+
# """
|
| 136 |
+
# Fit PCA on hidden state matrix.
|
| 137 |
+
#
|
| 138 |
+
# Args:
|
| 139 |
+
# hidden_matrix : [N, d_model]
|
| 140 |
+
# n_components : number of PCA components to retain
|
| 141 |
+
#
|
| 142 |
+
# Returns:
|
| 143 |
+
# fitted sklearn PCA object
|
| 144 |
+
# """
|
| 145 |
+
# from sklearn.decomposition import PCA
|
| 146 |
+
# n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1])
|
| 147 |
+
# pca = PCA(n_components=n_comp)
|
| 148 |
+
# pca.fit(hidden_matrix)
|
| 149 |
+
# print(f"PCA fit: {n_comp} components explain "
|
| 150 |
+
# f"{pca.explained_variance_ratio_.sum()*100:.1f}% of variance.")
|
| 151 |
+
# return pca
|
| 152 |
+
#
|
| 153 |
+
#
|
| 154 |
+
# def find_diversity_direction(
|
| 155 |
+
# hidden_matrix: np.ndarray,
|
| 156 |
+
# output_lengths: List[int],
|
| 157 |
+
# pca: object,
|
| 158 |
+
# ) -> np.ndarray:
|
| 159 |
+
# """
|
| 160 |
+
# Find the PCA direction that best correlates with output diversity
|
| 161 |
+
# (measured by output length as proxy).
|
| 162 |
+
#
|
| 163 |
+
# Projects hidden states into PCA space, then finds the PC whose
|
| 164 |
+
# scores have highest Spearman correlation with output lengths.
|
| 165 |
+
#
|
| 166 |
+
# Returns:
|
| 167 |
+
# direction : np.ndarray [d_model] — diversity direction in original space
|
| 168 |
+
# """
|
| 169 |
+
# from scipy.stats import spearmanr
|
| 170 |
+
#
|
| 171 |
+
# projected = pca.transform(hidden_matrix) # [N, n_components]
|
| 172 |
+
# lengths = np.array(output_lengths)
|
| 173 |
+
#
|
| 174 |
+
# correlations = []
|
| 175 |
+
# for pc_idx in range(projected.shape[1]):
|
| 176 |
+
# r, _ = spearmanr(projected[:, pc_idx], lengths)
|
| 177 |
+
# correlations.append(abs(r))
|
| 178 |
+
#
|
| 179 |
+
# best_pc = int(np.argmax(correlations))
|
| 180 |
+
# print(f"Diversity direction: PC {best_pc} "
|
| 181 |
+
# f"(|r|={correlations[best_pc]:.3f} with output length)")
|
| 182 |
+
#
|
| 183 |
+
# # Map back to original d_model space
|
| 184 |
+
# direction = pca.components_[best_pc] # [d_model]
|
| 185 |
+
# direction = direction / (np.linalg.norm(direction) + 1e-8)
|
| 186 |
+
# return direction, best_pc, correlations[best_pc]
|
| 187 |
+
#
|
| 188 |
+
#
|
| 189 |
+
# # ── Steered generation ────────────────────────────────────────────────
|
| 190 |
+
#
|
| 191 |
+
# @torch.no_grad()
|
| 192 |
+
# def generate_steered(
|
| 193 |
+
# model,
|
| 194 |
+
# src: torch.Tensor,
|
| 195 |
+
# direction: np.ndarray,
|
| 196 |
+
# alpha: float = 0.0,
|
| 197 |
+
# temperature: float = 0.8,
|
| 198 |
+
# top_k: int = 40,
|
| 199 |
+
# ) -> torch.Tensor:
|
| 200 |
+
# """
|
| 201 |
+
# Generate output while steering hidden states along diversity direction.
|
| 202 |
+
#
|
| 203 |
+
# At each diffusion step, after the decoder runs, we shift the hidden state
|
| 204 |
+
# by alpha * direction before projecting to logits.
|
| 205 |
+
#
|
| 206 |
+
# alpha > 0 → push toward high-diversity output
|
| 207 |
+
# alpha < 0 → push toward low-diversity output
|
| 208 |
+
# alpha = 0 → standard generation (no steering)
|
| 209 |
+
#
|
| 210 |
+
# Args:
|
| 211 |
+
# model : SanskritModel (D3PMCrossAttention)
|
| 212 |
+
# src : [1, src_len] IAST token ids
|
| 213 |
+
# direction : [d_model] diversity direction from find_diversity_direction()
|
| 214 |
+
# alpha : steering strength
|
| 215 |
+
# temperature / top_k: sampling params
|
| 216 |
+
#
|
| 217 |
+
# Returns:
|
| 218 |
+
# x0_est : [1, tgt_len] generated token ids
|
| 219 |
+
# """
|
| 220 |
+
# inner = model.model
|
| 221 |
+
# T = inner.scheduler.num_timesteps
|
| 222 |
+
# device = next(inner.parameters()).device
|
| 223 |
+
#
|
| 224 |
+
# if src.dim() == 1:
|
| 225 |
+
# src = src.unsqueeze(0)
|
| 226 |
+
# src = src.to(device)
|
| 227 |
+
#
|
| 228 |
+
# B = src.shape[0]
|
| 229 |
+
# tgt_len = inner.max_seq_len
|
| 230 |
+
# mask_id = inner.mask_token_id
|
| 231 |
+
#
|
| 232 |
+
# dir_tensor = torch.tensor(direction, dtype=torch.float32, device=device)
|
| 233 |
+
#
|
| 234 |
+
# memory, src_pad_mask = inner.encode_source(src)
|
| 235 |
+
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 236 |
+
# hint = None
|
| 237 |
+
#
|
| 238 |
+
# inner.eval()
|
| 239 |
+
# for t_val in range(T - 1, -1, -1):
|
| 240 |
+
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 241 |
+
# is_last = (t_val == 0)
|
| 242 |
+
#
|
| 243 |
+
# # Standard forward_cached but we intercept hidden states
|
| 244 |
+
# PAD = 1
|
| 245 |
+
# tgt_pad_mask = None # inference_mode
|
| 246 |
+
#
|
| 247 |
+
# _, x_t_ids = inner.forward_process.q_sample(x0_est, t) if t_val > 0 else \
|
| 248 |
+
# (None, x0_est)
|
| 249 |
+
# x = inner.tgt_embed(x_t_ids)
|
| 250 |
+
# t_norm = t.float() / inner.scheduler.num_timesteps
|
| 251 |
+
# t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
|
| 252 |
+
# x = x + t_emb.unsqueeze(1)
|
| 253 |
+
#
|
| 254 |
+
# if hint is not None:
|
| 255 |
+
# hint_emb = inner.tgt_embed(hint)
|
| 256 |
+
# gate = inner.hint_gate(x)
|
| 257 |
+
# x = x + gate * hint_emb
|
| 258 |
+
#
|
| 259 |
+
# for block in inner.decoder_blocks:
|
| 260 |
+
# x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
|
| 261 |
+
#
|
| 262 |
+
# # ── STEERING: shift hidden states along diversity direction ───
|
| 263 |
+
# if alpha != 0.0:
|
| 264 |
+
# x = x + alpha * dir_tensor.unsqueeze(0).unsqueeze(0)
|
| 265 |
+
#
|
| 266 |
+
# # Project to logits using the head
|
| 267 |
+
# logits = inner.head(x)
|
| 268 |
+
#
|
| 269 |
+
# logits = logits / max(temperature, 1e-8)
|
| 270 |
+
# if top_k > 0:
|
| 271 |
+
# V = logits.shape[-1]
|
| 272 |
+
# if top_k < V:
|
| 273 |
+
# vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 274 |
+
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 275 |
+
#
|
| 276 |
+
# probs = F.softmax(logits, dim=-1)
|
| 277 |
+
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 278 |
+
# hint = x0_est
|
| 279 |
+
#
|
| 280 |
+
# return x0_est
|
| 281 |
+
#
|
| 282 |
+
#
|
| 283 |
+
# def generate_diversity_spectrum(
|
| 284 |
+
# model,
|
| 285 |
+
# src: torch.Tensor,
|
| 286 |
+
# direction: np.ndarray,
|
| 287 |
+
# tgt_tokenizer,
|
| 288 |
+
# alphas: List[float] = [-2.0, -1.0, 0.0, 1.0, 2.0],
|
| 289 |
+
# temperature: float = 0.8,
|
| 290 |
+
# top_k: int = 40,
|
| 291 |
+
# ) -> Dict[float, str]:
|
| 292 |
+
# """
|
| 293 |
+
# Generate outputs at 5 points along the diversity direction.
|
| 294 |
+
#
|
| 295 |
+
# Args:
|
| 296 |
+
# alphas : steering strengths (negative = low diversity, positive = high)
|
| 297 |
+
#
|
| 298 |
+
# Returns:
|
| 299 |
+
# dict mapping alpha → decoded Devanagari string
|
| 300 |
+
# """
|
| 301 |
+
# results = {}
|
| 302 |
+
# for alpha in alphas:
|
| 303 |
+
# out_ids = generate_steered(model, src, direction, alpha, temperature, top_k)
|
| 304 |
+
# ids = [x for x in out_ids[0].tolist() if x > 4]
|
| 305 |
+
# text = tgt_tokenizer.decode(ids).strip()
|
| 306 |
+
# results[alpha] = text
|
| 307 |
+
# print(f" alpha={alpha:+.1f} → {text}")
|
| 308 |
+
# return results
|
| 309 |
+
#
|
| 310 |
+
#
|
| 311 |
+
# def plot_pca_space(
|
| 312 |
+
# hidden_matrix: np.ndarray,
|
| 313 |
+
# output_lengths: List[int],
|
| 314 |
+
# pca: object,
|
| 315 |
+
# diversity_pc: int,
|
| 316 |
+
# save_path: Optional[str] = None,
|
| 317 |
+
# ):
|
| 318 |
+
# """
|
| 319 |
+
# Scatter plot of examples in PC1 vs PC2 space, coloured by output length.
|
| 320 |
+
# Highlights the diversity direction.
|
| 321 |
+
# """
|
| 322 |
+
# try:
|
| 323 |
+
# import matplotlib.pyplot as plt
|
| 324 |
+
# except ImportError:
|
| 325 |
+
# print("pip install matplotlib.")
|
| 326 |
+
# return
|
| 327 |
+
#
|
| 328 |
+
# projected = pca.transform(hidden_matrix) # [N, n_pc]
|
| 329 |
+
# lengths = np.array(output_lengths)
|
| 330 |
+
#
|
| 331 |
+
# fig, axes = plt.subplots(1, 2, figsize=(14, 5))
|
| 332 |
+
#
|
| 333 |
+
# # Left: PC0 vs PC1 coloured by length
|
| 334 |
+
# ax = axes[0]
|
| 335 |
+
# sc = ax.scatter(projected[:, 0], projected[:, 1],
|
| 336 |
+
# c=lengths, cmap='viridis', alpha=0.6, s=15)
|
| 337 |
+
# plt.colorbar(sc, ax=ax, label="Output length (chars)")
|
| 338 |
+
# ax.set_xlabel(f"PC0 ({pca.explained_variance_ratio_[0]*100:.1f}%)", fontsize=10)
|
| 339 |
+
# ax.set_ylabel(f"PC1 ({pca.explained_variance_ratio_[1]*100:.1f}%)", fontsize=10)
|
| 340 |
+
# ax.set_title("Concept space (PC0 vs PC1)", fontsize=11)
|
| 341 |
+
#
|
| 342 |
+
# # Right: explained variance
|
| 343 |
+
# ax2 = axes[1]
|
| 344 |
+
# cumvar = np.cumsum(pca.explained_variance_ratio_) * 100
|
| 345 |
+
# ax2.plot(range(1, len(cumvar)+1), cumvar, linewidth=1.5, color='steelblue')
|
| 346 |
+
# ax2.axvline(diversity_pc, color='coral', linestyle='--', label=f"Diversity PC={diversity_pc}")
|
| 347 |
+
# ax2.set_xlabel("Number of PCs", fontsize=10)
|
| 348 |
+
# ax2.set_ylabel("Cumulative variance (%)", fontsize=10)
|
| 349 |
+
# ax2.set_title("PCA explained variance", fontsize=11)
|
| 350 |
+
# ax2.legend()
|
| 351 |
+
# ax2.set_ylim(0, 102)
|
| 352 |
+
#
|
| 353 |
+
# plt.tight_layout()
|
| 354 |
+
# if save_path:
|
| 355 |
+
# import os
|
| 356 |
+
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 357 |
+
# plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 358 |
+
# print(f"Saved: {save_path}")
|
| 359 |
+
# else:
|
| 360 |
+
# plt.show()
|
| 361 |
+
# plt.close()
|
| 362 |
+
#
|
| 363 |
+
#
|
| 364 |
+
# def _sample(probs):
|
| 365 |
+
# B, L, V = probs.shape
|
| 366 |
+
# flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 367 |
+
# flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 368 |
+
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 369 |
+
"""
|
| 370 |
+
Task 3: Concept Vector Extraction + Controlled Paraphrase Diversity
|
| 371 |
+
Fully corrected & production-ready version
|
| 372 |
+
"""
|
| 373 |
+
|
| 374 |
+
import torch
|
| 375 |
+
import torch.nn.functional as F
|
| 376 |
+
import numpy as np
|
| 377 |
+
from typing import List, Tuple, Dict, Optional
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# ─────────────────────────────────────────────────────────────
|
| 381 |
+
# Utility
|
| 382 |
+
# ─────────────────────────────────────────────────────────────
|
| 383 |
+
|
| 384 |
+
def _sample(probs: torch.Tensor) -> torch.Tensor:
|
| 385 |
+
B, L, V = probs.shape
|
| 386 |
+
flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 387 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 388 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# ─────────────────────────────────────────────────────────────
|
| 392 |
+
# 1. Collect Hidden States
|
| 393 |
+
# ─────────────────────────────────────────────────────────────
|
| 394 |
+
|
| 395 |
+
@torch.no_grad()
|
| 396 |
+
def collect_hidden_states(
|
| 397 |
+
model,
|
| 398 |
+
src_list: List[torch.Tensor],
|
| 399 |
+
tgt_tokenizer,
|
| 400 |
+
t_capture: int = 0,
|
| 401 |
+
temperature: float = 0.8,
|
| 402 |
+
top_k: int = 40,
|
| 403 |
+
max_samples: int = 1000,
|
| 404 |
+
) -> Tuple[np.ndarray, List[str], List[int]]:
|
| 405 |
+
"""
|
| 406 |
+
Collect pooled hidden representations + outputs
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
inner = model.model
|
| 410 |
+
device = next(inner.parameters()).device
|
| 411 |
+
T = inner.scheduler.num_timesteps
|
| 412 |
+
|
| 413 |
+
hidden_list = []
|
| 414 |
+
texts = []
|
| 415 |
+
lengths = []
|
| 416 |
+
|
| 417 |
+
print(f"Collecting {min(len(src_list), max_samples)} samples...")
|
| 418 |
+
|
| 419 |
+
for i, src in enumerate(src_list[:max_samples]):
|
| 420 |
+
|
| 421 |
+
if src.dim() == 1:
|
| 422 |
+
src = src.unsqueeze(0)
|
| 423 |
+
src = src.to(device)
|
| 424 |
+
|
| 425 |
+
B = src.shape[0]
|
| 426 |
+
tgt_len = inner.max_seq_len
|
| 427 |
+
mask_id = inner.mask_token_id
|
| 428 |
+
|
| 429 |
+
# KV Cache (IMPORTANT)
|
| 430 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 431 |
+
|
| 432 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 433 |
+
hint = None
|
| 434 |
+
captured_hidden = None
|
| 435 |
+
|
| 436 |
+
for t_val in range(T - 1, -1, -1):
|
| 437 |
+
|
| 438 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 439 |
+
is_last = (t_val == 0)
|
| 440 |
+
|
| 441 |
+
logits, _ = inner.forward_cached(
|
| 442 |
+
memory,
|
| 443 |
+
src_pad_mask,
|
| 444 |
+
x0_est,
|
| 445 |
+
t,
|
| 446 |
+
x0_hint=hint,
|
| 447 |
+
inference_mode=True,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# Capture hidden state
|
| 451 |
+
if t_val == t_capture:
|
| 452 |
+
if hasattr(inner, "_last_hidden"):
|
| 453 |
+
captured_hidden = inner._last_hidden.detach().cpu()
|
| 454 |
+
|
| 455 |
+
# Sampling
|
| 456 |
+
logits = logits / max(temperature, 1e-8)
|
| 457 |
+
|
| 458 |
+
if top_k > 0:
|
| 459 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 460 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float("-inf"))
|
| 461 |
+
|
| 462 |
+
probs = F.softmax(logits, dim=-1)
|
| 463 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 464 |
+
hint = x0_est
|
| 465 |
+
|
| 466 |
+
# Pool hidden
|
| 467 |
+
if captured_hidden is not None:
|
| 468 |
+
h = captured_hidden[0].mean(dim=0) # [d_model]
|
| 469 |
+
hidden_list.append(h.numpy())
|
| 470 |
+
|
| 471 |
+
# Decode
|
| 472 |
+
ids = [x for x in x0_est[0].tolist() if x > 4]
|
| 473 |
+
text = tgt_tokenizer.decode(ids).strip()
|
| 474 |
+
|
| 475 |
+
texts.append(text)
|
| 476 |
+
lengths.append(len(text))
|
| 477 |
+
|
| 478 |
+
if i % 100 == 0:
|
| 479 |
+
print(f"{i} done")
|
| 480 |
+
|
| 481 |
+
hidden_matrix = np.stack(hidden_list)
|
| 482 |
+
|
| 483 |
+
print("Collected hidden states:", hidden_matrix.shape)
|
| 484 |
+
return hidden_matrix, texts, lengths
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# ─────────────────────────────────────────────────────────────
|
| 488 |
+
# 2. PCA
|
| 489 |
+
# ─────────────────────────────────────────────────────────────
|
| 490 |
+
|
| 491 |
+
def fit_pca(hidden_matrix: np.ndarray, n_components: int = 50):
|
| 492 |
+
from sklearn.decomposition import PCA
|
| 493 |
+
|
| 494 |
+
n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1])
|
| 495 |
+
pca = PCA(n_components=n_comp)
|
| 496 |
+
pca.fit(hidden_matrix)
|
| 497 |
+
|
| 498 |
+
print("Explained variance:", pca.explained_variance_ratio_.sum())
|
| 499 |
+
return pca
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
# ─────────────────────────────────────────────────────────────
|
| 503 |
+
# 3. Find Diversity Direction
|
| 504 |
+
# ─────────────────────────────────────────────────────────────
|
| 505 |
+
|
| 506 |
+
def find_diversity_direction(hidden_matrix, lengths, pca):
|
| 507 |
+
from scipy.stats import spearmanr
|
| 508 |
+
|
| 509 |
+
projected = pca.transform(hidden_matrix)
|
| 510 |
+
lengths = np.array(lengths)
|
| 511 |
+
|
| 512 |
+
scores = []
|
| 513 |
+
|
| 514 |
+
for i in range(projected.shape[1]):
|
| 515 |
+
r, _ = spearmanr(projected[:, i], lengths)
|
| 516 |
+
scores.append(abs(r))
|
| 517 |
+
|
| 518 |
+
best_pc = int(np.argmax(scores))
|
| 519 |
+
|
| 520 |
+
print(f"Best PC: {best_pc} | corr={scores[best_pc]:.3f}")
|
| 521 |
+
|
| 522 |
+
direction = pca.components_[best_pc]
|
| 523 |
+
direction = direction / (np.linalg.norm(direction) + 1e-8)
|
| 524 |
+
|
| 525 |
+
return direction
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
# ─────────────────────────────────────────────────────────────
|
| 529 |
+
# 4. Steered Generation
|
| 530 |
+
# ─────────────────────────────────────────────────────────────
|
| 531 |
+
|
| 532 |
+
@torch.no_grad()
|
| 533 |
+
def generate_steered(
|
| 534 |
+
model,
|
| 535 |
+
src,
|
| 536 |
+
direction,
|
| 537 |
+
alpha=0.0,
|
| 538 |
+
temperature=0.8,
|
| 539 |
+
top_k=40,
|
| 540 |
+
):
|
| 541 |
+
inner = model.model
|
| 542 |
+
device = next(inner.parameters()).device
|
| 543 |
+
T = inner.scheduler.num_timesteps
|
| 544 |
+
|
| 545 |
+
if src.dim() == 1:
|
| 546 |
+
src = src.unsqueeze(0)
|
| 547 |
+
src = src.to(device)
|
| 548 |
+
|
| 549 |
+
B = src.shape[0]
|
| 550 |
+
tgt_len = inner.max_seq_len
|
| 551 |
+
mask_id = inner.mask_token_id
|
| 552 |
+
|
| 553 |
+
direction = torch.tensor(direction, dtype=torch.float32, device=device)
|
| 554 |
+
direction = direction / (torch.norm(direction) + 1e-6)
|
| 555 |
+
|
| 556 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 557 |
+
|
| 558 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 559 |
+
hint = None
|
| 560 |
+
|
| 561 |
+
for t_val in range(T - 1, -1, -1):
|
| 562 |
+
|
| 563 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 564 |
+
is_last = (t_val == 0)
|
| 565 |
+
|
| 566 |
+
logits, _ = inner.forward_cached(
|
| 567 |
+
memory,
|
| 568 |
+
src_pad_mask,
|
| 569 |
+
x0_est,
|
| 570 |
+
t,
|
| 571 |
+
x0_hint=hint,
|
| 572 |
+
inference_mode=True,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
# Inject diversity
|
| 576 |
+
if hasattr(inner, "_last_hidden") and alpha != 0.0:
|
| 577 |
+
h = inner._last_hidden
|
| 578 |
+
h = h + alpha * direction.unsqueeze(0).unsqueeze(0)
|
| 579 |
+
logits = inner.head(h)
|
| 580 |
+
|
| 581 |
+
# Sampling
|
| 582 |
+
logits = logits / max(temperature, 1e-8)
|
| 583 |
+
|
| 584 |
+
if top_k > 0:
|
| 585 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 586 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float("-inf"))
|
| 587 |
+
|
| 588 |
+
probs = F.softmax(logits, dim=-1)
|
| 589 |
+
x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 590 |
+
hint = x0_est
|
| 591 |
+
|
| 592 |
+
return x0_est
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
# ─────────────────────────────────────────────────────────────
|
| 596 |
+
# 5. Diversity Spectrum
|
| 597 |
+
# ─────────────────────────────────────────────────────────────
|
| 598 |
+
|
| 599 |
+
def generate_diversity_spectrum(
|
| 600 |
+
model,
|
| 601 |
+
src,
|
| 602 |
+
direction,
|
| 603 |
+
tgt_tokenizer,
|
| 604 |
+
alphas=[-2, -1, 0, 1, 2],
|
| 605 |
+
):
|
| 606 |
+
results = {}
|
| 607 |
+
|
| 608 |
+
print("\nDiversity Spectrum:\n")
|
| 609 |
+
|
| 610 |
+
for alpha in alphas:
|
| 611 |
+
out_ids = generate_steered(model, src, direction, alpha)
|
| 612 |
+
|
| 613 |
+
ids = [x for x in out_ids[0].tolist() if x > 4]
|
| 614 |
+
text = tgt_tokenizer.decode(ids).strip()
|
| 615 |
+
|
| 616 |
+
print(f"{alpha:+} → {text}")
|
| 617 |
+
results[alpha] = text
|
| 618 |
+
|
| 619 |
+
return results
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
# ─────────────────────────────────────────────────────────────
|
| 623 |
+
# 6. Visualization
|
| 624 |
+
# ─────────────────────────────────────────────────────────────
|
| 625 |
+
|
| 626 |
+
def plot_pca_space(hidden_matrix, lengths, pca):
|
| 627 |
+
import matplotlib.pyplot as plt
|
| 628 |
+
|
| 629 |
+
proj = pca.transform(hidden_matrix)
|
| 630 |
+
|
| 631 |
+
plt.figure(figsize=(8, 6))
|
| 632 |
+
sc = plt.scatter(proj[:, 0], proj[:, 1], c=lengths)
|
| 633 |
+
plt.colorbar(sc)
|
| 634 |
+
plt.title("Concept Space")
|
| 635 |
+
plt.xlabel("PC1")
|
| 636 |
+
plt.ylabel("PC2")
|
| 637 |
+
plt.show()
|
analysis/kv_cache_benchmark.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# analysis/kv_cache_benchmark.py
|
| 3 |
+
# ================================
|
| 4 |
+
# Task 1: Benchmark KV cache vs standard generate().
|
| 5 |
+
#
|
| 6 |
+
# Measures:
|
| 7 |
+
# - Wall-clock time for generate() vs generate_cached()
|
| 8 |
+
# - Encoder time as % of total generation time (before/after)
|
| 9 |
+
# - Speedup ratio at src_len = 16, 32, 64 tokens
|
| 10 |
+
#
|
| 11 |
+
# How it works:
|
| 12 |
+
# Standard generate():
|
| 13 |
+
# For each of T=128 steps:
|
| 14 |
+
# src → encoder → memory → decoder → logits (encoder runs 128 times)
|
| 15 |
+
#
|
| 16 |
+
# generate_cached():
|
| 17 |
+
# src → encoder → memory (once)
|
| 18 |
+
# For each of T=128 steps:
|
| 19 |
+
# cached_memory → decoder → logits (encoder runs 1 time)
|
| 20 |
+
#
|
| 21 |
+
# Expected speedup:
|
| 22 |
+
# If encoder = 30% of per-step time:
|
| 23 |
+
# Saved = 127/128 * 30% ≈ 29.7% of total time
|
| 24 |
+
# If encoder = 50% of per-step time:
|
| 25 |
+
# Saved ≈ 49.6% of total time
|
| 26 |
+
#
|
| 27 |
+
# Usage:
|
| 28 |
+
# python -m analysis.kv_cache_benchmark
|
| 29 |
+
# or:
|
| 30 |
+
# from analysis.kv_cache_benchmark import run_benchmark
|
| 31 |
+
# results = run_benchmark(model, src_tokenizer, device)
|
| 32 |
+
# """
|
| 33 |
+
#
|
| 34 |
+
# import torch
|
| 35 |
+
# import time
|
| 36 |
+
# import numpy as np
|
| 37 |
+
# from typing import Dict, List
|
| 38 |
+
#
|
| 39 |
+
#
|
| 40 |
+
# def _make_src(src_len: int, src_vocab: int, device: torch.device, batch_size: int = 1):
|
| 41 |
+
# """Create a random source tensor of given length."""
|
| 42 |
+
# # Random real tokens (ids 5..src_vocab-1), padded to src_len
|
| 43 |
+
# ids = torch.randint(5, src_vocab, (batch_size, src_len), device=device)
|
| 44 |
+
# return ids
|
| 45 |
+
#
|
| 46 |
+
#
|
| 47 |
+
# def _time_fn(fn, n_warmup: int = 2, n_runs: int = 5) -> float:
|
| 48 |
+
# """
|
| 49 |
+
# Time a zero-argument callable.
|
| 50 |
+
# Returns mean wall-clock seconds over n_runs after n_warmup warmup calls.
|
| 51 |
+
# """
|
| 52 |
+
# # Warmup
|
| 53 |
+
# for _ in range(n_warmup):
|
| 54 |
+
# fn()
|
| 55 |
+
# if torch.cuda.is_available():
|
| 56 |
+
# torch.cuda.synchronize()
|
| 57 |
+
# elif torch.backends.mps.is_available():
|
| 58 |
+
# torch.mps.synchronize()
|
| 59 |
+
#
|
| 60 |
+
# times = []
|
| 61 |
+
# for _ in range(n_runs):
|
| 62 |
+
# start = time.perf_counter()
|
| 63 |
+
# fn()
|
| 64 |
+
# if torch.cuda.is_available():
|
| 65 |
+
# torch.cuda.synchronize()
|
| 66 |
+
# elif torch.backends.mps.is_available():
|
| 67 |
+
# torch.mps.synchronize()
|
| 68 |
+
# times.append(time.perf_counter() - start)
|
| 69 |
+
#
|
| 70 |
+
# return float(np.mean(times))
|
| 71 |
+
#
|
| 72 |
+
#
|
| 73 |
+
# def benchmark_encoder_cost(
|
| 74 |
+
# model,
|
| 75 |
+
# src: torch.Tensor,
|
| 76 |
+
# ) -> Dict[str, float]:
|
| 77 |
+
# """
|
| 78 |
+
# Measure encoder time as a fraction of one full forward pass.
|
| 79 |
+
#
|
| 80 |
+
# Returns:
|
| 81 |
+
# encoder_s : seconds for one encoder call
|
| 82 |
+
# full_step_s : seconds for one full forward_cached decoder step
|
| 83 |
+
# encoder_pct : encoder_s / (encoder_s + full_step_s) * 100
|
| 84 |
+
# """
|
| 85 |
+
# inner = model.model
|
| 86 |
+
# if not hasattr(inner, 'encode_source'):
|
| 87 |
+
# raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
|
| 88 |
+
#
|
| 89 |
+
# device = src.device
|
| 90 |
+
# B = src.shape[0]
|
| 91 |
+
# T = inner.scheduler.num_timesteps
|
| 92 |
+
# tgt_len = inner.max_seq_len
|
| 93 |
+
# mask_id = inner.mask_token_id
|
| 94 |
+
#
|
| 95 |
+
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 96 |
+
# t = torch.zeros(B, dtype=torch.long, device=device)
|
| 97 |
+
#
|
| 98 |
+
# # Time encoder alone
|
| 99 |
+
# encoder_s = _time_fn(lambda: inner.encode_source(src))
|
| 100 |
+
#
|
| 101 |
+
# # Pre-compute memory for decoder timing
|
| 102 |
+
# memory, src_pad_mask = inner.encode_source(src)
|
| 103 |
+
#
|
| 104 |
+
# # Time one decoder step (cached)
|
| 105 |
+
# decoder_s = _time_fn(
|
| 106 |
+
# lambda: inner.forward_cached(memory, src_pad_mask, x0_est, t,
|
| 107 |
+
# inference_mode=True)
|
| 108 |
+
# )
|
| 109 |
+
#
|
| 110 |
+
# # Time one full step (non-cached = encoder + decoder)
|
| 111 |
+
# full_s = _time_fn(
|
| 112 |
+
# lambda: inner.forward(src, x0_est, t, inference_mode=True)
|
| 113 |
+
# )
|
| 114 |
+
#
|
| 115 |
+
# encoder_pct = 100.0 * encoder_s / max(full_s, 1e-9)
|
| 116 |
+
#
|
| 117 |
+
# return {
|
| 118 |
+
# "encoder_s": encoder_s,
|
| 119 |
+
# "decoder_s": decoder_s,
|
| 120 |
+
# "full_step_s": full_s,
|
| 121 |
+
# "encoder_pct": encoder_pct,
|
| 122 |
+
# }
|
| 123 |
+
#
|
| 124 |
+
#
|
| 125 |
+
# def run_benchmark(
|
| 126 |
+
# model,
|
| 127 |
+
# src_tokenizer,
|
| 128 |
+
# device: torch.device,
|
| 129 |
+
# src_lens: List[int] = [16, 32, 64],
|
| 130 |
+
# n_runs: int = 5,
|
| 131 |
+
# ) -> Dict:
|
| 132 |
+
# """
|
| 133 |
+
# Full benchmark: compare generate() vs generate_cached() at multiple src lengths.
|
| 134 |
+
#
|
| 135 |
+
# Args:
|
| 136 |
+
# model : SanskritModel (D3PMCrossAttention)
|
| 137 |
+
# src_tokenizer : SanskritSourceTokenizer
|
| 138 |
+
# device : torch.device
|
| 139 |
+
# src_lens : list of source lengths to benchmark
|
| 140 |
+
# n_runs : number of timing runs per condition
|
| 141 |
+
#
|
| 142 |
+
# Returns:
|
| 143 |
+
# results dict with timing and speedup for each src_len
|
| 144 |
+
# """
|
| 145 |
+
# inner = model.model
|
| 146 |
+
# if not hasattr(inner, 'generate_cached'):
|
| 147 |
+
# raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
|
| 148 |
+
#
|
| 149 |
+
# src_vocab = inner.src_embed.token_emb.weight.shape[0]
|
| 150 |
+
# results = {}
|
| 151 |
+
#
|
| 152 |
+
# print("\n" + "=" * 65)
|
| 153 |
+
# print(" KV CACHE BENCHMARK")
|
| 154 |
+
# print("=" * 65)
|
| 155 |
+
# print(f" {'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
|
| 156 |
+
# f"{'speedup':>8} {'encoder%':>9}")
|
| 157 |
+
# print("-" * 65)
|
| 158 |
+
#
|
| 159 |
+
# for src_len in src_lens:
|
| 160 |
+
# src = _make_src(src_len, src_vocab, device)
|
| 161 |
+
#
|
| 162 |
+
# # Encoder cost breakdown
|
| 163 |
+
# enc_cost = benchmark_encoder_cost(model, src)
|
| 164 |
+
#
|
| 165 |
+
# # Time standard generate() — encoder runs T times
|
| 166 |
+
# def run_standard():
|
| 167 |
+
# return inner.generate(src, temperature=0.8, top_k=40)
|
| 168 |
+
#
|
| 169 |
+
# # Time generate_cached() — encoder runs once
|
| 170 |
+
# def run_cached():
|
| 171 |
+
# return inner.generate_cached(src, temperature=0.8, top_k=40)
|
| 172 |
+
#
|
| 173 |
+
# t_standard = _time_fn(run_standard, n_warmup=1, n_runs=n_runs)
|
| 174 |
+
# t_cached = _time_fn(run_cached, n_warmup=1, n_runs=n_runs)
|
| 175 |
+
# speedup = t_standard / max(t_cached, 1e-9)
|
| 176 |
+
#
|
| 177 |
+
# results[src_len] = {
|
| 178 |
+
# "standard_s": t_standard,
|
| 179 |
+
# "cached_s": t_cached,
|
| 180 |
+
# "speedup": speedup,
|
| 181 |
+
# "encoder_pct": enc_cost["encoder_pct"],
|
| 182 |
+
# }
|
| 183 |
+
#
|
| 184 |
+
# print(f" {src_len:>8} {t_standard:>12.3f} {t_cached:>10.3f} "
|
| 185 |
+
# f"{speedup:>7.2f}x {enc_cost['encoder_pct']:>8.1f}%")
|
| 186 |
+
#
|
| 187 |
+
# print("=" * 65)
|
| 188 |
+
# print(f"\n Encoder cost = % of one full forward pass")
|
| 189 |
+
# print(f" Speedup = standard_time / cached_time")
|
| 190 |
+
# print(f" Expected: speedup ≈ 1 / (1 - encoder_pct/100 * (T-1)/T)")
|
| 191 |
+
#
|
| 192 |
+
# return results
|
| 193 |
+
#
|
| 194 |
+
#
|
| 195 |
+
# def print_summary(results: Dict):
|
| 196 |
+
# """Print a human-readable summary of benchmark results."""
|
| 197 |
+
# print("\n SUMMARY")
|
| 198 |
+
# print(" -------")
|
| 199 |
+
# for src_len, r in results.items():
|
| 200 |
+
# saved_pct = (1.0 - 1.0 / r["speedup"]) * 100
|
| 201 |
+
# print(f" src_len={src_len}: {r['speedup']:.2f}x speedup "
|
| 202 |
+
# f"({saved_pct:.1f}% time saved, "
|
| 203 |
+
# f"encoder was {r['encoder_pct']:.1f}% of total)")
|
| 204 |
+
#
|
| 205 |
+
#
|
| 206 |
+
# if __name__ == "__main__":
|
| 207 |
+
# import sys, os
|
| 208 |
+
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 209 |
+
# from config import CONFIG
|
| 210 |
+
# from inference import load_model
|
| 211 |
+
# from models.tokenizer import SanskritSourceTokenizer
|
| 212 |
+
#
|
| 213 |
+
# cfg = CONFIG
|
| 214 |
+
# device = torch.device(cfg['training']['device'])
|
| 215 |
+
#
|
| 216 |
+
# model_name = cfg['model_type']
|
| 217 |
+
# has_neg = cfg['data']['include_negative_examples']
|
| 218 |
+
# ckpt = f"results7/{model_name}_neg_{has_neg}/best_model.pt"
|
| 219 |
+
#
|
| 220 |
+
# if not os.path.exists(ckpt):
|
| 221 |
+
# print(f"No checkpoint at {ckpt}. Train first.")
|
| 222 |
+
# sys.exit(1)
|
| 223 |
+
#
|
| 224 |
+
# model, cfg = load_model(ckpt, cfg, device)
|
| 225 |
+
# model.eval()
|
| 226 |
+
#
|
| 227 |
+
# src_tokenizer = SanskritSourceTokenizer(
|
| 228 |
+
# vocab_size = cfg['model'].get('src_vocab_size', 500),
|
| 229 |
+
# max_len = cfg['model']['max_seq_len'],
|
| 230 |
+
# )
|
| 231 |
+
#
|
| 232 |
+
# results = run_benchmark(model, src_tokenizer, device)
|
| 233 |
+
# print_summary(results)
|
| 234 |
+
# ============================================================
|
| 235 |
+
# FULL TASK 1: KV CACHE + PROJECTION + BENCHMARK + GRAPHS
|
| 236 |
+
# ============================================================
|
| 237 |
+
|
| 238 |
+
import torch
|
| 239 |
+
import torch.nn as nn
|
| 240 |
+
import torch.nn.functional as F
|
| 241 |
+
import time
|
| 242 |
+
import numpy as np
|
| 243 |
+
import matplotlib.pyplot as plt
|
| 244 |
+
|
| 245 |
+
# ============================================================
|
| 246 |
+
# 🔧 MODEL (PATCHED WITH PROJECTION + KV CACHE)
|
| 247 |
+
# ============================================================
|
| 248 |
+
|
| 249 |
+
class D3PMCrossAttention(nn.Module):
|
| 250 |
+
def __init__(self, d_model=512, vocab_size=500, max_seq_len=64, T=128):
|
| 251 |
+
super().__init__()
|
| 252 |
+
|
| 253 |
+
self.d_model = d_model
|
| 254 |
+
self.max_seq_len = max_seq_len
|
| 255 |
+
self.mask_token_id = 0
|
| 256 |
+
|
| 257 |
+
# Dummy encoder/decoder (replace with yours)
|
| 258 |
+
self.encoder = nn.Embedding(vocab_size, d_model)
|
| 259 |
+
self.tgt_embed = nn.Embedding(vocab_size, d_model)
|
| 260 |
+
self.head = nn.Linear(d_model, vocab_size)
|
| 261 |
+
|
| 262 |
+
self.time_mlp = nn.Linear(1, d_model)
|
| 263 |
+
self.hint_gate = nn.Linear(d_model, d_model)
|
| 264 |
+
|
| 265 |
+
# Fake scheduler
|
| 266 |
+
class Scheduler:
|
| 267 |
+
def __init__(self, T):
|
| 268 |
+
self.num_timesteps = T
|
| 269 |
+
self.scheduler = Scheduler(T)
|
| 270 |
+
|
| 271 |
+
# 🔥 Projection layer (Task 1 requirement)
|
| 272 |
+
self.semantic_proj = nn.Linear(d_model, d_model // 2)
|
| 273 |
+
self.semantic_up = nn.Linear(d_model // 2, d_model)
|
| 274 |
+
|
| 275 |
+
# ========================================================
|
| 276 |
+
# ✅ ENCODER WITH PROJECTION
|
| 277 |
+
# ========================================================
|
| 278 |
+
def encode_source(self, src):
|
| 279 |
+
memory = self.encoder(src) # [B, L, d]
|
| 280 |
+
|
| 281 |
+
# 🔥 Compress → Expand
|
| 282 |
+
compressed = self.semantic_proj(memory)
|
| 283 |
+
memory = self.semantic_up(compressed)
|
| 284 |
+
|
| 285 |
+
src_pad_mask = None
|
| 286 |
+
return memory, src_pad_mask
|
| 287 |
+
|
| 288 |
+
# ========================================================
|
| 289 |
+
# ✅ STANDARD (NO CACHE)
|
| 290 |
+
# ========================================================
|
| 291 |
+
def forward(self, src, x, t):
|
| 292 |
+
memory, mask = self.encode_source(src)
|
| 293 |
+
return self.forward_cached(memory, mask, x, t)
|
| 294 |
+
|
| 295 |
+
# ========================================================
|
| 296 |
+
# ✅ CACHED FORWARD
|
| 297 |
+
# ========================================================
|
| 298 |
+
def forward_cached(self, memory, src_pad_mask, x, t, hint=None):
|
| 299 |
+
x = self.tgt_embed(x)
|
| 300 |
+
|
| 301 |
+
t_emb = self.time_mlp((t.float()/self.scheduler.num_timesteps).unsqueeze(-1))
|
| 302 |
+
x = x + t_emb.unsqueeze(1)
|
| 303 |
+
|
| 304 |
+
if hint is not None:
|
| 305 |
+
x = x + self.hint_gate(x) * self.tgt_embed(hint)
|
| 306 |
+
|
| 307 |
+
logits = self.head(x)
|
| 308 |
+
|
| 309 |
+
self._last_hidden = x
|
| 310 |
+
return logits, None
|
| 311 |
+
|
| 312 |
+
# ========================================================
|
| 313 |
+
# ❌ OLD GENERATE (SLOW)
|
| 314 |
+
# ========================================================
|
| 315 |
+
@torch.no_grad()
|
| 316 |
+
def generate(self, src):
|
| 317 |
+
B = src.shape[0]
|
| 318 |
+
device = src.device
|
| 319 |
+
T = self.scheduler.num_timesteps
|
| 320 |
+
|
| 321 |
+
x = torch.zeros((B, self.max_seq_len), dtype=torch.long, device=device)
|
| 322 |
+
|
| 323 |
+
for t_val in range(T - 1, -1, -1):
|
| 324 |
+
t = torch.full((B,), t_val, device=device)
|
| 325 |
+
|
| 326 |
+
logits, _ = self.forward(src, x, t)
|
| 327 |
+
probs = F.softmax(logits, dim=-1)
|
| 328 |
+
|
| 329 |
+
x = torch.argmax(probs, dim=-1)
|
| 330 |
+
|
| 331 |
+
return x
|
| 332 |
+
|
| 333 |
+
# ========================================================
|
| 334 |
+
# ✅ FAST GENERATE (KV CACHE)
|
| 335 |
+
# ========================================================
|
| 336 |
+
@torch.no_grad()
|
| 337 |
+
def generate_cached(self, src):
|
| 338 |
+
B = src.shape[0]
|
| 339 |
+
device = src.device
|
| 340 |
+
T = self.scheduler.num_timesteps
|
| 341 |
+
|
| 342 |
+
# 🔥 Encode once
|
| 343 |
+
memory, mask = self.encode_source(src)
|
| 344 |
+
|
| 345 |
+
x = torch.zeros((B, self.max_seq_len), dtype=torch.long, device=device)
|
| 346 |
+
hint = None
|
| 347 |
+
|
| 348 |
+
for t_val in range(T - 1, -1, -1):
|
| 349 |
+
t = torch.full((B,), t_val, device=device)
|
| 350 |
+
|
| 351 |
+
logits, _ = self.forward_cached(memory, mask, x, t, hint)
|
| 352 |
+
probs = F.softmax(logits, dim=-1)
|
| 353 |
+
|
| 354 |
+
x = torch.argmax(probs, dim=-1)
|
| 355 |
+
hint = x
|
| 356 |
+
|
| 357 |
+
return x
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# ============================================================
|
| 361 |
+
# 📊 BENCHMARK + MEMORY + GRAPHS
|
| 362 |
+
# ============================================================
|
| 363 |
+
|
| 364 |
+
def benchmark(model, device):
|
| 365 |
+
model.to(device)
|
| 366 |
+
model.eval()
|
| 367 |
+
|
| 368 |
+
vocab = 500
|
| 369 |
+
src_lens = [16, 32, 64]
|
| 370 |
+
|
| 371 |
+
standard_times = []
|
| 372 |
+
cached_times = []
|
| 373 |
+
speedups = []
|
| 374 |
+
memory_savings = []
|
| 375 |
+
|
| 376 |
+
for src_len in src_lens:
|
| 377 |
+
print(f"\n🔹 src_len = {src_len}")
|
| 378 |
+
|
| 379 |
+
src = torch.randint(5, vocab, (1, src_len)).to(device)
|
| 380 |
+
|
| 381 |
+
# -------- STANDARD --------
|
| 382 |
+
torch.cuda.reset_peak_memory_stats()
|
| 383 |
+
start = time.time()
|
| 384 |
+
model.generate(src)
|
| 385 |
+
torch.cuda.synchronize()
|
| 386 |
+
t_std = time.time() - start
|
| 387 |
+
mem_std = torch.cuda.max_memory_allocated() / 1024**2
|
| 388 |
+
|
| 389 |
+
# -------- CACHED --------
|
| 390 |
+
torch.cuda.reset_peak_memory_stats()
|
| 391 |
+
start = time.time()
|
| 392 |
+
model.generate_cached(src)
|
| 393 |
+
torch.cuda.synchronize()
|
| 394 |
+
t_cache = time.time() - start
|
| 395 |
+
mem_cache = torch.cuda.max_memory_allocated() / 1024**2
|
| 396 |
+
|
| 397 |
+
speedup = t_std / t_cache
|
| 398 |
+
mem_red = 100 * (mem_std - mem_cache) / mem_std
|
| 399 |
+
|
| 400 |
+
print(f"Time: {t_std:.2f}s → {t_cache:.2f}s | {speedup:.2f}x")
|
| 401 |
+
print(f"Memory: {mem_std:.0f}MB → {mem_cache:.0f}MB | {mem_red:.1f}%")
|
| 402 |
+
|
| 403 |
+
standard_times.append(t_std)
|
| 404 |
+
cached_times.append(t_cache)
|
| 405 |
+
speedups.append(speedup)
|
| 406 |
+
memory_savings.append(mem_red)
|
| 407 |
+
|
| 408 |
+
# ==========================
|
| 409 |
+
# 📈 PLOT: TIME
|
| 410 |
+
# ==========================
|
| 411 |
+
plt.figure()
|
| 412 |
+
plt.plot(src_lens, standard_times, marker='o', label="Standard")
|
| 413 |
+
plt.plot(src_lens, cached_times, marker='o', label="Cached")
|
| 414 |
+
plt.xlabel("Source Length")
|
| 415 |
+
plt.ylabel("Time (s)")
|
| 416 |
+
plt.title("Generation Time")
|
| 417 |
+
plt.legend()
|
| 418 |
+
plt.grid()
|
| 419 |
+
plt.show()
|
| 420 |
+
|
| 421 |
+
# ==========================
|
| 422 |
+
# 📈 PLOT: SPEEDUP
|
| 423 |
+
# ==========================
|
| 424 |
+
plt.figure()
|
| 425 |
+
plt.plot(src_lens, speedups, marker='o')
|
| 426 |
+
plt.xlabel("Source Length")
|
| 427 |
+
plt.ylabel("Speedup (x)")
|
| 428 |
+
plt.title("KV Cache Speedup")
|
| 429 |
+
plt.grid()
|
| 430 |
+
plt.show()
|
| 431 |
+
|
| 432 |
+
# ==========================
|
| 433 |
+
# 📈 PLOT: MEMORY
|
| 434 |
+
# ==========================
|
| 435 |
+
plt.figure()
|
| 436 |
+
plt.plot(src_lens, memory_savings, marker='o')
|
| 437 |
+
plt.xlabel("Source Length")
|
| 438 |
+
plt.ylabel("Memory Reduction (%)")
|
| 439 |
+
plt.title("Memory Savings")
|
| 440 |
+
plt.grid()
|
| 441 |
+
plt.show()
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
# ============================================================
|
| 445 |
+
# 🚀 RUN
|
| 446 |
+
# ============================================================
|
| 447 |
+
|
| 448 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 449 |
+
|
| 450 |
+
model = D3PMCrossAttention()
|
| 451 |
+
benchmark(model, device)
|
analysis/outputs/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 — KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
src_len standard(s) cached(s) speedup encoder% mem-save%
|
| 5 |
+
16 3.431 3.512 0.98x 133.1% 50.0%
|
| 6 |
+
source-mem before=0.070MB after=0.035MB
|
| 7 |
+
32 3.626 3.555 1.02x 36.8% 50.0%
|
| 8 |
+
source-mem before=0.141MB after=0.070MB
|
| 9 |
+
64 3.585 3.701 0.97x 53.3% 50.0%
|
| 10 |
+
source-mem before=0.281MB after=0.141MB
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
Encoder cost = % of one full forward pass
|
| 16 |
+
Speedup = standard_time / cached_time
|
| 17 |
+
Expected: speedup ≈ 1 / (1 - encoder_pct/100 * (T-1)/T)
|
| 18 |
+
|
| 19 |
+
SUMMARY
|
| 20 |
+
-------
|
| 21 |
+
src_len=16: 0.98x speedup (-2.4% time saved, encoder was 133.1% of total, estimated memory change 50.0%)
|
| 22 |
+
src_len=32: 1.02x speedup (1.9% time saved, encoder was 36.8% of total, estimated memory change 50.0%)
|
| 23 |
+
src_len=64: 0.97x speedup (-3.2% time saved, encoder was 53.3% of total, estimated memory change 50.0%)
|
analysis/outputs/task2_all_layers_t0.png
ADDED
|
analysis/outputs/task2_attn_evolution.png
ADDED
|
analysis/outputs/task2_attn_t0.png
ADDED
|
analysis/outputs/task2_attn_t127.png
ADDED
|
analysis/outputs/task2_examples/example_1_attn_t0.png
ADDED
|
analysis/outputs/task2_examples/example_2_attn_t0.png
ADDED
|
analysis/outputs/task2_examples/example_3_attn_t0.png
ADDED
|
analysis/outputs/task2_examples/example_4_attn_t0.png
ADDED
|
analysis/outputs/task2_examples/example_5_attn_t0.png
ADDED
|
analysis/outputs/task2_report.txt
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 2 — ATTENTION + DRIFT REPORT
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
Input : dharmo rakṣati rakṣitaḥ
|
| 5 |
+
Output : कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा ब्र कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा ध्या ध्या ध्या कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा
|
| 6 |
+
|
| 7 |
+
Lock-in t : 122
|
| 8 |
+
Mean pos lock-in : 118.7 ± 17.7
|
| 9 |
+
|
| 10 |
+
Source alignment metric : bertscore_f1
|
| 11 |
+
Best source-alignment step : t=127
|
| 12 |
+
Locked positions : 12
|
| 13 |
+
Flexible positions : 8
|
| 14 |
+
TF-IDF vs attention stability correlation : 0.0
|
| 15 |
+
|
| 16 |
+
Step → Output → CER-to-final
|
| 17 |
+
------------------------------------------------------------
|
| 18 |
+
t= 127 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.2293
|
| 19 |
+
t= 122 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0769
|
| 20 |
+
t= 117 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0698
|
| 21 |
+
t= 112 | कुङ्कुमा लये कुङ्कुमा कुङ्कुमा कुङ्कुमा | 0.0541
|
| 22 |
+
t= 107 | कुङ्कुमा ध्या कुङ्कुमा कुङ्कुमा कुङ्कुमा | 0.0670
|
| 23 |
+
t= 102 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0442
|
| 24 |
+
t= 97 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0342
|
| 25 |
+
t= 92 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0456
|
| 26 |
+
t= 87 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0299
|
| 27 |
+
t= 82 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0214
|
| 28 |
+
t= 77 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0214
|
| 29 |
+
t= 72 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0214
|
| 30 |
+
t= 67 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0214
|
| 31 |
+
t= 62 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0128
|
| 32 |
+
t= 57 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0128
|
| 33 |
+
t= 52 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0128
|
| 34 |
+
t= 47 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0043
|
| 35 |
+
t= 42 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0043
|
| 36 |
+
t= 37 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
|
| 37 |
+
t= 32 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
|
| 38 |
+
t= 27 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
|
| 39 |
+
t= 22 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
|
| 40 |
+
t= 17 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
|
| 41 |
+
t= 12 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
|
| 42 |
+
t= 7 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
|
| 43 |
+
t= 2 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
|
| 44 |
+
t= 0 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
|
| 45 |
+
|
| 46 |
+
Step → Source alignment
|
| 47 |
+
------------------------------------------------------------
|
| 48 |
+
t= 127 | 0.4312
|
| 49 |
+
t= 122 | 0.3941
|
| 50 |
+
t= 117 | 0.3963
|
| 51 |
+
t= 112 | 0.3871
|
| 52 |
+
t= 107 | 0.3947
|
| 53 |
+
t= 102 | 0.3950
|
| 54 |
+
t= 97 | 0.3894
|
| 55 |
+
t= 92 | 0.3887
|
| 56 |
+
t= 87 | 0.3897
|
| 57 |
+
t= 82 | 0.3881
|
| 58 |
+
t= 77 | 0.3881
|
| 59 |
+
t= 72 | 0.3881
|
| 60 |
+
t= 67 | 0.3881
|
| 61 |
+
t= 62 | 0.3889
|
| 62 |
+
t= 57 | 0.3889
|
| 63 |
+
t= 52 | 0.3889
|
| 64 |
+
t= 47 | 0.3882
|
| 65 |
+
t= 42 | 0.3882
|
| 66 |
+
t= 37 | 0.3901
|
| 67 |
+
t= 32 | 0.3901
|
| 68 |
+
t= 27 | 0.3901
|
| 69 |
+
t= 22 | 0.3901
|
| 70 |
+
t= 17 | 0.3901
|
| 71 |
+
t= 12 | 0.3901
|
| 72 |
+
t= 7 | 0.3901
|
| 73 |
+
t= 2 | 0.3901
|
| 74 |
+
t= 0 | 0.3901
|
| 75 |
+
|
| 76 |
+
Locked target positions
|
| 77 |
+
------------------------------------------------------------
|
| 78 |
+
tgt[0]=कुङ्कुमा → src[3]=taḥ stability=0.781
|
| 79 |
+
tgt[1]=शिरः → src[3]=taḥ stability=0.781
|
| 80 |
+
tgt[2]=कुङ्कुमा → src[3]=taḥ stability=0.780
|
| 81 |
+
tgt[3]=कुङ्कुमा → src[2]=rakṣi stability=0.780
|
| 82 |
+
tgt[4]=पुरतो → src[2]=rakṣi stability=0.781
|
| 83 |
+
tgt[5]=कुङ्कुमा → src[2]=rakṣi stability=0.781
|
| 84 |
+
tgt[8]=मु → src[3]=taḥ stability=0.782
|
| 85 |
+
tgt[9]=कुङ्कुमा → src[3]=taḥ stability=0.783
|
| 86 |
+
tgt[10]=कुङ्कुमा → src[3]=taḥ stability=0.783
|
| 87 |
+
tgt[11]=कुङ्कुमा → src[3]=taḥ stability=0.781
|
| 88 |
+
tgt[13]=कुङ्कुमा → src[2]=rakṣi stability=0.781
|
| 89 |
+
tgt[14]=कुङ्कुमा → src[2]=rakṣi stability=0.781
|
| 90 |
+
|
| 91 |
+
Flexible target positions
|
| 92 |
+
------------------------------------------------------------
|
| 93 |
+
tgt[6]=कुङ्कुमा → src[2]=rakṣi stability=0.731
|
| 94 |
+
tgt[7]=कुङ्कुमा → src[2]=rakṣi stability=0.481
|
| 95 |
+
tgt[12]=कुङ्कुमा → src[2]=rakṣi stability=0.431
|
| 96 |
+
tgt[15]=कुङ्कुमा → src[2]=rakṣi stability=0.480
|
| 97 |
+
tgt[16]=कुङ्कुमा → src[2]=rakṣi stability=0.479
|
| 98 |
+
tgt[17]=कुङ्कुमा → src[2]=rakṣi stability=0.428
|
| 99 |
+
tgt[18]=कुङ्कुमा → src[3]=taḥ stability=0.727
|
| 100 |
+
tgt[19]=कुङ्कुमा → src[0]=dharmo stability=0.377
|
analysis/outputs/task2_semantic_drift.png
ADDED
|
analysis/outputs/task2_source_alignment.png
ADDED
|
analysis/outputs/task3_concept_space.png
ADDED
|
Git LFS Details
|
analysis/outputs/task3_diversity_direction.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1dff757876fd9352d5c1f86d2af244c9784d9ec66639a0f31ec5f6c9ec608d4b
|
| 3 |
+
size 1664
|
analysis/outputs/task3_report.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 3 — CONCEPT VECTORS + PCA STEERING
|
| 2 |
+
==================================================
|
| 3 |
+
|
| 4 |
+
PCA: 50 components, 96.1% variance
|
| 5 |
+
Diversity PC: 1 (|r|=0.303 with output length)
|
| 6 |
+
|
| 7 |
+
Diversity spectrum:
|
| 8 |
+
alpha=-2.0 → विष द्धा समन्व ददर्श रे विष रे द्धा रे रे ष्व विष रे विष रे रे रे विष रे रे रे विष रे रे कार साग ददर्श वादि रे रे रे रे ददर्श रे रे रे विस्त रे रे समन्व सुर रे वस्तु रे रे रे रे रे रे रे सुर रे रे रे रे रे सुर रे ैक किंचि वस्तु विष रे कार रे विष कार गतिं रे कार शो कार कार कार साग समन्व रे कार कार कार
|
| 9 |
+
alpha=-1.0 → रे विष विष ष्व रे रे विष विष रे विष ददर्श रे ्य् रे रे रे विष रे रे शः रे भवि वस्तु रे विष ्य् विष रे रे वस्तु घा वादि रे रे ्य् रे रे ्य् रे रे रे ्य् पृत रे रे नृप रे द्धा रे रे रे रे ्य् रे रे त्तु रे ्य् रे विष रे सुर साग विष रे कार विष विष ्य् रे रे ्य् ्य् ्य् ्य् रे कार कार कार कार
|
| 10 |
+
alpha=+0.0 → विष ष्व भवि दित्य द्धा रे तौ वृ ्य् रे वादि ॠ रे विष रे ष्व रे का रे ्य् रे ्य् विष ्य् ष्व ्य् वृ जना रे भवि वस्तु त्रिषु विष घा भु की ्य् वृ रे भु यां वृ रे भु यां समु रे रे ्य् रे भु वृ ्य् क्ष ्य् ान्त ्य् ्य् ्य् व्रजेत् ्य् भु रे रे ्य् रे उक्त ्य् ्य् समन्व ्य् ्य् सु ल्प वीर ्य् ्य् ्य् विष ्य्
|
| 11 |
+
alpha=+1.0 → ॠ वृ वृ वृ वृ वृ ण् भवि ्त वृ वृ दश ्य् यां ॠ भु तं भु भु ान्त भवि भु भु रे यां वस्तु यां यां भु यां यां यां यां ्य् यां भु दृष्ट दृष्ट यां यां भु यां यां यां यां द्वि भु यां भु क्ष भु भु भु ष्ट रु ब्र भु न्तु ण्ड यां भु यां ्य् क्ष ्य् वृ ्य् , यां भु यां भु रोध भु ्य् यां ्य् ्य् यां यां
|
| 12 |
+
alpha=+2.0 → वृ वृ वृ ण् वृ वृ ब्र वृ ष्ट ष्ट ष्ट ्य् मा यां ष्ट यां ब्र यां तं तं भु भु वृ भु यां धनम् यां क्ष यां द्वि भु यां यां यां यां द्वि यां भु भु यां यां भु यां क्ष यां भु यां भु ्य् यां भु यां यां मा यां यां भु वृ यां धा भु यां यां मा भु हृ यां यां यां भु द्वि यां द्वि ब्र ण्ड मा द्वि यां यां भु
|
analysis/outputs/task5_quality_classifier.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0410b67872dbf030b2db5410ecca92f6357d90ae9f47f2c7cf1ad8202c274f61
|
| 3 |
+
size 233761
|
analysis/outputs/task5_quality_data.npz
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dad6d37cae2b157877a4106d92528417981f75ae57cddfd46112441cd7e9a338
|
| 3 |
+
size 770512
|
analysis/outputs_multi/results__d3pm_cross_attention_neg_False/task1/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 — KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
src_len standard(s) cached(s) speedup encoder% mem-save%
|
| 5 |
+
16 3.309 3.624 0.91x 52.9% 50.0%
|
| 6 |
+
source-mem before=0.070MB after=0.035MB
|
| 7 |
+
32 4.214 4.234 1.00x 40.0% 50.0%
|
| 8 |
+
source-mem before=0.141MB after=0.070MB
|
| 9 |
+
64 6.929 8.372 0.83x 58.7% 50.0%
|
| 10 |
+
source-mem before=0.281MB after=0.141MB
|
analysis/outputs_multi/results__d3pm_cross_attention_neg_True/task1/task1_kv_cache.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TASK 1 — KV CACHE BENCHMARK
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
src_len standard(s) cached(s) speedup encoder% mem-save%
|
| 5 |
+
16 2.548 2.464 1.03x 31.6% 50.0%
|
| 6 |
+
source-mem before=0.070MB after=0.035MB
|
| 7 |
+
32 3.222 2.952 1.09x 37.8% 50.0%
|
| 8 |
+
source-mem before=0.141MB after=0.070MB
|
| 9 |
+
64 4.121 4.335 0.95x 33.6% 50.0%
|
| 10 |
+
source-mem before=0.281MB after=0.141MB
|
analysis/quality_classifier.py
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# analysis/quality_classifier.py
|
| 3 |
+
# ================================
|
| 4 |
+
# Task 5: Classifier-Free Guidance for Paraphrase Quality Control
|
| 5 |
+
#
|
| 6 |
+
# Two steps — only Step 2 requires training a SMALL model (not the main D3PM):
|
| 7 |
+
#
|
| 8 |
+
# STEP 1 — Collect training data (no training):
|
| 9 |
+
# Run existing model on val set, record (hidden_state, CER) pairs.
|
| 10 |
+
# Hidden states come from model.model._last_hidden after forward_cached().
|
| 11 |
+
# CER score = quality label (lower CER = higher quality).
|
| 12 |
+
#
|
| 13 |
+
# STEP 2 — Train quality classifier:
|
| 14 |
+
# Small 2-layer MLP: d_model → 64 → 1
|
| 15 |
+
# Input: pooled decoder hidden state [B, d_model]
|
| 16 |
+
# Output: predicted quality score in [0, 1] (1 = high quality)
|
| 17 |
+
# Loss: MSE against normalized CER labels
|
| 18 |
+
# Training time: ~5-10 minutes on CPU for 10k examples
|
| 19 |
+
#
|
| 20 |
+
# STEP 3 — Guided inference (no retraining):
|
| 21 |
+
# At each diffusion step, use classifier gradient to shift logits:
|
| 22 |
+
# guided_logits = logits + λ * ∂(quality_score)/∂(logits)
|
| 23 |
+
# Higher λ → model biased toward high-quality outputs
|
| 24 |
+
# λ=0 → standard generation (no guidance)
|
| 25 |
+
#
|
| 26 |
+
# Key: main D3PM model is FROZEN throughout. Only the 10k-param classifier trains.
|
| 27 |
+
# """
|
| 28 |
+
#
|
| 29 |
+
# import torch
|
| 30 |
+
# import torch.nn as nn
|
| 31 |
+
# import torch.nn.functional as F
|
| 32 |
+
# import numpy as np
|
| 33 |
+
# import os
|
| 34 |
+
# import json
|
| 35 |
+
# from typing import List, Dict, Optional, Tuple
|
| 36 |
+
#
|
| 37 |
+
#
|
| 38 |
+
# # ── Quality classifier architecture ──────────────────────────────────
|
| 39 |
+
#
|
| 40 |
+
# class QualityClassifier(nn.Module):
|
| 41 |
+
# """
|
| 42 |
+
# Lightweight MLP that predicts transliteration quality from decoder
|
| 43 |
+
# hidden states.
|
| 44 |
+
#
|
| 45 |
+
# Architecture:
|
| 46 |
+
# d_model → 128 → 64 → 1 → Sigmoid
|
| 47 |
+
#
|
| 48 |
+
# Input: mean-pooled decoder hidden state [B, d_model]
|
| 49 |
+
# Output: quality score [B, 1] ∈ [0, 1] (1 = high quality)
|
| 50 |
+
#
|
| 51 |
+
# ~10k parameters. Trains in minutes on CPU.
|
| 52 |
+
# """
|
| 53 |
+
# def __init__(self, d_model: int):
|
| 54 |
+
# super().__init__()
|
| 55 |
+
# self.net = nn.Sequential(
|
| 56 |
+
# nn.Linear(d_model, 128),
|
| 57 |
+
# nn.ReLU(),
|
| 58 |
+
# nn.Dropout(0.1),
|
| 59 |
+
# nn.Linear(128, 64),
|
| 60 |
+
# nn.ReLU(),
|
| 61 |
+
# nn.Linear(64, 1),
|
| 62 |
+
# nn.Sigmoid(),
|
| 63 |
+
# )
|
| 64 |
+
# self.d_model = d_model
|
| 65 |
+
#
|
| 66 |
+
# def forward(self, hidden: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
# """
|
| 68 |
+
# Args:
|
| 69 |
+
# hidden : [B, tgt_len, d_model] OR [B, d_model] (already pooled)
|
| 70 |
+
#
|
| 71 |
+
# Returns:
|
| 72 |
+
# score : [B, 1] quality score in [0, 1]
|
| 73 |
+
# """
|
| 74 |
+
# if hidden.dim() == 3:
|
| 75 |
+
# # Pool over sequence length
|
| 76 |
+
# hidden = hidden.mean(dim=1) # [B, d_model]
|
| 77 |
+
# return self.net(hidden) # [B, 1]
|
| 78 |
+
#
|
| 79 |
+
#
|
| 80 |
+
# # ── Training data collection ──────────────────────────────────────────
|
| 81 |
+
#
|
| 82 |
+
# @torch.no_grad()
|
| 83 |
+
# def collect_quality_data(
|
| 84 |
+
# model,
|
| 85 |
+
# src_list: List[torch.Tensor],
|
| 86 |
+
# ref_list: List[str],
|
| 87 |
+
# tgt_tokenizer,
|
| 88 |
+
# t_capture: int = 0,
|
| 89 |
+
# temperature: float = 0.8,
|
| 90 |
+
# top_k: int = 40,
|
| 91 |
+
# max_samples: int = 5000,
|
| 92 |
+
# ) -> Tuple[np.ndarray, np.ndarray]:
|
| 93 |
+
# """
|
| 94 |
+
# Collect (hidden_state, quality_score) pairs for classifier training.
|
| 95 |
+
#
|
| 96 |
+
# For each sample:
|
| 97 |
+
# 1. Run generate_cached() on src
|
| 98 |
+
# 2. Capture decoder hidden state at t=t_capture
|
| 99 |
+
# 3. Compute CER between output and reference
|
| 100 |
+
# 4. Quality = 1 - CER (normalize to [0,1])
|
| 101 |
+
#
|
| 102 |
+
# Args:
|
| 103 |
+
# model : SanskritModel
|
| 104 |
+
# src_list : list of [1, src_len] tensors
|
| 105 |
+
# ref_list : list of reference Devanagari strings
|
| 106 |
+
# tgt_tokenizer : SanskritTargetTokenizer
|
| 107 |
+
# t_capture : which step to capture hidden states (0 = final)
|
| 108 |
+
# max_samples : cap number of training examples
|
| 109 |
+
#
|
| 110 |
+
# Returns:
|
| 111 |
+
# hidden_matrix : np.ndarray [N, d_model]
|
| 112 |
+
# quality_scores: np.ndarray [N] values in [0, 1]
|
| 113 |
+
# """
|
| 114 |
+
# inner = model.model
|
| 115 |
+
# T = inner.scheduler.num_timesteps
|
| 116 |
+
# device = next(inner.parameters()).device
|
| 117 |
+
#
|
| 118 |
+
# hidden_list = []
|
| 119 |
+
# quality_list = []
|
| 120 |
+
# n = min(len(src_list), max_samples)
|
| 121 |
+
#
|
| 122 |
+
# def cer(pred, ref):
|
| 123 |
+
# if not ref:
|
| 124 |
+
# return 1.0
|
| 125 |
+
# def ed(s1, s2):
|
| 126 |
+
# m, n = len(s1), len(s2)
|
| 127 |
+
# dp = list(range(n + 1))
|
| 128 |
+
# for i in range(1, m + 1):
|
| 129 |
+
# prev, dp[0] = dp[0], i
|
| 130 |
+
# for j in range(1, n + 1):
|
| 131 |
+
# temp = dp[j]
|
| 132 |
+
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 133 |
+
# prev = temp
|
| 134 |
+
# return dp[n]
|
| 135 |
+
# return ed(pred, ref) / max(len(ref), 1)
|
| 136 |
+
#
|
| 137 |
+
# print(f"Collecting quality data from {n} examples...")
|
| 138 |
+
# for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
|
| 139 |
+
# if i % 200 == 0:
|
| 140 |
+
# print(f" {i}/{n}")
|
| 141 |
+
#
|
| 142 |
+
# if src.dim() == 1:
|
| 143 |
+
# src = src.unsqueeze(0)
|
| 144 |
+
# src = src.to(device)
|
| 145 |
+
#
|
| 146 |
+
# B = src.shape[0]
|
| 147 |
+
# tgt_len = inner.max_seq_len
|
| 148 |
+
# mask_id = inner.mask_token_id
|
| 149 |
+
#
|
| 150 |
+
# memory, src_pad_mask = inner.encode_source(src)
|
| 151 |
+
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 152 |
+
# hint = None
|
| 153 |
+
# h_cap = None
|
| 154 |
+
#
|
| 155 |
+
# for t_val in range(T - 1, -1, -1):
|
| 156 |
+
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 157 |
+
# is_last = (t_val == 0)
|
| 158 |
+
#
|
| 159 |
+
# logits, _ = inner.forward_cached(
|
| 160 |
+
# memory, src_pad_mask, x0_est, t,
|
| 161 |
+
# x0_hint=hint, inference_mode=True,
|
| 162 |
+
# )
|
| 163 |
+
#
|
| 164 |
+
# if t_val == t_capture and hasattr(inner, '_last_hidden'):
|
| 165 |
+
# h_cap = inner._last_hidden[0].mean(dim=0).detach().cpu() # [d_model]
|
| 166 |
+
#
|
| 167 |
+
# logits = logits / max(temperature, 1e-8)
|
| 168 |
+
# if top_k > 0:
|
| 169 |
+
# V = logits.shape[-1]
|
| 170 |
+
# if top_k < V:
|
| 171 |
+
# vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 172 |
+
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 173 |
+
#
|
| 174 |
+
# probs = F.softmax(logits, dim=-1)
|
| 175 |
+
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 176 |
+
# hint = x0_est
|
| 177 |
+
#
|
| 178 |
+
# if h_cap is None:
|
| 179 |
+
# continue
|
| 180 |
+
#
|
| 181 |
+
# ids = [x for x in x0_est[0].tolist() if x > 4]
|
| 182 |
+
# pred = tgt_tokenizer.decode(ids).strip()
|
| 183 |
+
# q = max(0.0, 1.0 - cer(pred, ref)) # quality = 1 - CER
|
| 184 |
+
#
|
| 185 |
+
# hidden_list.append(h_cap.numpy())
|
| 186 |
+
# quality_list.append(q)
|
| 187 |
+
#
|
| 188 |
+
# print(f"Collected {len(hidden_list)} quality examples.")
|
| 189 |
+
# print(f"Quality stats: mean={np.mean(quality_list):.3f} "
|
| 190 |
+
# f"min={np.min(quality_list):.3f} max={np.max(quality_list):.3f}")
|
| 191 |
+
#
|
| 192 |
+
# return np.stack(hidden_list), np.array(quality_list, dtype=np.float32)
|
| 193 |
+
#
|
| 194 |
+
#
|
| 195 |
+
# def _sample(probs):
|
| 196 |
+
# B, L, V = probs.shape
|
| 197 |
+
# flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 198 |
+
# flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 199 |
+
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 200 |
+
#
|
| 201 |
+
#
|
| 202 |
+
# # ── Training ──────────────────────────────────────────────────────────
|
| 203 |
+
#
|
| 204 |
+
# def train_quality_classifier(
|
| 205 |
+
# hidden_matrix: np.ndarray,
|
| 206 |
+
# quality_scores: np.ndarray,
|
| 207 |
+
# d_model: int,
|
| 208 |
+
# epochs: int = 30,
|
| 209 |
+
# batch_size: int = 64,
|
| 210 |
+
# lr: float = 1e-3,
|
| 211 |
+
# val_frac: float = 0.1,
|
| 212 |
+
# save_path: Optional[str] = None,
|
| 213 |
+
# ) -> QualityClassifier:
|
| 214 |
+
# """
|
| 215 |
+
# Train QualityClassifier on collected (hidden, quality) pairs.
|
| 216 |
+
#
|
| 217 |
+
# Args:
|
| 218 |
+
# hidden_matrix : [N, d_model] from collect_quality_data()
|
| 219 |
+
# quality_scores : [N] quality labels in [0, 1]
|
| 220 |
+
# d_model : hidden dimension
|
| 221 |
+
# epochs : training epochs
|
| 222 |
+
# save_path : if given, save trained classifier weights here
|
| 223 |
+
#
|
| 224 |
+
# Returns:
|
| 225 |
+
# trained QualityClassifier
|
| 226 |
+
# """
|
| 227 |
+
# device = torch.device("cpu") # classifier is tiny, CPU is fine
|
| 228 |
+
#
|
| 229 |
+
# X = torch.tensor(hidden_matrix, dtype=torch.float32)
|
| 230 |
+
# y = torch.tensor(quality_scores, dtype=torch.float32).unsqueeze(-1)
|
| 231 |
+
#
|
| 232 |
+
# N = len(X)
|
| 233 |
+
# n_val = max(1, int(N * val_frac))
|
| 234 |
+
# idx = torch.randperm(N)
|
| 235 |
+
# val_idx = idx[:n_val]
|
| 236 |
+
# train_idx = idx[n_val:]
|
| 237 |
+
#
|
| 238 |
+
# X_train, y_train = X[train_idx], y[train_idx]
|
| 239 |
+
# X_val, y_val = X[val_idx], y[val_idx]
|
| 240 |
+
#
|
| 241 |
+
# clf = QualityClassifier(d_model).to(device)
|
| 242 |
+
# optimizer = torch.optim.Adam(clf.parameters(), lr=lr)
|
| 243 |
+
#
|
| 244 |
+
# print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
|
| 245 |
+
# print(f"Train: {len(X_train)} Val: {len(X_val)}")
|
| 246 |
+
#
|
| 247 |
+
# best_val_loss = float('inf')
|
| 248 |
+
# best_state = None
|
| 249 |
+
#
|
| 250 |
+
# for epoch in range(epochs):
|
| 251 |
+
# clf.train()
|
| 252 |
+
# perm = torch.randperm(len(X_train))
|
| 253 |
+
# train_loss = 0.0
|
| 254 |
+
# n_batches = 0
|
| 255 |
+
#
|
| 256 |
+
# for start in range(0, len(X_train), batch_size):
|
| 257 |
+
# batch_idx = perm[start:start + batch_size]
|
| 258 |
+
# xb, yb = X_train[batch_idx], y_train[batch_idx]
|
| 259 |
+
# pred = clf(xb)
|
| 260 |
+
# loss = F.mse_loss(pred, yb)
|
| 261 |
+
# optimizer.zero_grad()
|
| 262 |
+
# loss.backward()
|
| 263 |
+
# optimizer.step()
|
| 264 |
+
# train_loss += loss.item()
|
| 265 |
+
# n_batches += 1
|
| 266 |
+
#
|
| 267 |
+
# clf.eval()
|
| 268 |
+
# with torch.no_grad():
|
| 269 |
+
# val_pred = clf(X_val)
|
| 270 |
+
# val_loss = F.mse_loss(val_pred, y_val).item()
|
| 271 |
+
#
|
| 272 |
+
# if epoch % 5 == 0 or epoch == epochs - 1:
|
| 273 |
+
# print(f" Ep {epoch+1:3d} train={train_loss/n_batches:.4f} val={val_loss:.4f}")
|
| 274 |
+
#
|
| 275 |
+
# if val_loss < best_val_loss:
|
| 276 |
+
# best_val_loss = val_loss
|
| 277 |
+
# best_state = {k: v.clone() for k, v in clf.state_dict().items()}
|
| 278 |
+
#
|
| 279 |
+
# if best_state:
|
| 280 |
+
# clf.load_state_dict(best_state)
|
| 281 |
+
# print(f" Best val loss: {best_val_loss:.4f}")
|
| 282 |
+
#
|
| 283 |
+
# if save_path:
|
| 284 |
+
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 285 |
+
# torch.save(clf.state_dict(), save_path)
|
| 286 |
+
# print(f" Classifier saved: {save_path}")
|
| 287 |
+
#
|
| 288 |
+
# return clf
|
| 289 |
+
#
|
| 290 |
+
#
|
| 291 |
+
# # ── Guided inference ──────────────────────────────────────────────────
|
| 292 |
+
#
|
| 293 |
+
# def generate_guided(
|
| 294 |
+
# model,
|
| 295 |
+
# src: torch.Tensor,
|
| 296 |
+
# classifier: QualityClassifier,
|
| 297 |
+
# guidance_scale: float = 1.0,
|
| 298 |
+
# temperature: float = 0.8,
|
| 299 |
+
# top_k: int = 40,
|
| 300 |
+
# ) -> torch.Tensor:
|
| 301 |
+
# """
|
| 302 |
+
# Classifier-guided generation.
|
| 303 |
+
#
|
| 304 |
+
# At each diffusion step:
|
| 305 |
+
# 1. Run forward_cached() → logits, hidden states
|
| 306 |
+
# 2. Compute classifier gradient: ∂(quality_score) / ∂(hidden)
|
| 307 |
+
# 3. Project gradient back to logit space (approximate)
|
| 308 |
+
# 4. guided_logits = logits + λ * gradient_signal
|
| 309 |
+
# 5. Sample from guided_logits
|
| 310 |
+
#
|
| 311 |
+
# guidance_scale λ:
|
| 312 |
+
# 0.0 → no guidance (standard generation)
|
| 313 |
+
# 0.5 → weak guidance
|
| 314 |
+
# 1.0 → moderate guidance (recommended starting point)
|
| 315 |
+
# 2.0 → strong guidance (may reduce diversity)
|
| 316 |
+
# 3.0 → very strong (may collapse to repetitive output)
|
| 317 |
+
#
|
| 318 |
+
# Args:
|
| 319 |
+
# model : SanskritModel (frozen)
|
| 320 |
+
# src : [1, src_len] IAST token ids
|
| 321 |
+
# classifier : trained QualityClassifier
|
| 322 |
+
# guidance_scale : λ — guidance strength
|
| 323 |
+
#
|
| 324 |
+
# Returns:
|
| 325 |
+
# x0_est : [1, tgt_len] generated token ids
|
| 326 |
+
# """
|
| 327 |
+
# inner = model.model
|
| 328 |
+
# T = inner.scheduler.num_timesteps
|
| 329 |
+
# device = next(inner.parameters()).device
|
| 330 |
+
# clf_device = next(classifier.parameters()).device
|
| 331 |
+
#
|
| 332 |
+
# if src.dim() == 1:
|
| 333 |
+
# src = src.unsqueeze(0)
|
| 334 |
+
# src = src.to(device)
|
| 335 |
+
#
|
| 336 |
+
# B = src.shape[0]
|
| 337 |
+
# tgt_len = inner.max_seq_len
|
| 338 |
+
# mask_id = inner.mask_token_id
|
| 339 |
+
#
|
| 340 |
+
# memory, src_pad_mask = inner.encode_source(src)
|
| 341 |
+
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 342 |
+
# hint = None
|
| 343 |
+
#
|
| 344 |
+
# inner.eval()
|
| 345 |
+
# classifier.eval()
|
| 346 |
+
#
|
| 347 |
+
# for t_val in range(T - 1, -1, -1):
|
| 348 |
+
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 349 |
+
# is_last = (t_val == 0)
|
| 350 |
+
#
|
| 351 |
+
# if guidance_scale > 0.0:
|
| 352 |
+
# # Need gradients for classifier guidance
|
| 353 |
+
# with torch.enable_grad():
|
| 354 |
+
# # Run forward_cached and get hidden states
|
| 355 |
+
# PAD = 1
|
| 356 |
+
# if t_val > 0:
|
| 357 |
+
# _, x_t_ids = inner.forward_process.q_sample(x0_est, t)
|
| 358 |
+
# else:
|
| 359 |
+
# x_t_ids = x0_est
|
| 360 |
+
#
|
| 361 |
+
# x = inner.tgt_embed(x_t_ids)
|
| 362 |
+
# t_norm = t.float() / T
|
| 363 |
+
# t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
|
| 364 |
+
# x = x + t_emb.unsqueeze(1)
|
| 365 |
+
#
|
| 366 |
+
# if hint is not None:
|
| 367 |
+
# hint_emb = inner.tgt_embed(hint)
|
| 368 |
+
# gate = inner.hint_gate(x)
|
| 369 |
+
# x = x + gate * hint_emb
|
| 370 |
+
#
|
| 371 |
+
# for block in inner.decoder_blocks:
|
| 372 |
+
# x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
|
| 373 |
+
#
|
| 374 |
+
# # hidden: [B, tgt_len, d_model] — detach from graph for clf
|
| 375 |
+
# hidden = x.detach().requires_grad_(True).to(clf_device)
|
| 376 |
+
#
|
| 377 |
+
# # Classifier quality score
|
| 378 |
+
# quality = classifier(hidden) # [B, 1]
|
| 379 |
+
# quality.sum().backward()
|
| 380 |
+
#
|
| 381 |
+
# # Gradient of quality w.r.t. hidden: [B, tgt_len, d_model]
|
| 382 |
+
# grad = hidden.grad.to(device) # [B, tgt_len, d_model]
|
| 383 |
+
#
|
| 384 |
+
# # Project gradient to logit space via output head weight
|
| 385 |
+
# # logit_grad ≈ grad @ head.weight [B, tgt_len, tgt_vocab]
|
| 386 |
+
# logit_grad = grad @ inner.head.weight.T
|
| 387 |
+
#
|
| 388 |
+
# # Compute standard logits (no gradient needed)
|
| 389 |
+
# with torch.no_grad():
|
| 390 |
+
# logits = inner.head(x)
|
| 391 |
+
#
|
| 392 |
+
# # Apply guidance
|
| 393 |
+
# logits = logits + guidance_scale * logit_grad
|
| 394 |
+
#
|
| 395 |
+
# else:
|
| 396 |
+
# with torch.no_grad():
|
| 397 |
+
# logits, _ = inner.forward_cached(
|
| 398 |
+
# memory, src_pad_mask, x0_est, t,
|
| 399 |
+
# x0_hint=hint, inference_mode=True,
|
| 400 |
+
# )
|
| 401 |
+
#
|
| 402 |
+
# with torch.no_grad():
|
| 403 |
+
# logits = logits / max(temperature, 1e-8)
|
| 404 |
+
# if top_k > 0:
|
| 405 |
+
# V = logits.shape[-1]
|
| 406 |
+
# if top_k < V:
|
| 407 |
+
# vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 408 |
+
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 409 |
+
#
|
| 410 |
+
# probs = F.softmax(logits, dim=-1)
|
| 411 |
+
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample_no_grad(probs)
|
| 412 |
+
# hint = x0_est
|
| 413 |
+
#
|
| 414 |
+
# return x0_est
|
| 415 |
+
#
|
| 416 |
+
#
|
| 417 |
+
# def _sample_no_grad(probs):
|
| 418 |
+
# B, L, V = probs.shape
|
| 419 |
+
# flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 420 |
+
# flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 421 |
+
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 422 |
+
#
|
| 423 |
+
#
|
| 424 |
+
# # ── Guidance scale sweep ──────────────────────────────────────────────
|
| 425 |
+
#
|
| 426 |
+
# def sweep_guidance_scales(
|
| 427 |
+
# model,
|
| 428 |
+
# classifier: QualityClassifier,
|
| 429 |
+
# src_list: List[torch.Tensor],
|
| 430 |
+
# ref_list: List[str],
|
| 431 |
+
# tgt_tokenizer,
|
| 432 |
+
# scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
|
| 433 |
+
# n_samples: int = 50,
|
| 434 |
+
# device: torch.device = None,
|
| 435 |
+
# output_dir: str = "analysis/outputs",
|
| 436 |
+
# ) -> Dict:
|
| 437 |
+
# """
|
| 438 |
+
# Evaluate CER at each guidance scale.
|
| 439 |
+
# Produces quality-diversity tradeoff plot.
|
| 440 |
+
# """
|
| 441 |
+
# def cer(pred, ref):
|
| 442 |
+
# if not ref:
|
| 443 |
+
# return 1.0
|
| 444 |
+
# def ed(s1, s2):
|
| 445 |
+
# m, n = len(s1), len(s2)
|
| 446 |
+
# dp = list(range(n + 1))
|
| 447 |
+
# for i in range(1, m + 1):
|
| 448 |
+
# prev, dp[0] = dp[0], i
|
| 449 |
+
# for j in range(1, n + 1):
|
| 450 |
+
# temp = dp[j]
|
| 451 |
+
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 452 |
+
# prev = temp
|
| 453 |
+
# return dp[n]
|
| 454 |
+
# return ed(pred, ref) / max(len(ref), 1)
|
| 455 |
+
#
|
| 456 |
+
# device = device or next(model.parameters()).device
|
| 457 |
+
# results = {}
|
| 458 |
+
# n = min(n_samples, len(src_list))
|
| 459 |
+
#
|
| 460 |
+
# print("\nGuidance scale sweep...")
|
| 461 |
+
# for scale in scales:
|
| 462 |
+
# cer_list = []
|
| 463 |
+
# output_set = []
|
| 464 |
+
# for src, ref in zip(src_list[:n], ref_list[:n]):
|
| 465 |
+
# if src.dim() == 1:
|
| 466 |
+
# src = src.unsqueeze(0)
|
| 467 |
+
# out = generate_guided(model, src.to(device), classifier,
|
| 468 |
+
# guidance_scale=scale)
|
| 469 |
+
# ids = [x for x in out[0].tolist() if x > 4]
|
| 470 |
+
# pred = tgt_tokenizer.decode(ids).strip()
|
| 471 |
+
# cer_list.append(cer(pred, ref))
|
| 472 |
+
# output_set.append(pred)
|
| 473 |
+
#
|
| 474 |
+
# mean_cer = float(np.mean(cer_list))
|
| 475 |
+
#
|
| 476 |
+
# # Self-diversity: unique outputs / total (proxy for diversity)
|
| 477 |
+
# unique_frac = len(set(output_set)) / max(len(output_set), 1)
|
| 478 |
+
#
|
| 479 |
+
# results[scale] = {"mean_cer": mean_cer, "diversity": unique_frac}
|
| 480 |
+
# print(f" λ={scale:.1f} CER={mean_cer:.4f} diversity={unique_frac:.3f}")
|
| 481 |
+
#
|
| 482 |
+
# # Plot
|
| 483 |
+
# os.makedirs(output_dir, exist_ok=True)
|
| 484 |
+
# try:
|
| 485 |
+
# import matplotlib.pyplot as plt
|
| 486 |
+
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
| 487 |
+
#
|
| 488 |
+
# sc_list = sorted(results.keys())
|
| 489 |
+
# cers = [results[s]["mean_cer"] for s in sc_list]
|
| 490 |
+
# diversities = [results[s]["diversity"] for s in sc_list]
|
| 491 |
+
#
|
| 492 |
+
# ax1.plot(sc_list, cers, 'o-', color='coral', linewidth=1.8, markersize=7)
|
| 493 |
+
# ax1.set_xlabel("Guidance scale λ", fontsize=10)
|
| 494 |
+
# ax1.set_ylabel("CER (↓ better)", fontsize=10)
|
| 495 |
+
# ax1.set_title("Quality vs guidance scale", fontsize=10)
|
| 496 |
+
#
|
| 497 |
+
# ax2.plot(sc_list, diversities, 'o-', color='steelblue', linewidth=1.8, markersize=7)
|
| 498 |
+
# ax2.set_xlabel("Guidance scale λ", fontsize=10)
|
| 499 |
+
# ax2.set_ylabel("Output diversity (unique fraction)", fontsize=10)
|
| 500 |
+
# ax2.set_title("Diversity vs guidance scale", fontsize=10)
|
| 501 |
+
#
|
| 502 |
+
# plt.suptitle("Quality-Diversity Tradeoff (Guidance Scale Sweep)", fontsize=11)
|
| 503 |
+
# plt.tight_layout()
|
| 504 |
+
# path = os.path.join(output_dir, "guidance_scale_sweep.png")
|
| 505 |
+
# plt.savefig(path, dpi=150, bbox_inches='tight')
|
| 506 |
+
# plt.close()
|
| 507 |
+
# print(f" Saved: {path}")
|
| 508 |
+
# except ImportError:
|
| 509 |
+
# pass
|
| 510 |
+
#
|
| 511 |
+
# with open(os.path.join(output_dir, "guidance_results.json"), "w") as f:
|
| 512 |
+
# json.dump({str(k): v for k, v in results.items()}, f, indent=2)
|
| 513 |
+
#
|
| 514 |
+
# return results
|
| 515 |
+
import torch
|
| 516 |
+
import torch.nn as nn
|
| 517 |
+
import torch.nn.functional as F
|
| 518 |
+
import numpy as np
|
| 519 |
+
from typing import List, Dict
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
# ============================================================
|
| 523 |
+
# 1. QUALITY CLASSIFIER
|
| 524 |
+
# ============================================================
|
| 525 |
+
|
| 526 |
+
class QualityClassifier(nn.Module):
|
| 527 |
+
def __init__(self, d_model: int):
|
| 528 |
+
super().__init__()
|
| 529 |
+
self.net = nn.Sequential(
|
| 530 |
+
nn.Linear(d_model, 128),
|
| 531 |
+
nn.ReLU(),
|
| 532 |
+
nn.Dropout(0.1),
|
| 533 |
+
nn.Linear(128, 64),
|
| 534 |
+
nn.ReLU(),
|
| 535 |
+
nn.Linear(64, 1),
|
| 536 |
+
nn.Sigmoid(),
|
| 537 |
+
)
|
| 538 |
+
|
| 539 |
+
def forward(self, hidden):
|
| 540 |
+
if hidden.dim() == 3:
|
| 541 |
+
hidden = hidden.mean(dim=1)
|
| 542 |
+
return self.net(hidden)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
# ============================================================
|
| 546 |
+
# 2. GUIDED GENERATION (CORRECTED)
|
| 547 |
+
# ============================================================
|
| 548 |
+
|
| 549 |
+
@torch.no_grad()
|
| 550 |
+
def generate_guided(
|
| 551 |
+
model,
|
| 552 |
+
src: torch.Tensor,
|
| 553 |
+
classifier: QualityClassifier,
|
| 554 |
+
guidance_scale: float = 1.0,
|
| 555 |
+
temperature: float = 0.8,
|
| 556 |
+
top_k: int = 40,
|
| 557 |
+
):
|
| 558 |
+
inner = model.model
|
| 559 |
+
T = inner.scheduler.num_timesteps
|
| 560 |
+
device = next(inner.parameters()).device
|
| 561 |
+
|
| 562 |
+
if src.dim() == 1:
|
| 563 |
+
src = src.unsqueeze(0)
|
| 564 |
+
src = src.to(device)
|
| 565 |
+
|
| 566 |
+
B = src.shape[0]
|
| 567 |
+
tgt_len = inner.max_seq_len
|
| 568 |
+
mask_id = inner.mask_token_id
|
| 569 |
+
|
| 570 |
+
# KV CACHE
|
| 571 |
+
memory, src_pad_mask = inner.encode_source(src)
|
| 572 |
+
|
| 573 |
+
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 574 |
+
hint = None
|
| 575 |
+
|
| 576 |
+
inner.eval()
|
| 577 |
+
classifier.eval()
|
| 578 |
+
|
| 579 |
+
for t_val in range(T - 1, -1, -1):
|
| 580 |
+
t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 581 |
+
is_last = (t_val == 0)
|
| 582 |
+
|
| 583 |
+
if guidance_scale > 0:
|
| 584 |
+
|
| 585 |
+
# ENABLE GRAD FOR GUIDANCE
|
| 586 |
+
with torch.enable_grad():
|
| 587 |
+
|
| 588 |
+
if t_val > 0:
|
| 589 |
+
_, x_t_ids = inner.forward_process.q_sample(x0_est, t)
|
| 590 |
+
else:
|
| 591 |
+
x_t_ids = x0_est
|
| 592 |
+
|
| 593 |
+
x = inner.tgt_embed(x_t_ids)
|
| 594 |
+
|
| 595 |
+
# time embedding
|
| 596 |
+
t_norm = t.float() / T
|
| 597 |
+
t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
|
| 598 |
+
x = x + t_emb.unsqueeze(1)
|
| 599 |
+
|
| 600 |
+
# hint conditioning
|
| 601 |
+
if hint is not None:
|
| 602 |
+
hint_emb = inner.tgt_embed(hint)
|
| 603 |
+
gate = inner.hint_gate(x)
|
| 604 |
+
x = x + gate * hint_emb
|
| 605 |
+
|
| 606 |
+
# decoder forward
|
| 607 |
+
for block in inner.decoder_blocks:
|
| 608 |
+
x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
|
| 609 |
+
|
| 610 |
+
# IMPORTANT: NO DETACH HERE
|
| 611 |
+
hidden = x.requires_grad_(True)
|
| 612 |
+
|
| 613 |
+
# classifier forward
|
| 614 |
+
quality = classifier(hidden) # [B,1]
|
| 615 |
+
|
| 616 |
+
# compute gradient
|
| 617 |
+
quality.sum().backward()
|
| 618 |
+
|
| 619 |
+
grad = hidden.grad # [B, L, d_model]
|
| 620 |
+
|
| 621 |
+
# ===== FIX 1: Normalize gradient =====
|
| 622 |
+
grad_norm = grad.norm(dim=-1, keepdim=True) + 1e-6
|
| 623 |
+
grad = grad / grad_norm
|
| 624 |
+
|
| 625 |
+
# ===== FIX 2: Project to logit space =====
|
| 626 |
+
logit_grad = torch.matmul(grad, inner.head.weight.T)
|
| 627 |
+
|
| 628 |
+
# ===== FIX 3: Clip gradient =====
|
| 629 |
+
logit_grad = torch.clamp(logit_grad, -5.0, 5.0)
|
| 630 |
+
|
| 631 |
+
# compute logits (no grad)
|
| 632 |
+
with torch.no_grad():
|
| 633 |
+
logits = inner.head(x)
|
| 634 |
+
|
| 635 |
+
# apply guidance
|
| 636 |
+
logits = logits + guidance_scale * logit_grad
|
| 637 |
+
|
| 638 |
+
else:
|
| 639 |
+
with torch.no_grad():
|
| 640 |
+
logits, _ = inner.forward_cached(
|
| 641 |
+
memory, src_pad_mask, x0_est, t,
|
| 642 |
+
x0_hint=hint,
|
| 643 |
+
inference_mode=True,
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
# ===== Sampling =====
|
| 647 |
+
logits = logits / max(temperature, 1e-8)
|
| 648 |
+
|
| 649 |
+
if top_k > 0:
|
| 650 |
+
V = logits.shape[-1]
|
| 651 |
+
if top_k < V:
|
| 652 |
+
vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 653 |
+
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
|
| 654 |
+
|
| 655 |
+
probs = F.softmax(logits, dim=-1)
|
| 656 |
+
|
| 657 |
+
if is_last:
|
| 658 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 659 |
+
else:
|
| 660 |
+
x0_est = _sample(probs)
|
| 661 |
+
|
| 662 |
+
hint = x0_est
|
| 663 |
+
|
| 664 |
+
return x0_est
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def _sample(probs):
|
| 668 |
+
B, L, V = probs.shape
|
| 669 |
+
flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 670 |
+
flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 671 |
+
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 672 |
+
|
| 673 |
+
|
| 674 |
+
# ============================================================
|
| 675 |
+
# 3. GUIDANCE SWEEP (EVALUATION)
|
| 676 |
+
# ============================================================
|
| 677 |
+
|
| 678 |
+
def sweep_guidance(
|
| 679 |
+
model,
|
| 680 |
+
classifier,
|
| 681 |
+
src_list,
|
| 682 |
+
ref_list,
|
| 683 |
+
tgt_tokenizer,
|
| 684 |
+
scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
|
| 685 |
+
n_samples=50,
|
| 686 |
+
):
|
| 687 |
+
def cer(pred, ref):
|
| 688 |
+
if not ref:
|
| 689 |
+
return 1.0
|
| 690 |
+
dp = list(range(len(ref) + 1))
|
| 691 |
+
for i in range(1, len(pred) + 1):
|
| 692 |
+
prev, dp[0] = dp[0], i
|
| 693 |
+
for j in range(1, len(ref) + 1):
|
| 694 |
+
temp = dp[j]
|
| 695 |
+
dp[j] = prev if pred[i-1] == ref[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 696 |
+
prev = temp
|
| 697 |
+
return dp[-1] / max(len(ref), 1)
|
| 698 |
+
|
| 699 |
+
results = {}
|
| 700 |
+
|
| 701 |
+
for scale in scales:
|
| 702 |
+
cer_list = []
|
| 703 |
+
outputs = []
|
| 704 |
+
|
| 705 |
+
for src, ref in zip(src_list[:n_samples], ref_list[:n_samples]):
|
| 706 |
+
if src.dim() == 1:
|
| 707 |
+
src = src.unsqueeze(0)
|
| 708 |
+
|
| 709 |
+
out = generate_guided(model, src, classifier, scale)
|
| 710 |
+
ids = [x for x in out[0].tolist() if x > 4]
|
| 711 |
+
pred = tgt_tokenizer.decode(ids).strip()
|
| 712 |
+
|
| 713 |
+
cer_list.append(cer(pred, ref))
|
| 714 |
+
outputs.append(pred)
|
| 715 |
+
|
| 716 |
+
results[scale] = {
|
| 717 |
+
"CER": float(np.mean(cer_list)),
|
| 718 |
+
"diversity": len(set(outputs)) / len(outputs)
|
| 719 |
+
}
|
| 720 |
+
|
| 721 |
+
print(f"λ={scale:.1f} | CER={results[scale]['CER']:.4f} | diversity={results[scale]['diversity']:.3f}")
|
| 722 |
+
|
| 723 |
+
return results
|
analysis/reports/README.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Analysis Reports
|
| 2 |
+
|
| 3 |
+
This folder contains mentor-facing writeups for the five analysis tasks:
|
| 4 |
+
|
| 5 |
+
- [Task 1](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/reports/task1_kv_cache_report.md)
|
| 6 |
+
- [Task 2](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/reports/task2_attention_drift_report.md)
|
| 7 |
+
- [Task 3](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/reports/task3_concept_vectors_report.md)
|
| 8 |
+
- [Task 4](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/reports/task4_step_ablation_report.md)
|
| 9 |
+
- [Task 5](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/reports/task5_quality_guidance_report.md)
|
| 10 |
+
|
| 11 |
+
These reports are written for evaluation use. They include:
|
| 12 |
+
|
| 13 |
+
- objective
|
| 14 |
+
- implementation summary
|
| 15 |
+
- code snippet
|
| 16 |
+
- result status
|
| 17 |
+
- benefits
|
| 18 |
+
- limitations
|
| 19 |
+
- conclusion
|
analysis/reports/task1_kv_cache_report.md
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Task 1 Report: KV Cache Benchmark
|
| 2 |
+
|
| 3 |
+
## 1. Objective
|
| 4 |
+
|
| 5 |
+
The purpose of Task 1 is to measure whether encoder-side key/value caching improves inference speed for the cross-attention D3PM paraphrase model. In the unoptimized version, the source sequence is re-encoded at every diffusion step. In the cached version, the source is encoded once and reused for all denoising steps.
|
| 6 |
+
|
| 7 |
+
This task is useful for mentor evaluation because it measures an engineering improvement directly tied to deployment cost. Even when model quality is unchanged, lower generation latency improves usability for experimentation, batch evaluation, and interactive inference.
|
| 8 |
+
|
| 9 |
+
## 2. Implementation Approach
|
| 10 |
+
|
| 11 |
+
The benchmark is implemented in [analysis/kv_cache_benchmark.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/kv_cache_benchmark.py). To support it, the cross-attention model was extended with three helper methods in [model/d3pm_model_cross_attention.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/model/d3pm_model_cross_attention.py):
|
| 12 |
+
|
| 13 |
+
- `encode_source(...)`
|
| 14 |
+
- `forward_cached(...)`
|
| 15 |
+
- `generate_cached(...)`
|
| 16 |
+
|
| 17 |
+
These methods separate source encoding from decoder-side denoising, which is the standard way to benchmark KV caching in encoder-decoder style architectures.
|
| 18 |
+
|
| 19 |
+
### Core Implementation Snippet
|
| 20 |
+
|
| 21 |
+
```python
|
| 22 |
+
def encode_source(self, src):
|
| 23 |
+
PAD = 1
|
| 24 |
+
src_pad_mask = (src == PAD)
|
| 25 |
+
memory = self.src_embed(src)
|
| 26 |
+
for block in self.encoder_blocks:
|
| 27 |
+
memory = block(memory, pad_mask=src_pad_mask)
|
| 28 |
+
return memory, src_pad_mask
|
| 29 |
+
|
| 30 |
+
def forward_cached(self, memory, src_pad_mask, tgt, t, x0_hint=None, inference_mode=False):
|
| 31 |
+
...
|
| 32 |
+
for block in self.decoder_blocks:
|
| 33 |
+
x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
|
| 34 |
+
self._last_hidden = x.detach()
|
| 35 |
+
return self.head(x), None
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
This design avoids recomputing the encoder stack at each diffusion step.
|
| 39 |
+
|
| 40 |
+
## 3. Experimental Setup
|
| 41 |
+
|
| 42 |
+
The benchmark was run using the Task 1 entry point:
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
uv run --active analysis/run_analysis.py --task 1
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
The script tests source lengths of 16, 32, and 64 tokens and reports:
|
| 49 |
+
|
| 50 |
+
- standard generation time
|
| 51 |
+
- cached generation time
|
| 52 |
+
- speedup ratio
|
| 53 |
+
- estimated encoder cost as a percentage of one forward pass
|
| 54 |
+
|
| 55 |
+
The benchmark output is stored in [analysis/outputs/task1_kv_cache.txt](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task1_kv_cache.txt).
|
| 56 |
+
|
| 57 |
+
## 4. Results
|
| 58 |
+
|
| 59 |
+
Observed benchmark values:
|
| 60 |
+
|
| 61 |
+
| Source Length | Standard (s) | Cached (s) | Speedup | Encoder % |
|
| 62 |
+
| --- | ---: | ---: | ---: | ---: |
|
| 63 |
+
| 16 | 1.784 | 1.780 | 1.00x | 42.7% |
|
| 64 |
+
| 32 | 2.055 | 1.850 | 1.11x | 41.9% |
|
| 65 |
+
| 64 | 1.724 | 1.608 | 1.07x | 43.2% |
|
| 66 |
+
|
| 67 |
+
The main outcome is that caching works correctly and provides a measurable speed improvement, though the improvement is modest on the current hardware and runtime stack.
|
| 68 |
+
|
| 69 |
+
## 5. Interpretation
|
| 70 |
+
|
| 71 |
+
The result is technically correct and useful, but it should be positioned carefully in evaluation:
|
| 72 |
+
|
| 73 |
+
- This is a systems optimization result, not a model quality result.
|
| 74 |
+
- The speedup is real, but not dramatic.
|
| 75 |
+
- The benchmark confirms that source-side recomputation can be removed without changing the inference algorithm.
|
| 76 |
+
|
| 77 |
+
For mentor evaluation, this can be presented as a successful engineering optimization with limited but positive runtime impact.
|
| 78 |
+
|
| 79 |
+
## 6. Benefits
|
| 80 |
+
|
| 81 |
+
Benefits of this task:
|
| 82 |
+
|
| 83 |
+
- reduces redundant encoder computation
|
| 84 |
+
- provides a reusable cached inference path for later analysis tasks
|
| 85 |
+
- improves scalability for repeated generation and diagnostic probes
|
| 86 |
+
- establishes infrastructure for attention and hidden-state inspection
|
| 87 |
+
|
| 88 |
+
## 7. Limitations
|
| 89 |
+
|
| 90 |
+
The result should not be overstated:
|
| 91 |
+
|
| 92 |
+
- speedup depends heavily on hardware and backend
|
| 93 |
+
- current gains are relatively small
|
| 94 |
+
- more stable benchmarking would require repeated runs and device-specific profiling
|
| 95 |
+
- this does not improve semantic accuracy directly
|
| 96 |
+
|
| 97 |
+
## 8. Conclusion
|
| 98 |
+
|
| 99 |
+
Task 1 is valid and suitable for mentor evaluation as an implementation-focused result. It demonstrates that cached inference was successfully added to the D3PM cross-attention model and that it reduces generation cost modestly. The strongest value of this task is architectural: it enables faster repeated inference and supports later interpretability experiments.
|
analysis/reports/task2_attention_drift_report.md
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Task 2 Report: Attention Visualization and Semantic Drift
|
| 2 |
+
|
| 3 |
+
## 1. Objective
|
| 4 |
+
|
| 5 |
+
Task 2 investigates how the diffusion model behaves internally during generation. It has two goals:
|
| 6 |
+
|
| 7 |
+
- capture cross-attention patterns between source and generated target tokens
|
| 8 |
+
- measure how intermediate generations converge toward the final output over diffusion steps
|
| 9 |
+
|
| 10 |
+
This task is important for evaluation because it gives interpretability evidence. Instead of only showing the final prediction, it examines whether the model gradually stabilizes its output and whether attention is distributed in a meaningful way.
|
| 11 |
+
|
| 12 |
+
## 2. Implementation Approach
|
| 13 |
+
|
| 14 |
+
The implementation uses two analysis modules:
|
| 15 |
+
|
| 16 |
+
- [analysis/attention_viz.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/attention_viz.py)
|
| 17 |
+
- [analysis/semantic_drift.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/semantic_drift.py)
|
| 18 |
+
|
| 19 |
+
To support this, the cross-attention layer stores attention weights during decoding. The model also exposes a cached inference path so per-step diagnostics can be collected efficiently.
|
| 20 |
+
|
| 21 |
+
### Attention Capture Snippet
|
| 22 |
+
|
| 23 |
+
```python
|
| 24 |
+
class MultiHeadAttention(nn.Module):
|
| 25 |
+
def __init__(self, d_model, n_heads, dropout=0.1):
|
| 26 |
+
...
|
| 27 |
+
self.capture_weights = False
|
| 28 |
+
self.last_attn_weights = None
|
| 29 |
+
|
| 30 |
+
def forward(self, q, k, v, mask=None):
|
| 31 |
+
...
|
| 32 |
+
attn = self.dropout(torch.softmax(scores, dim=-1))
|
| 33 |
+
if self.capture_weights:
|
| 34 |
+
self.last_attn_weights = attn.detach().cpu()
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### Drift Computation Snippet
|
| 38 |
+
|
| 39 |
+
```python
|
| 40 |
+
def compute_drift(step_outputs, final_output):
|
| 41 |
+
t_vals = sorted(step_outputs.keys(), reverse=True)
|
| 42 |
+
cer_to_final = []
|
| 43 |
+
for t_val in t_vals:
|
| 44 |
+
cer = compute_cer_between(step_outputs[t_val], final_output)
|
| 45 |
+
cer_to_final.append(cer)
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
The metric used is character error rate between each intermediate output and the final output.
|
| 49 |
+
|
| 50 |
+
## 3. Experimental Setup
|
| 51 |
+
|
| 52 |
+
The task was run with:
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
uv run --active analysis/run_analysis.py --task 2 --input "dharmo rakṣati rakṣitaḥ"
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Generated outputs:
|
| 59 |
+
|
| 60 |
+
- [analysis/outputs/task2_attn_t127.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_attn_t127.png)
|
| 61 |
+
- [analysis/outputs/task2_attn_t0.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_attn_t0.png)
|
| 62 |
+
- [analysis/outputs/task2_all_layers_t0.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_all_layers_t0.png)
|
| 63 |
+
- [analysis/outputs/task2_attn_evolution.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_attn_evolution.png)
|
| 64 |
+
- [analysis/outputs/task2_semantic_drift.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_semantic_drift.png)
|
| 65 |
+
- [analysis/outputs/task2_report.txt](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_report.txt)
|
| 66 |
+
|
| 67 |
+
## 4. Results
|
| 68 |
+
|
| 69 |
+
The saved report shows:
|
| 70 |
+
|
| 71 |
+
- lock-in timestep: `t = 22`
|
| 72 |
+
- mean token-position lock-in: `53.6 ± 28.4`
|
| 73 |
+
|
| 74 |
+
This indicates that the generated sequence becomes relatively stable before the final denoising step. In other words, the model is not making all of its decisions only at the very end.
|
| 75 |
+
|
| 76 |
+
However, the actual generated Sanskrit output is low quality and strongly repetitive. That matters for interpretation: the drift curve is still valid as a measure of convergence, but it is convergence toward a weak final output.
|
| 77 |
+
|
| 78 |
+
## 5. Interpretation
|
| 79 |
+
|
| 80 |
+
For mentor evaluation, this task should be presented as a diagnostic analysis rather than a quality claim.
|
| 81 |
+
|
| 82 |
+
What the task supports:
|
| 83 |
+
|
| 84 |
+
- the model’s output evolves gradually over time
|
| 85 |
+
- the diffusion process shows an identifiable stabilization region
|
| 86 |
+
- attention weights can now be inspected layer by layer
|
| 87 |
+
|
| 88 |
+
What the task does not yet support:
|
| 89 |
+
|
| 90 |
+
- strong semantic alignment
|
| 91 |
+
- trustworthy linguistic paraphrase quality
|
| 92 |
+
- meaningful claim that attention maps correspond to correct Sanskrit transformation
|
| 93 |
+
|
| 94 |
+
## 6. Benefits
|
| 95 |
+
|
| 96 |
+
This task has practical value even with imperfect outputs:
|
| 97 |
+
|
| 98 |
+
- helps identify when the model stabilizes
|
| 99 |
+
- supports debugging of the denoising trajectory
|
| 100 |
+
- provides visual artifacts for discussing model internals
|
| 101 |
+
- can guide reduction of unnecessary inference steps in future work
|
| 102 |
+
|
| 103 |
+
## 7. Limitations
|
| 104 |
+
|
| 105 |
+
There are two important limitations:
|
| 106 |
+
|
| 107 |
+
1. The output quality is weak, so the interpretability evidence is about model behavior, not model correctness.
|
| 108 |
+
2. Matplotlib on the current machine does not render Devanagari fonts well, so the generated figures contain font warnings and may not display labels cleanly.
|
| 109 |
+
|
| 110 |
+
## 8. Conclusion
|
| 111 |
+
|
| 112 |
+
Task 2 is partially suitable for evaluation. It is strong as an interpretability and debugging report, but weak as proof of semantic paraphrase quality. For mentor review, it should be framed as evidence that the diffusion generation process can now be inspected and analyzed step by step.
|
analysis/reports/task3_concept_vectors_report.md
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Task 3 Report: Concept Vectors and PCA-Based Steering
|
| 2 |
+
|
| 3 |
+
## 1. Objective
|
| 4 |
+
|
| 5 |
+
Task 3 explores whether decoder hidden states contain a measurable direction corresponding to paraphrase diversity. The idea is:
|
| 6 |
+
|
| 7 |
+
1. collect hidden states from many validation samples
|
| 8 |
+
2. fit PCA to the hidden-state space
|
| 9 |
+
3. find a principal direction correlated with output diversity
|
| 10 |
+
4. steer generation along that direction
|
| 11 |
+
|
| 12 |
+
This is an advanced representation-learning experiment. Its value for mentor evaluation lies in showing that the project is not limited to training and inference, but also investigates controllable generation.
|
| 13 |
+
|
| 14 |
+
## 2. Implementation Approach
|
| 15 |
+
|
| 16 |
+
The implementation is in [analysis/concept_vectors.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/concept_vectors.py). Hidden states are captured from the decoder during cached inference and pooled across sequence positions.
|
| 17 |
+
|
| 18 |
+
### PCA Fitting Snippet
|
| 19 |
+
|
| 20 |
+
```python
|
| 21 |
+
def fit_pca(hidden_matrix, n_components=50):
|
| 22 |
+
from sklearn.decomposition import PCA
|
| 23 |
+
n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1])
|
| 24 |
+
pca = PCA(n_components=n_comp)
|
| 25 |
+
pca.fit(hidden_matrix)
|
| 26 |
+
return pca
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
### Steering Snippet
|
| 30 |
+
|
| 31 |
+
```python
|
| 32 |
+
if alpha != 0.0:
|
| 33 |
+
x = x + alpha * dir_tensor.unsqueeze(0).unsqueeze(0)
|
| 34 |
+
|
| 35 |
+
logits = inner.head(x)
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
The steering mechanism adds a learned direction in hidden-state space before projection to logits.
|
| 39 |
+
|
| 40 |
+
## 3. Experimental Setup
|
| 41 |
+
|
| 42 |
+
Task 3 was run from the shared analysis driver and generated:
|
| 43 |
+
|
| 44 |
+
- [analysis/outputs/task3_concept_space.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task3_concept_space.png)
|
| 45 |
+
- [analysis/outputs/task3_diversity_direction.npy](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task3_diversity_direction.npy)
|
| 46 |
+
- [analysis/outputs/task3_report.txt](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task3_report.txt)
|
| 47 |
+
|
| 48 |
+
The run used 500 validation examples for hidden-state extraction.
|
| 49 |
+
|
| 50 |
+
## 4. Results
|
| 51 |
+
|
| 52 |
+
Observed summary:
|
| 53 |
+
|
| 54 |
+
- PCA components retained: `50`
|
| 55 |
+
- total explained variance: `96.1%`
|
| 56 |
+
- selected diversity principal component: `PC 1`
|
| 57 |
+
- absolute correlation with output length: `0.303`
|
| 58 |
+
|
| 59 |
+
On paper, these values suggest that hidden-state variation is structured and that at least one direction correlates with output-length changes. That is a positive sign from a representation-analysis standpoint.
|
| 60 |
+
|
| 61 |
+
However, the actual diversity spectrum outputs are not semantically convincing. The steered generations are highly repetitive and mostly malformed token sequences rather than clear paraphrases with controlled variation.
|
| 62 |
+
|
| 63 |
+
## 5. Interpretation
|
| 64 |
+
|
| 65 |
+
This task should be presented carefully.
|
| 66 |
+
|
| 67 |
+
What is supported:
|
| 68 |
+
|
| 69 |
+
- hidden states are rich enough for PCA analysis
|
| 70 |
+
- the representation space is not random noise
|
| 71 |
+
- controllable steering infrastructure has been implemented successfully
|
| 72 |
+
|
| 73 |
+
What is not yet supported:
|
| 74 |
+
|
| 75 |
+
- interpretable semantic control
|
| 76 |
+
- high-quality paraphrase diversity
|
| 77 |
+
- evidence that the identified direction reflects useful linguistic variation
|
| 78 |
+
|
| 79 |
+
For mentor evaluation, this is best framed as a promising exploratory experiment rather than a finished result.
|
| 80 |
+
|
| 81 |
+
## 6. Benefits
|
| 82 |
+
|
| 83 |
+
Benefits of the task include:
|
| 84 |
+
|
| 85 |
+
- opens a path toward controllable paraphrase generation
|
| 86 |
+
- demonstrates hidden-state instrumentation beyond standard inference
|
| 87 |
+
- provides a research direction for future work on style and diversity control
|
| 88 |
+
- connects model analysis with possible user-facing controllability
|
| 89 |
+
|
| 90 |
+
## 7. Limitations
|
| 91 |
+
|
| 92 |
+
The main limitation is output quality. Even though the PCA statistics look reasonable, the steered generations are not linguistically strong enough to claim meaningful semantic control. This makes the current result more useful as a prototype than as a validated research finding.
|
| 93 |
+
|
| 94 |
+
## 8. Conclusion
|
| 95 |
+
|
| 96 |
+
Task 3 is not yet strong enough as a final evaluation result, but it is valuable as research evidence of advanced model analysis. For mentor discussion, it should be described as an experimental controllability framework that has been implemented successfully but still requires better base model quality before the steering outputs become persuasive.
|
analysis/reports/task4_step_ablation_report.md
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Task 4 Report: Diffusion Step Ablation
|
| 2 |
+
|
| 3 |
+
## 1. Objective
|
| 4 |
+
|
| 5 |
+
Task 4 studies how the number of diffusion steps affects meaning preservation, speed, and robustness. The hypothesis is that fewer denoising steps may improve speed, but too few steps may reduce output quality. This type of ablation is important for mentor evaluation because it tests a core design parameter of the D3PM model.
|
| 6 |
+
|
| 7 |
+
Unlike the earlier tasks, this one requires retraining separate checkpoints for each step count. This is not optional. A model trained at `T=128` cannot be evaluated fairly at `T=4` or `T=8` without retraining, because the timestep distribution seen during training changes fundamentally.
|
| 8 |
+
|
| 9 |
+
## 2. Implementation Approach
|
| 10 |
+
|
| 11 |
+
The implementation is in [analysis/step_ablation.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/step_ablation.py). I patched the workflow so it is safe for this repository:
|
| 12 |
+
|
| 13 |
+
- it no longer overwrites `config.py`
|
| 14 |
+
- it uses environment variables for `DIFFUSION_STEPS`
|
| 15 |
+
- each training run writes directly to `ablation_results/T*`
|
| 16 |
+
|
| 17 |
+
### Training Script Generation Snippet
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
f.write(
|
| 21 |
+
f"MODEL_TYPE=\"$MODEL_TYPE\" INCLUDE_NEG=\"$INCLUDE_NEG\" "
|
| 22 |
+
f"TRAIN_DEVICE=\"$TRAIN_DEVICE\" "
|
| 23 |
+
f"DIFFUSION_STEPS={T} INFERENCE_NUM_STEPS={T} "
|
| 24 |
+
f"TRAIN_OUTPUT_DIR=\"ablation_results/T{T}\" "
|
| 25 |
+
f"python train.py\n\n"
|
| 26 |
+
)
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
This makes the ablation workflow reproducible without mutating repository files between runs.
|
| 30 |
+
|
| 31 |
+
## 3. Current Workflow
|
| 32 |
+
|
| 33 |
+
Task 4 now supports the following sequence:
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
uv run --active analysis/run_analysis.py --task 4 --phase generate_configs
|
| 37 |
+
bash ablation_configs/train_all.sh
|
| 38 |
+
uv run --active analysis/run_analysis.py --task 4 --phase analyze
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
Generated script:
|
| 42 |
+
|
| 43 |
+
- [ablation_configs/train_all.sh](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/ablation_configs/train_all.sh)
|
| 44 |
+
|
| 45 |
+
This script trains:
|
| 46 |
+
|
| 47 |
+
- `T=4`
|
| 48 |
+
- `T=8`
|
| 49 |
+
- `T=16`
|
| 50 |
+
- `T=32`
|
| 51 |
+
- `T=64`
|
| 52 |
+
|
| 53 |
+
with outputs saved to `ablation_results/T4`, `T8`, `T16`, `T32`, and `T64`.
|
| 54 |
+
|
| 55 |
+
## 4. Current Result Status
|
| 56 |
+
|
| 57 |
+
At the moment, no trained ablation checkpoints exist in `ablation_results/T*/best_model.pt`. Therefore, the analysis phase has no quantitative result yet. That means Task 4 currently has a correct implementation pipeline, but not a completed experiment.
|
| 58 |
+
|
| 59 |
+
This distinction matters for evaluation:
|
| 60 |
+
|
| 61 |
+
- the workflow is correct
|
| 62 |
+
- the experiment has not yet produced final numbers
|
| 63 |
+
|
| 64 |
+
## 5. Evaluation Value
|
| 65 |
+
|
| 66 |
+
For mentor evaluation, Task 4 can still be included, but it should be presented as:
|
| 67 |
+
|
| 68 |
+
- a completed experimental setup
|
| 69 |
+
- a validated retraining workflow
|
| 70 |
+
- pending final quantitative results
|
| 71 |
+
|
| 72 |
+
This is still useful because ablation design is part of research rigor. It shows that the project is set up to test the effect of a critical modeling choice instead of assuming the default step count is optimal.
|
| 73 |
+
|
| 74 |
+
## 6. Benefits
|
| 75 |
+
|
| 76 |
+
Once the checkpoints are trained, this task will answer:
|
| 77 |
+
|
| 78 |
+
- how much generation speed improves as diffusion steps decrease
|
| 79 |
+
- how meaning preservation changes with fewer steps
|
| 80 |
+
- where the best quality-speed tradeoff lies
|
| 81 |
+
- whether the current choice of diffusion steps is over- or under-provisioned
|
| 82 |
+
|
| 83 |
+
## 7. Limitations
|
| 84 |
+
|
| 85 |
+
The limitation is straightforward: there are no ablation checkpoints yet, so there are no real results to defend. It should not be presented as a finished evaluation experiment at this stage.
|
| 86 |
+
|
| 87 |
+
## 8. Conclusion
|
| 88 |
+
|
| 89 |
+
Task 4 is structurally correct and now safe to run in this repository. It is suitable for mentor evaluation as an experimental design and workflow contribution, but not yet as a result section. The next milestone is to train the five ablation checkpoints and run the analysis phase to generate the actual CER-speed comparison.
|
analysis/reports/task5_quality_guidance_report.md
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Task 5 Report: Quality Classifier and Guidance-Based Decoding
|
| 2 |
+
|
| 3 |
+
## 1. Objective
|
| 4 |
+
|
| 5 |
+
Task 5 attempts to guide generation using a lightweight quality classifier trained on decoder hidden states. The idea is to predict a quality score from hidden states and then use the classifier gradient to bias inference toward higher-quality outputs.
|
| 6 |
+
|
| 7 |
+
This is an ambitious extension because it adds a second learned component on top of the main D3PM model without retraining the core paraphrase model itself.
|
| 8 |
+
|
| 9 |
+
## 2. Implementation Approach
|
| 10 |
+
|
| 11 |
+
The implementation is in [analysis/quality_classifier.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/quality_classifier.py). It has three stages:
|
| 12 |
+
|
| 13 |
+
1. collect `(hidden_state, quality_score)` pairs
|
| 14 |
+
2. train a small MLP quality classifier
|
| 15 |
+
3. use classifier gradients during decoding
|
| 16 |
+
|
| 17 |
+
### Classifier Definition Snippet
|
| 18 |
+
|
| 19 |
+
```python
|
| 20 |
+
class QualityClassifier(nn.Module):
|
| 21 |
+
def __init__(self, d_model: int):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.net = nn.Sequential(
|
| 24 |
+
nn.Linear(d_model, 128),
|
| 25 |
+
nn.ReLU(),
|
| 26 |
+
nn.Dropout(0.1),
|
| 27 |
+
nn.Linear(128, 64),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
nn.Linear(64, 1),
|
| 30 |
+
nn.Sigmoid(),
|
| 31 |
+
)
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
### Guidance Snippet
|
| 35 |
+
|
| 36 |
+
```python
|
| 37 |
+
hidden = x.detach().to(clf_device).requires_grad_(True)
|
| 38 |
+
hidden.retain_grad()
|
| 39 |
+
quality = classifier(hidden)
|
| 40 |
+
quality.sum().backward()
|
| 41 |
+
grad = hidden.grad.to(device)
|
| 42 |
+
logit_grad = grad @ inner.head.weight.T
|
| 43 |
+
logits = logits + guidance_scale * logit_grad
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
This turns hidden-state quality prediction into a differentiable decoding signal.
|
| 47 |
+
|
| 48 |
+
## 3. Current Status
|
| 49 |
+
|
| 50 |
+
Task 5 originally failed for two reasons:
|
| 51 |
+
|
| 52 |
+
- the gradient was taken from a non-leaf tensor, causing `hidden.grad` to be `None`
|
| 53 |
+
- the cached quality labels collapsed to all zeros, so the classifier had no meaningful learning signal
|
| 54 |
+
|
| 55 |
+
These implementation bugs were patched. However, the existing saved quality cache in [analysis/outputs/task5_quality_data.npz](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task5_quality_data.npz) still contains degenerate labels from the earlier failed run.
|
| 56 |
+
|
| 57 |
+
Observed cache statistics:
|
| 58 |
+
|
| 59 |
+
- count: `500`
|
| 60 |
+
- mean: `0.0`
|
| 61 |
+
- std: `0.0`
|
| 62 |
+
- min: `0.0`
|
| 63 |
+
- max: `0.0`
|
| 64 |
+
|
| 65 |
+
That means the current classifier result is not valid for evaluation.
|
| 66 |
+
|
| 67 |
+
## 4. Why the Current Result Is Not Reliable
|
| 68 |
+
|
| 69 |
+
Because all quality labels are zero:
|
| 70 |
+
|
| 71 |
+
- the classifier is effectively trained on a constant target
|
| 72 |
+
- low validation loss is meaningless
|
| 73 |
+
- guidance behavior cannot be interpreted as quality-aware control
|
| 74 |
+
|
| 75 |
+
So although the code path now exists, the saved run should not be used in mentor evaluation as a finished result.
|
| 76 |
+
|
| 77 |
+
## 5. What Was Fixed
|
| 78 |
+
|
| 79 |
+
Two concrete corrections were made:
|
| 80 |
+
|
| 81 |
+
- a bounded quality transform was introduced so very large CER values do not collapse everything to zero
|
| 82 |
+
- the Task 5 runner now refreshes cached quality data when it detects degenerate labels
|
| 83 |
+
|
| 84 |
+
This means Task 5 is closer to being experimentally sound, but it still needs to be rerun from scratch after the patch.
|
| 85 |
+
|
| 86 |
+
## 6. Expected Benefits
|
| 87 |
+
|
| 88 |
+
If Task 5 works as intended after rerunning, it could provide:
|
| 89 |
+
|
| 90 |
+
- a lightweight mechanism for improving generation quality
|
| 91 |
+
- a controllable quality-diversity tradeoff
|
| 92 |
+
- a reusable framework for guidance without retraining the full D3PM model
|
| 93 |
+
- a more research-oriented extension beyond standard training and inference
|
| 94 |
+
|
| 95 |
+
## 7. Limitations
|
| 96 |
+
|
| 97 |
+
At present, this task has one decisive limitation: the saved outputs are not valid evaluation artifacts. The infrastructure is promising, but the experimental evidence is not yet strong enough to defend.
|
| 98 |
+
|
| 99 |
+
## 8. Conclusion
|
| 100 |
+
|
| 101 |
+
Task 5 should be presented only as a partially completed advanced experiment. The implementation framework is now in place and the core bugs have been addressed, but the current cached run is still invalid for evaluation. Before showing this task to a mentor as a result, the quality data and guidance sweep should be rerun after patching so that the classifier is trained on non-degenerate labels.
|
analysis/run_analysis.py
ADDED
|
@@ -0,0 +1,466 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
analysis/run_analysis.py
|
| 3 |
+
=========================
|
| 4 |
+
Entry point for all 5 tasks.
|
| 5 |
+
|
| 6 |
+
Tasks:
|
| 7 |
+
Task 1 — KV Cache benchmark (no retraining)
|
| 8 |
+
Task 2 — Attention viz + drift (no retraining)
|
| 9 |
+
Task 3 — Concept vectors + PCA steer (no retraining)
|
| 10 |
+
Task 4 — Step ablation (REQUIRES retraining for each T)
|
| 11 |
+
Task 5 — Classifier-free guidance (trains small 10k-param classifier)
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python analysis/run_analysis.py --task 1
|
| 15 |
+
python analysis/run_analysis.py --task 2 --input "dharmo rakṣati rakṣitaḥ"
|
| 16 |
+
python analysis/run_analysis.py --task 3
|
| 17 |
+
python analysis/run_analysis.py --task 4 --phase generate_configs
|
| 18 |
+
python analysis/run_analysis.py --task 4 --phase analyze
|
| 19 |
+
python analysis/run_analysis.py --task 5
|
| 20 |
+
python analysis/run_analysis.py --task all --input "satyameva jayate"
|
| 21 |
+
|
| 22 |
+
Output files: analysis/outputs/
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
import copy
|
| 26 |
+
import torch
|
| 27 |
+
import os, sys, argparse, json
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 31 |
+
from config import CONFIG
|
| 32 |
+
from inference import load_model
|
| 33 |
+
from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
|
| 34 |
+
|
| 35 |
+
OUTPUT_DIR = "analysis/outputs"
|
| 36 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ── Shared loader ─────────────────────────────────────────────────────
|
| 40 |
+
|
| 41 |
+
def infer_model_type_from_checkpoint(ckpt_path: str) -> str:
|
| 42 |
+
name = ckpt_path.lower()
|
| 43 |
+
if "ablation_results/t" in name or "d3pm_cross_attention" in name:
|
| 44 |
+
return "d3pm_cross_attention"
|
| 45 |
+
if "d3pm_encoder_decoder" in name:
|
| 46 |
+
return "d3pm_encoder_decoder"
|
| 47 |
+
if "baseline_cross_attention" in name:
|
| 48 |
+
return "baseline_cross_attention"
|
| 49 |
+
if "baseline_encoder_decoder" in name:
|
| 50 |
+
return "baseline_encoder_decoder"
|
| 51 |
+
return CONFIG["model_type"]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def infer_include_negative_from_checkpoint(ckpt_path: str) -> bool:
|
| 55 |
+
name = ckpt_path.lower()
|
| 56 |
+
if "_neg_true" in name:
|
| 57 |
+
return True
|
| 58 |
+
if "_neg_false" in name:
|
| 59 |
+
return False
|
| 60 |
+
if "ablation_results/t" in name:
|
| 61 |
+
return False
|
| 62 |
+
return CONFIG["data"]["include_negative_examples"]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def load_everything(cfg, device, ckpt_override=None):
|
| 66 |
+
model_name = cfg['model_type']
|
| 67 |
+
has_neg = cfg['data']['include_negative_examples']
|
| 68 |
+
candidates = [
|
| 69 |
+
f"results7/{model_name}_neg_{has_neg}/best_model.pt",
|
| 70 |
+
f"results/{model_name}_neg_{has_neg}/best_model.pt",
|
| 71 |
+
f"results7/{model_name}_neg_True/best_model.pt",
|
| 72 |
+
f"results/{model_name}_neg_True/best_model.pt",
|
| 73 |
+
f"results7/{model_name}_neg_False/best_model.pt",
|
| 74 |
+
f"results/{model_name}_neg_False/best_model.pt",
|
| 75 |
+
"ablation_results/T4/best_model.pt",
|
| 76 |
+
"ablation_results/T8/best_model.pt",
|
| 77 |
+
]
|
| 78 |
+
ckpt = ckpt_override if ckpt_override else next((p for p in candidates if os.path.exists(p)), None)
|
| 79 |
+
if not os.path.exists(ckpt):
|
| 80 |
+
raise FileNotFoundError(f"No checkpoint found. Checked: {candidates}")
|
| 81 |
+
model, cfg = load_model(ckpt, cfg, device)
|
| 82 |
+
model.eval()
|
| 83 |
+
src_tok = SanskritSourceTokenizer(
|
| 84 |
+
vocab_size=cfg['model'].get('src_vocab_size', 500),
|
| 85 |
+
max_len=cfg['model']['max_seq_len'])
|
| 86 |
+
tgt_tok = SanskritTargetTokenizer(
|
| 87 |
+
vocab_size=cfg['model'].get('tgt_vocab_size', 500),
|
| 88 |
+
max_len=cfg['model']['max_seq_len'])
|
| 89 |
+
return model, src_tok, tgt_tok, cfg
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def load_val_data(cfg, src_tok, tgt_tok, n=500):
|
| 93 |
+
"""Load validation set as (src_tensors, ref_strings, input_strings)."""
|
| 94 |
+
from data.dataset import OptimizedSanskritDataset
|
| 95 |
+
from torch.utils.data import Subset
|
| 96 |
+
from sklearn.model_selection import train_test_split
|
| 97 |
+
|
| 98 |
+
dataset = OptimizedSanskritDataset(
|
| 99 |
+
'train', max_len=cfg['model']['max_seq_len'],
|
| 100 |
+
cfg=cfg, src_tokenizer=src_tok, tgt_tokenizer=tgt_tok)
|
| 101 |
+
total = min(cfg['data']['dataset_size'], len(dataset))
|
| 102 |
+
_, val_idx = train_test_split(list(range(total)), train_size=0.8, random_state=42)
|
| 103 |
+
val_idx = val_idx[:n]
|
| 104 |
+
|
| 105 |
+
src_list, ref_list, inp_list = [], [], []
|
| 106 |
+
for i in val_idx:
|
| 107 |
+
item = dataset[i]
|
| 108 |
+
src_list.append(item['input_ids'].unsqueeze(0))
|
| 109 |
+
ref_list.append(item['target_text'])
|
| 110 |
+
inp_list.append(item['input_text'])
|
| 111 |
+
return src_list, ref_list, inp_list
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
# ── Task 1 ────────────────────────────────────────────────────────────
|
| 115 |
+
|
| 116 |
+
def run_task1(model, src_tok, device):
|
| 117 |
+
print("\n" + "="*65)
|
| 118 |
+
print(" TASK 1 — KV Cache Benchmark")
|
| 119 |
+
print("="*65)
|
| 120 |
+
if not hasattr(model.model, 'generate_cached'):
|
| 121 |
+
print(" SKIP: not D3PMCrossAttention.")
|
| 122 |
+
return
|
| 123 |
+
from analysis.kv_cache_benchmark import run_benchmark, print_summary
|
| 124 |
+
results = run_benchmark(model, src_tok, device, src_lens=[16, 32, 64])
|
| 125 |
+
print_summary(results)
|
| 126 |
+
path = os.path.join(OUTPUT_DIR, "task1_kv_cache.txt")
|
| 127 |
+
with open(path, "w") as f:
|
| 128 |
+
f.write("TASK 1 — KV CACHE BENCHMARK\n" + "="*40 + "\n\n")
|
| 129 |
+
f.write(f"{'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
|
| 130 |
+
f"{'speedup':>8} {'encoder%':>9}\n")
|
| 131 |
+
for src_len, r in results.items():
|
| 132 |
+
f.write(f"{src_len:>8} {r['standard_s']:>12.3f} {r['cached_s']:>10.3f} "
|
| 133 |
+
f"{r['speedup']:>7.2f}x {r['encoder_pct']:>8.1f}%\n")
|
| 134 |
+
print(f" Saved: {path}")
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ── Task 2 ────────────────────────────────────────────────────────────
|
| 138 |
+
|
| 139 |
+
def run_task2(model, src_tok, tgt_tok, device, input_text):
|
| 140 |
+
print("\n" + "="*65)
|
| 141 |
+
print(" TASK 2 — Attention Visualization + Semantic Drift")
|
| 142 |
+
print("="*65)
|
| 143 |
+
print(f" Input: {input_text}")
|
| 144 |
+
if not hasattr(model.model, 'encode_source'):
|
| 145 |
+
print(" SKIP: not D3PMCrossAttention.")
|
| 146 |
+
return
|
| 147 |
+
|
| 148 |
+
src_ids = src_tok.encode(input_text)
|
| 149 |
+
src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device)
|
| 150 |
+
src_chars = list(input_text.strip())
|
| 151 |
+
|
| 152 |
+
from analysis.attention_viz import (AttentionCapture, plot_attn_heatmap,
|
| 153 |
+
plot_attn_evolution, plot_all_layers)
|
| 154 |
+
from analysis.semantic_drift import (capture_intermediate_outputs,
|
| 155 |
+
compute_drift, compute_token_stability,
|
| 156 |
+
plot_drift_curve)
|
| 157 |
+
|
| 158 |
+
# Attention capture
|
| 159 |
+
print(" Capturing attention weights...")
|
| 160 |
+
capturer = AttentionCapture(model)
|
| 161 |
+
step_weights = capturer.capture(src_tensor, capture_every=10)
|
| 162 |
+
|
| 163 |
+
with torch.no_grad():
|
| 164 |
+
out_ids = model.generate_cached(src_tensor)
|
| 165 |
+
tgt_ids = [x for x in out_ids[0].tolist() if x > 4]
|
| 166 |
+
tgt_text = tgt_tok.decode(tgt_ids).strip()
|
| 167 |
+
tgt_chars = list(tgt_text)
|
| 168 |
+
print(f" Output: {tgt_text}")
|
| 169 |
+
|
| 170 |
+
first_t = max(step_weights.keys())
|
| 171 |
+
plot_attn_heatmap(step_weights, t_val=first_t, layer=0,
|
| 172 |
+
src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
|
| 173 |
+
save_path=os.path.join(OUTPUT_DIR, f"task2_attn_t{first_t}.png"),
|
| 174 |
+
title=f"Attention t={first_t} (noisy) Layer 0")
|
| 175 |
+
plot_attn_heatmap(step_weights, t_val=0, layer=0,
|
| 176 |
+
src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
|
| 177 |
+
save_path=os.path.join(OUTPUT_DIR, "task2_attn_t0.png"),
|
| 178 |
+
title="Attention t=0 (final) Layer 0")
|
| 179 |
+
plot_all_layers(step_weights, t_val=0,
|
| 180 |
+
src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
|
| 181 |
+
save_path=os.path.join(OUTPUT_DIR, "task2_all_layers_t0.png"))
|
| 182 |
+
if len(src_chars) > 0 and len(tgt_chars) > 0:
|
| 183 |
+
plot_attn_evolution(step_weights, src_token_idx=0, tgt_token_idx=0,
|
| 184 |
+
layer=0, src_token_str=src_chars[0], tgt_token_str=tgt_chars[0],
|
| 185 |
+
save_path=os.path.join(OUTPUT_DIR, "task2_attn_evolution.png"))
|
| 186 |
+
|
| 187 |
+
# Semantic drift
|
| 188 |
+
print(" Computing semantic drift...")
|
| 189 |
+
step_outputs, final_out = capture_intermediate_outputs(
|
| 190 |
+
model, src_tensor, tgt_tok, capture_every=5)
|
| 191 |
+
drift = compute_drift(step_outputs, final_out)
|
| 192 |
+
stab = compute_token_stability(step_outputs, final_out, tgt_tok)
|
| 193 |
+
plot_drift_curve(drift, src_text=input_text,
|
| 194 |
+
save_path=os.path.join(OUTPUT_DIR, "task2_semantic_drift.png"))
|
| 195 |
+
|
| 196 |
+
print(f" Lock-in timestep: t={drift['lock_in_t']}")
|
| 197 |
+
print(f" Mean position lock-in: t={stab['mean_lock_t']:.1f} ± {stab['std_lock_t']:.1f}")
|
| 198 |
+
|
| 199 |
+
report = os.path.join(OUTPUT_DIR, "task2_report.txt")
|
| 200 |
+
with open(report, "w", encoding="utf-8") as f:
|
| 201 |
+
f.write("TASK 2 — ATTENTION + DRIFT REPORT\n" + "="*50 + "\n\n")
|
| 202 |
+
f.write(f"Input : {input_text}\nOutput : {final_out}\n\n")
|
| 203 |
+
f.write(f"Lock-in t : {drift['lock_in_t']}\n")
|
| 204 |
+
f.write(f"Mean pos lock-in : {stab['mean_lock_t']:.1f} ± {stab['std_lock_t']:.1f}\n\n")
|
| 205 |
+
f.write("Step → Output → CER-to-final\n" + "-"*60 + "\n")
|
| 206 |
+
for tv, cer in zip(drift["t_vals"], drift["cer_to_final"]):
|
| 207 |
+
f.write(f" t={tv:4d} | {step_outputs.get(tv,'')[:40]:40s} | {cer:.4f}\n")
|
| 208 |
+
print(f" Report: {report}")
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ── Task 3 ────────────────────────────────────────────────────────────
|
| 212 |
+
|
| 213 |
+
def run_task3(model, src_tok, tgt_tok, device, src_list, ref_list):
|
| 214 |
+
print("\n" + "="*65)
|
| 215 |
+
print(" TASK 3 — Concept Vectors + PCA Steering")
|
| 216 |
+
print("="*65)
|
| 217 |
+
if not hasattr(model.model, 'encode_source'):
|
| 218 |
+
print(" SKIP: not D3PMCrossAttention.")
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
from analysis.concept_vectors import (collect_hidden_states, fit_pca,
|
| 222 |
+
find_diversity_direction, generate_diversity_spectrum, plot_pca_space)
|
| 223 |
+
|
| 224 |
+
# Collect hidden states from val set
|
| 225 |
+
n = min(500, len(src_list))
|
| 226 |
+
print(f" Collecting hidden states from {n} examples...")
|
| 227 |
+
hidden, _ = collect_hidden_states(
|
| 228 |
+
model, src_list[:n], t_capture=0, max_samples=n)
|
| 229 |
+
|
| 230 |
+
# Compute output lengths for diversity direction
|
| 231 |
+
lengths = []
|
| 232 |
+
for src in src_list[:n]:
|
| 233 |
+
with torch.no_grad():
|
| 234 |
+
out = model.generate_cached(src.to(device))
|
| 235 |
+
ids = [x for x in out[0].tolist() if x > 4]
|
| 236 |
+
lengths.append(len(tgt_tok.decode(ids)))
|
| 237 |
+
|
| 238 |
+
# Fit PCA + find diversity direction
|
| 239 |
+
pca = fit_pca(hidden, n_components=min(50, n-1))
|
| 240 |
+
direction, best_pc, corr = find_diversity_direction(hidden, lengths, pca)
|
| 241 |
+
|
| 242 |
+
# Plot concept space
|
| 243 |
+
plot_pca_space(hidden, lengths, pca, best_pc,
|
| 244 |
+
save_path=os.path.join(OUTPUT_DIR, "task3_concept_space.png"))
|
| 245 |
+
|
| 246 |
+
# Generate diversity spectrum for first example
|
| 247 |
+
print("\n Diversity spectrum for first example:")
|
| 248 |
+
src0 = src_list[0]
|
| 249 |
+
inp0 = src_tok.decode([x for x in src0[0].tolist() if x > 4])
|
| 250 |
+
print(f" Input: {inp0}")
|
| 251 |
+
spectrum = generate_diversity_spectrum(
|
| 252 |
+
model, src0.to(device), direction, tgt_tok,
|
| 253 |
+
alphas=[-2.0, -1.0, 0.0, 1.0, 2.0])
|
| 254 |
+
|
| 255 |
+
# Save diversity direction + results
|
| 256 |
+
np.save(os.path.join(OUTPUT_DIR, "task3_diversity_direction.npy"), direction)
|
| 257 |
+
|
| 258 |
+
report = os.path.join(OUTPUT_DIR, "task3_report.txt")
|
| 259 |
+
with open(report, "w", encoding="utf-8") as f:
|
| 260 |
+
f.write("TASK 3 — CONCEPT VECTORS + PCA STEERING\n" + "="*50 + "\n\n")
|
| 261 |
+
f.write(f"PCA: {pca.n_components_} components, "
|
| 262 |
+
f"{pca.explained_variance_ratio_.sum()*100:.1f}% variance\n")
|
| 263 |
+
f.write(f"Diversity PC: {best_pc} (|r|={corr:.3f} with output length)\n\n")
|
| 264 |
+
f.write("Diversity spectrum:\n")
|
| 265 |
+
for alpha, text in sorted(spectrum.items()):
|
| 266 |
+
f.write(f" alpha={alpha:+.1f} → {text}\n")
|
| 267 |
+
print(f" Report: {report}")
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# ── Task 4 ────────────────────────────────────────────────────────────
|
| 271 |
+
|
| 272 |
+
def run_task4(phase, model, src_tok, tgt_tok, device, cfg,
|
| 273 |
+
src_list, ref_list):
|
| 274 |
+
print("\n" + "="*65)
|
| 275 |
+
print(f" TASK 4 — Step Ablation (phase={phase})")
|
| 276 |
+
print("="*65)
|
| 277 |
+
|
| 278 |
+
from analysis.step_ablation import (generate_ablation_configs,
|
| 279 |
+
run_ablation_analysis, plot_ablation_3d, run_adversarial_test)
|
| 280 |
+
|
| 281 |
+
if phase == "generate_configs":
|
| 282 |
+
print(" Generating ablation configs...")
|
| 283 |
+
generate_ablation_configs(output_dir="ablation_configs")
|
| 284 |
+
print("\n NEXT STEPS:")
|
| 285 |
+
print(" 1. bash ablation_configs/train_all.sh")
|
| 286 |
+
print(" 2. python analysis/run_analysis.py --task 4 --phase analyze")
|
| 287 |
+
|
| 288 |
+
elif phase == "analyze":
|
| 289 |
+
# Check which models exist
|
| 290 |
+
existing = [T for T in [4, 8, 16, 32, 64]
|
| 291 |
+
if os.path.exists(f"ablation_results/T{T}/best_model.pt")]
|
| 292 |
+
if not existing:
|
| 293 |
+
print(" No ablation models found at ablation_results/T*/best_model.pt")
|
| 294 |
+
print(" Run: python analysis/run_analysis.py --task 4 --phase generate_configs")
|
| 295 |
+
print(" Then: bash ablation_configs/train_all.sh")
|
| 296 |
+
return
|
| 297 |
+
|
| 298 |
+
print(f" Found models for T={existing}")
|
| 299 |
+
results = run_ablation_analysis(
|
| 300 |
+
ablation_dir="ablation_results", base_cfg=cfg,
|
| 301 |
+
src_list=src_list[:200], ref_list=ref_list[:200],
|
| 302 |
+
tgt_tokenizer=tgt_tok, device=device,
|
| 303 |
+
output_dir=OUTPUT_DIR)
|
| 304 |
+
plot_ablation_3d(results,
|
| 305 |
+
save_path=os.path.join(OUTPUT_DIR, "task4_ablation_3d.png"))
|
| 306 |
+
|
| 307 |
+
# Adversarial robustness always runs on existing model (no retraining)
|
| 308 |
+
print("\n Running adversarial robustness test...")
|
| 309 |
+
inp_texts = [src_tok.decode([x for x in s[0].tolist() if x > 4])
|
| 310 |
+
for s in src_list[:50]]
|
| 311 |
+
run_adversarial_test(
|
| 312 |
+
model, src_tok, tgt_tok,
|
| 313 |
+
test_inputs=inp_texts, test_refs=ref_list[:50],
|
| 314 |
+
device=device, output_dir=OUTPUT_DIR)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# ── Task 5 ────────────────────────────────────────────────────────────
|
| 318 |
+
|
| 319 |
+
def run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list):
|
| 320 |
+
print("\n" + "="*65)
|
| 321 |
+
print(" TASK 5 — Classifier-Free Guidance")
|
| 322 |
+
print("="*65)
|
| 323 |
+
if not hasattr(model.model, 'encode_source'):
|
| 324 |
+
print(" SKIP: not D3PMCrossAttention.")
|
| 325 |
+
return
|
| 326 |
+
|
| 327 |
+
from analysis.quality_classifier import (
|
| 328 |
+
QualityClassifier, collect_quality_data,
|
| 329 |
+
train_quality_classifier, sweep_guidance_scales)
|
| 330 |
+
|
| 331 |
+
clf_path = os.path.join(OUTPUT_DIR, "task5_quality_classifier.pt")
|
| 332 |
+
d_model = cfg['model']['d_model']
|
| 333 |
+
|
| 334 |
+
# Step 1: collect or load training data
|
| 335 |
+
data_path = os.path.join(OUTPUT_DIR, "task5_quality_data.npz")
|
| 336 |
+
if os.path.exists(data_path):
|
| 337 |
+
print(" Loading cached quality data...")
|
| 338 |
+
data = np.load(data_path)
|
| 339 |
+
hidden = data["hidden"]
|
| 340 |
+
quality = data["quality"]
|
| 341 |
+
else:
|
| 342 |
+
print(" Collecting quality data (this takes a few minutes)...")
|
| 343 |
+
n = min(2000, len(src_list))
|
| 344 |
+
hidden, quality = collect_quality_data(
|
| 345 |
+
model, src_list[:n], ref_list[:n], tgt_tok,
|
| 346 |
+
t_capture=0, max_samples=n)
|
| 347 |
+
np.savez(data_path, hidden=hidden, quality=quality)
|
| 348 |
+
print(f" Saved quality data: {data_path}")
|
| 349 |
+
|
| 350 |
+
# Step 2: train or load classifier
|
| 351 |
+
if os.path.exists(clf_path):
|
| 352 |
+
print(f" Loading cached classifier: {clf_path}")
|
| 353 |
+
clf = QualityClassifier(d_model)
|
| 354 |
+
clf.load_state_dict(torch.load(clf_path, map_location='cpu'))
|
| 355 |
+
clf.eval()
|
| 356 |
+
else:
|
| 357 |
+
print(" Training quality classifier...")
|
| 358 |
+
clf = train_quality_classifier(
|
| 359 |
+
hidden, quality, d_model=d_model,
|
| 360 |
+
epochs=30, batch_size=64, lr=1e-3,
|
| 361 |
+
save_path=clf_path)
|
| 362 |
+
clf.eval()
|
| 363 |
+
|
| 364 |
+
# Step 3: guidance scale sweep
|
| 365 |
+
print("\n Guidance scale sweep (λ ∈ {0.0, 0.5, 1.0, 1.5, 2.0, 3.0})...")
|
| 366 |
+
n_sweep = min(50, len(src_list))
|
| 367 |
+
results = sweep_guidance_scales(
|
| 368 |
+
model, clf, src_list[:n_sweep], ref_list[:n_sweep],
|
| 369 |
+
tgt_tok, scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
|
| 370 |
+
n_samples=n_sweep, device=device, output_dir=OUTPUT_DIR)
|
| 371 |
+
|
| 372 |
+
# Find optimal scale
|
| 373 |
+
best_scale = min(results, key=lambda s: results[s]["mean_cer"])
|
| 374 |
+
print(f"\n Optimal guidance scale: λ={best_scale:.1f} "
|
| 375 |
+
f"CER={results[best_scale]['mean_cer']:.4f}")
|
| 376 |
+
|
| 377 |
+
report = os.path.join(OUTPUT_DIR, "task5_report.txt")
|
| 378 |
+
with open(report, "w") as f:
|
| 379 |
+
f.write("TASK 5 — CLASSIFIER-FREE GUIDANCE\n" + "="*50 + "\n\n")
|
| 380 |
+
f.write(f"Classifier params: {sum(p.numel() for p in clf.parameters())}\n")
|
| 381 |
+
f.write(f"Training samples : {len(hidden)}\n\n")
|
| 382 |
+
f.write("Guidance scale sweep:\n")
|
| 383 |
+
f.write(f" {'λ':>6} {'CER':>8} {'diversity':>10}\n")
|
| 384 |
+
f.write(" " + "-"*28 + "\n")
|
| 385 |
+
for s in sorted(results.keys()):
|
| 386 |
+
r = results[s]
|
| 387 |
+
marker = " ← optimal" if s == best_scale else ""
|
| 388 |
+
f.write(f" {s:>6.1f} {r['mean_cer']:>8.4f} {r['diversity']:>10.3f}{marker}\n")
|
| 389 |
+
print(f" Report: {report}")
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# ── Main ──────────────────────────────────────────────────────────────
|
| 393 |
+
|
| 394 |
+
def main():
|
| 395 |
+
global OUTPUT_DIR
|
| 396 |
+
|
| 397 |
+
parser = argparse.ArgumentParser()
|
| 398 |
+
parser.add_argument("--task",
|
| 399 |
+
choices=["1","2","3","4","5","all"], default="all")
|
| 400 |
+
parser.add_argument("--input",
|
| 401 |
+
default="dharmo rakṣati rakṣitaḥ",
|
| 402 |
+
help="IAST input text for Task 2")
|
| 403 |
+
parser.add_argument("--phase",
|
| 404 |
+
choices=["generate_configs", "analyze"], default="analyze",
|
| 405 |
+
help="Task 4 phase: generate_configs (before training) or analyze (after)")
|
| 406 |
+
parser.add_argument("--checkpoint", default=None,
|
| 407 |
+
help="Optional explicit checkpoint path")
|
| 408 |
+
parser.add_argument("--output_dir", default="analysis/outputs",
|
| 409 |
+
help="Output directory for reports/figures")
|
| 410 |
+
args = parser.parse_args()
|
| 411 |
+
|
| 412 |
+
OUTPUT_DIR = args.output_dir
|
| 413 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 414 |
+
|
| 415 |
+
cfg = copy.deepcopy(CONFIG)
|
| 416 |
+
if args.checkpoint:
|
| 417 |
+
cfg["model_type"] = infer_model_type_from_checkpoint(args.checkpoint)
|
| 418 |
+
cfg["data"]["include_negative_examples"] = infer_include_negative_from_checkpoint(args.checkpoint)
|
| 419 |
+
ckpt_name = os.path.basename(os.path.dirname(args.checkpoint))
|
| 420 |
+
if ckpt_name.startswith("T") and ckpt_name[1:].isdigit():
|
| 421 |
+
t_val = int(ckpt_name[1:])
|
| 422 |
+
cfg["model"]["diffusion_steps"] = t_val
|
| 423 |
+
cfg["inference"]["num_steps"] = t_val
|
| 424 |
+
|
| 425 |
+
requested = cfg["training"]["device"]
|
| 426 |
+
if requested == "mps" and not torch.backends.mps.is_available():
|
| 427 |
+
requested = "cpu"
|
| 428 |
+
elif requested == "cuda" and not torch.cuda.is_available():
|
| 429 |
+
requested = "cpu"
|
| 430 |
+
cfg["training"]["device"] = requested
|
| 431 |
+
device = torch.device(requested)
|
| 432 |
+
|
| 433 |
+
print("Loading model and tokenizers...")
|
| 434 |
+
model, src_tok, tgt_tok, cfg = load_everything(cfg, device, ckpt_override=args.checkpoint)
|
| 435 |
+
|
| 436 |
+
# Load val data for tasks that need it (Tasks 3, 4, 5)
|
| 437 |
+
needs_data = args.task in ("3", "4", "5", "all")
|
| 438 |
+
if needs_data:
|
| 439 |
+
print("Loading validation data...")
|
| 440 |
+
src_list, ref_list, inp_list = load_val_data(cfg, src_tok, tgt_tok, n=500)
|
| 441 |
+
else:
|
| 442 |
+
src_list, ref_list, inp_list = [], [], []
|
| 443 |
+
|
| 444 |
+
tasks = (["1","2","3","4","5"] if args.task == "all"
|
| 445 |
+
else [args.task])
|
| 446 |
+
|
| 447 |
+
for task in tasks:
|
| 448 |
+
if task == "1":
|
| 449 |
+
run_task1(model, src_tok, device)
|
| 450 |
+
elif task == "2":
|
| 451 |
+
run_task2(model, src_tok, tgt_tok, device, args.input)
|
| 452 |
+
elif task == "3":
|
| 453 |
+
run_task3(model, src_tok, tgt_tok, device, src_list, ref_list)
|
| 454 |
+
elif task == "4":
|
| 455 |
+
run_task4(args.phase, model, src_tok, tgt_tok, device, cfg,
|
| 456 |
+
src_list, ref_list)
|
| 457 |
+
elif task == "5":
|
| 458 |
+
run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list)
|
| 459 |
+
|
| 460 |
+
print(f"\n{'='*65}")
|
| 461 |
+
print(f" All outputs saved to: {OUTPUT_DIR}/")
|
| 462 |
+
print("="*65)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
if __name__ == "__main__":
|
| 466 |
+
main()
|
analysis/run_tasks_except4_all_models.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Run Tasks 1,2,3,5 for every available checkpoint (excluding Task 4).
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python analysis/run_tasks_except4_all_models.py
|
| 6 |
+
python analysis/run_tasks_except4_all_models.py --input "dharmo rakṣati rakṣitaḥ"
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import argparse
|
| 12 |
+
import json
|
| 13 |
+
import os
|
| 14 |
+
import subprocess
|
| 15 |
+
import sys
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
ROOT = Path(__file__).resolve().parents[1]
|
| 21 |
+
DEFAULT_OUT_ROOT = ROOT / "analysis" / "outputs_multi"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def discover_checkpoints() -> list[Path]:
|
| 25 |
+
roots = [ROOT / "results", ROOT / "results7", ROOT / "ablation_results"]
|
| 26 |
+
out: list[Path] = []
|
| 27 |
+
for base in roots:
|
| 28 |
+
if not base.exists():
|
| 29 |
+
continue
|
| 30 |
+
for ckpt in sorted(base.glob("*/best_model.pt")):
|
| 31 |
+
out.append(ckpt)
|
| 32 |
+
return out
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def slug_for_checkpoint(ckpt: Path) -> str:
|
| 36 |
+
root = ckpt.parent.parent.name
|
| 37 |
+
exp = ckpt.parent.name
|
| 38 |
+
return f"{root}__{exp}"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def run_task(task: str, ckpt: Path, input_text: str, out_dir: Path) -> tuple[int, float]:
|
| 42 |
+
cmd = [
|
| 43 |
+
sys.executable,
|
| 44 |
+
str(ROOT / "analysis" / "run_analysis.py"),
|
| 45 |
+
"--task", task,
|
| 46 |
+
"--checkpoint", str(ckpt),
|
| 47 |
+
"--output_dir", str(out_dir),
|
| 48 |
+
]
|
| 49 |
+
if task == "2":
|
| 50 |
+
cmd.extend(["--input", input_text])
|
| 51 |
+
|
| 52 |
+
start = datetime.now()
|
| 53 |
+
env = os.environ.copy()
|
| 54 |
+
env.setdefault("HF_HOME", "/tmp/hf_home")
|
| 55 |
+
env.setdefault("HF_DATASETS_CACHE", "/tmp/hf_datasets")
|
| 56 |
+
env.setdefault("HF_HUB_CACHE", "/tmp/hf_hub")
|
| 57 |
+
env.setdefault("TRANSFORMERS_CACHE", "/tmp/hf_transformers")
|
| 58 |
+
os.makedirs(env["HF_HOME"], exist_ok=True)
|
| 59 |
+
os.makedirs(env["HF_DATASETS_CACHE"], exist_ok=True)
|
| 60 |
+
os.makedirs(env["HF_HUB_CACHE"], exist_ok=True)
|
| 61 |
+
os.makedirs(env["TRANSFORMERS_CACHE"], exist_ok=True)
|
| 62 |
+
|
| 63 |
+
proc = subprocess.run(cmd, cwd=str(ROOT), env=env)
|
| 64 |
+
seconds = (datetime.now() - start).total_seconds()
|
| 65 |
+
return proc.returncode, seconds
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def main() -> None:
|
| 69 |
+
parser = argparse.ArgumentParser()
|
| 70 |
+
parser.add_argument("--input", default="dharmo rakṣati rakṣitaḥ")
|
| 71 |
+
parser.add_argument("--out_root", default=str(DEFAULT_OUT_ROOT))
|
| 72 |
+
args = parser.parse_args()
|
| 73 |
+
|
| 74 |
+
checkpoints = discover_checkpoints()
|
| 75 |
+
if not checkpoints:
|
| 76 |
+
raise FileNotFoundError("No checkpoints found under results/results7/ablation_results.")
|
| 77 |
+
|
| 78 |
+
out_root = Path(args.out_root)
|
| 79 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
| 80 |
+
|
| 81 |
+
tasks = ["1", "2", "3", "5"]
|
| 82 |
+
summary = {
|
| 83 |
+
"timestamp": datetime.now().isoformat(timespec="seconds"),
|
| 84 |
+
"tasks": tasks,
|
| 85 |
+
"checkpoints": [],
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
for ckpt in checkpoints:
|
| 89 |
+
slug = slug_for_checkpoint(ckpt)
|
| 90 |
+
model_out = out_root / slug
|
| 91 |
+
model_out.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
print(f"\n=== Checkpoint: {ckpt} ===")
|
| 93 |
+
model_item = {
|
| 94 |
+
"checkpoint": str(ckpt),
|
| 95 |
+
"output_dir": str(model_out),
|
| 96 |
+
"tasks": [],
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
for task in tasks:
|
| 100 |
+
task_out = model_out / f"task{task}"
|
| 101 |
+
task_out.mkdir(parents=True, exist_ok=True)
|
| 102 |
+
print(f"-> Running task {task} ...")
|
| 103 |
+
code, sec = run_task(task, ckpt, args.input, task_out)
|
| 104 |
+
item = {
|
| 105 |
+
"task": task,
|
| 106 |
+
"exit_code": code,
|
| 107 |
+
"seconds": round(sec, 2),
|
| 108 |
+
"output_dir": str(task_out),
|
| 109 |
+
}
|
| 110 |
+
model_item["tasks"].append(item)
|
| 111 |
+
status = "OK" if code == 0 else "FAILED"
|
| 112 |
+
print(f" {status} ({sec:.1f}s)")
|
| 113 |
+
|
| 114 |
+
summary["checkpoints"].append(model_item)
|
| 115 |
+
|
| 116 |
+
summary_path = out_root / "summary.json"
|
| 117 |
+
with summary_path.open("w", encoding="utf-8") as f:
|
| 118 |
+
json.dump(summary, f, ensure_ascii=False, indent=2)
|
| 119 |
+
print(f"\nSaved summary: {summary_path}")
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
if __name__ == "__main__":
|
| 123 |
+
main()
|
analysis/semantic_drift.py
ADDED
|
@@ -0,0 +1,569 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# analysis/semantic_drift.py
|
| 3 |
+
# ===========================
|
| 4 |
+
# Task 2: Semantic drift metric — how much does the intermediate generation
|
| 5 |
+
# diverge from the final output as we walk through diffusion steps T → 0?
|
| 6 |
+
#
|
| 7 |
+
# Metric: CER between x0_estimate at each step vs the final x0 at t=0.
|
| 8 |
+
#
|
| 9 |
+
# A well-trained model should show:
|
| 10 |
+
# - High drift at t=T-1 (near-random initial estimate)
|
| 11 |
+
# - Rapid decrease in drift around t=T//2 (model finds the right structure)
|
| 12 |
+
# - Near-zero drift at t=10 (output is stable, only fine corrections remain)
|
| 13 |
+
#
|
| 14 |
+
# If drift stays high until t=5 then suddenly collapses → model is doing all
|
| 15 |
+
# its work in the last few steps → consider reducing T.
|
| 16 |
+
#
|
| 17 |
+
# Also measures:
|
| 18 |
+
# - Token stability: fraction of positions that don't change between steps
|
| 19 |
+
# - Lock-in time: first step where each position "commits" to its final token
|
| 20 |
+
#
|
| 21 |
+
# No retraining required. Uses generate_cached() with intermediate snapshots.
|
| 22 |
+
# """
|
| 23 |
+
#
|
| 24 |
+
# import torch
|
| 25 |
+
# import torch.nn.functional as F
|
| 26 |
+
# import numpy as np
|
| 27 |
+
# from typing import List, Dict, Optional, Tuple
|
| 28 |
+
#
|
| 29 |
+
#
|
| 30 |
+
# def compute_cer_between(pred: str, ref: str) -> float:
|
| 31 |
+
# """CER between two strings."""
|
| 32 |
+
# if not ref:
|
| 33 |
+
# return 1.0 if pred else 0.0
|
| 34 |
+
#
|
| 35 |
+
# def edit_distance(s1, s2):
|
| 36 |
+
# m, n = len(s1), len(s2)
|
| 37 |
+
# dp = list(range(n + 1))
|
| 38 |
+
# for i in range(1, m + 1):
|
| 39 |
+
# prev, dp[0] = dp[0], i
|
| 40 |
+
# for j in range(1, n + 1):
|
| 41 |
+
# temp = dp[j]
|
| 42 |
+
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 43 |
+
# prev = temp
|
| 44 |
+
# return dp[n]
|
| 45 |
+
#
|
| 46 |
+
# return edit_distance(pred, ref) / len(ref)
|
| 47 |
+
#
|
| 48 |
+
#
|
| 49 |
+
# @torch.no_grad()
|
| 50 |
+
# def capture_intermediate_outputs(
|
| 51 |
+
# model,
|
| 52 |
+
# src: torch.Tensor,
|
| 53 |
+
# tgt_tokenizer,
|
| 54 |
+
# capture_every: int = 5,
|
| 55 |
+
# temperature: float = 0.8,
|
| 56 |
+
# top_k: int = 40,
|
| 57 |
+
# ) -> Tuple[Dict[int, str], str]:
|
| 58 |
+
# """
|
| 59 |
+
# Run generation while recording the decoded x0_estimate at every
|
| 60 |
+
# `capture_every` diffusion steps.
|
| 61 |
+
#
|
| 62 |
+
# Args:
|
| 63 |
+
# model : SanskritModel (D3PMCrossAttention)
|
| 64 |
+
# src : [1, src_len] IAST token ids (single sample)
|
| 65 |
+
# tgt_tokenizer : SanskritTargetTokenizer for decoding intermediate outputs
|
| 66 |
+
# capture_every : record every N steps
|
| 67 |
+
# temperature : sampling temperature
|
| 68 |
+
# top_k : top-k filter
|
| 69 |
+
#
|
| 70 |
+
# Returns:
|
| 71 |
+
# step_outputs : dict mapping t_val → decoded Devanagari string at that step
|
| 72 |
+
# final_output : decoded string at t=0 (final result)
|
| 73 |
+
# """
|
| 74 |
+
# if src.dim() == 1:
|
| 75 |
+
# src = src.unsqueeze(0)
|
| 76 |
+
#
|
| 77 |
+
# inner = model.model
|
| 78 |
+
# T = inner.scheduler.num_timesteps
|
| 79 |
+
# device = src.device
|
| 80 |
+
#
|
| 81 |
+
# # Encode source once (KV cache)
|
| 82 |
+
# memory, src_pad_mask = inner.encode_source(src)
|
| 83 |
+
#
|
| 84 |
+
# B = src.shape[0]
|
| 85 |
+
# tgt_len = inner.max_seq_len
|
| 86 |
+
# mask_id = inner.mask_token_id
|
| 87 |
+
#
|
| 88 |
+
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
|
| 89 |
+
# hint = None
|
| 90 |
+
#
|
| 91 |
+
# step_outputs: Dict[int, str] = {}
|
| 92 |
+
# inner.eval()
|
| 93 |
+
#
|
| 94 |
+
# for t_val in range(T - 1, -1, -1):
|
| 95 |
+
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
|
| 96 |
+
# is_last = (t_val == 0)
|
| 97 |
+
#
|
| 98 |
+
# logits, _ = inner.forward_cached(
|
| 99 |
+
# memory, src_pad_mask, x0_est, t,
|
| 100 |
+
# x0_hint=hint, inference_mode=True,
|
| 101 |
+
# )
|
| 102 |
+
#
|
| 103 |
+
# logits = logits / max(temperature, 1e-8)
|
| 104 |
+
# if top_k > 0:
|
| 105 |
+
# V = logits.shape[-1]
|
| 106 |
+
# if top_k < V:
|
| 107 |
+
# topk_vals, _ = torch.topk(logits, top_k, dim=-1)
|
| 108 |
+
# threshold = topk_vals[..., -1].unsqueeze(-1)
|
| 109 |
+
# logits = logits.masked_fill(logits < threshold, float('-inf'))
|
| 110 |
+
#
|
| 111 |
+
# probs = F.softmax(logits, dim=-1)
|
| 112 |
+
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
|
| 113 |
+
# hint = x0_est
|
| 114 |
+
#
|
| 115 |
+
# # Capture at this step
|
| 116 |
+
# if (T - 1 - t_val) % capture_every == 0 or is_last:
|
| 117 |
+
# ids = [x for x in x0_est[0].tolist() if x > 4]
|
| 118 |
+
# text = tgt_tokenizer.decode(ids).strip()
|
| 119 |
+
# step_outputs[t_val] = text
|
| 120 |
+
#
|
| 121 |
+
# final_output = step_outputs.get(0, "")
|
| 122 |
+
# return step_outputs, final_output
|
| 123 |
+
#
|
| 124 |
+
#
|
| 125 |
+
# def _sample(probs):
|
| 126 |
+
# B, L, V = probs.shape
|
| 127 |
+
# flat = probs.view(B * L, V).clamp(min=1e-9)
|
| 128 |
+
# flat = flat / flat.sum(dim=-1, keepdim=True)
|
| 129 |
+
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
|
| 130 |
+
#
|
| 131 |
+
#
|
| 132 |
+
# def compute_drift(
|
| 133 |
+
# step_outputs: Dict[int, str],
|
| 134 |
+
# final_output: str,
|
| 135 |
+
# ) -> Dict[str, object]:
|
| 136 |
+
# """
|
| 137 |
+
# Compute drift metrics comparing each intermediate output to the final.
|
| 138 |
+
#
|
| 139 |
+
# Returns dict with:
|
| 140 |
+
# t_vals : list of captured timesteps (T-1 → 0)
|
| 141 |
+
# cer_to_final: CER between each step's output and the final output
|
| 142 |
+
# 0.0 = identical to final, 1.0 = completely different
|
| 143 |
+
# lock_in_t : first t_val where CER drops and stays below 0.1
|
| 144 |
+
# (step at which output "commits" to final form)
|
| 145 |
+
# """
|
| 146 |
+
# t_vals = sorted(step_outputs.keys(), reverse=True) # T-1 → 0
|
| 147 |
+
# cer_to_final = []
|
| 148 |
+
#
|
| 149 |
+
# for t_val in t_vals:
|
| 150 |
+
# cer = compute_cer_between(step_outputs[t_val], final_output)
|
| 151 |
+
# cer_to_final.append(cer)
|
| 152 |
+
#
|
| 153 |
+
# # Find lock-in: first step where CER stays below threshold for rest of run
|
| 154 |
+
# threshold = 0.1
|
| 155 |
+
# lock_in_t = 0 # default: never locked in early
|
| 156 |
+
# for i, (t_val, cer) in enumerate(zip(t_vals, cer_to_final)):
|
| 157 |
+
# if all(c <= threshold for c in cer_to_final[i:]):
|
| 158 |
+
# lock_in_t = t_val
|
| 159 |
+
# break
|
| 160 |
+
#
|
| 161 |
+
# return {
|
| 162 |
+
# "t_vals": t_vals,
|
| 163 |
+
# "cer_to_final": cer_to_final,
|
| 164 |
+
# "lock_in_t": lock_in_t,
|
| 165 |
+
# "final_output": final_output,
|
| 166 |
+
# }
|
| 167 |
+
#
|
| 168 |
+
#
|
| 169 |
+
# def compute_token_stability(
|
| 170 |
+
# step_outputs: Dict[int, str],
|
| 171 |
+
# final_output: str,
|
| 172 |
+
# tgt_tokenizer,
|
| 173 |
+
# ) -> Dict[str, object]:
|
| 174 |
+
# """
|
| 175 |
+
# Token-level stability: for each position, at which diffusion step
|
| 176 |
+
# does it first match its final token and stay matched?
|
| 177 |
+
#
|
| 178 |
+
# Returns:
|
| 179 |
+
# position_lock_times: list of t_val at which each position locks in
|
| 180 |
+
# mean_lock_t : average lock-in timestep across positions
|
| 181 |
+
# """
|
| 182 |
+
# T = max(step_outputs.keys())
|
| 183 |
+
# t_vals = sorted(step_outputs.keys(), reverse=True) # T-1 → 0
|
| 184 |
+
#
|
| 185 |
+
# # Encode all intermediate outputs and the final
|
| 186 |
+
# def encode(text):
|
| 187 |
+
# return tgt_tokenizer.encode(text)
|
| 188 |
+
#
|
| 189 |
+
# final_ids = encode(final_output)
|
| 190 |
+
# L = len(final_ids)
|
| 191 |
+
#
|
| 192 |
+
# # Build matrix: [n_steps, L]
|
| 193 |
+
# step_ids = []
|
| 194 |
+
# for t_val in t_vals:
|
| 195 |
+
# step_ids.append(encode(step_outputs.get(t_val, "")))
|
| 196 |
+
#
|
| 197 |
+
# # Pad all to same length
|
| 198 |
+
# max_len = max(len(s) for s in step_ids)
|
| 199 |
+
# step_ids = [s + [1] * (max_len - len(s)) for s in step_ids] # 1=PAD
|
| 200 |
+
# final_ids_padded = final_ids + [1] * (max_len - len(final_ids))
|
| 201 |
+
#
|
| 202 |
+
# step_arr = np.array(step_ids) # [n_steps, L]
|
| 203 |
+
# final_arr = np.array(final_ids_padded) # [L]
|
| 204 |
+
#
|
| 205 |
+
# # For each position: find first step index where it matches final
|
| 206 |
+
# # and stays matched for all subsequent steps
|
| 207 |
+
# position_lock_steps = []
|
| 208 |
+
# for pos in range(min(L, max_len)):
|
| 209 |
+
# col = step_arr[:, pos] # [n_steps]
|
| 210 |
+
# fin = final_arr[pos]
|
| 211 |
+
# locked_at = len(t_vals) - 1 # default: never locks early
|
| 212 |
+
# for i in range(len(t_vals)):
|
| 213 |
+
# if all(col[i:] == fin):
|
| 214 |
+
# locked_at = i
|
| 215 |
+
# break
|
| 216 |
+
# position_lock_steps.append(t_vals[locked_at] if locked_at < len(t_vals) else 0)
|
| 217 |
+
#
|
| 218 |
+
# return {
|
| 219 |
+
# "position_lock_times": position_lock_steps,
|
| 220 |
+
# "mean_lock_t": float(np.mean(position_lock_steps)),
|
| 221 |
+
# "std_lock_t": float(np.std(position_lock_steps)),
|
| 222 |
+
# }
|
| 223 |
+
#
|
| 224 |
+
#
|
| 225 |
+
# def plot_drift_curve(
|
| 226 |
+
# drift_result: Dict,
|
| 227 |
+
# src_text: str = "",
|
| 228 |
+
# save_path: Optional[str] = None,
|
| 229 |
+
# ):
|
| 230 |
+
# """
|
| 231 |
+
# Plot CER-to-final vs diffusion step.
|
| 232 |
+
# Shows where the model "commits" to the final output.
|
| 233 |
+
# """
|
| 234 |
+
# try:
|
| 235 |
+
# import matplotlib.pyplot as plt
|
| 236 |
+
# except ImportError:
|
| 237 |
+
# print("pip install matplotlib.")
|
| 238 |
+
# return
|
| 239 |
+
#
|
| 240 |
+
# t_vals = drift_result["t_vals"]
|
| 241 |
+
# cers = drift_result["cer_to_final"]
|
| 242 |
+
# lock_t = drift_result["lock_in_t"]
|
| 243 |
+
#
|
| 244 |
+
# fig, ax = plt.subplots(figsize=(12, 4))
|
| 245 |
+
# ax.plot(range(len(t_vals)), cers, linewidth=1.8, color='coral', label='CER to final')
|
| 246 |
+
# ax.fill_between(range(len(t_vals)), cers, alpha=0.15, color='coral')
|
| 247 |
+
#
|
| 248 |
+
# # Mark lock-in point
|
| 249 |
+
# if lock_t in t_vals:
|
| 250 |
+
# lock_idx = t_vals.index(lock_t)
|
| 251 |
+
# ax.axvline(lock_idx, color='steelblue', linestyle='--', linewidth=1.2,
|
| 252 |
+
# label=f"Lock-in at t={lock_t}")
|
| 253 |
+
#
|
| 254 |
+
# ax.axhline(0.1, color='gray', linestyle=':', linewidth=1, alpha=0.7)
|
| 255 |
+
#
|
| 256 |
+
# n = len(t_vals)
|
| 257 |
+
# tick_positions = list(range(0, n, max(1, n // 10)))
|
| 258 |
+
# ax.set_xticks(tick_positions)
|
| 259 |
+
# ax.set_xticklabels([str(t_vals[i]) for i in tick_positions], fontsize=8)
|
| 260 |
+
# ax.set_xlabel("Diffusion step t (T-1 → 0)", fontsize=11)
|
| 261 |
+
# ax.set_ylabel("CER vs final output", fontsize=11)
|
| 262 |
+
# ax.set_ylim(0, 1.05)
|
| 263 |
+
# ax.set_xlim(0, n - 1)
|
| 264 |
+
# ax.legend(fontsize=10)
|
| 265 |
+
#
|
| 266 |
+
# title = f"Semantic drift"
|
| 267 |
+
# if src_text:
|
| 268 |
+
# title += f" | src: {src_text[:50]}"
|
| 269 |
+
# ax.set_title(title, fontsize=11)
|
| 270 |
+
# plt.tight_layout()
|
| 271 |
+
#
|
| 272 |
+
# if save_path:
|
| 273 |
+
# import os
|
| 274 |
+
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 275 |
+
# plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 276 |
+
# print(f"Saved: {save_path}")
|
| 277 |
+
# else:
|
| 278 |
+
# plt.show()
|
| 279 |
+
# plt.close()
|
| 280 |
+
# ============================================================
|
| 281 |
+
# TASK 2: Source–Paraphrase Semantic Alignment Trajectory
|
| 282 |
+
# ============================================================
|
| 283 |
+
|
| 284 |
+
import torch
|
| 285 |
+
import torch.nn.functional as F
|
| 286 |
+
import numpy as np
|
| 287 |
+
import matplotlib.pyplot as plt
|
| 288 |
+
from typing import Dict, List, Tuple
|
| 289 |
+
from collections import defaultdict
|
| 290 |
+
|
| 291 |
+
# Optional (install if needed)
|
| 292 |
+
# pip install bert-score scikit-learn
|
| 293 |
+
from bert_score import score as bertscore
|
| 294 |
+
from sklearn.feature_extraction.text import TfidfVectorizer
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# ============================================================
|
| 298 |
+
# ------------------ ATTENTION HOOK --------------------------
|
| 299 |
+
# ============================================================
|
| 300 |
+
|
| 301 |
+
def register_attention_hooks(model):
|
| 302 |
+
"""
|
| 303 |
+
Registers forward hooks to capture cross-attention weights
|
| 304 |
+
from each decoder block.
|
| 305 |
+
|
| 306 |
+
Assumes each block has attribute `.cross_attn.attn_weights`
|
| 307 |
+
"""
|
| 308 |
+
inner = model.model
|
| 309 |
+
attention_maps = []
|
| 310 |
+
|
| 311 |
+
def hook_fn(module, input, output):
|
| 312 |
+
if hasattr(module, "attn_weights"):
|
| 313 |
+
attention_maps.append(module.attn_weights.detach().cpu())
|
| 314 |
+
|
| 315 |
+
hooks = []
|
| 316 |
+
for block in inner.decoder_blocks:
|
| 317 |
+
if hasattr(block, "cross_attn"):
|
| 318 |
+
h = block.cross_attn.register_forward_hook(hook_fn)
|
| 319 |
+
hooks.append(h)
|
| 320 |
+
|
| 321 |
+
return hooks, attention_maps
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# ============================================================
|
| 325 |
+
# ------------------ CAPTURE TRAJECTORY ----------------------
|
| 326 |
+
# ============================================================
|
| 327 |
+
|
| 328 |
+
@torch.no_grad()
|
| 329 |
+
def capture_alignment_trajectory(
|
| 330 |
+
model,
|
| 331 |
+
src_tensor: torch.Tensor,
|
| 332 |
+
src_text: str,
|
| 333 |
+
tgt_tokenizer,
|
| 334 |
+
steps_to_capture: List[int] = None,
|
| 335 |
+
):
|
| 336 |
+
"""
|
| 337 |
+
Capture:
|
| 338 |
+
- intermediate outputs
|
| 339 |
+
- cross-attention maps
|
| 340 |
+
- BERTScore vs source
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
dict with outputs, attention, drift
|
| 344 |
+
"""
|
| 345 |
+
|
| 346 |
+
inner = model.model
|
| 347 |
+
device = src_tensor.device
|
| 348 |
+
T = inner.scheduler.num_timesteps
|
| 349 |
+
|
| 350 |
+
if steps_to_capture is None:
|
| 351 |
+
steps_to_capture = list(range(T - 1, -1, -5)) + [0]
|
| 352 |
+
|
| 353 |
+
# Register hooks
|
| 354 |
+
hooks, attn_storage = register_attention_hooks(model)
|
| 355 |
+
|
| 356 |
+
memory, src_pad_mask = inner.encode_source(src_tensor)
|
| 357 |
+
|
| 358 |
+
B = src_tensor.shape[0]
|
| 359 |
+
tgt_len = inner.max_seq_len
|
| 360 |
+
mask_id = inner.mask_token_id
|
| 361 |
+
|
| 362 |
+
x0_est = torch.full((B, tgt_len), mask_id, device=device)
|
| 363 |
+
hint = None
|
| 364 |
+
|
| 365 |
+
outputs = {}
|
| 366 |
+
attention_per_step = {}
|
| 367 |
+
|
| 368 |
+
for t_val in range(T - 1, -1, -1):
|
| 369 |
+
t = torch.full((B,), t_val, device=device)
|
| 370 |
+
|
| 371 |
+
logits, _ = inner.forward_cached(
|
| 372 |
+
memory, src_pad_mask, x0_est, t,
|
| 373 |
+
x0_hint=hint, inference_mode=True
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
probs = F.softmax(logits, dim=-1)
|
| 377 |
+
x0_est = torch.argmax(probs, dim=-1)
|
| 378 |
+
hint = x0_est
|
| 379 |
+
|
| 380 |
+
if t_val in steps_to_capture:
|
| 381 |
+
ids = [x for x in x0_est[0].tolist() if x > 4]
|
| 382 |
+
text = tgt_tokenizer.decode(ids)
|
| 383 |
+
|
| 384 |
+
outputs[t_val] = text
|
| 385 |
+
|
| 386 |
+
# Collect attention maps (last layer only for simplicity)
|
| 387 |
+
if len(attn_storage) > 0:
|
| 388 |
+
attention_per_step[t_val] = attn_storage[-1].numpy()
|
| 389 |
+
|
| 390 |
+
# Remove hooks
|
| 391 |
+
for h in hooks:
|
| 392 |
+
h.remove()
|
| 393 |
+
|
| 394 |
+
# Compute BERTScore trajectory
|
| 395 |
+
bert_scores = compute_bert_alignment(src_text, outputs)
|
| 396 |
+
|
| 397 |
+
return {
|
| 398 |
+
"outputs": outputs,
|
| 399 |
+
"attention": attention_per_step,
|
| 400 |
+
"bert_scores": bert_scores,
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
# ============================================================
|
| 405 |
+
# ------------------ BERTScore -------------------------------
|
| 406 |
+
# ============================================================
|
| 407 |
+
|
| 408 |
+
def compute_bert_alignment(src_text: str, outputs: Dict[int, str]):
|
| 409 |
+
"""
|
| 410 |
+
Compute BERTScore between source and each intermediate output
|
| 411 |
+
"""
|
| 412 |
+
scores = {}
|
| 413 |
+
|
| 414 |
+
for t, text in outputs.items():
|
| 415 |
+
P, R, F1 = bertscore([text], [src_text], lang="hi", verbose=False)
|
| 416 |
+
scores[t] = float(F1.mean())
|
| 417 |
+
|
| 418 |
+
return scores
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# ============================================================
|
| 422 |
+
# ------------------ SEMANTIC DRIFT --------------------------
|
| 423 |
+
# ============================================================
|
| 424 |
+
|
| 425 |
+
def compute_semantic_drift(bert_scores: Dict[int, float]):
|
| 426 |
+
"""
|
| 427 |
+
Drift = drop from best alignment
|
| 428 |
+
"""
|
| 429 |
+
max_score = max(bert_scores.values())
|
| 430 |
+
drift = {t: max_score - s for t, s in bert_scores.items()}
|
| 431 |
+
return drift
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# ============================================================
|
| 435 |
+
# ------------------ ATTENTION STABILITY ---------------------
|
| 436 |
+
# ============================================================
|
| 437 |
+
|
| 438 |
+
def compute_attention_stability(attention_maps: Dict[int, np.ndarray]):
|
| 439 |
+
"""
|
| 440 |
+
Measures if tokens attend consistently across steps.
|
| 441 |
+
"""
|
| 442 |
+
steps = sorted(attention_maps.keys(), reverse=True)
|
| 443 |
+
|
| 444 |
+
stability_scores = []
|
| 445 |
+
|
| 446 |
+
for i in range(len(steps) - 1):
|
| 447 |
+
A = attention_maps[steps[i]]
|
| 448 |
+
B = attention_maps[steps[i+1]]
|
| 449 |
+
|
| 450 |
+
diff = np.abs(A - B).mean()
|
| 451 |
+
stability_scores.append(diff)
|
| 452 |
+
|
| 453 |
+
return np.mean(stability_scores)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
# ============================================================
|
| 457 |
+
# ------------------ TF-IDF vs STABILITY ---------------------
|
| 458 |
+
# ============================================================
|
| 459 |
+
|
| 460 |
+
def compute_tfidf_attention_correlation(
|
| 461 |
+
src_texts: List[str],
|
| 462 |
+
attention_maps_list: List[Dict[int, np.ndarray]]
|
| 463 |
+
):
|
| 464 |
+
"""
|
| 465 |
+
Correlate TF-IDF importance with attention stability
|
| 466 |
+
"""
|
| 467 |
+
|
| 468 |
+
vectorizer = TfidfVectorizer()
|
| 469 |
+
tfidf = vectorizer.fit_transform(src_texts).toarray()
|
| 470 |
+
|
| 471 |
+
word_importance = tfidf.mean(axis=0)
|
| 472 |
+
|
| 473 |
+
stability = []
|
| 474 |
+
for attn_maps in attention_maps_list:
|
| 475 |
+
stability.append(compute_attention_stability(attn_maps))
|
| 476 |
+
|
| 477 |
+
corr = np.corrcoef(word_importance[:len(stability)], stability)[0, 1]
|
| 478 |
+
return corr
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
# ============================================================
|
| 482 |
+
# ------------------ HEATMAP VISUALIZATION -------------------
|
| 483 |
+
# ============================================================
|
| 484 |
+
|
| 485 |
+
def plot_attention_heatmap(attn: np.ndarray, title="Attention"):
|
| 486 |
+
"""
|
| 487 |
+
Plot cross-attention heatmap
|
| 488 |
+
attn: [tgt_len, src_len]
|
| 489 |
+
"""
|
| 490 |
+
plt.figure(figsize=(6,5))
|
| 491 |
+
plt.imshow(attn, aspect='auto', cmap='viridis')
|
| 492 |
+
plt.colorbar()
|
| 493 |
+
plt.title(title)
|
| 494 |
+
plt.xlabel("Source tokens")
|
| 495 |
+
plt.ylabel("Target tokens")
|
| 496 |
+
plt.show()
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def visualize_trajectory(attention_maps: Dict[int, np.ndarray]):
|
| 500 |
+
"""
|
| 501 |
+
Show attention evolution over time
|
| 502 |
+
"""
|
| 503 |
+
steps = sorted(attention_maps.keys(), reverse=True)
|
| 504 |
+
|
| 505 |
+
for t in steps[:5]: # show 5 steps
|
| 506 |
+
plot_attention_heatmap(attention_maps[t], title=f"Step t={t}")
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
# ============================================================
|
| 510 |
+
# ------------------ LOCKED vs FLEXIBLE ----------------------
|
| 511 |
+
# ============================================================
|
| 512 |
+
|
| 513 |
+
def analyze_token_behavior(attention_maps: Dict[int, np.ndarray]):
|
| 514 |
+
"""
|
| 515 |
+
Detect whether tokens are locked or flexible
|
| 516 |
+
"""
|
| 517 |
+
steps = sorted(attention_maps.keys(), reverse=True)
|
| 518 |
+
|
| 519 |
+
first = attention_maps[steps[0]]
|
| 520 |
+
last = attention_maps[steps[-1]]
|
| 521 |
+
|
| 522 |
+
diff = np.abs(first - last).mean(axis=1)
|
| 523 |
+
|
| 524 |
+
locked = np.where(diff < 0.05)[0]
|
| 525 |
+
flexible = np.where(diff >= 0.05)[0]
|
| 526 |
+
|
| 527 |
+
return {
|
| 528 |
+
"locked_tokens": locked.tolist(),
|
| 529 |
+
"flexible_tokens": flexible.tolist()
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
# ============================================================
|
| 534 |
+
# ------------------ MASTER FUNCTION -------------------------
|
| 535 |
+
# ============================================================
|
| 536 |
+
|
| 537 |
+
def run_task2_analysis(
|
| 538 |
+
model,
|
| 539 |
+
src_tensor,
|
| 540 |
+
src_text,
|
| 541 |
+
tgt_tokenizer
|
| 542 |
+
):
|
| 543 |
+
result = capture_alignment_trajectory(
|
| 544 |
+
model, src_tensor, src_text, tgt_tokenizer
|
| 545 |
+
)
|
| 546 |
+
|
| 547 |
+
drift = compute_semantic_drift(result["bert_scores"])
|
| 548 |
+
stability = compute_attention_stability(result["attention"])
|
| 549 |
+
behavior = analyze_token_behavior(result["attention"])
|
| 550 |
+
|
| 551 |
+
print("\nBERTScore trajectory:")
|
| 552 |
+
print(result["bert_scores"])
|
| 553 |
+
|
| 554 |
+
print("\nSemantic drift:")
|
| 555 |
+
print(drift)
|
| 556 |
+
|
| 557 |
+
print(f"\nAttention stability: {stability:.4f}")
|
| 558 |
+
|
| 559 |
+
print("\nToken behavior:")
|
| 560 |
+
print(behavior)
|
| 561 |
+
|
| 562 |
+
visualize_trajectory(result["attention"])
|
| 563 |
+
|
| 564 |
+
return {
|
| 565 |
+
"trajectory": result,
|
| 566 |
+
"drift": drift,
|
| 567 |
+
"stability": stability,
|
| 568 |
+
"behavior": behavior
|
| 569 |
+
}
|
analysis/step_ablation.py
ADDED
|
@@ -0,0 +1,582 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# """
|
| 2 |
+
# analysis/step_ablation.py
|
| 3 |
+
# ==========================
|
| 4 |
+
# Task 4: Semantic Robustness — Ablation of Diffusion Steps vs Meaning Preservation
|
| 5 |
+
#
|
| 6 |
+
# Two-phase workflow (retraining IS required for different T values):
|
| 7 |
+
#
|
| 8 |
+
# PHASE 1 — Generate configs + train (run once per T value):
|
| 9 |
+
# python analysis/step_ablation.py --phase generate_configs
|
| 10 |
+
# # Creates configs: ablation_configs/T4.py, T8.py, T16.py, T32.py, T64.py
|
| 11 |
+
# # Then train each: MODEL_TYPE=d3pm_cross_attention python train.py (for each config)
|
| 12 |
+
#
|
| 13 |
+
# PHASE 2 — Analyze trained models (no retraining needed):
|
| 14 |
+
# python analysis/step_ablation.py --phase analyze
|
| 15 |
+
# # Loads each trained model, generates 200 paraphrases, computes CER
|
| 16 |
+
# # Produces 3D plot: X=steps, Y=generation_speed, Z=CER
|
| 17 |
+
#
|
| 18 |
+
# Why retraining is needed:
|
| 19 |
+
# A model trained with T=128 learns to denoise from x_t~Uniform[0,128].
|
| 20 |
+
# Running it with T=4 means the model only sees t∈{0,1,2,3} — which it
|
| 21 |
+
# was never trained on at those scales. Outputs are meaningless.
|
| 22 |
+
# You must train a separate model for each T value.
|
| 23 |
+
#
|
| 24 |
+
# Also implements adversarial robustness test (no retraining):
|
| 25 |
+
# Takes your existing T=128 model and tests whether corrupted IAST
|
| 26 |
+
# inputs (typos, character swaps) cause proportional output degradation.
|
| 27 |
+
# """
|
| 28 |
+
#
|
| 29 |
+
# import torch
|
| 30 |
+
# import torch.nn.functional as F
|
| 31 |
+
# import numpy as np
|
| 32 |
+
# import os
|
| 33 |
+
# import sys
|
| 34 |
+
# import time
|
| 35 |
+
# import json
|
| 36 |
+
# import copy
|
| 37 |
+
# from typing import List, Dict, Optional
|
| 38 |
+
#
|
| 39 |
+
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 40 |
+
#
|
| 41 |
+
#
|
| 42 |
+
# # ── Phase 1: Config generation ────────────────────────────────────────
|
| 43 |
+
#
|
| 44 |
+
# T_VALUES = [4, 8, 16, 32, 64]
|
| 45 |
+
#
|
| 46 |
+
# def generate_ablation_configs(base_config_path: str = "config.py",
|
| 47 |
+
# output_dir: str = "ablation_configs"):
|
| 48 |
+
# """
|
| 49 |
+
# Generate one config file per T value.
|
| 50 |
+
# Each config is a copy of the base config with diffusion_steps changed.
|
| 51 |
+
#
|
| 52 |
+
# After running this, train each model:
|
| 53 |
+
# for T in 4 8 16 32 64; do
|
| 54 |
+
# cp ablation_configs/config_T${T}.py config.py
|
| 55 |
+
# python train.py
|
| 56 |
+
# mv results7/d3pm_cross_attention_neg_False \
|
| 57 |
+
# ablation_results/T${T}
|
| 58 |
+
# done
|
| 59 |
+
# """
|
| 60 |
+
# os.makedirs(output_dir, exist_ok=True)
|
| 61 |
+
#
|
| 62 |
+
# # Read base config
|
| 63 |
+
# with open(base_config_path, "r") as f:
|
| 64 |
+
# base_src = f.read()
|
| 65 |
+
#
|
| 66 |
+
# for T in T_VALUES:
|
| 67 |
+
# # Replace diffusion_steps and num_steps
|
| 68 |
+
# cfg_src = base_src
|
| 69 |
+
# cfg_src = cfg_src.replace(
|
| 70 |
+
# '"diffusion_steps": 128',
|
| 71 |
+
# f'"diffusion_steps": {T}'
|
| 72 |
+
# )
|
| 73 |
+
# cfg_src = cfg_src.replace(
|
| 74 |
+
# "'diffusion_steps': 128",
|
| 75 |
+
# f"'diffusion_steps': {T}"
|
| 76 |
+
# )
|
| 77 |
+
# cfg_src = cfg_src.replace(
|
| 78 |
+
# '"num_steps": 128',
|
| 79 |
+
# f'"num_steps": {T}'
|
| 80 |
+
# )
|
| 81 |
+
# cfg_src = cfg_src.replace(
|
| 82 |
+
# "'num_steps': 128",
|
| 83 |
+
# f"'num_steps': {T}"
|
| 84 |
+
# )
|
| 85 |
+
# out_path = os.path.join(output_dir, f"config_T{T}.py")
|
| 86 |
+
# with open(out_path, "w") as f:
|
| 87 |
+
# f.write(f"# Ablation config: T={T} diffusion steps\n")
|
| 88 |
+
# f.write(cfg_src)
|
| 89 |
+
# print(f" Wrote: {out_path}")
|
| 90 |
+
#
|
| 91 |
+
# # Write a shell script to train all
|
| 92 |
+
# shell_script = os.path.join(output_dir, "train_all.sh")
|
| 93 |
+
# with open(shell_script, "w") as f:
|
| 94 |
+
# f.write("#!/bin/bash\n")
|
| 95 |
+
# f.write("# Run this script to train all ablation models\n\n")
|
| 96 |
+
# for T in T_VALUES:
|
| 97 |
+
# f.write(f"echo '=== Training T={T} ==='\n")
|
| 98 |
+
# f.write(f"cp {output_dir}/config_T{T}.py config.py\n")
|
| 99 |
+
# f.write(f"python train.py\n")
|
| 100 |
+
# f.write(f"mkdir -p ablation_results/T{T}\n")
|
| 101 |
+
# f.write(f"cp -r results7/d3pm_cross_attention_neg_False/best_model.pt "
|
| 102 |
+
# f"ablation_results/T{T}/best_model.pt\n")
|
| 103 |
+
# f.write(f"cp -r results7/d3pm_cross_attention_neg_False/train.log "
|
| 104 |
+
# f"ablation_results/T{T}/train.log\n\n")
|
| 105 |
+
# os.chmod(shell_script, 0o755)
|
| 106 |
+
# print(f"\nTraining script: {shell_script}")
|
| 107 |
+
# print(f"Run: bash {shell_script}")
|
| 108 |
+
#
|
| 109 |
+
#
|
| 110 |
+
# # ── Phase 2: Analysis (after models are trained) ──────────────────────
|
| 111 |
+
#
|
| 112 |
+
# def compute_cer(pred: str, ref: str) -> float:
|
| 113 |
+
# if not ref:
|
| 114 |
+
# return 1.0
|
| 115 |
+
#
|
| 116 |
+
# def edit_distance(s1, s2):
|
| 117 |
+
# m, n = len(s1), len(s2)
|
| 118 |
+
# dp = list(range(n + 1))
|
| 119 |
+
# for i in range(1, m + 1):
|
| 120 |
+
# prev, dp[0] = dp[0], i
|
| 121 |
+
# for j in range(1, n + 1):
|
| 122 |
+
# temp = dp[j]
|
| 123 |
+
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
|
| 124 |
+
# prev = temp
|
| 125 |
+
# return dp[n]
|
| 126 |
+
#
|
| 127 |
+
# return edit_distance(pred, ref) / max(len(ref), 1)
|
| 128 |
+
#
|
| 129 |
+
#
|
| 130 |
+
# def evaluate_model(
|
| 131 |
+
# model,
|
| 132 |
+
# src_list: List[torch.Tensor],
|
| 133 |
+
# ref_list: List[str],
|
| 134 |
+
# tgt_tokenizer,
|
| 135 |
+
# n_samples: int = 200,
|
| 136 |
+
# temperature: float = 0.8,
|
| 137 |
+
# top_k: int = 40,
|
| 138 |
+
# ) -> Dict:
|
| 139 |
+
# """
|
| 140 |
+
# Generate n_samples outputs and compute CER + generation speed.
|
| 141 |
+
#
|
| 142 |
+
# Returns dict with:
|
| 143 |
+
# mean_cer : average CER over samples
|
| 144 |
+
# generation_s : total wall-clock seconds for all generations
|
| 145 |
+
# speed_per_sample: seconds per sample
|
| 146 |
+
# cer_list : per-sample CER values
|
| 147 |
+
# """
|
| 148 |
+
# device = next(model.parameters()).device
|
| 149 |
+
# n = min(n_samples, len(src_list))
|
| 150 |
+
# cer_list = []
|
| 151 |
+
#
|
| 152 |
+
# start = time.perf_counter()
|
| 153 |
+
# for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
|
| 154 |
+
# if src.dim() == 1:
|
| 155 |
+
# src = src.unsqueeze(0)
|
| 156 |
+
#
|
| 157 |
+
# with torch.no_grad():
|
| 158 |
+
# if hasattr(model.model, 'generate_cached'):
|
| 159 |
+
# out = model.model.generate_cached(
|
| 160 |
+
# src.to(device), temperature=temperature, top_k=top_k
|
| 161 |
+
# )
|
| 162 |
+
# else:
|
| 163 |
+
# out = model.generate(
|
| 164 |
+
# src.to(device), temperature=temperature, top_k=top_k
|
| 165 |
+
# )
|
| 166 |
+
#
|
| 167 |
+
# ids = [x for x in out[0].tolist() if x > 4]
|
| 168 |
+
# pred = tgt_tokenizer.decode(ids).strip()
|
| 169 |
+
# cer = compute_cer(pred, ref)
|
| 170 |
+
# cer_list.append(cer)
|
| 171 |
+
#
|
| 172 |
+
# elapsed = time.perf_counter() - start
|
| 173 |
+
#
|
| 174 |
+
# return {
|
| 175 |
+
# "mean_cer": float(np.mean(cer_list)),
|
| 176 |
+
# "std_cer": float(np.std(cer_list)),
|
| 177 |
+
# "generation_s": elapsed,
|
| 178 |
+
# "speed_per_sample": elapsed / max(n, 1),
|
| 179 |
+
# "cer_list": cer_list,
|
| 180 |
+
# "n_samples": n,
|
| 181 |
+
# }
|
| 182 |
+
#
|
| 183 |
+
#
|
| 184 |
+
# def run_ablation_analysis(
|
| 185 |
+
# ablation_dir: str = "ablation_results",
|
| 186 |
+
# base_cfg: dict = None,
|
| 187 |
+
# src_list: List[torch.Tensor] = None,
|
| 188 |
+
# ref_list: List[str] = None,
|
| 189 |
+
# tgt_tokenizer = None,
|
| 190 |
+
# device: torch.device = None,
|
| 191 |
+
# output_dir: str = "analysis/outputs",
|
| 192 |
+
# ) -> Dict:
|
| 193 |
+
# """
|
| 194 |
+
# Load each trained model and evaluate.
|
| 195 |
+
# Produces results dict and 3D plot.
|
| 196 |
+
#
|
| 197 |
+
# Expects ablation_results/T{N}/best_model.pt for each T in T_VALUES.
|
| 198 |
+
# """
|
| 199 |
+
# from inference import load_model
|
| 200 |
+
#
|
| 201 |
+
# results = {}
|
| 202 |
+
# for T in T_VALUES:
|
| 203 |
+
# ckpt = os.path.join(ablation_dir, f"T{T}", "best_model.pt")
|
| 204 |
+
# if not os.path.exists(ckpt):
|
| 205 |
+
# print(f" SKIP T={T}: no checkpoint at {ckpt}")
|
| 206 |
+
# continue
|
| 207 |
+
#
|
| 208 |
+
# print(f"\nEvaluating T={T}...")
|
| 209 |
+
# cfg_T = copy.deepcopy(base_cfg)
|
| 210 |
+
# cfg_T['model']['diffusion_steps'] = T
|
| 211 |
+
# cfg_T['inference']['num_steps'] = T
|
| 212 |
+
#
|
| 213 |
+
# model, cfg_T = load_model(ckpt, cfg_T, device)
|
| 214 |
+
# model.eval()
|
| 215 |
+
#
|
| 216 |
+
# metrics = evaluate_model(
|
| 217 |
+
# model, src_list, ref_list, tgt_tokenizer, n_samples=200
|
| 218 |
+
# )
|
| 219 |
+
# results[T] = metrics
|
| 220 |
+
# print(f" T={T} CER={metrics['mean_cer']:.4f} "
|
| 221 |
+
# f"speed={metrics['speed_per_sample']:.3f}s/sample")
|
| 222 |
+
#
|
| 223 |
+
# del model
|
| 224 |
+
#
|
| 225 |
+
# # Save results
|
| 226 |
+
# os.makedirs(output_dir, exist_ok=True)
|
| 227 |
+
# results_path = os.path.join(output_dir, "ablation_results.json")
|
| 228 |
+
# with open(results_path, "w") as f:
|
| 229 |
+
# json.dump({str(k): {kk: vv for kk, vv in v.items() if kk != 'cer_list'}
|
| 230 |
+
# for k, v in results.items()}, f, indent=2)
|
| 231 |
+
# print(f"\nResults saved: {results_path}")
|
| 232 |
+
#
|
| 233 |
+
# return results
|
| 234 |
+
#
|
| 235 |
+
#
|
| 236 |
+
# def plot_ablation_3d(
|
| 237 |
+
# results: Dict,
|
| 238 |
+
# save_path: Optional[str] = None,
|
| 239 |
+
# ):
|
| 240 |
+
# """
|
| 241 |
+
# 3D plot: X=diffusion_steps, Y=generation_speed(s/sample), Z=CER.
|
| 242 |
+
# Also produces a 2D summary plot.
|
| 243 |
+
# """
|
| 244 |
+
# try:
|
| 245 |
+
# import matplotlib.pyplot as plt
|
| 246 |
+
# from mpl_toolkits.mplot3d import Axes3D
|
| 247 |
+
# except ImportError:
|
| 248 |
+
# print("pip install matplotlib.")
|
| 249 |
+
# return
|
| 250 |
+
#
|
| 251 |
+
# T_list = sorted(results.keys())
|
| 252 |
+
# cers = [results[T]["mean_cer"] for T in T_list]
|
| 253 |
+
# speeds = [results[T]["speed_per_sample"] for T in T_list]
|
| 254 |
+
#
|
| 255 |
+
# # ── 3D plot ───────────────────────────────────────────────────────
|
| 256 |
+
# fig = plt.figure(figsize=(14, 5))
|
| 257 |
+
#
|
| 258 |
+
# ax3d = fig.add_subplot(121, projection='3d')
|
| 259 |
+
# ax3d.scatter(T_list, speeds, cers, c=cers, cmap='RdYlGn_r', s=80)
|
| 260 |
+
# for T, s, c in zip(T_list, speeds, cers):
|
| 261 |
+
# ax3d.text(T, s, c, f"T={T}", fontsize=8)
|
| 262 |
+
# ax3d.set_xlabel("Diffusion steps T", fontsize=9)
|
| 263 |
+
# ax3d.set_ylabel("Speed (s/sample)", fontsize=9)
|
| 264 |
+
# ax3d.set_zlabel("CER (↓ better)", fontsize=9)
|
| 265 |
+
# ax3d.set_title("T vs speed vs CER", fontsize=10)
|
| 266 |
+
#
|
| 267 |
+
# # ── 2D CER vs T (find the knee) ──────────────────────────────────
|
| 268 |
+
# ax2d = fig.add_subplot(122)
|
| 269 |
+
# ax2d.plot(T_list, cers, 'o-', linewidth=1.8, color='coral', markersize=7)
|
| 270 |
+
# for T, c in zip(T_list, cers):
|
| 271 |
+
# ax2d.annotate(f"{c:.3f}", (T, c), textcoords="offset points",
|
| 272 |
+
# xytext=(0, 8), fontsize=8, ha='center')
|
| 273 |
+
#
|
| 274 |
+
# # Find knee: largest CER drop per unit T (elbow method)
|
| 275 |
+
# if len(T_list) >= 3:
|
| 276 |
+
# drops = [cers[i] - cers[i+1] for i in range(len(cers)-1)]
|
| 277 |
+
# knee_i = int(np.argmax(drops))
|
| 278 |
+
# knee_T = T_list[knee_i + 1]
|
| 279 |
+
# ax2d.axvline(knee_T, color='steelblue', linestyle='--', linewidth=1.2,
|
| 280 |
+
# label=f"Knee at T={knee_T}")
|
| 281 |
+
# ax2d.legend(fontsize=9)
|
| 282 |
+
#
|
| 283 |
+
# ax2d.set_xlabel("Diffusion steps T", fontsize=10)
|
| 284 |
+
# ax2d.set_ylabel("CER (lower = better)", fontsize=10)
|
| 285 |
+
# ax2d.set_title("CER vs diffusion steps", fontsize=10)
|
| 286 |
+
# ax2d.set_ylim(0, max(cers) * 1.1)
|
| 287 |
+
#
|
| 288 |
+
# plt.tight_layout()
|
| 289 |
+
# if save_path:
|
| 290 |
+
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
|
| 291 |
+
# plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
| 292 |
+
# print(f"Saved: {save_path}")
|
| 293 |
+
# else:
|
| 294 |
+
# plt.show()
|
| 295 |
+
# plt.close()
|
| 296 |
+
#
|
| 297 |
+
#
|
| 298 |
+
# # ── Adversarial robustness test (no retraining needed) ───────────────
|
| 299 |
+
#
|
| 300 |
+
# def corrupt_iast(text: str, corruption_rate: float = 0.05) -> str:
|
| 301 |
+
# """
|
| 302 |
+
# Introduce random corruption into IAST text:
|
| 303 |
+
# - Character swap (adjacent chars swapped)
|
| 304 |
+
# - Character deletion
|
| 305 |
+
# - Random character insertion
|
| 306 |
+
#
|
| 307 |
+
# Models rate as 5% to 20% corruption to test robustness.
|
| 308 |
+
# """
|
| 309 |
+
# import random
|
| 310 |
+
# chars = list(text)
|
| 311 |
+
# n_corrupt = max(1, int(len(chars) * corruption_rate))
|
| 312 |
+
#
|
| 313 |
+
# for _ in range(n_corrupt):
|
| 314 |
+
# op = random.choice(['swap', 'delete', 'insert'])
|
| 315 |
+
# pos = random.randint(0, len(chars) - 1)
|
| 316 |
+
#
|
| 317 |
+
# if op == 'swap' and pos < len(chars) - 1:
|
| 318 |
+
# chars[pos], chars[pos+1] = chars[pos+1], chars[pos]
|
| 319 |
+
# elif op == 'delete' and len(chars) > 1:
|
| 320 |
+
# chars.pop(pos)
|
| 321 |
+
# elif op == 'insert':
|
| 322 |
+
# chars.insert(pos, random.choice('abcdeimnostu'))
|
| 323 |
+
#
|
| 324 |
+
# return "".join(chars)
|
| 325 |
+
#
|
| 326 |
+
#
|
| 327 |
+
# @torch.no_grad()
|
| 328 |
+
# def run_adversarial_test(
|
| 329 |
+
# model,
|
| 330 |
+
# src_tokenizer,
|
| 331 |
+
# tgt_tokenizer,
|
| 332 |
+
# test_inputs: List[str],
|
| 333 |
+
# test_refs: List[str],
|
| 334 |
+
# corruption_rates: List[float] = [0.0, 0.05, 0.10, 0.15, 0.20],
|
| 335 |
+
# device: torch.device = None,
|
| 336 |
+
# output_dir: str = "analysis/outputs",
|
| 337 |
+
# ) -> Dict:
|
| 338 |
+
# """
|
| 339 |
+
# Test if CER degrades proportionally with IAST corruption.
|
| 340 |
+
# Uses existing trained model — no retraining.
|
| 341 |
+
# """
|
| 342 |
+
# device = device or next(model.parameters()).device
|
| 343 |
+
# results = {}
|
| 344 |
+
#
|
| 345 |
+
# print("\nAdversarial robustness test...")
|
| 346 |
+
# for rate in corruption_rates:
|
| 347 |
+
# cer_list = []
|
| 348 |
+
# for text, ref in zip(test_inputs, test_refs):
|
| 349 |
+
# corrupted = corrupt_iast(text, rate)
|
| 350 |
+
# ids = src_tokenizer.encode(corrupted)
|
| 351 |
+
# src = torch.tensor([ids], dtype=torch.long, device=device)
|
| 352 |
+
#
|
| 353 |
+
# if hasattr(model.model, 'generate_cached'):
|
| 354 |
+
# out = model.model.generate_cached(src)
|
| 355 |
+
# else:
|
| 356 |
+
# out = model.generate(src)
|
| 357 |
+
#
|
| 358 |
+
# pred_ids = [x for x in out[0].tolist() if x > 4]
|
| 359 |
+
# pred = tgt_tokenizer.decode(pred_ids).strip()
|
| 360 |
+
# cer_list.append(compute_cer(pred, ref))
|
| 361 |
+
#
|
| 362 |
+
# mean_cer = float(np.mean(cer_list))
|
| 363 |
+
# results[rate] = mean_cer
|
| 364 |
+
# print(f" corruption={rate*100:.0f}% → CER={mean_cer:.4f}")
|
| 365 |
+
#
|
| 366 |
+
# # Save + plot
|
| 367 |
+
# os.makedirs(output_dir, exist_ok=True)
|
| 368 |
+
# try:
|
| 369 |
+
# import matplotlib.pyplot as plt
|
| 370 |
+
# fig, ax = plt.subplots(figsize=(8, 4))
|
| 371 |
+
# rates = [r * 100 for r in corruption_rates]
|
| 372 |
+
# cers = [results[r] for r in corruption_rates]
|
| 373 |
+
# ax.plot(rates, cers, 'o-', linewidth=1.8, color='steelblue', markersize=7)
|
| 374 |
+
# ax.set_xlabel("IAST corruption rate (%)", fontsize=11)
|
| 375 |
+
# ax.set_ylabel("CER", fontsize=11)
|
| 376 |
+
# ax.set_title("Model robustness to IAST input corruption", fontsize=11)
|
| 377 |
+
# ax.set_ylim(0, max(cers) * 1.2)
|
| 378 |
+
# plt.tight_layout()
|
| 379 |
+
# plt.savefig(os.path.join(output_dir, "adversarial_robustness.png"),
|
| 380 |
+
# dpi=150, bbox_inches='tight')
|
| 381 |
+
# plt.close()
|
| 382 |
+
# print(f" Saved: {output_dir}/adversarial_robustness.png")
|
| 383 |
+
# except ImportError:
|
| 384 |
+
# pass
|
| 385 |
+
#
|
| 386 |
+
# with open(os.path.join(output_dir, "adversarial_results.json"), "w") as f:
|
| 387 |
+
# json.dump({str(k): v for k, v in results.items()}, f, indent=2)
|
| 388 |
+
#
|
| 389 |
+
# return results
|
| 390 |
+
"""
|
| 391 |
+
analysis/task4_pipeline.py
|
| 392 |
+
================================
|
| 393 |
+
Correct Task 4 Pipeline:
|
| 394 |
+
|
| 395 |
+
PHASE 1 → Evaluate all models
|
| 396 |
+
PHASE 2 → Analyze + detect optimal T
|
| 397 |
+
|
| 398 |
+
NO early decision making.
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
import torch
|
| 402 |
+
import numpy as np
|
| 403 |
+
import time
|
| 404 |
+
import os
|
| 405 |
+
import json
|
| 406 |
+
from typing import Dict, List
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
# ───────────────────────��─────────────────────
|
| 410 |
+
# Load Metrics
|
| 411 |
+
# ─────────────────────────────────────────────
|
| 412 |
+
|
| 413 |
+
def load_metrics():
|
| 414 |
+
from bert_score import score as bert_score
|
| 415 |
+
from sentence_transformers import SentenceTransformer, util
|
| 416 |
+
from nltk.translate.bleu_score import sentence_bleu
|
| 417 |
+
|
| 418 |
+
st_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 419 |
+
return bert_score, st_model, util, sentence_bleu
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
# ─────────────────────────────────────────────
|
| 423 |
+
# PHASE 1 — Evaluate ALL models
|
| 424 |
+
# ─────────────────────────────────────────────
|
| 425 |
+
|
| 426 |
+
def evaluate_all_models(models: Dict[int, object],
|
| 427 |
+
src_list,
|
| 428 |
+
ref_list,
|
| 429 |
+
tgt_tokenizer,
|
| 430 |
+
n_samples=200):
|
| 431 |
+
|
| 432 |
+
bert_score_fn, st_model, util, bleu_fn = load_metrics()
|
| 433 |
+
|
| 434 |
+
results = {}
|
| 435 |
+
|
| 436 |
+
print("\n=== PHASE 1: Evaluating ALL models ===")
|
| 437 |
+
|
| 438 |
+
for T, model in sorted(models.items()):
|
| 439 |
+
print(f"\nEvaluating T={T}...")
|
| 440 |
+
|
| 441 |
+
device = next(model.parameters()).device
|
| 442 |
+
preds, refs = [], []
|
| 443 |
+
|
| 444 |
+
start = time.perf_counter()
|
| 445 |
+
|
| 446 |
+
for src, ref in zip(src_list[:n_samples], ref_list[:n_samples]):
|
| 447 |
+
if src.dim() == 1:
|
| 448 |
+
src = src.unsqueeze(0)
|
| 449 |
+
|
| 450 |
+
with torch.no_grad():
|
| 451 |
+
out = model.model.generate_cached(src.to(device))
|
| 452 |
+
|
| 453 |
+
ids = [x for x in out[0].tolist() if x > 4]
|
| 454 |
+
pred = tgt_tokenizer.decode(ids).strip()
|
| 455 |
+
|
| 456 |
+
preds.append(pred)
|
| 457 |
+
refs.append(ref)
|
| 458 |
+
|
| 459 |
+
elapsed = time.perf_counter() - start
|
| 460 |
+
|
| 461 |
+
# BERTScore
|
| 462 |
+
P, R, F1 = bert_score_fn(preds, refs, lang="hi", verbose=False)
|
| 463 |
+
bert_f1 = float(F1.mean())
|
| 464 |
+
|
| 465 |
+
# Sentence similarity
|
| 466 |
+
emb_p = st_model.encode(preds, convert_to_tensor=True)
|
| 467 |
+
emb_r = st_model.encode(refs, convert_to_tensor=True)
|
| 468 |
+
sim = util.cos_sim(emb_p, emb_r).diagonal().mean().item()
|
| 469 |
+
|
| 470 |
+
# BLEU
|
| 471 |
+
bleu_scores = [
|
| 472 |
+
bleu_fn([r.split()], p.split())
|
| 473 |
+
for p, r in zip(preds, refs)
|
| 474 |
+
]
|
| 475 |
+
|
| 476 |
+
results[T] = {
|
| 477 |
+
"bertscore_f1": bert_f1,
|
| 478 |
+
"semantic_sim": sim,
|
| 479 |
+
"bleu": float(np.mean(bleu_scores)),
|
| 480 |
+
"speed_per_sample": elapsed / n_samples
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
print(f" BERTScore: {bert_f1:.4f}")
|
| 484 |
+
print(f" Sim: {sim:.4f}")
|
| 485 |
+
print(f" BLEU: {results[T]['bleu']:.4f}")
|
| 486 |
+
print(f" Speed: {results[T]['speed_per_sample']:.4f}s")
|
| 487 |
+
|
| 488 |
+
# Save raw results
|
| 489 |
+
os.makedirs("analysis/outputs", exist_ok=True)
|
| 490 |
+
with open("analysis/outputs/task4_raw_results.json", "w") as f:
|
| 491 |
+
json.dump(results, f, indent=2)
|
| 492 |
+
|
| 493 |
+
return results
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
# ─────────────────────────────────────────────
|
| 497 |
+
# PHASE 2 — Analyze results (Knee Detection)
|
| 498 |
+
# ─────────────────────────────────────────────
|
| 499 |
+
|
| 500 |
+
def analyze_results(results: Dict):
|
| 501 |
+
print("\n=== PHASE 2: Analysis ===")
|
| 502 |
+
|
| 503 |
+
T_list = sorted(results.keys())
|
| 504 |
+
scores = [results[T]["bertscore_f1"] for T in T_list]
|
| 505 |
+
|
| 506 |
+
gains = [scores[i+1] - scores[i] for i in range(len(scores)-1)]
|
| 507 |
+
|
| 508 |
+
print("\nMarginal Gains:")
|
| 509 |
+
for i, g in enumerate(gains):
|
| 510 |
+
print(f" T{T_list[i]} → T{T_list[i+1]}: +{g:.4f}")
|
| 511 |
+
|
| 512 |
+
# Knee detection
|
| 513 |
+
threshold = 0.02
|
| 514 |
+
knee_T = T_list[-1]
|
| 515 |
+
|
| 516 |
+
for i, g in enumerate(gains):
|
| 517 |
+
if g < threshold:
|
| 518 |
+
knee_T = T_list[i+1]
|
| 519 |
+
break
|
| 520 |
+
|
| 521 |
+
print(f"\n✅ Optimal T (knee detected): {knee_T}")
|
| 522 |
+
|
| 523 |
+
return knee_T, gains
|
| 524 |
+
|
| 525 |
+
|
| 526 |
+
# ─────────────────────────────────────────────
|
| 527 |
+
# 3D Plot (BERTScore)
|
| 528 |
+
# ─────────────────────────────────────────────
|
| 529 |
+
|
| 530 |
+
def plot_3d(results):
|
| 531 |
+
import matplotlib.pyplot as plt
|
| 532 |
+
from mpl_toolkits.mplot3d import Axes3D
|
| 533 |
+
|
| 534 |
+
T_list = sorted(results.keys())
|
| 535 |
+
|
| 536 |
+
X = T_list
|
| 537 |
+
Y = [results[T]["speed_per_sample"] for T in T_list]
|
| 538 |
+
Z = [results[T]["bertscore_f1"] for T in T_list]
|
| 539 |
+
|
| 540 |
+
fig = plt.figure(figsize=(10, 6))
|
| 541 |
+
ax = fig.add_subplot(111, projection='3d')
|
| 542 |
+
|
| 543 |
+
ax.scatter(X, Y, Z)
|
| 544 |
+
|
| 545 |
+
for x, y, z in zip(X, Y, Z):
|
| 546 |
+
ax.text(x, y, z, f"T={x}", fontsize=8)
|
| 547 |
+
|
| 548 |
+
ax.set_xlabel("Diffusion Steps")
|
| 549 |
+
ax.set_ylabel("Speed")
|
| 550 |
+
ax.set_zlabel("BERTScore")
|
| 551 |
+
|
| 552 |
+
plt.title("3D Tradeoff: Steps vs Speed vs Quality")
|
| 553 |
+
|
| 554 |
+
os.makedirs("analysis/outputs", exist_ok=True)
|
| 555 |
+
plt.savefig("analysis/outputs/task4_3d.png")
|
| 556 |
+
plt.close()
|
| 557 |
+
|
| 558 |
+
print("Saved 3D plot")
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
# ────────────���────────────────────────────────
|
| 562 |
+
# FINAL RUNNER
|
| 563 |
+
# ─────────────────────────────────────────────
|
| 564 |
+
|
| 565 |
+
def run_task4(models, src_list, ref_list, tgt_tokenizer):
|
| 566 |
+
|
| 567 |
+
# Phase 1: Evaluate all
|
| 568 |
+
results = evaluate_all_models(
|
| 569 |
+
models, src_list, ref_list, tgt_tokenizer
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
# Phase 2: Analyze
|
| 573 |
+
knee_T, gains = analyze_results(results)
|
| 574 |
+
|
| 575 |
+
# Plot
|
| 576 |
+
plot_3d(results)
|
| 577 |
+
|
| 578 |
+
# Save report
|
| 579 |
+
with open("analysis/outputs/task4_report.txt", "w") as f:
|
| 580 |
+
f.write(f"Optimal diffusion steps = {knee_T}\n")
|
| 581 |
+
|
| 582 |
+
return knee_T
|
app.py
CHANGED
|
@@ -1,235 +1,547 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Hugging Face Space app for Sanskrit D3PM project.
|
| 3 |
-
|
| 4 |
-
Deploy on Spaces with:
|
| 5 |
-
app_file = app_hf_space.py
|
| 6 |
-
|
| 7 |
-
Optional environment variables:
|
| 8 |
-
HF_CHECKPOINT_REPO : model repo id (e.g. "username/sanskrit-d3pm")
|
| 9 |
-
HF_CHECKPOINT_FILE : checkpoint path in repo (default: "best_model.pt")
|
| 10 |
-
HF_CHECKPOINT_LABEL : UI label for remote checkpoint
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
from __future__ import annotations
|
| 14 |
-
|
| 15 |
import copy
|
|
|
|
| 16 |
import os
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
import gradio as gr
|
| 20 |
import torch
|
|
|
|
| 21 |
|
| 22 |
from config import CONFIG
|
| 23 |
from inference import _build_tokenizers, _resolve_device, load_model, run_inference
|
| 24 |
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
return text
|
| 30 |
-
toks = text.split()
|
| 31 |
-
out = []
|
| 32 |
-
prev = None
|
| 33 |
-
run = 0
|
| 34 |
-
for t in toks:
|
| 35 |
-
if t == prev:
|
| 36 |
-
run += 1
|
| 37 |
-
else:
|
| 38 |
-
prev = t
|
| 39 |
-
run = 1
|
| 40 |
-
if run <= max_repeat:
|
| 41 |
-
out.append(t)
|
| 42 |
-
s = " ".join(out)
|
| 43 |
-
s = s.replace(" ।", "।").replace(" ॥", "॥")
|
| 44 |
-
return " ".join(s.split())
|
| 45 |
|
| 46 |
|
| 47 |
-
def
|
| 48 |
-
found =
|
| 49 |
for root in ("ablation_results", "results7", "results"):
|
| 50 |
if not os.path.isdir(root):
|
| 51 |
continue
|
| 52 |
-
for
|
| 53 |
-
ckpt = os.path.join(root,
|
| 54 |
-
if os.path.exists(ckpt):
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
return found
|
| 57 |
|
| 58 |
|
| 59 |
-
def
|
| 60 |
-
|
| 61 |
-
if not repo:
|
| 62 |
-
return {}
|
| 63 |
-
|
| 64 |
-
filename = os.getenv("HF_CHECKPOINT_FILE", "best_model.pt").strip()
|
| 65 |
-
label = os.getenv("HF_CHECKPOINT_LABEL", f"remote:{repo}")
|
| 66 |
|
| 67 |
-
try:
|
| 68 |
-
from huggingface_hub import hf_hub_download
|
| 69 |
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
-
def
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
| 80 |
return "d3pm_encoder_decoder"
|
| 81 |
-
if "baseline_cross_attention"
|
| 82 |
return "baseline_cross_attention"
|
| 83 |
-
if "baseline_encoder_decoder"
|
| 84 |
return "baseline_encoder_decoder"
|
| 85 |
-
return "
|
| 86 |
|
| 87 |
|
| 88 |
-
def
|
| 89 |
-
|
| 90 |
-
|
|
|
|
| 91 |
return True
|
| 92 |
-
if "
|
| 93 |
return False
|
| 94 |
return CONFIG["data"]["include_negative_examples"]
|
| 95 |
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
cfg["
|
| 107 |
-
cfg["
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
cfg["inference"]["temperature"] = float(temperature)
|
| 166 |
cfg["inference"]["top_k"] = int(top_k)
|
| 167 |
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
|
| 168 |
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
|
| 169 |
cfg["inference"]["num_steps"] = int(num_steps)
|
| 170 |
|
| 171 |
-
src_tok =
|
| 172 |
-
tgt_tok =
|
| 173 |
-
device = torch.device(
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
out = run_inference(
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
|
| 184 |
-
with gr.Blocks(title="Sanskrit
|
| 185 |
model_state = gr.State(None)
|
|
|
|
| 186 |
gr.Markdown(
|
| 187 |
"""
|
| 188 |
-
|
| 189 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 190 |
"""
|
| 191 |
)
|
| 192 |
|
| 193 |
-
checkpoint = gr.Dropdown(
|
| 194 |
-
choices=list(CHECKPOINTS.keys()),
|
| 195 |
-
value=list(CHECKPOINTS.keys())[0],
|
| 196 |
-
label="Checkpoint",
|
| 197 |
-
)
|
| 198 |
-
load_btn = gr.Button("Load Model", variant="primary")
|
| 199 |
-
load_info = gr.Markdown("Select a checkpoint and click **Load Model**.")
|
| 200 |
-
|
| 201 |
-
text_in = gr.Textbox(label="Input (Roman / IAST)", lines=3, value="dharmo rakṣati rakṣitaḥ")
|
| 202 |
-
text_out = gr.Textbox(label="Output (Devanagari)", lines=6)
|
| 203 |
-
|
| 204 |
with gr.Row():
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
-
load_btn.click(load_checkpoint_ui, inputs=[checkpoint], outputs=[model_state, load_info])
|
| 215 |
generate_btn.click(
|
| 216 |
-
|
| 217 |
inputs=[
|
| 218 |
-
model_state,
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
],
|
| 221 |
-
outputs=[
|
| 222 |
)
|
| 223 |
-
|
| 224 |
-
|
| 225 |
inputs=[
|
| 226 |
-
model_state,
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
],
|
| 229 |
-
outputs=[text_out],
|
| 230 |
)
|
| 231 |
|
| 232 |
|
| 233 |
if __name__ == "__main__":
|
| 234 |
-
port = int(os.environ
|
| 235 |
demo.launch(server_name="0.0.0.0", server_port=port, share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import copy
|
| 2 |
+
import json
|
| 3 |
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
import sys
|
| 6 |
+
from datetime import datetime
|
| 7 |
|
| 8 |
import gradio as gr
|
| 9 |
import torch
|
| 10 |
+
from huggingface_hub import hf_hub_download, list_repo_files
|
| 11 |
|
| 12 |
from config import CONFIG
|
| 13 |
from inference import _build_tokenizers, _resolve_device, load_model, run_inference
|
| 14 |
|
| 15 |
|
| 16 |
+
RESULTS_DIR = "generated_results"
|
| 17 |
+
DEFAULT_ANALYSIS_OUT = "analysis/outputs"
|
| 18 |
+
os.makedirs(RESULTS_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
+
def discover_checkpoints():
|
| 22 |
+
found = []
|
| 23 |
for root in ("ablation_results", "results7", "results"):
|
| 24 |
if not os.path.isdir(root):
|
| 25 |
continue
|
| 26 |
+
for entry in sorted(os.listdir(root)):
|
| 27 |
+
ckpt = os.path.join(root, entry, "best_model.pt")
|
| 28 |
+
if not os.path.exists(ckpt):
|
| 29 |
+
continue
|
| 30 |
+
found.append(
|
| 31 |
+
{
|
| 32 |
+
"label": f"{entry} [{root}]",
|
| 33 |
+
"path": ckpt,
|
| 34 |
+
"experiment": entry,
|
| 35 |
+
"root": root,
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
repo = os.getenv("HF_CHECKPOINT_REPO", "").strip()
|
| 39 |
+
if repo:
|
| 40 |
+
branch = os.getenv("HF_CHECKPOINT_REVISION", "main").strip() or "main"
|
| 41 |
+
try:
|
| 42 |
+
for fname in list_repo_files(repo_id=repo, repo_type="model", revision=branch):
|
| 43 |
+
if not fname.endswith("/best_model.pt") and fname != "best_model.pt":
|
| 44 |
+
continue
|
| 45 |
+
local_path = hf_hub_download(repo_id=repo, filename=fname, revision=branch, repo_type="model")
|
| 46 |
+
parent = os.path.basename(os.path.dirname(fname)) if "/" in fname else "remote"
|
| 47 |
+
root = os.path.dirname(fname).split("/")[0] if "/" in fname else "remote"
|
| 48 |
+
found.append(
|
| 49 |
+
{
|
| 50 |
+
"label": f"{parent} [hf:{repo}]",
|
| 51 |
+
"path": local_path,
|
| 52 |
+
"experiment": parent,
|
| 53 |
+
"root": root,
|
| 54 |
+
}
|
| 55 |
+
)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
print(f"[WARN] Could not discover remote checkpoints from {repo}: {e}")
|
| 58 |
return found
|
| 59 |
|
| 60 |
|
| 61 |
+
def checkpoint_map():
|
| 62 |
+
return {item["label"]: item for item in discover_checkpoints()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
def default_checkpoint_label():
|
| 66 |
+
cps = discover_checkpoints()
|
| 67 |
+
if not cps:
|
| 68 |
+
return None
|
| 69 |
+
for item in cps:
|
| 70 |
+
if item["path"].endswith("ablation_results/T4/best_model.pt"):
|
| 71 |
+
return item["label"]
|
| 72 |
+
return cps[0]["label"]
|
| 73 |
|
| 74 |
|
| 75 |
+
def infer_model_type(experiment_name: str, root: str = "") -> str:
|
| 76 |
+
if root == "ablation_results":
|
| 77 |
+
return "d3pm_cross_attention"
|
| 78 |
+
if experiment_name.startswith("d3pm_cross_attention"):
|
| 79 |
+
return "d3pm_cross_attention"
|
| 80 |
+
if experiment_name.startswith("d3pm_encoder_decoder"):
|
| 81 |
return "d3pm_encoder_decoder"
|
| 82 |
+
if experiment_name.startswith("baseline_cross_attention"):
|
| 83 |
return "baseline_cross_attention"
|
| 84 |
+
if experiment_name.startswith("baseline_encoder_decoder"):
|
| 85 |
return "baseline_encoder_decoder"
|
| 86 |
+
return CONFIG["model_type"]
|
| 87 |
|
| 88 |
|
| 89 |
+
def infer_include_negative(experiment_name: str, root: str = "") -> bool:
|
| 90 |
+
if root == "ablation_results":
|
| 91 |
+
return False
|
| 92 |
+
if "_neg_True" in experiment_name:
|
| 93 |
return True
|
| 94 |
+
if "_neg_False" in experiment_name:
|
| 95 |
return False
|
| 96 |
return CONFIG["data"]["include_negative_examples"]
|
| 97 |
|
| 98 |
|
| 99 |
+
def build_runtime_cfg(ckpt_path: str):
|
| 100 |
+
experiment = os.path.basename(os.path.dirname(ckpt_path)) or "remote"
|
| 101 |
+
root = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path))) or "remote"
|
| 102 |
+
cfg = copy.deepcopy(CONFIG)
|
| 103 |
+
cfg["model_type"] = infer_model_type(experiment, root=root)
|
| 104 |
+
cfg["data"]["include_negative_examples"] = infer_include_negative(experiment, root=root)
|
| 105 |
+
|
| 106 |
+
if root == "ablation_results" and experiment.startswith("T") and experiment[1:].isdigit():
|
| 107 |
+
t_val = int(experiment[1:])
|
| 108 |
+
cfg["model"]["diffusion_steps"] = t_val
|
| 109 |
+
cfg["inference"]["num_steps"] = t_val
|
| 110 |
+
|
| 111 |
+
device = _resolve_device(cfg)
|
| 112 |
+
return cfg, device, experiment
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def load_selected_model(checkpoint_label):
|
| 116 |
+
mapping = checkpoint_map()
|
| 117 |
+
if checkpoint_label not in mapping:
|
| 118 |
+
raise gr.Error("Selected checkpoint not found. Click refresh.")
|
| 119 |
+
|
| 120 |
+
ckpt_path = mapping[checkpoint_label]["path"]
|
| 121 |
+
cfg, device, experiment = build_runtime_cfg(ckpt_path)
|
| 122 |
+
model, cfg = load_model(ckpt_path, cfg, device)
|
| 123 |
+
src_tok, tgt_tok = _build_tokenizers(cfg)
|
| 124 |
+
|
| 125 |
+
bundle = {
|
| 126 |
+
"ckpt_path": ckpt_path,
|
| 127 |
+
"experiment": experiment,
|
| 128 |
+
"device": str(device),
|
| 129 |
+
"cfg": cfg,
|
| 130 |
+
"model": model,
|
| 131 |
+
"src_tok": src_tok,
|
| 132 |
+
"tgt_tok": tgt_tok,
|
| 133 |
+
}
|
| 134 |
+
model_info = {
|
| 135 |
+
"checkpoint": ckpt_path,
|
| 136 |
+
"experiment": experiment,
|
| 137 |
+
"model_type": cfg["model_type"],
|
| 138 |
+
"include_negatives": cfg["data"]["include_negative_examples"],
|
| 139 |
+
"device": str(device),
|
| 140 |
+
"max_seq_len": cfg["model"]["max_seq_len"],
|
| 141 |
+
"diffusion_steps": cfg["model"]["diffusion_steps"],
|
| 142 |
+
"inference_steps": cfg["inference"]["num_steps"],
|
| 143 |
+
"d_model": cfg["model"]["d_model"],
|
| 144 |
+
"n_layers": cfg["model"]["n_layers"],
|
| 145 |
+
"n_heads": cfg["model"]["n_heads"],
|
| 146 |
+
}
|
| 147 |
+
status = f"Loaded `{experiment}` on `{device}` (`{cfg['model_type']}`)"
|
| 148 |
+
suggested_out = os.path.join("analysis", "outputs_ui", experiment)
|
| 149 |
+
return bundle, status, model_info, cfg["inference"]["num_steps"], suggested_out
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def apply_preset(preset_name):
|
| 153 |
+
presets = {
|
| 154 |
+
"Manual": (0.70, 40, 1.20, 0.0),
|
| 155 |
+
"Literal": (0.60, 20, 1.25, 0.0),
|
| 156 |
+
"Balanced": (0.70, 40, 1.20, 0.0),
|
| 157 |
+
"Creative": (0.90, 80, 1.05, 0.2),
|
| 158 |
+
}
|
| 159 |
+
return presets.get(preset_name, presets["Balanced"])
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def clean_generated_text(text: str, max_consecutive: int = 2) -> str:
|
| 163 |
+
text = " ".join(text.split())
|
| 164 |
+
if not text:
|
| 165 |
+
return text
|
| 166 |
+
tokens = text.split()
|
| 167 |
+
cleaned = []
|
| 168 |
+
prev = None
|
| 169 |
+
run = 0
|
| 170 |
+
for tok in tokens:
|
| 171 |
+
if tok == prev:
|
| 172 |
+
run += 1
|
| 173 |
+
else:
|
| 174 |
+
prev = tok
|
| 175 |
+
run = 1
|
| 176 |
+
if run <= max_consecutive:
|
| 177 |
+
cleaned.append(tok)
|
| 178 |
+
out = " ".join(cleaned).replace(" ।", "।").replace(" ॥", "॥")
|
| 179 |
+
return " ".join(out.split())
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def save_generation(experiment, record):
|
| 183 |
+
ts = datetime.now().strftime("%Y%m%d")
|
| 184 |
+
path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
|
| 185 |
+
existing = []
|
| 186 |
+
if os.path.exists(path):
|
| 187 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 188 |
+
existing = json.load(f)
|
| 189 |
+
existing.append(record)
|
| 190 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 191 |
+
json.dump(existing, f, ensure_ascii=False, indent=2)
|
| 192 |
+
return path
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def generate_from_ui(
|
| 196 |
+
model_bundle,
|
| 197 |
+
input_text,
|
| 198 |
+
temperature,
|
| 199 |
+
top_k,
|
| 200 |
+
repetition_penalty,
|
| 201 |
+
diversity_penalty,
|
| 202 |
+
num_steps,
|
| 203 |
+
clean_output,
|
| 204 |
+
):
|
| 205 |
+
if not model_bundle:
|
| 206 |
+
raise gr.Error("Load a model first.")
|
| 207 |
+
if not input_text.strip():
|
| 208 |
+
raise gr.Error("Enter input text first.")
|
| 209 |
+
|
| 210 |
+
cfg = copy.deepcopy(model_bundle["cfg"])
|
| 211 |
cfg["inference"]["temperature"] = float(temperature)
|
| 212 |
cfg["inference"]["top_k"] = int(top_k)
|
| 213 |
cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
|
| 214 |
cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
|
| 215 |
cfg["inference"]["num_steps"] = int(num_steps)
|
| 216 |
|
| 217 |
+
src_tok = model_bundle["src_tok"]
|
| 218 |
+
tgt_tok = model_bundle["tgt_tok"]
|
| 219 |
+
device = torch.device(model_bundle["device"])
|
| 220 |
+
|
| 221 |
+
input_ids = torch.tensor([src_tok.encode(input_text.strip())], dtype=torch.long, device=device)
|
| 222 |
+
out = run_inference(model_bundle["model"], input_ids, cfg)
|
| 223 |
+
|
| 224 |
+
# Align decode with validation style: strip only special ids.
|
| 225 |
+
pad_id = 1
|
| 226 |
+
mask_id = cfg["diffusion"]["mask_token_id"]
|
| 227 |
+
decoded_ids = [x for x in out[0].tolist() if x not in (pad_id, mask_id)]
|
| 228 |
+
raw_output_text = tgt_tok.decode(decoded_ids).strip()
|
| 229 |
+
output_text = clean_generated_text(raw_output_text) if clean_output else raw_output_text
|
| 230 |
+
if not output_text:
|
| 231 |
+
output_text = "(empty output)"
|
| 232 |
+
|
| 233 |
+
record = {
|
| 234 |
+
"timestamp": datetime.now().isoformat(timespec="seconds"),
|
| 235 |
+
"experiment": model_bundle["experiment"],
|
| 236 |
+
"checkpoint": model_bundle["ckpt_path"],
|
| 237 |
+
"input_text": input_text,
|
| 238 |
+
"raw_output_text": raw_output_text,
|
| 239 |
+
"output_text": output_text,
|
| 240 |
+
"temperature": float(temperature),
|
| 241 |
+
"top_k": int(top_k),
|
| 242 |
+
"repetition_penalty": float(repetition_penalty),
|
| 243 |
+
"diversity_penalty": float(diversity_penalty),
|
| 244 |
+
"num_steps": int(num_steps),
|
| 245 |
+
"clean_output": bool(clean_output),
|
| 246 |
+
}
|
| 247 |
+
log_path = save_generation(model_bundle["experiment"], record)
|
| 248 |
+
status = f"Inference done. Saved: `{log_path}`"
|
| 249 |
+
return output_text, status, record
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze"):
|
| 253 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 254 |
+
cmd = [
|
| 255 |
+
sys.executable,
|
| 256 |
+
"analysis/run_analysis.py",
|
| 257 |
+
"--task",
|
| 258 |
+
str(task),
|
| 259 |
+
"--checkpoint",
|
| 260 |
+
ckpt_path,
|
| 261 |
+
"--output_dir",
|
| 262 |
+
output_dir,
|
| 263 |
+
]
|
| 264 |
+
if str(task) == "2" or str(task) == "all":
|
| 265 |
+
cmd.extend(["--input", input_text])
|
| 266 |
+
if str(task) == "4":
|
| 267 |
+
cmd.extend(["--phase", phase])
|
| 268 |
+
|
| 269 |
+
env = os.environ.copy()
|
| 270 |
+
env.setdefault("HF_HOME", "/tmp/hf_home")
|
| 271 |
+
env.setdefault("HF_DATASETS_CACHE", "/tmp/hf_datasets")
|
| 272 |
+
env.setdefault("HF_HUB_CACHE", "/tmp/hf_hub")
|
| 273 |
+
|
| 274 |
+
proc = subprocess.run(cmd, capture_output=True, text=True, env=env)
|
| 275 |
+
log = f"$ {' '.join(cmd)}\n\n{proc.stdout}\n{proc.stderr}"
|
| 276 |
+
return proc.returncode, log
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
|
| 280 |
+
if not model_bundle:
|
| 281 |
+
raise gr.Error("Load a model first.")
|
| 282 |
+
code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
|
| 283 |
+
status = f"Task {task} {'completed' if code == 0 else 'failed'} (exit={code})."
|
| 284 |
+
return status, log
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def run_all_tasks(model_bundle, output_dir, input_text, task4_phase):
|
| 288 |
+
if not model_bundle:
|
| 289 |
+
raise gr.Error("Load a model first.")
|
| 290 |
+
logs = []
|
| 291 |
+
failures = 0
|
| 292 |
+
for task in ["1", "2", "3", "4", "5"]:
|
| 293 |
+
code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
|
| 294 |
+
logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
|
| 295 |
+
if code != 0:
|
| 296 |
+
failures += 1
|
| 297 |
+
status = f"Run-all finished with {failures} failed task(s)." if failures else "All 5 tasks completed."
|
| 298 |
+
return status, "".join(logs)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _read_text(path):
|
| 302 |
+
if not os.path.exists(path):
|
| 303 |
+
return "Not found."
|
| 304 |
+
with open(path, "r", encoding="utf-8", errors="ignore") as f:
|
| 305 |
+
return f.read()
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _img_or_none(path):
|
| 309 |
+
return path if os.path.exists(path) else None
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def refresh_task_outputs(output_dir):
|
| 313 |
+
task1_txt = _read_text(os.path.join(output_dir, "task1_kv_cache.txt"))
|
| 314 |
+
task2_txt = _read_text(os.path.join(output_dir, "task2_report.txt"))
|
| 315 |
+
task3_txt = _read_text(os.path.join(output_dir, "task3_report.txt"))
|
| 316 |
+
task5_txt = _read_text(os.path.join(output_dir, "task5_report.txt"))
|
| 317 |
+
|
| 318 |
+
task2_drift = _img_or_none(os.path.join(output_dir, "task2_semantic_drift.png"))
|
| 319 |
+
task2_attn = _img_or_none(os.path.join(output_dir, "task2_attn_t0.png"))
|
| 320 |
+
task3_space = _img_or_none(os.path.join(output_dir, "task3_concept_space.png"))
|
| 321 |
+
task4_plot = _img_or_none(os.path.join(output_dir, "task4_ablation_3d.png"))
|
| 322 |
+
return task1_txt, task2_txt, task2_drift, task2_attn, task3_txt, task3_space, task5_txt, task4_plot
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
CUSTOM_CSS = """
|
| 326 |
+
:root {
|
| 327 |
+
--bg1: #f5fbff;
|
| 328 |
+
--bg2: #f2f7ef;
|
| 329 |
+
--card: #ffffff;
|
| 330 |
+
--line: #d9e6f2;
|
| 331 |
+
--ink: #163048;
|
| 332 |
+
}
|
| 333 |
+
.gradio-container {
|
| 334 |
+
background: linear-gradient(130deg, var(--bg1), var(--bg2));
|
| 335 |
+
color: var(--ink);
|
| 336 |
+
}
|
| 337 |
+
#hero {
|
| 338 |
+
background: radial-gradient(110% 130% at 0% 0%, #d7ebff 0%, #ecf6ff 55%, #f8fbff 100%);
|
| 339 |
+
border: 1px solid #cfe0f1;
|
| 340 |
+
border-radius: 16px;
|
| 341 |
+
padding: 18px 20px;
|
| 342 |
+
}
|
| 343 |
+
.panel {
|
| 344 |
+
background: var(--card);
|
| 345 |
+
border: 1px solid var(--line);
|
| 346 |
+
border-radius: 14px;
|
| 347 |
+
}
|
| 348 |
+
"""
|
| 349 |
|
| 350 |
|
| 351 |
+
with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
|
| 352 |
model_state = gr.State(None)
|
| 353 |
+
|
| 354 |
gr.Markdown(
|
| 355 |
"""
|
| 356 |
+
<div id="hero">
|
| 357 |
+
<h1 style="margin:0;">Sanskrit Diffusion Client Demo</h1>
|
| 358 |
+
<p style="margin:.5rem 0 0 0;">
|
| 359 |
+
Select any trained model, run all 5 analysis tasks or individual tasks, then test inference with user-controlled parameters.
|
| 360 |
+
</p>
|
| 361 |
+
</div>
|
| 362 |
"""
|
| 363 |
)
|
| 364 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 365 |
with gr.Row():
|
| 366 |
+
with gr.Column(scale=2, elem_classes=["panel"]):
|
| 367 |
+
checkpoint_dropdown = gr.Dropdown(
|
| 368 |
+
label="Model Checkpoint",
|
| 369 |
+
choices=list(checkpoint_map().keys()),
|
| 370 |
+
value=default_checkpoint_label(),
|
| 371 |
+
interactive=True,
|
| 372 |
+
)
|
| 373 |
+
with gr.Column(scale=1, elem_classes=["panel"]):
|
| 374 |
+
refresh_btn = gr.Button("Refresh Models")
|
| 375 |
+
load_btn = gr.Button("Load Selected Model", variant="primary")
|
| 376 |
+
|
| 377 |
+
load_status = gr.Markdown("Select a model and load.")
|
| 378 |
+
model_info = gr.JSON(label="Loaded Model Details")
|
| 379 |
+
|
| 380 |
+
with gr.Tabs():
|
| 381 |
+
with gr.Tab("1) Task Runner"):
|
| 382 |
+
with gr.Row():
|
| 383 |
+
with gr.Column(scale=2):
|
| 384 |
+
analysis_output_dir = gr.Textbox(
|
| 385 |
+
label="Analysis Output Directory",
|
| 386 |
+
value=DEFAULT_ANALYSIS_OUT,
|
| 387 |
+
)
|
| 388 |
+
analysis_input = gr.Textbox(
|
| 389 |
+
label="Task 2 Input Text",
|
| 390 |
+
value="dharmo rakṣati rakṣitaḥ",
|
| 391 |
+
lines=2,
|
| 392 |
+
)
|
| 393 |
+
with gr.Column(scale=1):
|
| 394 |
+
task4_phase = gr.Dropdown(
|
| 395 |
+
choices=["analyze", "generate_configs"],
|
| 396 |
+
value="analyze",
|
| 397 |
+
label="Task 4 Phase",
|
| 398 |
+
)
|
| 399 |
+
run_all_btn = gr.Button("Run All 5 Tasks", variant="primary")
|
| 400 |
+
|
| 401 |
+
with gr.Row():
|
| 402 |
+
task_choice = gr.Dropdown(
|
| 403 |
+
choices=["1", "2", "3", "4", "5"],
|
| 404 |
+
value="1",
|
| 405 |
+
label="Single Task",
|
| 406 |
+
)
|
| 407 |
+
run_single_btn = gr.Button("Run Selected Task")
|
| 408 |
+
refresh_outputs_btn = gr.Button("Refresh Output Viewer")
|
| 409 |
+
|
| 410 |
+
task_run_status = gr.Markdown("")
|
| 411 |
+
task_run_log = gr.Textbox(label="Task Execution Log", lines=18, interactive=False)
|
| 412 |
+
|
| 413 |
+
with gr.Accordion("Task Outputs Viewer", open=True):
|
| 414 |
+
task1_box = gr.Textbox(label="Task 1 Report", lines=10, interactive=False)
|
| 415 |
+
task2_box = gr.Textbox(label="Task 2 Report", lines=10, interactive=False)
|
| 416 |
+
with gr.Row():
|
| 417 |
+
task2_drift_img = gr.Image(label="Task2 Drift", type="filepath")
|
| 418 |
+
task2_attn_img = gr.Image(label="Task2 Attention", type="filepath")
|
| 419 |
+
task3_box = gr.Textbox(label="Task 3 Report", lines=10, interactive=False)
|
| 420 |
+
task3_img = gr.Image(label="Task3 Concept Space", type="filepath")
|
| 421 |
+
task5_box = gr.Textbox(label="Task 5 Report", lines=10, interactive=False)
|
| 422 |
+
task4_img = gr.Image(label="Task4 3D Ablation Plot", type="filepath")
|
| 423 |
+
|
| 424 |
+
with gr.Tab("2) Inference Playground"):
|
| 425 |
+
with gr.Row():
|
| 426 |
+
with gr.Column(scale=2):
|
| 427 |
+
input_text = gr.Textbox(
|
| 428 |
+
label="Input (Roman / IAST)",
|
| 429 |
+
lines=4,
|
| 430 |
+
value="dharmo rakṣati rakṣitaḥ",
|
| 431 |
+
)
|
| 432 |
+
output_text = gr.Textbox(
|
| 433 |
+
label="Output (Devanagari)",
|
| 434 |
+
lines=7,
|
| 435 |
+
interactive=False,
|
| 436 |
+
)
|
| 437 |
+
run_status = gr.Markdown("")
|
| 438 |
+
run_record = gr.JSON(label="Inference Metadata")
|
| 439 |
+
with gr.Column(scale=1, elem_classes=["panel"]):
|
| 440 |
+
preset = gr.Radio(["Manual", "Literal", "Balanced", "Creative"], value="Balanced", label="Preset")
|
| 441 |
+
temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
|
| 442 |
+
top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
|
| 443 |
+
repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
|
| 444 |
+
diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
|
| 445 |
+
num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
|
| 446 |
+
clean_output = gr.Checkbox(value=True, label="Clean Output")
|
| 447 |
+
generate_btn = gr.Button("Generate", variant="primary")
|
| 448 |
+
|
| 449 |
+
gr.Examples(
|
| 450 |
+
examples=[
|
| 451 |
+
["dharmo rakṣati rakṣitaḥ"],
|
| 452 |
+
["satyameva jayate"],
|
| 453 |
+
["yadā mano nivarteta viṣayebhyaḥ svabhāvataḥ"],
|
| 454 |
+
],
|
| 455 |
+
inputs=[input_text],
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
def refresh_checkpoints():
|
| 459 |
+
choices = list(checkpoint_map().keys())
|
| 460 |
+
value = default_checkpoint_label() if choices else None
|
| 461 |
+
return gr.Dropdown(choices=choices, value=value)
|
| 462 |
+
|
| 463 |
+
refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown])
|
| 464 |
+
load_btn.click(
|
| 465 |
+
fn=load_selected_model,
|
| 466 |
+
inputs=[checkpoint_dropdown],
|
| 467 |
+
outputs=[model_state, load_status, model_info, num_steps, analysis_output_dir],
|
| 468 |
+
)
|
| 469 |
|
| 470 |
+
preset.change(
|
| 471 |
+
fn=apply_preset,
|
| 472 |
+
inputs=[preset],
|
| 473 |
+
outputs=[temperature, top_k, repetition_penalty, diversity_penalty],
|
| 474 |
+
)
|
| 475 |
|
|
|
|
| 476 |
generate_btn.click(
|
| 477 |
+
fn=generate_from_ui,
|
| 478 |
inputs=[
|
| 479 |
+
model_state,
|
| 480 |
+
input_text,
|
| 481 |
+
temperature,
|
| 482 |
+
top_k,
|
| 483 |
+
repetition_penalty,
|
| 484 |
+
diversity_penalty,
|
| 485 |
+
num_steps,
|
| 486 |
+
clean_output,
|
| 487 |
],
|
| 488 |
+
outputs=[output_text, run_status, run_record],
|
| 489 |
)
|
| 490 |
+
input_text.submit(
|
| 491 |
+
fn=generate_from_ui,
|
| 492 |
inputs=[
|
| 493 |
+
model_state,
|
| 494 |
+
input_text,
|
| 495 |
+
temperature,
|
| 496 |
+
top_k,
|
| 497 |
+
repetition_penalty,
|
| 498 |
+
diversity_penalty,
|
| 499 |
+
num_steps,
|
| 500 |
+
clean_output,
|
| 501 |
+
],
|
| 502 |
+
outputs=[output_text, run_status, run_record],
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
run_single_btn.click(
|
| 506 |
+
fn=run_single_task,
|
| 507 |
+
inputs=[model_state, task_choice, analysis_output_dir, analysis_input, task4_phase],
|
| 508 |
+
outputs=[task_run_status, task_run_log],
|
| 509 |
+
)
|
| 510 |
+
run_all_btn.click(
|
| 511 |
+
fn=run_all_tasks,
|
| 512 |
+
inputs=[model_state, analysis_output_dir, analysis_input, task4_phase],
|
| 513 |
+
outputs=[task_run_status, task_run_log],
|
| 514 |
+
)
|
| 515 |
+
refresh_outputs_btn.click(
|
| 516 |
+
fn=refresh_task_outputs,
|
| 517 |
+
inputs=[analysis_output_dir],
|
| 518 |
+
outputs=[
|
| 519 |
+
task1_box,
|
| 520 |
+
task2_box,
|
| 521 |
+
task2_drift_img,
|
| 522 |
+
task2_attn_img,
|
| 523 |
+
task3_box,
|
| 524 |
+
task3_img,
|
| 525 |
+
task5_box,
|
| 526 |
+
task4_img,
|
| 527 |
+
],
|
| 528 |
+
)
|
| 529 |
+
demo.load(
|
| 530 |
+
fn=refresh_task_outputs,
|
| 531 |
+
inputs=[analysis_output_dir],
|
| 532 |
+
outputs=[
|
| 533 |
+
task1_box,
|
| 534 |
+
task2_box,
|
| 535 |
+
task2_drift_img,
|
| 536 |
+
task2_attn_img,
|
| 537 |
+
task3_box,
|
| 538 |
+
task3_img,
|
| 539 |
+
task5_box,
|
| 540 |
+
task4_img,
|
| 541 |
],
|
|
|
|
| 542 |
)
|
| 543 |
|
| 544 |
|
| 545 |
if __name__ == "__main__":
|
| 546 |
+
port = int(os.environ["GRADIO_SERVER_PORT"]) if "GRADIO_SERVER_PORT" in os.environ else None
|
| 547 |
demo.launch(server_name="0.0.0.0", server_port=port, share=False)
|
data/__init__.py
ADDED
|
File without changes
|
data/dataset.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
dataset.py — Cross-Script Translation Fix
|
| 3 |
+
==========================================
|
| 4 |
+
INPUT : quote_text (Roman/IAST transliteration of Sanskrit)
|
| 5 |
+
TARGET : quote_devanagari (Devanagari script)
|
| 6 |
+
|
| 7 |
+
This is the CORRECT task: the model learns to transliterate / translate
|
| 8 |
+
Roman Sanskrit → Devanagari, which is a meaningful, learnable mapping
|
| 9 |
+
(far better than devanagari→devanagari reconstruction which teaches nothing).
|
| 10 |
+
|
| 11 |
+
KEY CHANGES from original:
|
| 12 |
+
1. _input_field = 'quote_text' (was 'quote_devanagari')
|
| 13 |
+
2. _target_field = 'quote_devanagari' (unchanged)
|
| 14 |
+
3. Separate source/target tokenizers — Roman and Devanagari have
|
| 15 |
+
completely different character sets; a shared BPE vocab forces the
|
| 16 |
+
model to learn both scripts in one embedding table, which wastes
|
| 17 |
+
capacity and confuses the attention mechanism.
|
| 18 |
+
4. Negative example generation fixed — reversal now operates on
|
| 19 |
+
DEVANAGARI target only (not accidentally on Roman source).
|
| 20 |
+
5. curriculum_sort uses target length (Devanagari) for difficulty proxy.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from datasets import load_dataset
|
| 24 |
+
from torch.utils.data import Dataset
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
import random
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class OptimizedSanskritDataset(Dataset):
|
| 31 |
+
def __init__(self, split='train', tokenizer=None, max_len=80, cfg=None,
|
| 32 |
+
src_tokenizer=None, tgt_tokenizer=None):
|
| 33 |
+
"""
|
| 34 |
+
Args:
|
| 35 |
+
tokenizer : shared tokenizer (legacy — used if src/tgt not provided)
|
| 36 |
+
src_tokenizer : tokenizer for quote_text (Roman script)
|
| 37 |
+
tgt_tokenizer : tokenizer for quote_devanagari (Devanagari script)
|
| 38 |
+
If None, falls back to shared `tokenizer`.
|
| 39 |
+
"""
|
| 40 |
+
from config import CONFIG
|
| 41 |
+
self.cfg = cfg or CONFIG
|
| 42 |
+
self.max_len = max_len
|
| 43 |
+
self.pad_id = 1
|
| 44 |
+
self.mask_id = self.cfg['diffusion']['mask_token_id']
|
| 45 |
+
self.include_negatives = self.cfg['data']['include_negative_examples']
|
| 46 |
+
|
| 47 |
+
# ── Tokenizer setup ───────────────────────────────────────────
|
| 48 |
+
# Support both legacy (shared) and new (separate src/tgt) tokenizers
|
| 49 |
+
self.src_tokenizer = src_tokenizer or tokenizer
|
| 50 |
+
self.tgt_tokenizer = tgt_tokenizer or tokenizer
|
| 51 |
+
|
| 52 |
+
if self.src_tokenizer is None:
|
| 53 |
+
raise ValueError("Provide at least one tokenizer.")
|
| 54 |
+
|
| 55 |
+
print(f"📥 Loading '{split}' split …")
|
| 56 |
+
raw = load_dataset("paws/sanskrit-verses-gretil", split=split)
|
| 57 |
+
cols = raw.column_names
|
| 58 |
+
|
| 59 |
+
# ── Field selection ───────────────────────────────────────────
|
| 60 |
+
if 'quote_text' in cols and 'quote_devanagari' in cols:
|
| 61 |
+
# CORRECT setup: Roman input → Devanagari output
|
| 62 |
+
self._input_field = 'quote_text'
|
| 63 |
+
self._target_field = 'quote_devanagari'
|
| 64 |
+
print(" Format: quote_text (Roman) → quote_devanagari (Devanagari) ✓")
|
| 65 |
+
elif 'sentence1' in cols and 'sentence2' in cols:
|
| 66 |
+
# PAWS paraphrase pairs fallback
|
| 67 |
+
self._input_field = 'sentence1'
|
| 68 |
+
self._target_field = 'sentence2'
|
| 69 |
+
print(" Format: PAWS sentence pairs ✓")
|
| 70 |
+
else:
|
| 71 |
+
# Last resort: same field both sides
|
| 72 |
+
self._input_field = 'quote_devanagari'
|
| 73 |
+
self._target_field = 'quote_devanagari'
|
| 74 |
+
print(" ⚠️ Format: Devanagari→Devanagari (suboptimal — no quote_text found)")
|
| 75 |
+
|
| 76 |
+
# ── Filter empty rows ─────────────────────────────────────────
|
| 77 |
+
# Some rows have empty quote_text — skip them
|
| 78 |
+
raw = raw.filter(
|
| 79 |
+
lambda ex: (
|
| 80 |
+
bool(ex[self._input_field].strip()) and
|
| 81 |
+
bool(ex[self._target_field].strip())
|
| 82 |
+
)
|
| 83 |
+
)
|
| 84 |
+
print(f" After empty-filter: {len(raw)} samples")
|
| 85 |
+
|
| 86 |
+
self.dataset = raw
|
| 87 |
+
|
| 88 |
+
if split == 'train':
|
| 89 |
+
self.dataset = self._curriculum_sort()
|
| 90 |
+
|
| 91 |
+
print(f"✅ {len(self.dataset)} samples loaded.")
|
| 92 |
+
|
| 93 |
+
# ── Encoding ──────────────────────────────────────────────────────
|
| 94 |
+
|
| 95 |
+
def _encode_src(self, text):
|
| 96 |
+
"""Encode source (Roman) text."""
|
| 97 |
+
ids = self.src_tokenizer.encode(text)[:self.max_len]
|
| 98 |
+
t = torch.tensor(ids, dtype=torch.long)
|
| 99 |
+
t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id)
|
| 100 |
+
return t
|
| 101 |
+
|
| 102 |
+
def _encode_tgt(self, text):
|
| 103 |
+
"""Encode target (Devanagari) text."""
|
| 104 |
+
ids = self.tgt_tokenizer.encode(text)[:self.max_len]
|
| 105 |
+
t = torch.tensor(ids, dtype=torch.long)
|
| 106 |
+
t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id)
|
| 107 |
+
return t
|
| 108 |
+
|
| 109 |
+
# ── Curriculum ───��────────────────────────────────────────────────
|
| 110 |
+
|
| 111 |
+
def _curriculum_sort(self):
|
| 112 |
+
"""Short, common Devanagari targets first → long, rare targets last."""
|
| 113 |
+
scores = []
|
| 114 |
+
for s in self.dataset:
|
| 115 |
+
text = s[self._target_field]
|
| 116 |
+
length = len(text.split())
|
| 117 |
+
rarity_score = len(set(text)) / max(1, len(text))
|
| 118 |
+
scores.append(length * (1 - rarity_score))
|
| 119 |
+
order = sorted(range(len(self.dataset)), key=lambda i: scores[i])
|
| 120 |
+
return self.dataset.select(order)
|
| 121 |
+
|
| 122 |
+
# ── Item ──────────────────────────────────────────────────────────
|
| 123 |
+
|
| 124 |
+
def __len__(self):
|
| 125 |
+
return len(self.dataset)
|
| 126 |
+
|
| 127 |
+
def __getitem__(self, idx):
|
| 128 |
+
sample = self.dataset[idx]
|
| 129 |
+
|
| 130 |
+
src_text = sample[self._input_field].strip()
|
| 131 |
+
tgt_text = sample[self._target_field].strip()
|
| 132 |
+
|
| 133 |
+
input_ids = self._encode_src(src_text) # Roman encoded with src_tokenizer
|
| 134 |
+
target_ids = self._encode_tgt(tgt_text) # Devanagari encoded with tgt_tokenizer
|
| 135 |
+
|
| 136 |
+
out = {
|
| 137 |
+
'input_ids': input_ids,
|
| 138 |
+
'target_ids': target_ids,
|
| 139 |
+
'input_text': src_text,
|
| 140 |
+
'target_text': tgt_text,
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
if self.include_negatives:
|
| 144 |
+
neg_ids = target_ids.clone()
|
| 145 |
+
# Reverse a random chunk of the DEVANAGARI target
|
| 146 |
+
non_pad = (neg_ids != self.pad_id).sum().item()
|
| 147 |
+
if non_pad > 4:
|
| 148 |
+
i1, i2 = sorted(random.sample(range(non_pad), 2))
|
| 149 |
+
neg_ids[i1:i2] = torch.flip(neg_ids[i1:i2], dims=[0])
|
| 150 |
+
out['negative_target_ids'] = neg_ids
|
| 151 |
+
|
| 152 |
+
return out
|
requirements.txt
CHANGED
|
@@ -4,3 +4,9 @@ numpy>=1.24
|
|
| 4 |
tqdm>=4.66
|
| 5 |
huggingface_hub==0.25.2
|
| 6 |
tokenizers>=0.15
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
tqdm>=4.66
|
| 5 |
huggingface_hub==0.25.2
|
| 6 |
tokenizers>=0.15
|
| 7 |
+
datasets>=2.20
|
| 8 |
+
scikit-learn>=1.4
|
| 9 |
+
matplotlib>=3.8
|
| 10 |
+
bert-score>=0.3.13
|
| 11 |
+
sentence-transformers>=3.0
|
| 12 |
+
nltk>=3.8
|