bhsinghgrid commited on
Commit
29e5bf8
·
verified ·
1 Parent(s): 35a343d

Upgrade UI: model selection + tasks 1-5 + analysis modules

Browse files
Files changed (41) hide show
  1. .gitattributes +1 -0
  2. __pycache__/app.cpython-311.pyc +0 -0
  3. analysis/__pycache__/run_analysis.cpython-311.pyc +0 -0
  4. analysis/attention_viz.py +621 -0
  5. analysis/concept_vectors.py +637 -0
  6. analysis/kv_cache_benchmark.py +451 -0
  7. analysis/outputs/task1_kv_cache.txt +23 -0
  8. analysis/outputs/task2_all_layers_t0.png +0 -0
  9. analysis/outputs/task2_attn_evolution.png +0 -0
  10. analysis/outputs/task2_attn_t0.png +0 -0
  11. analysis/outputs/task2_attn_t127.png +0 -0
  12. analysis/outputs/task2_examples/example_1_attn_t0.png +0 -0
  13. analysis/outputs/task2_examples/example_2_attn_t0.png +0 -0
  14. analysis/outputs/task2_examples/example_3_attn_t0.png +0 -0
  15. analysis/outputs/task2_examples/example_4_attn_t0.png +0 -0
  16. analysis/outputs/task2_examples/example_5_attn_t0.png +0 -0
  17. analysis/outputs/task2_report.txt +100 -0
  18. analysis/outputs/task2_semantic_drift.png +0 -0
  19. analysis/outputs/task2_source_alignment.png +0 -0
  20. analysis/outputs/task3_concept_space.png +3 -0
  21. analysis/outputs/task3_diversity_direction.npy +3 -0
  22. analysis/outputs/task3_report.txt +12 -0
  23. analysis/outputs/task5_quality_classifier.pt +3 -0
  24. analysis/outputs/task5_quality_data.npz +3 -0
  25. analysis/outputs_multi/results__d3pm_cross_attention_neg_False/task1/task1_kv_cache.txt +10 -0
  26. analysis/outputs_multi/results__d3pm_cross_attention_neg_True/task1/task1_kv_cache.txt +10 -0
  27. analysis/quality_classifier.py +723 -0
  28. analysis/reports/README.md +19 -0
  29. analysis/reports/task1_kv_cache_report.md +99 -0
  30. analysis/reports/task2_attention_drift_report.md +112 -0
  31. analysis/reports/task3_concept_vectors_report.md +96 -0
  32. analysis/reports/task4_step_ablation_report.md +89 -0
  33. analysis/reports/task5_quality_guidance_report.md +101 -0
  34. analysis/run_analysis.py +466 -0
  35. analysis/run_tasks_except4_all_models.py +123 -0
  36. analysis/semantic_drift.py +569 -0
  37. analysis/step_ablation.py +582 -0
  38. app.py +487 -175
  39. data/__init__.py +0 -0
  40. data/dataset.py +152 -0
  41. requirements.txt +6 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ analysis/outputs/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
__pycache__/app.cpython-311.pyc ADDED
Binary file (29.2 kB). View file
 
analysis/__pycache__/run_analysis.cpython-311.pyc ADDED
Binary file (32 kB). View file
 
analysis/attention_viz.py ADDED
@@ -0,0 +1,621 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # analysis/attention_viz.py
3
+ # ==========================
4
+ # Task 2: Attention weight capture and visualization across diffusion steps.
5
+ #
6
+ # How it works (no retraining needed):
7
+ # MultiHeadAttention now has two attributes:
8
+ # - capture_weights: bool — set True to start storing weights
9
+ # - last_attn_weights: Tensor — [B, n_heads, Lq, Lk], updated each forward call
10
+ #
11
+ # AttentionCapture:
12
+ # - Sets capture_weights=True on all cross-attention layers
13
+ # - Hooks into generate_cached() to record weights at every diffusion step
14
+ # - Returns a dict: {t_val: [layer_0_weights, layer_1_weights, ...]}
15
+ #
16
+ # Visualization:
17
+ # - plot_attn_heatmap(): shows src→tgt alignment at a single step
18
+ # - plot_attn_evolution(): shows how one src→tgt pair evolves over T steps
19
+ # - plot_all_layers(): grid of heatmaps per layer at a given step
20
+ #
21
+ # Usage:
22
+ # from analysis.attention_viz import AttentionCapture, plot_attn_heatmap
23
+ #
24
+ # capturer = AttentionCapture(model)
25
+ # weights = capturer.capture(src_ids, src_tokens, tgt_tokens)
26
+ # plot_attn_heatmap(weights, step=0, layer=0, src_tokens=..., tgt_tokens=...)
27
+ # """
28
+ #
29
+ # import torch
30
+ # import numpy as np
31
+ # import os
32
+ # from typing import List, Dict, Optional
33
+ #
34
+ #
35
+ # # ── Attention capture ─────────────────────────────────────────────────
36
+ #
37
+ # class AttentionCapture:
38
+ # """
39
+ # Captures cross-attention weights from all decoder layers at every
40
+ # diffusion step during generate_cached().
41
+ #
42
+ # Works by:
43
+ # 1. Setting capture_weights=True on each DecoderBlock.cross_attn
44
+ # 2. Running generate_cached() (encoder runs once via KV cache)
45
+ # 3. After each denoising step, reading last_attn_weights from each layer
46
+ # 4. Storing as {t_val: list_of_layer_weights}
47
+ #
48
+ # Zero retraining required — uses the flag added to MultiHeadAttention.
49
+ # """
50
+ #
51
+ # def __init__(self, model):
52
+ # """
53
+ # Args:
54
+ # model : SanskritModel wrapper (must be D3PMCrossAttention)
55
+ # """
56
+ # self.model = model
57
+ # self.inner = model.model # D3PMCrossAttention
58
+ # self._cross_attns = []
59
+ #
60
+ # # Collect all cross-attention modules from decoder blocks
61
+ # if hasattr(self.inner, 'decoder_blocks'):
62
+ # for block in self.inner.decoder_blocks:
63
+ # if hasattr(block, 'cross_attn'):
64
+ # self._cross_attns.append(block.cross_attn)
65
+ #
66
+ # if not self._cross_attns:
67
+ # raise ValueError(
68
+ # "No cross-attention layers found. "
69
+ # "AttentionCapture only works with D3PMCrossAttention."
70
+ # )
71
+ #
72
+ # print(f"AttentionCapture: found {len(self._cross_attns)} cross-attention layers.")
73
+ #
74
+ # def _enable(self):
75
+ # """Turn on weight capture for all cross-attention layers."""
76
+ # for ca in self._cross_attns:
77
+ # ca.capture_weights = True
78
+ #
79
+ # def _disable(self):
80
+ # """Turn off weight capture (restores zero overhead)."""
81
+ # for ca in self._cross_attns:
82
+ # ca.capture_weights = False
83
+ # ca.last_attn_weights = None
84
+ #
85
+ # def _read_weights(self) -> List[np.ndarray]:
86
+ # """
87
+ # Read current last_attn_weights from all layers.
88
+ # Returns list of [B, n_heads, Lq, Lk] arrays — one per layer.
89
+ # Averages over heads to produce [B, Lq, Lk].
90
+ # """
91
+ # weights = []
92
+ # for ca in self._cross_attns:
93
+ # if ca.last_attn_weights is not None:
94
+ # # Average over attention heads → [B, Lq, Lk]
95
+ # w = ca.last_attn_weights.float().mean(dim=1)
96
+ # weights.append(w.numpy())
97
+ # return weights
98
+ #
99
+ # @torch.no_grad()
100
+ # def capture(
101
+ # self,
102
+ # src: torch.Tensor,
103
+ # capture_every: int = 10,
104
+ # ) -> Dict[int, List[np.ndarray]]:
105
+ # """
106
+ # Run full generation while capturing attention at every `capture_every` steps.
107
+ #
108
+ # Args:
109
+ # src : [1, src_len] or [B, src_len] IAST token ids
110
+ # capture_every : capture weights every N steps (default 10)
111
+ # Use 1 to capture every step (slow, high memory).
112
+ #
113
+ # Returns:
114
+ # step_weights : dict mapping t_val → list of [B, Lq, Lk] arrays
115
+ # one array per decoder layer
116
+ # keys are t values: T-1, T-1-N, ..., 0
117
+ #
118
+ # Example:
119
+ # weights = capturer.capture(src_ids, capture_every=10)
120
+ # # weights[127] = layer weights at t=127 (heavy noise)
121
+ # # weights[0] = layer weights at t=0 (clean output)
122
+ # """
123
+ # if src.dim() == 1:
124
+ # src = src.unsqueeze(0)
125
+ #
126
+ # inner = self.inner
127
+ # T = inner.scheduler.num_timesteps
128
+ # device = src.device
129
+ #
130
+ # # KV cache: encode source once
131
+ # memory, src_pad_mask = inner.encode_source(src)
132
+ #
133
+ # B = src.shape[0]
134
+ # tgt_len = inner.max_seq_len
135
+ # mask_id = inner.mask_token_id
136
+ #
137
+ # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
138
+ # hint = None
139
+ #
140
+ # step_weights: Dict[int, List[np.ndarray]] = {}
141
+ #
142
+ # self._enable()
143
+ # try:
144
+ # inner.eval()
145
+ # for t_val in range(T - 1, -1, -1):
146
+ # t = torch.full((B,), t_val, dtype=torch.long, device=device)
147
+ # is_last = (t_val == 0)
148
+ #
149
+ # logits, _ = inner.forward_cached(
150
+ # memory, src_pad_mask, x0_est, t,
151
+ # x0_hint=hint, inference_mode=True,
152
+ # )
153
+ #
154
+ # # Capture at this step if scheduled or it's the last step
155
+ # if (T - 1 - t_val) % capture_every == 0 or is_last:
156
+ # step_weights[t_val] = self._read_weights()
157
+ #
158
+ # import torch.nn.functional as F
159
+ # probs = F.softmax(logits / 0.8, dim=-1)
160
+ # x0_est = torch.argmax(probs, dim=-1) if is_last else \
161
+ # _multinomial_sample(probs)
162
+ # hint = x0_est
163
+ #
164
+ # finally:
165
+ # self._disable() # always restore — even if exception raised
166
+ #
167
+ # print(f"Captured attention at {len(step_weights)} steps "
168
+ # f"({len(self._cross_attns)} layers each).")
169
+ # return step_weights
170
+ #
171
+ #
172
+ # def _multinomial_sample(probs: torch.Tensor) -> torch.Tensor:
173
+ # B, L, V = probs.shape
174
+ # flat = probs.view(B * L, V).clamp(min=1e-9)
175
+ # flat = flat / flat.sum(dim=-1, keepdim=True)
176
+ # return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
177
+ #
178
+ #
179
+ # # ── Visualization ─────────────────────────────────────────────────────
180
+ #
181
+ # def plot_attn_heatmap(
182
+ # step_weights: Dict[int, List[np.ndarray]],
183
+ # t_val: int,
184
+ # layer: int,
185
+ # src_tokens: List[str],
186
+ # tgt_tokens: List[str],
187
+ # sample_idx: int = 0,
188
+ # save_path: Optional[str] = None,
189
+ # title: Optional[str] = None,
190
+ # ):
191
+ # """
192
+ # Plot cross-attention heatmap for a single step and layer.
193
+ #
194
+ # X-axis = source (IAST) tokens
195
+ # Y-axis = target (Devanagari) positions
196
+ # Color = attention weight (brighter = stronger attention)
197
+ #
198
+ # Args:
199
+ # step_weights : output of AttentionCapture.capture()
200
+ # t_val : which diffusion step to visualize
201
+ # layer : which decoder layer (0 = first, -1 = last)
202
+ # src_tokens : list of IAST token strings for x-axis labels
203
+ # tgt_tokens : list of Devanagari token strings for y-axis labels
204
+ # sample_idx : which batch item to visualize (default 0)
205
+ # save_path : if given, save figure to this path
206
+ # title : custom plot title
207
+ # """
208
+ # try:
209
+ # import matplotlib.pyplot as plt
210
+ # import matplotlib.ticker as ticker
211
+ # except ImportError:
212
+ # print("pip install matplotlib to use visualization functions.")
213
+ # return
214
+ #
215
+ # if t_val not in step_weights:
216
+ # available = sorted(step_weights.keys())
217
+ # raise ValueError(
218
+ # f"t_val={t_val} not in captured steps. "
219
+ # f"Available: {available[:5]}...{available[-5:]}"
220
+ # )
221
+ #
222
+ # layers = step_weights[t_val]
223
+ # weights = layers[layer][sample_idx] # [Lq, Lk]
224
+ #
225
+ # # Trim to actual token lengths
226
+ # n_src = min(len(src_tokens), weights.shape[1])
227
+ # n_tgt = min(len(tgt_tokens), weights.shape[0])
228
+ # weights = weights[:n_tgt, :n_src]
229
+ #
230
+ # fig, ax = plt.subplots(figsize=(max(8, n_src * 0.4), max(6, n_tgt * 0.35)))
231
+ # im = ax.imshow(weights, aspect='auto', cmap='YlOrRd', interpolation='nearest')
232
+ #
233
+ # ax.set_xticks(range(n_src))
234
+ # ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=9)
235
+ # ax.set_yticks(range(n_tgt))
236
+ # ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=9)
237
+ #
238
+ # ax.set_xlabel("Source (IAST)", fontsize=11)
239
+ # ax.set_ylabel("Target position (Devanagari)", fontsize=11)
240
+ #
241
+ # plot_title = title or f"Cross-Attention | t={t_val} | Layer {layer}"
242
+ # ax.set_title(plot_title, fontsize=12, pad=10)
243
+ #
244
+ # plt.colorbar(im, ax=ax, label="Attention weight")
245
+ # plt.tight_layout()
246
+ #
247
+ # if save_path:
248
+ # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
249
+ # plt.savefig(save_path, dpi=150, bbox_inches='tight')
250
+ # print(f"Saved: {save_path}")
251
+ # else:
252
+ # plt.show()
253
+ # plt.close()
254
+ #
255
+ #
256
+ # def plot_attn_evolution(
257
+ # step_weights: Dict[int, List[np.ndarray]],
258
+ # src_token_idx: int,
259
+ # tgt_token_idx: int,
260
+ # layer: int = -1,
261
+ # sample_idx: int = 0,
262
+ # src_token_str: str = "",
263
+ # tgt_token_str: str = "",
264
+ # save_path: Optional[str] = None,
265
+ # ):
266
+ # """
267
+ # Plot how attention between one specific src↔tgt token pair evolves
268
+ # across all captured diffusion steps (T → 0).
269
+ #
270
+ # Reveals whether a token pair is 'locked' (stable from early steps)
271
+ # or 'flexible' (weight fluctuates until final steps).
272
+ #
273
+ # Args:
274
+ # step_weights : output of AttentionCapture.capture()
275
+ # src_token_idx : index of source token to track
276
+ # tgt_token_idx : index of target position to track
277
+ # layer : decoder layer index
278
+ # sample_idx : batch item
279
+ # src_token_str : string label for the source token (for plot title)
280
+ # tgt_token_str : string label for the target token (for plot title)
281
+ # save_path : if given, save figure to this path
282
+ # """
283
+ # try:
284
+ # import matplotlib.pyplot as plt
285
+ # except ImportError:
286
+ # print("pip install matplotlib to use visualization functions.")
287
+ # return
288
+ #
289
+ # t_vals = sorted(step_weights.keys(), reverse=True) # T-1 → 0
290
+ # weights = []
291
+ #
292
+ # for t_val in t_vals:
293
+ # layers = step_weights[t_val]
294
+ # w = layers[layer][sample_idx] # [Lq, Lk]
295
+ # if tgt_token_idx < w.shape[0] and src_token_idx < w.shape[1]:
296
+ # weights.append(w[tgt_token_idx, src_token_idx])
297
+ # else:
298
+ # weights.append(0.0)
299
+ #
300
+ # fig, ax = plt.subplots(figsize=(12, 4))
301
+ # ax.plot(range(len(t_vals)), weights, linewidth=1.5, color='steelblue')
302
+ # ax.fill_between(range(len(t_vals)), weights, alpha=0.2, color='steelblue')
303
+ #
304
+ # # Mark every 10th step on x-axis
305
+ # step_labels = [str(t) if i % max(1, len(t_vals)//10) == 0 else ""
306
+ # for i, t in enumerate(t_vals)]
307
+ # ax.set_xticks(range(len(t_vals)))
308
+ # ax.set_xticklabels(step_labels, fontsize=8)
309
+ # ax.set_xlabel("Diffusion step (T → 0)", fontsize=11)
310
+ # ax.set_ylabel("Attention weight", fontsize=11)
311
+ #
312
+ # pair_str = f"src[{src_token_idx}]={src_token_str!r} → tgt[{tgt_token_idx}]={tgt_token_str!r}"
313
+ # ax.set_title(f"Attention evolution | {pair_str} | Layer {layer}", fontsize=11)
314
+ # ax.set_xlim(0, len(t_vals) - 1)
315
+ # ax.set_ylim(0, None)
316
+ # plt.tight_layout()
317
+ #
318
+ # if save_path:
319
+ # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
320
+ # plt.savefig(save_path, dpi=150, bbox_inches='tight')
321
+ # print(f"Saved: {save_path}")
322
+ # else:
323
+ # plt.show()
324
+ # plt.close()
325
+ #
326
+ #
327
+ # def plot_all_layers(
328
+ # step_weights: Dict[int, List[np.ndarray]],
329
+ # t_val: int,
330
+ # src_tokens: List[str],
331
+ # tgt_tokens: List[str],
332
+ # sample_idx: int = 0,
333
+ # save_path: Optional[str] = None,
334
+ # ):
335
+ # """
336
+ # Plot attention heatmaps for ALL decoder layers at a single diffusion step.
337
+ # Shows how different layers specialize their attention patterns.
338
+ # """
339
+ # try:
340
+ # import matplotlib.pyplot as plt
341
+ # except ImportError:
342
+ # print("pip install matplotlib to use visualization functions.")
343
+ # return
344
+ #
345
+ # layers = step_weights[t_val]
346
+ # n_layers = len(layers)
347
+ # n_cols = min(4, n_layers)
348
+ # n_rows = (n_layers + n_cols - 1) // n_cols
349
+ #
350
+ # fig, axes = plt.subplots(n_rows, n_cols,
351
+ # figsize=(n_cols * 5, n_rows * 4))
352
+ # axes = np.array(axes).flatten() if n_layers > 1 else [axes]
353
+ #
354
+ # n_src = min(len(src_tokens), layers[0][sample_idx].shape[1])
355
+ # n_tgt = min(len(tgt_tokens), layers[0][sample_idx].shape[0])
356
+ #
357
+ # for i, (ax, layer_w) in enumerate(zip(axes, layers)):
358
+ # w = layer_w[sample_idx][:n_tgt, :n_src]
359
+ # im = ax.imshow(w, aspect='auto', cmap='YlOrRd', interpolation='nearest',
360
+ # vmin=0, vmax=w.max())
361
+ # ax.set_title(f"Layer {i}", fontsize=10)
362
+ # ax.set_xticks(range(n_src))
363
+ # ax.set_xticklabels(src_tokens[:n_src], rotation=45, ha='right', fontsize=7)
364
+ # ax.set_yticks(range(n_tgt))
365
+ # ax.set_yticklabels(tgt_tokens[:n_tgt], fontsize=7)
366
+ #
367
+ # for ax in axes[n_layers:]:
368
+ # ax.set_visible(False)
369
+ #
370
+ # fig.suptitle(f"All layers at t={t_val}", fontsize=13, y=1.02)
371
+ # plt.tight_layout()
372
+ #
373
+ # if save_path:
374
+ # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
375
+ # plt.savefig(save_path, dpi=150, bbox_inches='tight')
376
+ # print(f"Saved: {save_path}")
377
+ # else:
378
+ # plt.show()
379
+ # plt.close()
380
+ """
381
+ analysis/task2_full.py
382
+ =====================
383
+
384
+ FULL Task 2 implementation:
385
+ ✔ Attention trajectory (already yours)
386
+ ✔ BERTScore over diffusion steps
387
+ ✔ Semantic drift metric
388
+ ✔ Locked vs flexible token detection
389
+ ✔ TF-IDF vs attention stability correlation
390
+ """
391
+
392
+ import torch
393
+ import numpy as np
394
+ from typing import Dict, List
395
+ from collections import defaultdict
396
+
397
+ # Optional metrics
398
+ from sklearn.feature_extraction.text import TfidfVectorizer
399
+
400
+ try:
401
+ import evaluate
402
+ bertscore = evaluate.load("bertscore")
403
+ USE_BERT = True
404
+ except:
405
+ USE_BERT = False
406
+
407
+
408
+ # ─────────────────────────────────────────────────────────────
409
+ # 1. ATTENTION CAPTURE (FIXED VERSION)
410
+ # ─────────────────────────────────────────────────────────────
411
+
412
+ class AttentionCapture:
413
+ def __init__(self, model):
414
+ self.model = model
415
+ self.inner = model.model
416
+ self.cross_attns = []
417
+
418
+ for block in self.inner.decoder_blocks:
419
+ if hasattr(block, "cross_attn"):
420
+ self.cross_attns.append(block.cross_attn)
421
+
422
+ def _enable(self):
423
+ for ca in self.cross_attns:
424
+ ca.capture_weights = True
425
+
426
+ def _disable(self):
427
+ for ca in self.cross_attns:
428
+ ca.capture_weights = False
429
+ ca.last_attn_weights = None
430
+
431
+ def _read(self):
432
+ weights = []
433
+ for ca in self.cross_attns:
434
+ if ca.last_attn_weights is not None:
435
+ w = ca.last_attn_weights.mean(dim=1) # avg heads
436
+ weights.append(w.cpu().numpy())
437
+ return weights
438
+
439
+ @torch.no_grad()
440
+ def run(self, src_ids):
441
+ inner = self.inner
442
+ T = inner.scheduler.num_timesteps
443
+ device = src_ids.device
444
+
445
+ memory, mask = inner.encode_source(src_ids)
446
+
447
+ x = torch.full(
448
+ (1, inner.max_seq_len),
449
+ inner.mask_token_id,
450
+ dtype=torch.long,
451
+ device=device
452
+ )
453
+
454
+ hint = None
455
+ step_weights = {}
456
+ step_outputs = {}
457
+
458
+ self._enable()
459
+
460
+ try:
461
+ for t_val in range(T - 1, -1, -1):
462
+ t = torch.tensor([t_val], device=device)
463
+
464
+ logits, _ = inner.forward_cached(
465
+ memory, mask, x, t, x0_hint=hint, inference_mode=True
466
+ )
467
+
468
+ probs = torch.softmax(logits, dim=-1)
469
+ x = torch.argmax(probs, dim=-1)
470
+
471
+ step_weights[t_val] = self._read()
472
+ step_outputs[t_val] = x.clone()
473
+
474
+ hint = x
475
+
476
+ finally:
477
+ self._disable()
478
+
479
+ return step_weights, step_outputs
480
+
481
+
482
+ # ─────────────────────────────────────────────────────────────
483
+ # 2. BERTScore + Semantic Drift
484
+ # ─────────────────────────────────────────────────────────────
485
+
486
+ def compute_trajectory_metrics(
487
+ step_outputs,
488
+ tgt_tokenizer,
489
+ reference_text
490
+ ):
491
+ trajectory = []
492
+
493
+ for t, ids in step_outputs.items():
494
+ text = tgt_tokenizer.decode(
495
+ [x for x in ids[0].tolist() if x > 4]
496
+ )
497
+
498
+ if USE_BERT:
499
+ score = bertscore.compute(
500
+ predictions=[text],
501
+ references=[reference_text],
502
+ lang="hi"
503
+ )["f1"][0]
504
+ else:
505
+ score = 0.0
506
+
507
+ drift = 1.0 - score
508
+
509
+ trajectory.append({
510
+ "step": t,
511
+ "text": text,
512
+ "bert": score,
513
+ "drift": drift
514
+ })
515
+
516
+ return sorted(trajectory, key=lambda x: -x["step"])
517
+
518
+
519
+ # ─────────────────────────────────────────────────────────────
520
+ # 3. LOCKED vs FLEXIBLE TOKENS
521
+ # ─────────────────────────────────────────────────────────────
522
+
523
+ def analyze_token_stability(step_weights):
524
+ """
525
+ Measure variance of attention over time
526
+ """
527
+ token_stability = defaultdict(list)
528
+
529
+ for t, layers in step_weights.items():
530
+ last_layer = layers[-1][0] # [Lq, Lk]
531
+
532
+ # max attention source index per target token
533
+ align = np.argmax(last_layer, axis=1)
534
+
535
+ for tgt_idx, src_idx in enumerate(align):
536
+ token_stability[tgt_idx].append(src_idx)
537
+
538
+ results = {}
539
+
540
+ for tgt_idx, src_seq in token_stability.items():
541
+ changes = sum(
542
+ 1 for i in range(1, len(src_seq))
543
+ if src_seq[i] != src_seq[i-1]
544
+ )
545
+
546
+ if changes <= 2:
547
+ results[tgt_idx] = "LOCKED"
548
+ else:
549
+ results[tgt_idx] = "FLEXIBLE"
550
+
551
+ return results
552
+
553
+
554
+ # ─────────────────────────────────────────────────────────────
555
+ # 4. TF-IDF vs ATTENTION STABILITY
556
+ # ─────────────────────────���───────────────────────────────────
557
+
558
+ def tfidf_attention_correlation(src_text, step_weights):
559
+ vectorizer = TfidfVectorizer()
560
+ tfidf = vectorizer.fit_transform([src_text]).toarray()[0]
561
+
562
+ # Avg attention over steps
563
+ attn_scores = None
564
+
565
+ for t, layers in step_weights.items():
566
+ w = layers[-1][0] # last layer
567
+ avg = w.mean(axis=0) # per source token
568
+
569
+ if attn_scores is None:
570
+ attn_scores = avg
571
+ else:
572
+ attn_scores += avg
573
+
574
+ attn_scores /= len(step_weights)
575
+
576
+ # Correlation
577
+ min_len = min(len(tfidf), len(attn_scores))
578
+ corr = np.corrcoef(tfidf[:min_len], attn_scores[:min_len])[0, 1]
579
+
580
+ return corr
581
+
582
+
583
+ # ─────────────────────────────────────────────────────────────
584
+ # 5. FULL PIPELINE
585
+ # ─────────────────────────────────────────────────────────────
586
+
587
+ def run_task2_analysis(
588
+ text,
589
+ model,
590
+ src_tokenizer,
591
+ tgt_tokenizer,
592
+ device
593
+ ):
594
+ src_ids = torch.tensor(
595
+ [src_tokenizer.encode(text)],
596
+ device=device
597
+ )
598
+
599
+ capturer = AttentionCapture(model)
600
+
601
+ # Step 1: Capture
602
+ step_weights, step_outputs = capturer.run(src_ids)
603
+
604
+ # Step 2: Metrics
605
+ trajectory = compute_trajectory_metrics(
606
+ step_outputs,
607
+ tgt_tokenizer,
608
+ reference_text=text # transliteration task
609
+ )
610
+
611
+ # Step 3: Token stability
612
+ stability = analyze_token_stability(step_weights)
613
+
614
+ # Step 4: TF-IDF correlation
615
+ corr = tfidf_attention_correlation(text, step_weights)
616
+
617
+ return {
618
+ "trajectory": trajectory,
619
+ "token_stability": stability,
620
+ "tfidf_corr": corr
621
+ }
analysis/concept_vectors.py ADDED
@@ -0,0 +1,637 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # analysis/concept_vectors.py
3
+ # ============================
4
+ # Task 3: Concept Vector Extraction + Controlled Paraphrase Diversity
5
+ #
6
+ # No retraining required. Uses decoder hidden states already computed
7
+ # during generate_cached() — stored in model.model._last_hidden after
8
+ # each forward_cached() call.
9
+ #
10
+ # Steps:
11
+ # 1. Collect hidden states from N examples at a fixed diffusion step
12
+ # 2. Pool sequence dimension → [N, d_model] representation per example
13
+ # 3. PCA → find principal directions in concept space
14
+ # 4. Identify "diversity direction" (PC that best separates short/long outputs)
15
+ # 5. Steer: at inference, shift hidden states along diversity direction
16
+ # before the output head projection
17
+ # 6. Generate at 5 points along the direction, measure output diversity
18
+ #
19
+ # Key insight: the diversity direction is found purely from model outputs
20
+ # (no human annotation needed). We use output length as a proxy:
21
+ # short output → low diversity (model collapsed to simple token)
22
+ # long output → high diversity (model exploring more of the space)
23
+ # """
24
+ #
25
+ # import torch
26
+ # import torch.nn as nn
27
+ # import torch.nn.functional as F
28
+ # import numpy as np
29
+ # from typing import List, Dict, Optional, Tuple
30
+ #
31
+ #
32
+ # # ── Hidden state collection ───────────────────────────────────────────
33
+ #
34
+ # @torch.no_grad()
35
+ # def collect_hidden_states(
36
+ # model,
37
+ # src_list: List[torch.Tensor],
38
+ # t_capture: int = 0,
39
+ # temperature: float = 0.8,
40
+ # top_k: int = 40,
41
+ # max_samples: int = 1000,
42
+ # ) -> Tuple[np.ndarray, List[str]]:
43
+ # """
44
+ # Run generate_cached() on a list of source tensors, collecting the
45
+ # decoder hidden state at timestep t_capture for each sample.
46
+ #
47
+ # Args:
48
+ # model : SanskritModel (D3PMCrossAttention)
49
+ # src_list : list of [1, src_len] tensors, one per sample
50
+ # t_capture : which diffusion step to capture hidden states at
51
+ # 0 = final (clean), T-1 = noisy start
52
+ # temperature: sampling temperature
53
+ # top_k : top-k filter
54
+ # max_samples: cap at this many samples
55
+ #
56
+ # Returns:
57
+ # hidden_matrix : np.ndarray [N, d_model] — pooled hidden states
58
+ # output_texts : list of N decoded output strings (for diversity analysis)
59
+ # """
60
+ # inner = model.model
61
+ # T = inner.scheduler.num_timesteps
62
+ # device = next(inner.parameters()).device
63
+ #
64
+ # hidden_list = []
65
+ # output_list = []
66
+ #
67
+ # n = min(len(src_list), max_samples)
68
+ # print(f"Collecting hidden states from {n} examples at t={t_capture}...")
69
+ #
70
+ # for i, src in enumerate(src_list[:n]):
71
+ # if i % 100 == 0:
72
+ # print(f" {i}/{n}")
73
+ #
74
+ # if src.dim() == 1:
75
+ # src = src.unsqueeze(0)
76
+ # src = src.to(device)
77
+ #
78
+ # B = src.shape[0]
79
+ # tgt_len = inner.max_seq_len
80
+ # mask_id = inner.mask_token_id
81
+ #
82
+ # # KV cache
83
+ # memory, src_pad_mask = inner.encode_source(src)
84
+ #
85
+ # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
86
+ # hint = None
87
+ # captured_hidden = None
88
+ #
89
+ # for t_val in range(T - 1, -1, -1):
90
+ # t = torch.full((B,), t_val, dtype=torch.long, device=device)
91
+ # is_last = (t_val == 0)
92
+ #
93
+ # logits, _ = inner.forward_cached(
94
+ # memory, src_pad_mask, x0_est, t,
95
+ # x0_hint=hint, inference_mode=True,
96
+ # )
97
+ #
98
+ # # Capture hidden state at target step
99
+ # if t_val == t_capture and hasattr(inner, '_last_hidden'):
100
+ # captured_hidden = inner._last_hidden.detach().cpu()
101
+ #
102
+ # logits = logits / max(temperature, 1e-8)
103
+ # if top_k > 0:
104
+ # V = logits.shape[-1]
105
+ # if top_k < V:
106
+ # vals, _ = torch.topk(logits, top_k, dim=-1)
107
+ # logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
108
+ #
109
+ # probs = F.softmax(logits, dim=-1)
110
+ # x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
111
+ # hint = x0_est
112
+ #
113
+ # # Pool hidden state over non-PAD positions → [d_model]
114
+ # if captured_hidden is not None:
115
+ # non_pad = (x0_est[0] > 1).cpu() # [tgt_len] bool
116
+ # if non_pad.sum() > 0:
117
+ # h = captured_hidden[0][non_pad].mean(dim=0) # [d_model]
118
+ # else:
119
+ # h = captured_hidden[0].mean(dim=0)
120
+ # hidden_list.append(h.numpy())
121
+ #
122
+ # # Decode output
123
+ # ids = [x for x in x0_est[0].tolist() if x > 4]
124
+ #
125
+ # print(f"Collected {len(hidden_list)} hidden states.")
126
+ # return np.stack(hidden_list), output_list
127
+ #
128
+ #
129
+ # # ── PCA on hidden states ─────────────────────────────���────────────────
130
+ #
131
+ # def fit_pca(
132
+ # hidden_matrix: np.ndarray,
133
+ # n_components: int = 50,
134
+ # ) -> object:
135
+ # """
136
+ # Fit PCA on hidden state matrix.
137
+ #
138
+ # Args:
139
+ # hidden_matrix : [N, d_model]
140
+ # n_components : number of PCA components to retain
141
+ #
142
+ # Returns:
143
+ # fitted sklearn PCA object
144
+ # """
145
+ # from sklearn.decomposition import PCA
146
+ # n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1])
147
+ # pca = PCA(n_components=n_comp)
148
+ # pca.fit(hidden_matrix)
149
+ # print(f"PCA fit: {n_comp} components explain "
150
+ # f"{pca.explained_variance_ratio_.sum()*100:.1f}% of variance.")
151
+ # return pca
152
+ #
153
+ #
154
+ # def find_diversity_direction(
155
+ # hidden_matrix: np.ndarray,
156
+ # output_lengths: List[int],
157
+ # pca: object,
158
+ # ) -> np.ndarray:
159
+ # """
160
+ # Find the PCA direction that best correlates with output diversity
161
+ # (measured by output length as proxy).
162
+ #
163
+ # Projects hidden states into PCA space, then finds the PC whose
164
+ # scores have highest Spearman correlation with output lengths.
165
+ #
166
+ # Returns:
167
+ # direction : np.ndarray [d_model] — diversity direction in original space
168
+ # """
169
+ # from scipy.stats import spearmanr
170
+ #
171
+ # projected = pca.transform(hidden_matrix) # [N, n_components]
172
+ # lengths = np.array(output_lengths)
173
+ #
174
+ # correlations = []
175
+ # for pc_idx in range(projected.shape[1]):
176
+ # r, _ = spearmanr(projected[:, pc_idx], lengths)
177
+ # correlations.append(abs(r))
178
+ #
179
+ # best_pc = int(np.argmax(correlations))
180
+ # print(f"Diversity direction: PC {best_pc} "
181
+ # f"(|r|={correlations[best_pc]:.3f} with output length)")
182
+ #
183
+ # # Map back to original d_model space
184
+ # direction = pca.components_[best_pc] # [d_model]
185
+ # direction = direction / (np.linalg.norm(direction) + 1e-8)
186
+ # return direction, best_pc, correlations[best_pc]
187
+ #
188
+ #
189
+ # # ── Steered generation ────────────────────────────────────────────────
190
+ #
191
+ # @torch.no_grad()
192
+ # def generate_steered(
193
+ # model,
194
+ # src: torch.Tensor,
195
+ # direction: np.ndarray,
196
+ # alpha: float = 0.0,
197
+ # temperature: float = 0.8,
198
+ # top_k: int = 40,
199
+ # ) -> torch.Tensor:
200
+ # """
201
+ # Generate output while steering hidden states along diversity direction.
202
+ #
203
+ # At each diffusion step, after the decoder runs, we shift the hidden state
204
+ # by alpha * direction before projecting to logits.
205
+ #
206
+ # alpha > 0 → push toward high-diversity output
207
+ # alpha < 0 → push toward low-diversity output
208
+ # alpha = 0 → standard generation (no steering)
209
+ #
210
+ # Args:
211
+ # model : SanskritModel (D3PMCrossAttention)
212
+ # src : [1, src_len] IAST token ids
213
+ # direction : [d_model] diversity direction from find_diversity_direction()
214
+ # alpha : steering strength
215
+ # temperature / top_k: sampling params
216
+ #
217
+ # Returns:
218
+ # x0_est : [1, tgt_len] generated token ids
219
+ # """
220
+ # inner = model.model
221
+ # T = inner.scheduler.num_timesteps
222
+ # device = next(inner.parameters()).device
223
+ #
224
+ # if src.dim() == 1:
225
+ # src = src.unsqueeze(0)
226
+ # src = src.to(device)
227
+ #
228
+ # B = src.shape[0]
229
+ # tgt_len = inner.max_seq_len
230
+ # mask_id = inner.mask_token_id
231
+ #
232
+ # dir_tensor = torch.tensor(direction, dtype=torch.float32, device=device)
233
+ #
234
+ # memory, src_pad_mask = inner.encode_source(src)
235
+ # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
236
+ # hint = None
237
+ #
238
+ # inner.eval()
239
+ # for t_val in range(T - 1, -1, -1):
240
+ # t = torch.full((B,), t_val, dtype=torch.long, device=device)
241
+ # is_last = (t_val == 0)
242
+ #
243
+ # # Standard forward_cached but we intercept hidden states
244
+ # PAD = 1
245
+ # tgt_pad_mask = None # inference_mode
246
+ #
247
+ # _, x_t_ids = inner.forward_process.q_sample(x0_est, t) if t_val > 0 else \
248
+ # (None, x0_est)
249
+ # x = inner.tgt_embed(x_t_ids)
250
+ # t_norm = t.float() / inner.scheduler.num_timesteps
251
+ # t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
252
+ # x = x + t_emb.unsqueeze(1)
253
+ #
254
+ # if hint is not None:
255
+ # hint_emb = inner.tgt_embed(hint)
256
+ # gate = inner.hint_gate(x)
257
+ # x = x + gate * hint_emb
258
+ #
259
+ # for block in inner.decoder_blocks:
260
+ # x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
261
+ #
262
+ # # ── STEERING: shift hidden states along diversity direction ───
263
+ # if alpha != 0.0:
264
+ # x = x + alpha * dir_tensor.unsqueeze(0).unsqueeze(0)
265
+ #
266
+ # # Project to logits using the head
267
+ # logits = inner.head(x)
268
+ #
269
+ # logits = logits / max(temperature, 1e-8)
270
+ # if top_k > 0:
271
+ # V = logits.shape[-1]
272
+ # if top_k < V:
273
+ # vals, _ = torch.topk(logits, top_k, dim=-1)
274
+ # logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
275
+ #
276
+ # probs = F.softmax(logits, dim=-1)
277
+ # x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
278
+ # hint = x0_est
279
+ #
280
+ # return x0_est
281
+ #
282
+ #
283
+ # def generate_diversity_spectrum(
284
+ # model,
285
+ # src: torch.Tensor,
286
+ # direction: np.ndarray,
287
+ # tgt_tokenizer,
288
+ # alphas: List[float] = [-2.0, -1.0, 0.0, 1.0, 2.0],
289
+ # temperature: float = 0.8,
290
+ # top_k: int = 40,
291
+ # ) -> Dict[float, str]:
292
+ # """
293
+ # Generate outputs at 5 points along the diversity direction.
294
+ #
295
+ # Args:
296
+ # alphas : steering strengths (negative = low diversity, positive = high)
297
+ #
298
+ # Returns:
299
+ # dict mapping alpha → decoded Devanagari string
300
+ # """
301
+ # results = {}
302
+ # for alpha in alphas:
303
+ # out_ids = generate_steered(model, src, direction, alpha, temperature, top_k)
304
+ # ids = [x for x in out_ids[0].tolist() if x > 4]
305
+ # text = tgt_tokenizer.decode(ids).strip()
306
+ # results[alpha] = text
307
+ # print(f" alpha={alpha:+.1f} → {text}")
308
+ # return results
309
+ #
310
+ #
311
+ # def plot_pca_space(
312
+ # hidden_matrix: np.ndarray,
313
+ # output_lengths: List[int],
314
+ # pca: object,
315
+ # diversity_pc: int,
316
+ # save_path: Optional[str] = None,
317
+ # ):
318
+ # """
319
+ # Scatter plot of examples in PC1 vs PC2 space, coloured by output length.
320
+ # Highlights the diversity direction.
321
+ # """
322
+ # try:
323
+ # import matplotlib.pyplot as plt
324
+ # except ImportError:
325
+ # print("pip install matplotlib.")
326
+ # return
327
+ #
328
+ # projected = pca.transform(hidden_matrix) # [N, n_pc]
329
+ # lengths = np.array(output_lengths)
330
+ #
331
+ # fig, axes = plt.subplots(1, 2, figsize=(14, 5))
332
+ #
333
+ # # Left: PC0 vs PC1 coloured by length
334
+ # ax = axes[0]
335
+ # sc = ax.scatter(projected[:, 0], projected[:, 1],
336
+ # c=lengths, cmap='viridis', alpha=0.6, s=15)
337
+ # plt.colorbar(sc, ax=ax, label="Output length (chars)")
338
+ # ax.set_xlabel(f"PC0 ({pca.explained_variance_ratio_[0]*100:.1f}%)", fontsize=10)
339
+ # ax.set_ylabel(f"PC1 ({pca.explained_variance_ratio_[1]*100:.1f}%)", fontsize=10)
340
+ # ax.set_title("Concept space (PC0 vs PC1)", fontsize=11)
341
+ #
342
+ # # Right: explained variance
343
+ # ax2 = axes[1]
344
+ # cumvar = np.cumsum(pca.explained_variance_ratio_) * 100
345
+ # ax2.plot(range(1, len(cumvar)+1), cumvar, linewidth=1.5, color='steelblue')
346
+ # ax2.axvline(diversity_pc, color='coral', linestyle='--', label=f"Diversity PC={diversity_pc}")
347
+ # ax2.set_xlabel("Number of PCs", fontsize=10)
348
+ # ax2.set_ylabel("Cumulative variance (%)", fontsize=10)
349
+ # ax2.set_title("PCA explained variance", fontsize=11)
350
+ # ax2.legend()
351
+ # ax2.set_ylim(0, 102)
352
+ #
353
+ # plt.tight_layout()
354
+ # if save_path:
355
+ # import os
356
+ # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
357
+ # plt.savefig(save_path, dpi=150, bbox_inches='tight')
358
+ # print(f"Saved: {save_path}")
359
+ # else:
360
+ # plt.show()
361
+ # plt.close()
362
+ #
363
+ #
364
+ # def _sample(probs):
365
+ # B, L, V = probs.shape
366
+ # flat = probs.view(B * L, V).clamp(min=1e-9)
367
+ # flat = flat / flat.sum(dim=-1, keepdim=True)
368
+ # return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
369
+ """
370
+ Task 3: Concept Vector Extraction + Controlled Paraphrase Diversity
371
+ Fully corrected & production-ready version
372
+ """
373
+
374
+ import torch
375
+ import torch.nn.functional as F
376
+ import numpy as np
377
+ from typing import List, Tuple, Dict, Optional
378
+
379
+
380
+ # ─────────────────────────────────────────────────────────────
381
+ # Utility
382
+ # ─────────────────────────────────────────────────────────────
383
+
384
+ def _sample(probs: torch.Tensor) -> torch.Tensor:
385
+ B, L, V = probs.shape
386
+ flat = probs.view(B * L, V).clamp(min=1e-9)
387
+ flat = flat / flat.sum(dim=-1, keepdim=True)
388
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
389
+
390
+
391
+ # ─────────────────────────────────────────────────────────────
392
+ # 1. Collect Hidden States
393
+ # ─────────────────────────────────────────────────────────────
394
+
395
+ @torch.no_grad()
396
+ def collect_hidden_states(
397
+ model,
398
+ src_list: List[torch.Tensor],
399
+ tgt_tokenizer,
400
+ t_capture: int = 0,
401
+ temperature: float = 0.8,
402
+ top_k: int = 40,
403
+ max_samples: int = 1000,
404
+ ) -> Tuple[np.ndarray, List[str], List[int]]:
405
+ """
406
+ Collect pooled hidden representations + outputs
407
+ """
408
+
409
+ inner = model.model
410
+ device = next(inner.parameters()).device
411
+ T = inner.scheduler.num_timesteps
412
+
413
+ hidden_list = []
414
+ texts = []
415
+ lengths = []
416
+
417
+ print(f"Collecting {min(len(src_list), max_samples)} samples...")
418
+
419
+ for i, src in enumerate(src_list[:max_samples]):
420
+
421
+ if src.dim() == 1:
422
+ src = src.unsqueeze(0)
423
+ src = src.to(device)
424
+
425
+ B = src.shape[0]
426
+ tgt_len = inner.max_seq_len
427
+ mask_id = inner.mask_token_id
428
+
429
+ # KV Cache (IMPORTANT)
430
+ memory, src_pad_mask = inner.encode_source(src)
431
+
432
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
433
+ hint = None
434
+ captured_hidden = None
435
+
436
+ for t_val in range(T - 1, -1, -1):
437
+
438
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
439
+ is_last = (t_val == 0)
440
+
441
+ logits, _ = inner.forward_cached(
442
+ memory,
443
+ src_pad_mask,
444
+ x0_est,
445
+ t,
446
+ x0_hint=hint,
447
+ inference_mode=True,
448
+ )
449
+
450
+ # Capture hidden state
451
+ if t_val == t_capture:
452
+ if hasattr(inner, "_last_hidden"):
453
+ captured_hidden = inner._last_hidden.detach().cpu()
454
+
455
+ # Sampling
456
+ logits = logits / max(temperature, 1e-8)
457
+
458
+ if top_k > 0:
459
+ vals, _ = torch.topk(logits, top_k, dim=-1)
460
+ logits = logits.masked_fill(logits < vals[..., -1:], float("-inf"))
461
+
462
+ probs = F.softmax(logits, dim=-1)
463
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
464
+ hint = x0_est
465
+
466
+ # Pool hidden
467
+ if captured_hidden is not None:
468
+ h = captured_hidden[0].mean(dim=0) # [d_model]
469
+ hidden_list.append(h.numpy())
470
+
471
+ # Decode
472
+ ids = [x for x in x0_est[0].tolist() if x > 4]
473
+ text = tgt_tokenizer.decode(ids).strip()
474
+
475
+ texts.append(text)
476
+ lengths.append(len(text))
477
+
478
+ if i % 100 == 0:
479
+ print(f"{i} done")
480
+
481
+ hidden_matrix = np.stack(hidden_list)
482
+
483
+ print("Collected hidden states:", hidden_matrix.shape)
484
+ return hidden_matrix, texts, lengths
485
+
486
+
487
+ # ─────────────────────────────────────────────────────────────
488
+ # 2. PCA
489
+ # ─────────────────────────────────────────────────────────────
490
+
491
+ def fit_pca(hidden_matrix: np.ndarray, n_components: int = 50):
492
+ from sklearn.decomposition import PCA
493
+
494
+ n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1])
495
+ pca = PCA(n_components=n_comp)
496
+ pca.fit(hidden_matrix)
497
+
498
+ print("Explained variance:", pca.explained_variance_ratio_.sum())
499
+ return pca
500
+
501
+
502
+ # ─────────────────────────────────────────────────────────────
503
+ # 3. Find Diversity Direction
504
+ # ─────────────────────────────────────────────────────────────
505
+
506
+ def find_diversity_direction(hidden_matrix, lengths, pca):
507
+ from scipy.stats import spearmanr
508
+
509
+ projected = pca.transform(hidden_matrix)
510
+ lengths = np.array(lengths)
511
+
512
+ scores = []
513
+
514
+ for i in range(projected.shape[1]):
515
+ r, _ = spearmanr(projected[:, i], lengths)
516
+ scores.append(abs(r))
517
+
518
+ best_pc = int(np.argmax(scores))
519
+
520
+ print(f"Best PC: {best_pc} | corr={scores[best_pc]:.3f}")
521
+
522
+ direction = pca.components_[best_pc]
523
+ direction = direction / (np.linalg.norm(direction) + 1e-8)
524
+
525
+ return direction
526
+
527
+
528
+ # ─────────────────────────────────────────────────────────────
529
+ # 4. Steered Generation
530
+ # ─────────────────────────────────────────────────────────────
531
+
532
+ @torch.no_grad()
533
+ def generate_steered(
534
+ model,
535
+ src,
536
+ direction,
537
+ alpha=0.0,
538
+ temperature=0.8,
539
+ top_k=40,
540
+ ):
541
+ inner = model.model
542
+ device = next(inner.parameters()).device
543
+ T = inner.scheduler.num_timesteps
544
+
545
+ if src.dim() == 1:
546
+ src = src.unsqueeze(0)
547
+ src = src.to(device)
548
+
549
+ B = src.shape[0]
550
+ tgt_len = inner.max_seq_len
551
+ mask_id = inner.mask_token_id
552
+
553
+ direction = torch.tensor(direction, dtype=torch.float32, device=device)
554
+ direction = direction / (torch.norm(direction) + 1e-6)
555
+
556
+ memory, src_pad_mask = inner.encode_source(src)
557
+
558
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
559
+ hint = None
560
+
561
+ for t_val in range(T - 1, -1, -1):
562
+
563
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
564
+ is_last = (t_val == 0)
565
+
566
+ logits, _ = inner.forward_cached(
567
+ memory,
568
+ src_pad_mask,
569
+ x0_est,
570
+ t,
571
+ x0_hint=hint,
572
+ inference_mode=True,
573
+ )
574
+
575
+ # Inject diversity
576
+ if hasattr(inner, "_last_hidden") and alpha != 0.0:
577
+ h = inner._last_hidden
578
+ h = h + alpha * direction.unsqueeze(0).unsqueeze(0)
579
+ logits = inner.head(h)
580
+
581
+ # Sampling
582
+ logits = logits / max(temperature, 1e-8)
583
+
584
+ if top_k > 0:
585
+ vals, _ = torch.topk(logits, top_k, dim=-1)
586
+ logits = logits.masked_fill(logits < vals[..., -1:], float("-inf"))
587
+
588
+ probs = F.softmax(logits, dim=-1)
589
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
590
+ hint = x0_est
591
+
592
+ return x0_est
593
+
594
+
595
+ # ─────────────────────────────────────────────────────────────
596
+ # 5. Diversity Spectrum
597
+ # ─────────────────────────────────────────────────────────────
598
+
599
+ def generate_diversity_spectrum(
600
+ model,
601
+ src,
602
+ direction,
603
+ tgt_tokenizer,
604
+ alphas=[-2, -1, 0, 1, 2],
605
+ ):
606
+ results = {}
607
+
608
+ print("\nDiversity Spectrum:\n")
609
+
610
+ for alpha in alphas:
611
+ out_ids = generate_steered(model, src, direction, alpha)
612
+
613
+ ids = [x for x in out_ids[0].tolist() if x > 4]
614
+ text = tgt_tokenizer.decode(ids).strip()
615
+
616
+ print(f"{alpha:+} → {text}")
617
+ results[alpha] = text
618
+
619
+ return results
620
+
621
+
622
+ # ─────────────────────────────────────────────────────────────
623
+ # 6. Visualization
624
+ # ─────────────────────────────────────────────────────────────
625
+
626
+ def plot_pca_space(hidden_matrix, lengths, pca):
627
+ import matplotlib.pyplot as plt
628
+
629
+ proj = pca.transform(hidden_matrix)
630
+
631
+ plt.figure(figsize=(8, 6))
632
+ sc = plt.scatter(proj[:, 0], proj[:, 1], c=lengths)
633
+ plt.colorbar(sc)
634
+ plt.title("Concept Space")
635
+ plt.xlabel("PC1")
636
+ plt.ylabel("PC2")
637
+ plt.show()
analysis/kv_cache_benchmark.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # analysis/kv_cache_benchmark.py
3
+ # ================================
4
+ # Task 1: Benchmark KV cache vs standard generate().
5
+ #
6
+ # Measures:
7
+ # - Wall-clock time for generate() vs generate_cached()
8
+ # - Encoder time as % of total generation time (before/after)
9
+ # - Speedup ratio at src_len = 16, 32, 64 tokens
10
+ #
11
+ # How it works:
12
+ # Standard generate():
13
+ # For each of T=128 steps:
14
+ # src → encoder → memory → decoder → logits (encoder runs 128 times)
15
+ #
16
+ # generate_cached():
17
+ # src → encoder → memory (once)
18
+ # For each of T=128 steps:
19
+ # cached_memory → decoder → logits (encoder runs 1 time)
20
+ #
21
+ # Expected speedup:
22
+ # If encoder = 30% of per-step time:
23
+ # Saved = 127/128 * 30% ≈ 29.7% of total time
24
+ # If encoder = 50% of per-step time:
25
+ # Saved ≈ 49.6% of total time
26
+ #
27
+ # Usage:
28
+ # python -m analysis.kv_cache_benchmark
29
+ # or:
30
+ # from analysis.kv_cache_benchmark import run_benchmark
31
+ # results = run_benchmark(model, src_tokenizer, device)
32
+ # """
33
+ #
34
+ # import torch
35
+ # import time
36
+ # import numpy as np
37
+ # from typing import Dict, List
38
+ #
39
+ #
40
+ # def _make_src(src_len: int, src_vocab: int, device: torch.device, batch_size: int = 1):
41
+ # """Create a random source tensor of given length."""
42
+ # # Random real tokens (ids 5..src_vocab-1), padded to src_len
43
+ # ids = torch.randint(5, src_vocab, (batch_size, src_len), device=device)
44
+ # return ids
45
+ #
46
+ #
47
+ # def _time_fn(fn, n_warmup: int = 2, n_runs: int = 5) -> float:
48
+ # """
49
+ # Time a zero-argument callable.
50
+ # Returns mean wall-clock seconds over n_runs after n_warmup warmup calls.
51
+ # """
52
+ # # Warmup
53
+ # for _ in range(n_warmup):
54
+ # fn()
55
+ # if torch.cuda.is_available():
56
+ # torch.cuda.synchronize()
57
+ # elif torch.backends.mps.is_available():
58
+ # torch.mps.synchronize()
59
+ #
60
+ # times = []
61
+ # for _ in range(n_runs):
62
+ # start = time.perf_counter()
63
+ # fn()
64
+ # if torch.cuda.is_available():
65
+ # torch.cuda.synchronize()
66
+ # elif torch.backends.mps.is_available():
67
+ # torch.mps.synchronize()
68
+ # times.append(time.perf_counter() - start)
69
+ #
70
+ # return float(np.mean(times))
71
+ #
72
+ #
73
+ # def benchmark_encoder_cost(
74
+ # model,
75
+ # src: torch.Tensor,
76
+ # ) -> Dict[str, float]:
77
+ # """
78
+ # Measure encoder time as a fraction of one full forward pass.
79
+ #
80
+ # Returns:
81
+ # encoder_s : seconds for one encoder call
82
+ # full_step_s : seconds for one full forward_cached decoder step
83
+ # encoder_pct : encoder_s / (encoder_s + full_step_s) * 100
84
+ # """
85
+ # inner = model.model
86
+ # if not hasattr(inner, 'encode_source'):
87
+ # raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
88
+ #
89
+ # device = src.device
90
+ # B = src.shape[0]
91
+ # T = inner.scheduler.num_timesteps
92
+ # tgt_len = inner.max_seq_len
93
+ # mask_id = inner.mask_token_id
94
+ #
95
+ # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
96
+ # t = torch.zeros(B, dtype=torch.long, device=device)
97
+ #
98
+ # # Time encoder alone
99
+ # encoder_s = _time_fn(lambda: inner.encode_source(src))
100
+ #
101
+ # # Pre-compute memory for decoder timing
102
+ # memory, src_pad_mask = inner.encode_source(src)
103
+ #
104
+ # # Time one decoder step (cached)
105
+ # decoder_s = _time_fn(
106
+ # lambda: inner.forward_cached(memory, src_pad_mask, x0_est, t,
107
+ # inference_mode=True)
108
+ # )
109
+ #
110
+ # # Time one full step (non-cached = encoder + decoder)
111
+ # full_s = _time_fn(
112
+ # lambda: inner.forward(src, x0_est, t, inference_mode=True)
113
+ # )
114
+ #
115
+ # encoder_pct = 100.0 * encoder_s / max(full_s, 1e-9)
116
+ #
117
+ # return {
118
+ # "encoder_s": encoder_s,
119
+ # "decoder_s": decoder_s,
120
+ # "full_step_s": full_s,
121
+ # "encoder_pct": encoder_pct,
122
+ # }
123
+ #
124
+ #
125
+ # def run_benchmark(
126
+ # model,
127
+ # src_tokenizer,
128
+ # device: torch.device,
129
+ # src_lens: List[int] = [16, 32, 64],
130
+ # n_runs: int = 5,
131
+ # ) -> Dict:
132
+ # """
133
+ # Full benchmark: compare generate() vs generate_cached() at multiple src lengths.
134
+ #
135
+ # Args:
136
+ # model : SanskritModel (D3PMCrossAttention)
137
+ # src_tokenizer : SanskritSourceTokenizer
138
+ # device : torch.device
139
+ # src_lens : list of source lengths to benchmark
140
+ # n_runs : number of timing runs per condition
141
+ #
142
+ # Returns:
143
+ # results dict with timing and speedup for each src_len
144
+ # """
145
+ # inner = model.model
146
+ # if not hasattr(inner, 'generate_cached'):
147
+ # raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
148
+ #
149
+ # src_vocab = inner.src_embed.token_emb.weight.shape[0]
150
+ # results = {}
151
+ #
152
+ # print("\n" + "=" * 65)
153
+ # print(" KV CACHE BENCHMARK")
154
+ # print("=" * 65)
155
+ # print(f" {'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
156
+ # f"{'speedup':>8} {'encoder%':>9}")
157
+ # print("-" * 65)
158
+ #
159
+ # for src_len in src_lens:
160
+ # src = _make_src(src_len, src_vocab, device)
161
+ #
162
+ # # Encoder cost breakdown
163
+ # enc_cost = benchmark_encoder_cost(model, src)
164
+ #
165
+ # # Time standard generate() — encoder runs T times
166
+ # def run_standard():
167
+ # return inner.generate(src, temperature=0.8, top_k=40)
168
+ #
169
+ # # Time generate_cached() — encoder runs once
170
+ # def run_cached():
171
+ # return inner.generate_cached(src, temperature=0.8, top_k=40)
172
+ #
173
+ # t_standard = _time_fn(run_standard, n_warmup=1, n_runs=n_runs)
174
+ # t_cached = _time_fn(run_cached, n_warmup=1, n_runs=n_runs)
175
+ # speedup = t_standard / max(t_cached, 1e-9)
176
+ #
177
+ # results[src_len] = {
178
+ # "standard_s": t_standard,
179
+ # "cached_s": t_cached,
180
+ # "speedup": speedup,
181
+ # "encoder_pct": enc_cost["encoder_pct"],
182
+ # }
183
+ #
184
+ # print(f" {src_len:>8} {t_standard:>12.3f} {t_cached:>10.3f} "
185
+ # f"{speedup:>7.2f}x {enc_cost['encoder_pct']:>8.1f}%")
186
+ #
187
+ # print("=" * 65)
188
+ # print(f"\n Encoder cost = % of one full forward pass")
189
+ # print(f" Speedup = standard_time / cached_time")
190
+ # print(f" Expected: speedup ≈ 1 / (1 - encoder_pct/100 * (T-1)/T)")
191
+ #
192
+ # return results
193
+ #
194
+ #
195
+ # def print_summary(results: Dict):
196
+ # """Print a human-readable summary of benchmark results."""
197
+ # print("\n SUMMARY")
198
+ # print(" -------")
199
+ # for src_len, r in results.items():
200
+ # saved_pct = (1.0 - 1.0 / r["speedup"]) * 100
201
+ # print(f" src_len={src_len}: {r['speedup']:.2f}x speedup "
202
+ # f"({saved_pct:.1f}% time saved, "
203
+ # f"encoder was {r['encoder_pct']:.1f}% of total)")
204
+ #
205
+ #
206
+ # if __name__ == "__main__":
207
+ # import sys, os
208
+ # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
209
+ # from config import CONFIG
210
+ # from inference import load_model
211
+ # from models.tokenizer import SanskritSourceTokenizer
212
+ #
213
+ # cfg = CONFIG
214
+ # device = torch.device(cfg['training']['device'])
215
+ #
216
+ # model_name = cfg['model_type']
217
+ # has_neg = cfg['data']['include_negative_examples']
218
+ # ckpt = f"results7/{model_name}_neg_{has_neg}/best_model.pt"
219
+ #
220
+ # if not os.path.exists(ckpt):
221
+ # print(f"No checkpoint at {ckpt}. Train first.")
222
+ # sys.exit(1)
223
+ #
224
+ # model, cfg = load_model(ckpt, cfg, device)
225
+ # model.eval()
226
+ #
227
+ # src_tokenizer = SanskritSourceTokenizer(
228
+ # vocab_size = cfg['model'].get('src_vocab_size', 500),
229
+ # max_len = cfg['model']['max_seq_len'],
230
+ # )
231
+ #
232
+ # results = run_benchmark(model, src_tokenizer, device)
233
+ # print_summary(results)
234
+ # ============================================================
235
+ # FULL TASK 1: KV CACHE + PROJECTION + BENCHMARK + GRAPHS
236
+ # ============================================================
237
+
238
+ import torch
239
+ import torch.nn as nn
240
+ import torch.nn.functional as F
241
+ import time
242
+ import numpy as np
243
+ import matplotlib.pyplot as plt
244
+
245
+ # ============================================================
246
+ # 🔧 MODEL (PATCHED WITH PROJECTION + KV CACHE)
247
+ # ============================================================
248
+
249
+ class D3PMCrossAttention(nn.Module):
250
+ def __init__(self, d_model=512, vocab_size=500, max_seq_len=64, T=128):
251
+ super().__init__()
252
+
253
+ self.d_model = d_model
254
+ self.max_seq_len = max_seq_len
255
+ self.mask_token_id = 0
256
+
257
+ # Dummy encoder/decoder (replace with yours)
258
+ self.encoder = nn.Embedding(vocab_size, d_model)
259
+ self.tgt_embed = nn.Embedding(vocab_size, d_model)
260
+ self.head = nn.Linear(d_model, vocab_size)
261
+
262
+ self.time_mlp = nn.Linear(1, d_model)
263
+ self.hint_gate = nn.Linear(d_model, d_model)
264
+
265
+ # Fake scheduler
266
+ class Scheduler:
267
+ def __init__(self, T):
268
+ self.num_timesteps = T
269
+ self.scheduler = Scheduler(T)
270
+
271
+ # 🔥 Projection layer (Task 1 requirement)
272
+ self.semantic_proj = nn.Linear(d_model, d_model // 2)
273
+ self.semantic_up = nn.Linear(d_model // 2, d_model)
274
+
275
+ # ========================================================
276
+ # ✅ ENCODER WITH PROJECTION
277
+ # ========================================================
278
+ def encode_source(self, src):
279
+ memory = self.encoder(src) # [B, L, d]
280
+
281
+ # 🔥 Compress → Expand
282
+ compressed = self.semantic_proj(memory)
283
+ memory = self.semantic_up(compressed)
284
+
285
+ src_pad_mask = None
286
+ return memory, src_pad_mask
287
+
288
+ # ========================================================
289
+ # ✅ STANDARD (NO CACHE)
290
+ # ========================================================
291
+ def forward(self, src, x, t):
292
+ memory, mask = self.encode_source(src)
293
+ return self.forward_cached(memory, mask, x, t)
294
+
295
+ # ========================================================
296
+ # ✅ CACHED FORWARD
297
+ # ========================================================
298
+ def forward_cached(self, memory, src_pad_mask, x, t, hint=None):
299
+ x = self.tgt_embed(x)
300
+
301
+ t_emb = self.time_mlp((t.float()/self.scheduler.num_timesteps).unsqueeze(-1))
302
+ x = x + t_emb.unsqueeze(1)
303
+
304
+ if hint is not None:
305
+ x = x + self.hint_gate(x) * self.tgt_embed(hint)
306
+
307
+ logits = self.head(x)
308
+
309
+ self._last_hidden = x
310
+ return logits, None
311
+
312
+ # ========================================================
313
+ # ❌ OLD GENERATE (SLOW)
314
+ # ========================================================
315
+ @torch.no_grad()
316
+ def generate(self, src):
317
+ B = src.shape[0]
318
+ device = src.device
319
+ T = self.scheduler.num_timesteps
320
+
321
+ x = torch.zeros((B, self.max_seq_len), dtype=torch.long, device=device)
322
+
323
+ for t_val in range(T - 1, -1, -1):
324
+ t = torch.full((B,), t_val, device=device)
325
+
326
+ logits, _ = self.forward(src, x, t)
327
+ probs = F.softmax(logits, dim=-1)
328
+
329
+ x = torch.argmax(probs, dim=-1)
330
+
331
+ return x
332
+
333
+ # ========================================================
334
+ # ✅ FAST GENERATE (KV CACHE)
335
+ # ========================================================
336
+ @torch.no_grad()
337
+ def generate_cached(self, src):
338
+ B = src.shape[0]
339
+ device = src.device
340
+ T = self.scheduler.num_timesteps
341
+
342
+ # 🔥 Encode once
343
+ memory, mask = self.encode_source(src)
344
+
345
+ x = torch.zeros((B, self.max_seq_len), dtype=torch.long, device=device)
346
+ hint = None
347
+
348
+ for t_val in range(T - 1, -1, -1):
349
+ t = torch.full((B,), t_val, device=device)
350
+
351
+ logits, _ = self.forward_cached(memory, mask, x, t, hint)
352
+ probs = F.softmax(logits, dim=-1)
353
+
354
+ x = torch.argmax(probs, dim=-1)
355
+ hint = x
356
+
357
+ return x
358
+
359
+
360
+ # ============================================================
361
+ # 📊 BENCHMARK + MEMORY + GRAPHS
362
+ # ============================================================
363
+
364
+ def benchmark(model, device):
365
+ model.to(device)
366
+ model.eval()
367
+
368
+ vocab = 500
369
+ src_lens = [16, 32, 64]
370
+
371
+ standard_times = []
372
+ cached_times = []
373
+ speedups = []
374
+ memory_savings = []
375
+
376
+ for src_len in src_lens:
377
+ print(f"\n🔹 src_len = {src_len}")
378
+
379
+ src = torch.randint(5, vocab, (1, src_len)).to(device)
380
+
381
+ # -------- STANDARD --------
382
+ torch.cuda.reset_peak_memory_stats()
383
+ start = time.time()
384
+ model.generate(src)
385
+ torch.cuda.synchronize()
386
+ t_std = time.time() - start
387
+ mem_std = torch.cuda.max_memory_allocated() / 1024**2
388
+
389
+ # -------- CACHED --------
390
+ torch.cuda.reset_peak_memory_stats()
391
+ start = time.time()
392
+ model.generate_cached(src)
393
+ torch.cuda.synchronize()
394
+ t_cache = time.time() - start
395
+ mem_cache = torch.cuda.max_memory_allocated() / 1024**2
396
+
397
+ speedup = t_std / t_cache
398
+ mem_red = 100 * (mem_std - mem_cache) / mem_std
399
+
400
+ print(f"Time: {t_std:.2f}s → {t_cache:.2f}s | {speedup:.2f}x")
401
+ print(f"Memory: {mem_std:.0f}MB → {mem_cache:.0f}MB | {mem_red:.1f}%")
402
+
403
+ standard_times.append(t_std)
404
+ cached_times.append(t_cache)
405
+ speedups.append(speedup)
406
+ memory_savings.append(mem_red)
407
+
408
+ # ==========================
409
+ # 📈 PLOT: TIME
410
+ # ==========================
411
+ plt.figure()
412
+ plt.plot(src_lens, standard_times, marker='o', label="Standard")
413
+ plt.plot(src_lens, cached_times, marker='o', label="Cached")
414
+ plt.xlabel("Source Length")
415
+ plt.ylabel("Time (s)")
416
+ plt.title("Generation Time")
417
+ plt.legend()
418
+ plt.grid()
419
+ plt.show()
420
+
421
+ # ==========================
422
+ # 📈 PLOT: SPEEDUP
423
+ # ==========================
424
+ plt.figure()
425
+ plt.plot(src_lens, speedups, marker='o')
426
+ plt.xlabel("Source Length")
427
+ plt.ylabel("Speedup (x)")
428
+ plt.title("KV Cache Speedup")
429
+ plt.grid()
430
+ plt.show()
431
+
432
+ # ==========================
433
+ # 📈 PLOT: MEMORY
434
+ # ==========================
435
+ plt.figure()
436
+ plt.plot(src_lens, memory_savings, marker='o')
437
+ plt.xlabel("Source Length")
438
+ plt.ylabel("Memory Reduction (%)")
439
+ plt.title("Memory Savings")
440
+ plt.grid()
441
+ plt.show()
442
+
443
+
444
+ # ============================================================
445
+ # 🚀 RUN
446
+ # ============================================================
447
+
448
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
449
+
450
+ model = D3PMCrossAttention()
451
+ benchmark(model, device)
analysis/outputs/task1_kv_cache.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 1 — KV CACHE BENCHMARK
2
+ ========================================
3
+
4
+ src_len standard(s) cached(s) speedup encoder% mem-save%
5
+ 16 3.431 3.512 0.98x 133.1% 50.0%
6
+ source-mem before=0.070MB after=0.035MB
7
+ 32 3.626 3.555 1.02x 36.8% 50.0%
8
+ source-mem before=0.141MB after=0.070MB
9
+ 64 3.585 3.701 0.97x 53.3% 50.0%
10
+ source-mem before=0.281MB after=0.141MB
11
+
12
+
13
+
14
+
15
+ Encoder cost = % of one full forward pass
16
+ Speedup = standard_time / cached_time
17
+ Expected: speedup ≈ 1 / (1 - encoder_pct/100 * (T-1)/T)
18
+
19
+ SUMMARY
20
+ -------
21
+ src_len=16: 0.98x speedup (-2.4% time saved, encoder was 133.1% of total, estimated memory change 50.0%)
22
+ src_len=32: 1.02x speedup (1.9% time saved, encoder was 36.8% of total, estimated memory change 50.0%)
23
+ src_len=64: 0.97x speedup (-3.2% time saved, encoder was 53.3% of total, estimated memory change 50.0%)
analysis/outputs/task2_all_layers_t0.png ADDED
analysis/outputs/task2_attn_evolution.png ADDED
analysis/outputs/task2_attn_t0.png ADDED
analysis/outputs/task2_attn_t127.png ADDED
analysis/outputs/task2_examples/example_1_attn_t0.png ADDED
analysis/outputs/task2_examples/example_2_attn_t0.png ADDED
analysis/outputs/task2_examples/example_3_attn_t0.png ADDED
analysis/outputs/task2_examples/example_4_attn_t0.png ADDED
analysis/outputs/task2_examples/example_5_attn_t0.png ADDED
analysis/outputs/task2_report.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 2 — ATTENTION + DRIFT REPORT
2
+ ==================================================
3
+
4
+ Input : dharmo rakṣati rakṣitaḥ
5
+ Output : कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा ब्र कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा ध्या ध्या ध्या कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा
6
+
7
+ Lock-in t : 122
8
+ Mean pos lock-in : 118.7 ± 17.7
9
+
10
+ Source alignment metric : bertscore_f1
11
+ Best source-alignment step : t=127
12
+ Locked positions : 12
13
+ Flexible positions : 8
14
+ TF-IDF vs attention stability correlation : 0.0
15
+
16
+ Step → Output → CER-to-final
17
+ ------------------------------------------------------------
18
+ t= 127 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.2293
19
+ t= 122 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0769
20
+ t= 117 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0698
21
+ t= 112 | कुङ्कुमा लये कुङ्कुमा कुङ्कुमा कुङ्कुमा | 0.0541
22
+ t= 107 | कुङ्कुमा ध्या कुङ्कुमा कुङ्कुमा कुङ्कुमा | 0.0670
23
+ t= 102 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0442
24
+ t= 97 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0342
25
+ t= 92 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0456
26
+ t= 87 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0299
27
+ t= 82 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0214
28
+ t= 77 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0214
29
+ t= 72 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0214
30
+ t= 67 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0214
31
+ t= 62 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0128
32
+ t= 57 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0128
33
+ t= 52 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0128
34
+ t= 47 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0043
35
+ t= 42 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0043
36
+ t= 37 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
37
+ t= 32 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
38
+ t= 27 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
39
+ t= 22 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
40
+ t= 17 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
41
+ t= 12 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
42
+ t= 7 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
43
+ t= 2 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
44
+ t= 0 | कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ्कुमा कुङ् | 0.0000
45
+
46
+ Step → Source alignment
47
+ ------------------------------------------------------------
48
+ t= 127 | 0.4312
49
+ t= 122 | 0.3941
50
+ t= 117 | 0.3963
51
+ t= 112 | 0.3871
52
+ t= 107 | 0.3947
53
+ t= 102 | 0.3950
54
+ t= 97 | 0.3894
55
+ t= 92 | 0.3887
56
+ t= 87 | 0.3897
57
+ t= 82 | 0.3881
58
+ t= 77 | 0.3881
59
+ t= 72 | 0.3881
60
+ t= 67 | 0.3881
61
+ t= 62 | 0.3889
62
+ t= 57 | 0.3889
63
+ t= 52 | 0.3889
64
+ t= 47 | 0.3882
65
+ t= 42 | 0.3882
66
+ t= 37 | 0.3901
67
+ t= 32 | 0.3901
68
+ t= 27 | 0.3901
69
+ t= 22 | 0.3901
70
+ t= 17 | 0.3901
71
+ t= 12 | 0.3901
72
+ t= 7 | 0.3901
73
+ t= 2 | 0.3901
74
+ t= 0 | 0.3901
75
+
76
+ Locked target positions
77
+ ------------------------------------------------------------
78
+ tgt[0]=कुङ्कुमा → src[3]=taḥ stability=0.781
79
+ tgt[1]=शिरः → src[3]=taḥ stability=0.781
80
+ tgt[2]=कुङ्कुमा → src[3]=taḥ stability=0.780
81
+ tgt[3]=कुङ्कुमा → src[2]=rakṣi stability=0.780
82
+ tgt[4]=पुरतो → src[2]=rakṣi stability=0.781
83
+ tgt[5]=कुङ्कुमा → src[2]=rakṣi stability=0.781
84
+ tgt[8]=मु → src[3]=taḥ stability=0.782
85
+ tgt[9]=कुङ्कुमा → src[3]=taḥ stability=0.783
86
+ tgt[10]=कुङ्कुमा → src[3]=taḥ stability=0.783
87
+ tgt[11]=कुङ्कुमा → src[3]=taḥ stability=0.781
88
+ tgt[13]=कुङ्कुमा → src[2]=rakṣi stability=0.781
89
+ tgt[14]=कुङ्कुमा → src[2]=rakṣi stability=0.781
90
+
91
+ Flexible target positions
92
+ ------------------------------------------------------------
93
+ tgt[6]=कुङ्कुमा → src[2]=rakṣi stability=0.731
94
+ tgt[7]=कुङ्कुमा → src[2]=rakṣi stability=0.481
95
+ tgt[12]=कुङ्कुमा → src[2]=rakṣi stability=0.431
96
+ tgt[15]=कुङ्कुमा → src[2]=rakṣi stability=0.480
97
+ tgt[16]=कुङ्कुमा → src[2]=rakṣi stability=0.479
98
+ tgt[17]=कुङ्कुमा → src[2]=rakṣi stability=0.428
99
+ tgt[18]=कुङ्कुमा → src[3]=taḥ stability=0.727
100
+ tgt[19]=कुङ्कुमा → src[0]=dharmo stability=0.377
analysis/outputs/task2_semantic_drift.png ADDED
analysis/outputs/task2_source_alignment.png ADDED
analysis/outputs/task3_concept_space.png ADDED

Git LFS Details

  • SHA256: 22933b0a457dfd10d659987574594b5dd8e88c8b25b5bb3f9cd5f9517f9f4865
  • Pointer size: 131 Bytes
  • Size of remote file: 202 kB
analysis/outputs/task3_diversity_direction.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dff757876fd9352d5c1f86d2af244c9784d9ec66639a0f31ec5f6c9ec608d4b
3
+ size 1664
analysis/outputs/task3_report.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 3 — CONCEPT VECTORS + PCA STEERING
2
+ ==================================================
3
+
4
+ PCA: 50 components, 96.1% variance
5
+ Diversity PC: 1 (|r|=0.303 with output length)
6
+
7
+ Diversity spectrum:
8
+ alpha=-2.0 → विष द्धा समन्व ददर्श रे विष रे द्धा रे रे ष्व विष रे विष रे रे रे विष रे रे रे विष रे रे कार साग ददर्श वादि रे रे रे रे ददर्श रे रे रे विस्त रे रे समन्व सुर रे वस्तु रे रे रे रे रे रे रे सुर रे रे रे रे रे सुर रे ैक किंचि वस्तु विष रे कार रे विष कार गतिं रे कार शो कार कार कार साग समन्व रे कार कार कार
9
+ alpha=-1.0 → रे विष विष ष्व रे रे विष विष रे विष ददर्श रे ्य् रे रे रे विष रे रे शः रे भवि वस्तु रे विष ्य् विष रे रे वस्तु घा वादि रे रे ्य् रे रे ्य् रे रे रे ्य् पृत रे रे नृप रे द्धा रे रे रे रे ्य् रे रे त्तु रे ्य् रे विष रे सुर साग विष रे कार विष विष ्य् रे रे ्य् ्य् ्य् ्य् रे कार कार कार कार
10
+ alpha=+0.0 → विष ष्व भवि दित्य द्धा रे तौ वृ ्य् रे वादि ॠ रे विष रे ष्व रे का रे ्य् रे ्य् विष ्य् ष्व ्य् वृ जना रे भवि वस्तु त्रिषु विष घा भु की ्य् वृ रे भु यां वृ रे भु यां समु रे रे ्य् रे भु वृ ्य् क्ष ्य् ान्त ्य् ्य् ्य् व्रजेत् ्य् भु रे रे ्य् रे उक्त ्य् ्य् समन्व ्य् ्य् सु ल्प वीर ्य् ्य् ्य् विष ्य्
11
+ alpha=+1.0 → ॠ वृ वृ वृ वृ वृ ण् भवि ्त वृ वृ दश ्य् यां ॠ भु तं भु भु ान्त भवि भु भु रे यां वस्तु यां यां भु यां यां यां यां ्य् यां भु दृष्ट दृष्ट यां यां भु यां यां यां यां द्वि भु यां भु क्ष भु भु भु ष्ट रु ब्र भु न्तु ण्ड यां भु यां ्य् क्ष ्य् वृ ्य् , यां भु यां भु रोध भु ्य् यां ्य् ्य् यां यां
12
+ alpha=+2.0 → वृ वृ वृ ण् वृ वृ ब्र वृ ष्ट ष्ट ष्ट ्य् मा यां ष्ट यां ब्र यां तं तं भु भु वृ भु यां धनम् यां क्ष यां द्वि भु यां यां यां यां द्वि यां भु भु यां यां भु यां क्ष यां भु यां भु ्य् यां भु यां यां मा यां यां भु वृ यां धा भु यां यां मा भु हृ यां यां यां भु द्वि यां द्वि ब्र ण्ड मा द्वि यां यां भु
analysis/outputs/task5_quality_classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0410b67872dbf030b2db5410ecca92f6357d90ae9f47f2c7cf1ad8202c274f61
3
+ size 233761
analysis/outputs/task5_quality_data.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dad6d37cae2b157877a4106d92528417981f75ae57cddfd46112441cd7e9a338
3
+ size 770512
analysis/outputs_multi/results__d3pm_cross_attention_neg_False/task1/task1_kv_cache.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 1 — KV CACHE BENCHMARK
2
+ ========================================
3
+
4
+ src_len standard(s) cached(s) speedup encoder% mem-save%
5
+ 16 3.309 3.624 0.91x 52.9% 50.0%
6
+ source-mem before=0.070MB after=0.035MB
7
+ 32 4.214 4.234 1.00x 40.0% 50.0%
8
+ source-mem before=0.141MB after=0.070MB
9
+ 64 6.929 8.372 0.83x 58.7% 50.0%
10
+ source-mem before=0.281MB after=0.141MB
analysis/outputs_multi/results__d3pm_cross_attention_neg_True/task1/task1_kv_cache.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 1 — KV CACHE BENCHMARK
2
+ ========================================
3
+
4
+ src_len standard(s) cached(s) speedup encoder% mem-save%
5
+ 16 2.548 2.464 1.03x 31.6% 50.0%
6
+ source-mem before=0.070MB after=0.035MB
7
+ 32 3.222 2.952 1.09x 37.8% 50.0%
8
+ source-mem before=0.141MB after=0.070MB
9
+ 64 4.121 4.335 0.95x 33.6% 50.0%
10
+ source-mem before=0.281MB after=0.141MB
analysis/quality_classifier.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # analysis/quality_classifier.py
3
+ # ================================
4
+ # Task 5: Classifier-Free Guidance for Paraphrase Quality Control
5
+ #
6
+ # Two steps — only Step 2 requires training a SMALL model (not the main D3PM):
7
+ #
8
+ # STEP 1 — Collect training data (no training):
9
+ # Run existing model on val set, record (hidden_state, CER) pairs.
10
+ # Hidden states come from model.model._last_hidden after forward_cached().
11
+ # CER score = quality label (lower CER = higher quality).
12
+ #
13
+ # STEP 2 — Train quality classifier:
14
+ # Small 2-layer MLP: d_model → 64 → 1
15
+ # Input: pooled decoder hidden state [B, d_model]
16
+ # Output: predicted quality score in [0, 1] (1 = high quality)
17
+ # Loss: MSE against normalized CER labels
18
+ # Training time: ~5-10 minutes on CPU for 10k examples
19
+ #
20
+ # STEP 3 — Guided inference (no retraining):
21
+ # At each diffusion step, use classifier gradient to shift logits:
22
+ # guided_logits = logits + λ * ∂(quality_score)/∂(logits)
23
+ # Higher λ → model biased toward high-quality outputs
24
+ # λ=0 → standard generation (no guidance)
25
+ #
26
+ # Key: main D3PM model is FROZEN throughout. Only the 10k-param classifier trains.
27
+ # """
28
+ #
29
+ # import torch
30
+ # import torch.nn as nn
31
+ # import torch.nn.functional as F
32
+ # import numpy as np
33
+ # import os
34
+ # import json
35
+ # from typing import List, Dict, Optional, Tuple
36
+ #
37
+ #
38
+ # # ── Quality classifier architecture ──────────────────────────────────
39
+ #
40
+ # class QualityClassifier(nn.Module):
41
+ # """
42
+ # Lightweight MLP that predicts transliteration quality from decoder
43
+ # hidden states.
44
+ #
45
+ # Architecture:
46
+ # d_model → 128 → 64 → 1 → Sigmoid
47
+ #
48
+ # Input: mean-pooled decoder hidden state [B, d_model]
49
+ # Output: quality score [B, 1] ∈ [0, 1] (1 = high quality)
50
+ #
51
+ # ~10k parameters. Trains in minutes on CPU.
52
+ # """
53
+ # def __init__(self, d_model: int):
54
+ # super().__init__()
55
+ # self.net = nn.Sequential(
56
+ # nn.Linear(d_model, 128),
57
+ # nn.ReLU(),
58
+ # nn.Dropout(0.1),
59
+ # nn.Linear(128, 64),
60
+ # nn.ReLU(),
61
+ # nn.Linear(64, 1),
62
+ # nn.Sigmoid(),
63
+ # )
64
+ # self.d_model = d_model
65
+ #
66
+ # def forward(self, hidden: torch.Tensor) -> torch.Tensor:
67
+ # """
68
+ # Args:
69
+ # hidden : [B, tgt_len, d_model] OR [B, d_model] (already pooled)
70
+ #
71
+ # Returns:
72
+ # score : [B, 1] quality score in [0, 1]
73
+ # """
74
+ # if hidden.dim() == 3:
75
+ # # Pool over sequence length
76
+ # hidden = hidden.mean(dim=1) # [B, d_model]
77
+ # return self.net(hidden) # [B, 1]
78
+ #
79
+ #
80
+ # # ── Training data collection ──────────────────────────────────────────
81
+ #
82
+ # @torch.no_grad()
83
+ # def collect_quality_data(
84
+ # model,
85
+ # src_list: List[torch.Tensor],
86
+ # ref_list: List[str],
87
+ # tgt_tokenizer,
88
+ # t_capture: int = 0,
89
+ # temperature: float = 0.8,
90
+ # top_k: int = 40,
91
+ # max_samples: int = 5000,
92
+ # ) -> Tuple[np.ndarray, np.ndarray]:
93
+ # """
94
+ # Collect (hidden_state, quality_score) pairs for classifier training.
95
+ #
96
+ # For each sample:
97
+ # 1. Run generate_cached() on src
98
+ # 2. Capture decoder hidden state at t=t_capture
99
+ # 3. Compute CER between output and reference
100
+ # 4. Quality = 1 - CER (normalize to [0,1])
101
+ #
102
+ # Args:
103
+ # model : SanskritModel
104
+ # src_list : list of [1, src_len] tensors
105
+ # ref_list : list of reference Devanagari strings
106
+ # tgt_tokenizer : SanskritTargetTokenizer
107
+ # t_capture : which step to capture hidden states (0 = final)
108
+ # max_samples : cap number of training examples
109
+ #
110
+ # Returns:
111
+ # hidden_matrix : np.ndarray [N, d_model]
112
+ # quality_scores: np.ndarray [N] values in [0, 1]
113
+ # """
114
+ # inner = model.model
115
+ # T = inner.scheduler.num_timesteps
116
+ # device = next(inner.parameters()).device
117
+ #
118
+ # hidden_list = []
119
+ # quality_list = []
120
+ # n = min(len(src_list), max_samples)
121
+ #
122
+ # def cer(pred, ref):
123
+ # if not ref:
124
+ # return 1.0
125
+ # def ed(s1, s2):
126
+ # m, n = len(s1), len(s2)
127
+ # dp = list(range(n + 1))
128
+ # for i in range(1, m + 1):
129
+ # prev, dp[0] = dp[0], i
130
+ # for j in range(1, n + 1):
131
+ # temp = dp[j]
132
+ # dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
133
+ # prev = temp
134
+ # return dp[n]
135
+ # return ed(pred, ref) / max(len(ref), 1)
136
+ #
137
+ # print(f"Collecting quality data from {n} examples...")
138
+ # for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
139
+ # if i % 200 == 0:
140
+ # print(f" {i}/{n}")
141
+ #
142
+ # if src.dim() == 1:
143
+ # src = src.unsqueeze(0)
144
+ # src = src.to(device)
145
+ #
146
+ # B = src.shape[0]
147
+ # tgt_len = inner.max_seq_len
148
+ # mask_id = inner.mask_token_id
149
+ #
150
+ # memory, src_pad_mask = inner.encode_source(src)
151
+ # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
152
+ # hint = None
153
+ # h_cap = None
154
+ #
155
+ # for t_val in range(T - 1, -1, -1):
156
+ # t = torch.full((B,), t_val, dtype=torch.long, device=device)
157
+ # is_last = (t_val == 0)
158
+ #
159
+ # logits, _ = inner.forward_cached(
160
+ # memory, src_pad_mask, x0_est, t,
161
+ # x0_hint=hint, inference_mode=True,
162
+ # )
163
+ #
164
+ # if t_val == t_capture and hasattr(inner, '_last_hidden'):
165
+ # h_cap = inner._last_hidden[0].mean(dim=0).detach().cpu() # [d_model]
166
+ #
167
+ # logits = logits / max(temperature, 1e-8)
168
+ # if top_k > 0:
169
+ # V = logits.shape[-1]
170
+ # if top_k < V:
171
+ # vals, _ = torch.topk(logits, top_k, dim=-1)
172
+ # logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
173
+ #
174
+ # probs = F.softmax(logits, dim=-1)
175
+ # x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
176
+ # hint = x0_est
177
+ #
178
+ # if h_cap is None:
179
+ # continue
180
+ #
181
+ # ids = [x for x in x0_est[0].tolist() if x > 4]
182
+ # pred = tgt_tokenizer.decode(ids).strip()
183
+ # q = max(0.0, 1.0 - cer(pred, ref)) # quality = 1 - CER
184
+ #
185
+ # hidden_list.append(h_cap.numpy())
186
+ # quality_list.append(q)
187
+ #
188
+ # print(f"Collected {len(hidden_list)} quality examples.")
189
+ # print(f"Quality stats: mean={np.mean(quality_list):.3f} "
190
+ # f"min={np.min(quality_list):.3f} max={np.max(quality_list):.3f}")
191
+ #
192
+ # return np.stack(hidden_list), np.array(quality_list, dtype=np.float32)
193
+ #
194
+ #
195
+ # def _sample(probs):
196
+ # B, L, V = probs.shape
197
+ # flat = probs.view(B * L, V).clamp(min=1e-9)
198
+ # flat = flat / flat.sum(dim=-1, keepdim=True)
199
+ # return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
200
+ #
201
+ #
202
+ # # ── Training ──────────────────────────────────────────────────────────
203
+ #
204
+ # def train_quality_classifier(
205
+ # hidden_matrix: np.ndarray,
206
+ # quality_scores: np.ndarray,
207
+ # d_model: int,
208
+ # epochs: int = 30,
209
+ # batch_size: int = 64,
210
+ # lr: float = 1e-3,
211
+ # val_frac: float = 0.1,
212
+ # save_path: Optional[str] = None,
213
+ # ) -> QualityClassifier:
214
+ # """
215
+ # Train QualityClassifier on collected (hidden, quality) pairs.
216
+ #
217
+ # Args:
218
+ # hidden_matrix : [N, d_model] from collect_quality_data()
219
+ # quality_scores : [N] quality labels in [0, 1]
220
+ # d_model : hidden dimension
221
+ # epochs : training epochs
222
+ # save_path : if given, save trained classifier weights here
223
+ #
224
+ # Returns:
225
+ # trained QualityClassifier
226
+ # """
227
+ # device = torch.device("cpu") # classifier is tiny, CPU is fine
228
+ #
229
+ # X = torch.tensor(hidden_matrix, dtype=torch.float32)
230
+ # y = torch.tensor(quality_scores, dtype=torch.float32).unsqueeze(-1)
231
+ #
232
+ # N = len(X)
233
+ # n_val = max(1, int(N * val_frac))
234
+ # idx = torch.randperm(N)
235
+ # val_idx = idx[:n_val]
236
+ # train_idx = idx[n_val:]
237
+ #
238
+ # X_train, y_train = X[train_idx], y[train_idx]
239
+ # X_val, y_val = X[val_idx], y[val_idx]
240
+ #
241
+ # clf = QualityClassifier(d_model).to(device)
242
+ # optimizer = torch.optim.Adam(clf.parameters(), lr=lr)
243
+ #
244
+ # print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
245
+ # print(f"Train: {len(X_train)} Val: {len(X_val)}")
246
+ #
247
+ # best_val_loss = float('inf')
248
+ # best_state = None
249
+ #
250
+ # for epoch in range(epochs):
251
+ # clf.train()
252
+ # perm = torch.randperm(len(X_train))
253
+ # train_loss = 0.0
254
+ # n_batches = 0
255
+ #
256
+ # for start in range(0, len(X_train), batch_size):
257
+ # batch_idx = perm[start:start + batch_size]
258
+ # xb, yb = X_train[batch_idx], y_train[batch_idx]
259
+ # pred = clf(xb)
260
+ # loss = F.mse_loss(pred, yb)
261
+ # optimizer.zero_grad()
262
+ # loss.backward()
263
+ # optimizer.step()
264
+ # train_loss += loss.item()
265
+ # n_batches += 1
266
+ #
267
+ # clf.eval()
268
+ # with torch.no_grad():
269
+ # val_pred = clf(X_val)
270
+ # val_loss = F.mse_loss(val_pred, y_val).item()
271
+ #
272
+ # if epoch % 5 == 0 or epoch == epochs - 1:
273
+ # print(f" Ep {epoch+1:3d} train={train_loss/n_batches:.4f} val={val_loss:.4f}")
274
+ #
275
+ # if val_loss < best_val_loss:
276
+ # best_val_loss = val_loss
277
+ # best_state = {k: v.clone() for k, v in clf.state_dict().items()}
278
+ #
279
+ # if best_state:
280
+ # clf.load_state_dict(best_state)
281
+ # print(f" Best val loss: {best_val_loss:.4f}")
282
+ #
283
+ # if save_path:
284
+ # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
285
+ # torch.save(clf.state_dict(), save_path)
286
+ # print(f" Classifier saved: {save_path}")
287
+ #
288
+ # return clf
289
+ #
290
+ #
291
+ # # ── Guided inference ──────────────────────────────────────────────────
292
+ #
293
+ # def generate_guided(
294
+ # model,
295
+ # src: torch.Tensor,
296
+ # classifier: QualityClassifier,
297
+ # guidance_scale: float = 1.0,
298
+ # temperature: float = 0.8,
299
+ # top_k: int = 40,
300
+ # ) -> torch.Tensor:
301
+ # """
302
+ # Classifier-guided generation.
303
+ #
304
+ # At each diffusion step:
305
+ # 1. Run forward_cached() → logits, hidden states
306
+ # 2. Compute classifier gradient: ∂(quality_score) / ∂(hidden)
307
+ # 3. Project gradient back to logit space (approximate)
308
+ # 4. guided_logits = logits + λ * gradient_signal
309
+ # 5. Sample from guided_logits
310
+ #
311
+ # guidance_scale λ:
312
+ # 0.0 → no guidance (standard generation)
313
+ # 0.5 → weak guidance
314
+ # 1.0 → moderate guidance (recommended starting point)
315
+ # 2.0 → strong guidance (may reduce diversity)
316
+ # 3.0 → very strong (may collapse to repetitive output)
317
+ #
318
+ # Args:
319
+ # model : SanskritModel (frozen)
320
+ # src : [1, src_len] IAST token ids
321
+ # classifier : trained QualityClassifier
322
+ # guidance_scale : λ — guidance strength
323
+ #
324
+ # Returns:
325
+ # x0_est : [1, tgt_len] generated token ids
326
+ # """
327
+ # inner = model.model
328
+ # T = inner.scheduler.num_timesteps
329
+ # device = next(inner.parameters()).device
330
+ # clf_device = next(classifier.parameters()).device
331
+ #
332
+ # if src.dim() == 1:
333
+ # src = src.unsqueeze(0)
334
+ # src = src.to(device)
335
+ #
336
+ # B = src.shape[0]
337
+ # tgt_len = inner.max_seq_len
338
+ # mask_id = inner.mask_token_id
339
+ #
340
+ # memory, src_pad_mask = inner.encode_source(src)
341
+ # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
342
+ # hint = None
343
+ #
344
+ # inner.eval()
345
+ # classifier.eval()
346
+ #
347
+ # for t_val in range(T - 1, -1, -1):
348
+ # t = torch.full((B,), t_val, dtype=torch.long, device=device)
349
+ # is_last = (t_val == 0)
350
+ #
351
+ # if guidance_scale > 0.0:
352
+ # # Need gradients for classifier guidance
353
+ # with torch.enable_grad():
354
+ # # Run forward_cached and get hidden states
355
+ # PAD = 1
356
+ # if t_val > 0:
357
+ # _, x_t_ids = inner.forward_process.q_sample(x0_est, t)
358
+ # else:
359
+ # x_t_ids = x0_est
360
+ #
361
+ # x = inner.tgt_embed(x_t_ids)
362
+ # t_norm = t.float() / T
363
+ # t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
364
+ # x = x + t_emb.unsqueeze(1)
365
+ #
366
+ # if hint is not None:
367
+ # hint_emb = inner.tgt_embed(hint)
368
+ # gate = inner.hint_gate(x)
369
+ # x = x + gate * hint_emb
370
+ #
371
+ # for block in inner.decoder_blocks:
372
+ # x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
373
+ #
374
+ # # hidden: [B, tgt_len, d_model] — detach from graph for clf
375
+ # hidden = x.detach().requires_grad_(True).to(clf_device)
376
+ #
377
+ # # Classifier quality score
378
+ # quality = classifier(hidden) # [B, 1]
379
+ # quality.sum().backward()
380
+ #
381
+ # # Gradient of quality w.r.t. hidden: [B, tgt_len, d_model]
382
+ # grad = hidden.grad.to(device) # [B, tgt_len, d_model]
383
+ #
384
+ # # Project gradient to logit space via output head weight
385
+ # # logit_grad ≈ grad @ head.weight [B, tgt_len, tgt_vocab]
386
+ # logit_grad = grad @ inner.head.weight.T
387
+ #
388
+ # # Compute standard logits (no gradient needed)
389
+ # with torch.no_grad():
390
+ # logits = inner.head(x)
391
+ #
392
+ # # Apply guidance
393
+ # logits = logits + guidance_scale * logit_grad
394
+ #
395
+ # else:
396
+ # with torch.no_grad():
397
+ # logits, _ = inner.forward_cached(
398
+ # memory, src_pad_mask, x0_est, t,
399
+ # x0_hint=hint, inference_mode=True,
400
+ # )
401
+ #
402
+ # with torch.no_grad():
403
+ # logits = logits / max(temperature, 1e-8)
404
+ # if top_k > 0:
405
+ # V = logits.shape[-1]
406
+ # if top_k < V:
407
+ # vals, _ = torch.topk(logits, top_k, dim=-1)
408
+ # logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
409
+ #
410
+ # probs = F.softmax(logits, dim=-1)
411
+ # x0_est = torch.argmax(probs, dim=-1) if is_last else _sample_no_grad(probs)
412
+ # hint = x0_est
413
+ #
414
+ # return x0_est
415
+ #
416
+ #
417
+ # def _sample_no_grad(probs):
418
+ # B, L, V = probs.shape
419
+ # flat = probs.view(B * L, V).clamp(min=1e-9)
420
+ # flat = flat / flat.sum(dim=-1, keepdim=True)
421
+ # return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
422
+ #
423
+ #
424
+ # # ── Guidance scale sweep ──────────────────────────────────────────────
425
+ #
426
+ # def sweep_guidance_scales(
427
+ # model,
428
+ # classifier: QualityClassifier,
429
+ # src_list: List[torch.Tensor],
430
+ # ref_list: List[str],
431
+ # tgt_tokenizer,
432
+ # scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
433
+ # n_samples: int = 50,
434
+ # device: torch.device = None,
435
+ # output_dir: str = "analysis/outputs",
436
+ # ) -> Dict:
437
+ # """
438
+ # Evaluate CER at each guidance scale.
439
+ # Produces quality-diversity tradeoff plot.
440
+ # """
441
+ # def cer(pred, ref):
442
+ # if not ref:
443
+ # return 1.0
444
+ # def ed(s1, s2):
445
+ # m, n = len(s1), len(s2)
446
+ # dp = list(range(n + 1))
447
+ # for i in range(1, m + 1):
448
+ # prev, dp[0] = dp[0], i
449
+ # for j in range(1, n + 1):
450
+ # temp = dp[j]
451
+ # dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
452
+ # prev = temp
453
+ # return dp[n]
454
+ # return ed(pred, ref) / max(len(ref), 1)
455
+ #
456
+ # device = device or next(model.parameters()).device
457
+ # results = {}
458
+ # n = min(n_samples, len(src_list))
459
+ #
460
+ # print("\nGuidance scale sweep...")
461
+ # for scale in scales:
462
+ # cer_list = []
463
+ # output_set = []
464
+ # for src, ref in zip(src_list[:n], ref_list[:n]):
465
+ # if src.dim() == 1:
466
+ # src = src.unsqueeze(0)
467
+ # out = generate_guided(model, src.to(device), classifier,
468
+ # guidance_scale=scale)
469
+ # ids = [x for x in out[0].tolist() if x > 4]
470
+ # pred = tgt_tokenizer.decode(ids).strip()
471
+ # cer_list.append(cer(pred, ref))
472
+ # output_set.append(pred)
473
+ #
474
+ # mean_cer = float(np.mean(cer_list))
475
+ #
476
+ # # Self-diversity: unique outputs / total (proxy for diversity)
477
+ # unique_frac = len(set(output_set)) / max(len(output_set), 1)
478
+ #
479
+ # results[scale] = {"mean_cer": mean_cer, "diversity": unique_frac}
480
+ # print(f" λ={scale:.1f} CER={mean_cer:.4f} diversity={unique_frac:.3f}")
481
+ #
482
+ # # Plot
483
+ # os.makedirs(output_dir, exist_ok=True)
484
+ # try:
485
+ # import matplotlib.pyplot as plt
486
+ # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
487
+ #
488
+ # sc_list = sorted(results.keys())
489
+ # cers = [results[s]["mean_cer"] for s in sc_list]
490
+ # diversities = [results[s]["diversity"] for s in sc_list]
491
+ #
492
+ # ax1.plot(sc_list, cers, 'o-', color='coral', linewidth=1.8, markersize=7)
493
+ # ax1.set_xlabel("Guidance scale λ", fontsize=10)
494
+ # ax1.set_ylabel("CER (↓ better)", fontsize=10)
495
+ # ax1.set_title("Quality vs guidance scale", fontsize=10)
496
+ #
497
+ # ax2.plot(sc_list, diversities, 'o-', color='steelblue', linewidth=1.8, markersize=7)
498
+ # ax2.set_xlabel("Guidance scale λ", fontsize=10)
499
+ # ax2.set_ylabel("Output diversity (unique fraction)", fontsize=10)
500
+ # ax2.set_title("Diversity vs guidance scale", fontsize=10)
501
+ #
502
+ # plt.suptitle("Quality-Diversity Tradeoff (Guidance Scale Sweep)", fontsize=11)
503
+ # plt.tight_layout()
504
+ # path = os.path.join(output_dir, "guidance_scale_sweep.png")
505
+ # plt.savefig(path, dpi=150, bbox_inches='tight')
506
+ # plt.close()
507
+ # print(f" Saved: {path}")
508
+ # except ImportError:
509
+ # pass
510
+ #
511
+ # with open(os.path.join(output_dir, "guidance_results.json"), "w") as f:
512
+ # json.dump({str(k): v for k, v in results.items()}, f, indent=2)
513
+ #
514
+ # return results
515
+ import torch
516
+ import torch.nn as nn
517
+ import torch.nn.functional as F
518
+ import numpy as np
519
+ from typing import List, Dict
520
+
521
+
522
+ # ============================================================
523
+ # 1. QUALITY CLASSIFIER
524
+ # ============================================================
525
+
526
+ class QualityClassifier(nn.Module):
527
+ def __init__(self, d_model: int):
528
+ super().__init__()
529
+ self.net = nn.Sequential(
530
+ nn.Linear(d_model, 128),
531
+ nn.ReLU(),
532
+ nn.Dropout(0.1),
533
+ nn.Linear(128, 64),
534
+ nn.ReLU(),
535
+ nn.Linear(64, 1),
536
+ nn.Sigmoid(),
537
+ )
538
+
539
+ def forward(self, hidden):
540
+ if hidden.dim() == 3:
541
+ hidden = hidden.mean(dim=1)
542
+ return self.net(hidden)
543
+
544
+
545
+ # ============================================================
546
+ # 2. GUIDED GENERATION (CORRECTED)
547
+ # ============================================================
548
+
549
+ @torch.no_grad()
550
+ def generate_guided(
551
+ model,
552
+ src: torch.Tensor,
553
+ classifier: QualityClassifier,
554
+ guidance_scale: float = 1.0,
555
+ temperature: float = 0.8,
556
+ top_k: int = 40,
557
+ ):
558
+ inner = model.model
559
+ T = inner.scheduler.num_timesteps
560
+ device = next(inner.parameters()).device
561
+
562
+ if src.dim() == 1:
563
+ src = src.unsqueeze(0)
564
+ src = src.to(device)
565
+
566
+ B = src.shape[0]
567
+ tgt_len = inner.max_seq_len
568
+ mask_id = inner.mask_token_id
569
+
570
+ # KV CACHE
571
+ memory, src_pad_mask = inner.encode_source(src)
572
+
573
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
574
+ hint = None
575
+
576
+ inner.eval()
577
+ classifier.eval()
578
+
579
+ for t_val in range(T - 1, -1, -1):
580
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
581
+ is_last = (t_val == 0)
582
+
583
+ if guidance_scale > 0:
584
+
585
+ # ENABLE GRAD FOR GUIDANCE
586
+ with torch.enable_grad():
587
+
588
+ if t_val > 0:
589
+ _, x_t_ids = inner.forward_process.q_sample(x0_est, t)
590
+ else:
591
+ x_t_ids = x0_est
592
+
593
+ x = inner.tgt_embed(x_t_ids)
594
+
595
+ # time embedding
596
+ t_norm = t.float() / T
597
+ t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
598
+ x = x + t_emb.unsqueeze(1)
599
+
600
+ # hint conditioning
601
+ if hint is not None:
602
+ hint_emb = inner.tgt_embed(hint)
603
+ gate = inner.hint_gate(x)
604
+ x = x + gate * hint_emb
605
+
606
+ # decoder forward
607
+ for block in inner.decoder_blocks:
608
+ x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
609
+
610
+ # IMPORTANT: NO DETACH HERE
611
+ hidden = x.requires_grad_(True)
612
+
613
+ # classifier forward
614
+ quality = classifier(hidden) # [B,1]
615
+
616
+ # compute gradient
617
+ quality.sum().backward()
618
+
619
+ grad = hidden.grad # [B, L, d_model]
620
+
621
+ # ===== FIX 1: Normalize gradient =====
622
+ grad_norm = grad.norm(dim=-1, keepdim=True) + 1e-6
623
+ grad = grad / grad_norm
624
+
625
+ # ===== FIX 2: Project to logit space =====
626
+ logit_grad = torch.matmul(grad, inner.head.weight.T)
627
+
628
+ # ===== FIX 3: Clip gradient =====
629
+ logit_grad = torch.clamp(logit_grad, -5.0, 5.0)
630
+
631
+ # compute logits (no grad)
632
+ with torch.no_grad():
633
+ logits = inner.head(x)
634
+
635
+ # apply guidance
636
+ logits = logits + guidance_scale * logit_grad
637
+
638
+ else:
639
+ with torch.no_grad():
640
+ logits, _ = inner.forward_cached(
641
+ memory, src_pad_mask, x0_est, t,
642
+ x0_hint=hint,
643
+ inference_mode=True,
644
+ )
645
+
646
+ # ===== Sampling =====
647
+ logits = logits / max(temperature, 1e-8)
648
+
649
+ if top_k > 0:
650
+ V = logits.shape[-1]
651
+ if top_k < V:
652
+ vals, _ = torch.topk(logits, top_k, dim=-1)
653
+ logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
654
+
655
+ probs = F.softmax(logits, dim=-1)
656
+
657
+ if is_last:
658
+ x0_est = torch.argmax(probs, dim=-1)
659
+ else:
660
+ x0_est = _sample(probs)
661
+
662
+ hint = x0_est
663
+
664
+ return x0_est
665
+
666
+
667
+ def _sample(probs):
668
+ B, L, V = probs.shape
669
+ flat = probs.view(B * L, V).clamp(min=1e-9)
670
+ flat = flat / flat.sum(dim=-1, keepdim=True)
671
+ return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
672
+
673
+
674
+ # ============================================================
675
+ # 3. GUIDANCE SWEEP (EVALUATION)
676
+ # ============================================================
677
+
678
+ def sweep_guidance(
679
+ model,
680
+ classifier,
681
+ src_list,
682
+ ref_list,
683
+ tgt_tokenizer,
684
+ scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
685
+ n_samples=50,
686
+ ):
687
+ def cer(pred, ref):
688
+ if not ref:
689
+ return 1.0
690
+ dp = list(range(len(ref) + 1))
691
+ for i in range(1, len(pred) + 1):
692
+ prev, dp[0] = dp[0], i
693
+ for j in range(1, len(ref) + 1):
694
+ temp = dp[j]
695
+ dp[j] = prev if pred[i-1] == ref[j-1] else 1 + min(prev, dp[j], dp[j-1])
696
+ prev = temp
697
+ return dp[-1] / max(len(ref), 1)
698
+
699
+ results = {}
700
+
701
+ for scale in scales:
702
+ cer_list = []
703
+ outputs = []
704
+
705
+ for src, ref in zip(src_list[:n_samples], ref_list[:n_samples]):
706
+ if src.dim() == 1:
707
+ src = src.unsqueeze(0)
708
+
709
+ out = generate_guided(model, src, classifier, scale)
710
+ ids = [x for x in out[0].tolist() if x > 4]
711
+ pred = tgt_tokenizer.decode(ids).strip()
712
+
713
+ cer_list.append(cer(pred, ref))
714
+ outputs.append(pred)
715
+
716
+ results[scale] = {
717
+ "CER": float(np.mean(cer_list)),
718
+ "diversity": len(set(outputs)) / len(outputs)
719
+ }
720
+
721
+ print(f"λ={scale:.1f} | CER={results[scale]['CER']:.4f} | diversity={results[scale]['diversity']:.3f}")
722
+
723
+ return results
analysis/reports/README.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Analysis Reports
2
+
3
+ This folder contains mentor-facing writeups for the five analysis tasks:
4
+
5
+ - [Task 1](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/reports/task1_kv_cache_report.md)
6
+ - [Task 2](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/reports/task2_attention_drift_report.md)
7
+ - [Task 3](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/reports/task3_concept_vectors_report.md)
8
+ - [Task 4](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/reports/task4_step_ablation_report.md)
9
+ - [Task 5](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/reports/task5_quality_guidance_report.md)
10
+
11
+ These reports are written for evaluation use. They include:
12
+
13
+ - objective
14
+ - implementation summary
15
+ - code snippet
16
+ - result status
17
+ - benefits
18
+ - limitations
19
+ - conclusion
analysis/reports/task1_kv_cache_report.md ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Task 1 Report: KV Cache Benchmark
2
+
3
+ ## 1. Objective
4
+
5
+ The purpose of Task 1 is to measure whether encoder-side key/value caching improves inference speed for the cross-attention D3PM paraphrase model. In the unoptimized version, the source sequence is re-encoded at every diffusion step. In the cached version, the source is encoded once and reused for all denoising steps.
6
+
7
+ This task is useful for mentor evaluation because it measures an engineering improvement directly tied to deployment cost. Even when model quality is unchanged, lower generation latency improves usability for experimentation, batch evaluation, and interactive inference.
8
+
9
+ ## 2. Implementation Approach
10
+
11
+ The benchmark is implemented in [analysis/kv_cache_benchmark.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/kv_cache_benchmark.py). To support it, the cross-attention model was extended with three helper methods in [model/d3pm_model_cross_attention.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/model/d3pm_model_cross_attention.py):
12
+
13
+ - `encode_source(...)`
14
+ - `forward_cached(...)`
15
+ - `generate_cached(...)`
16
+
17
+ These methods separate source encoding from decoder-side denoising, which is the standard way to benchmark KV caching in encoder-decoder style architectures.
18
+
19
+ ### Core Implementation Snippet
20
+
21
+ ```python
22
+ def encode_source(self, src):
23
+ PAD = 1
24
+ src_pad_mask = (src == PAD)
25
+ memory = self.src_embed(src)
26
+ for block in self.encoder_blocks:
27
+ memory = block(memory, pad_mask=src_pad_mask)
28
+ return memory, src_pad_mask
29
+
30
+ def forward_cached(self, memory, src_pad_mask, tgt, t, x0_hint=None, inference_mode=False):
31
+ ...
32
+ for block in self.decoder_blocks:
33
+ x = block(x, memory, tgt_pad_mask=tgt_pad_mask, src_pad_mask=src_pad_mask)
34
+ self._last_hidden = x.detach()
35
+ return self.head(x), None
36
+ ```
37
+
38
+ This design avoids recomputing the encoder stack at each diffusion step.
39
+
40
+ ## 3. Experimental Setup
41
+
42
+ The benchmark was run using the Task 1 entry point:
43
+
44
+ ```bash
45
+ uv run --active analysis/run_analysis.py --task 1
46
+ ```
47
+
48
+ The script tests source lengths of 16, 32, and 64 tokens and reports:
49
+
50
+ - standard generation time
51
+ - cached generation time
52
+ - speedup ratio
53
+ - estimated encoder cost as a percentage of one forward pass
54
+
55
+ The benchmark output is stored in [analysis/outputs/task1_kv_cache.txt](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task1_kv_cache.txt).
56
+
57
+ ## 4. Results
58
+
59
+ Observed benchmark values:
60
+
61
+ | Source Length | Standard (s) | Cached (s) | Speedup | Encoder % |
62
+ | --- | ---: | ---: | ---: | ---: |
63
+ | 16 | 1.784 | 1.780 | 1.00x | 42.7% |
64
+ | 32 | 2.055 | 1.850 | 1.11x | 41.9% |
65
+ | 64 | 1.724 | 1.608 | 1.07x | 43.2% |
66
+
67
+ The main outcome is that caching works correctly and provides a measurable speed improvement, though the improvement is modest on the current hardware and runtime stack.
68
+
69
+ ## 5. Interpretation
70
+
71
+ The result is technically correct and useful, but it should be positioned carefully in evaluation:
72
+
73
+ - This is a systems optimization result, not a model quality result.
74
+ - The speedup is real, but not dramatic.
75
+ - The benchmark confirms that source-side recomputation can be removed without changing the inference algorithm.
76
+
77
+ For mentor evaluation, this can be presented as a successful engineering optimization with limited but positive runtime impact.
78
+
79
+ ## 6. Benefits
80
+
81
+ Benefits of this task:
82
+
83
+ - reduces redundant encoder computation
84
+ - provides a reusable cached inference path for later analysis tasks
85
+ - improves scalability for repeated generation and diagnostic probes
86
+ - establishes infrastructure for attention and hidden-state inspection
87
+
88
+ ## 7. Limitations
89
+
90
+ The result should not be overstated:
91
+
92
+ - speedup depends heavily on hardware and backend
93
+ - current gains are relatively small
94
+ - more stable benchmarking would require repeated runs and device-specific profiling
95
+ - this does not improve semantic accuracy directly
96
+
97
+ ## 8. Conclusion
98
+
99
+ Task 1 is valid and suitable for mentor evaluation as an implementation-focused result. It demonstrates that cached inference was successfully added to the D3PM cross-attention model and that it reduces generation cost modestly. The strongest value of this task is architectural: it enables faster repeated inference and supports later interpretability experiments.
analysis/reports/task2_attention_drift_report.md ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Task 2 Report: Attention Visualization and Semantic Drift
2
+
3
+ ## 1. Objective
4
+
5
+ Task 2 investigates how the diffusion model behaves internally during generation. It has two goals:
6
+
7
+ - capture cross-attention patterns between source and generated target tokens
8
+ - measure how intermediate generations converge toward the final output over diffusion steps
9
+
10
+ This task is important for evaluation because it gives interpretability evidence. Instead of only showing the final prediction, it examines whether the model gradually stabilizes its output and whether attention is distributed in a meaningful way.
11
+
12
+ ## 2. Implementation Approach
13
+
14
+ The implementation uses two analysis modules:
15
+
16
+ - [analysis/attention_viz.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/attention_viz.py)
17
+ - [analysis/semantic_drift.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/semantic_drift.py)
18
+
19
+ To support this, the cross-attention layer stores attention weights during decoding. The model also exposes a cached inference path so per-step diagnostics can be collected efficiently.
20
+
21
+ ### Attention Capture Snippet
22
+
23
+ ```python
24
+ class MultiHeadAttention(nn.Module):
25
+ def __init__(self, d_model, n_heads, dropout=0.1):
26
+ ...
27
+ self.capture_weights = False
28
+ self.last_attn_weights = None
29
+
30
+ def forward(self, q, k, v, mask=None):
31
+ ...
32
+ attn = self.dropout(torch.softmax(scores, dim=-1))
33
+ if self.capture_weights:
34
+ self.last_attn_weights = attn.detach().cpu()
35
+ ```
36
+
37
+ ### Drift Computation Snippet
38
+
39
+ ```python
40
+ def compute_drift(step_outputs, final_output):
41
+ t_vals = sorted(step_outputs.keys(), reverse=True)
42
+ cer_to_final = []
43
+ for t_val in t_vals:
44
+ cer = compute_cer_between(step_outputs[t_val], final_output)
45
+ cer_to_final.append(cer)
46
+ ```
47
+
48
+ The metric used is character error rate between each intermediate output and the final output.
49
+
50
+ ## 3. Experimental Setup
51
+
52
+ The task was run with:
53
+
54
+ ```bash
55
+ uv run --active analysis/run_analysis.py --task 2 --input "dharmo rakṣati rakṣitaḥ"
56
+ ```
57
+
58
+ Generated outputs:
59
+
60
+ - [analysis/outputs/task2_attn_t127.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_attn_t127.png)
61
+ - [analysis/outputs/task2_attn_t0.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_attn_t0.png)
62
+ - [analysis/outputs/task2_all_layers_t0.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_all_layers_t0.png)
63
+ - [analysis/outputs/task2_attn_evolution.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_attn_evolution.png)
64
+ - [analysis/outputs/task2_semantic_drift.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_semantic_drift.png)
65
+ - [analysis/outputs/task2_report.txt](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task2_report.txt)
66
+
67
+ ## 4. Results
68
+
69
+ The saved report shows:
70
+
71
+ - lock-in timestep: `t = 22`
72
+ - mean token-position lock-in: `53.6 ± 28.4`
73
+
74
+ This indicates that the generated sequence becomes relatively stable before the final denoising step. In other words, the model is not making all of its decisions only at the very end.
75
+
76
+ However, the actual generated Sanskrit output is low quality and strongly repetitive. That matters for interpretation: the drift curve is still valid as a measure of convergence, but it is convergence toward a weak final output.
77
+
78
+ ## 5. Interpretation
79
+
80
+ For mentor evaluation, this task should be presented as a diagnostic analysis rather than a quality claim.
81
+
82
+ What the task supports:
83
+
84
+ - the model’s output evolves gradually over time
85
+ - the diffusion process shows an identifiable stabilization region
86
+ - attention weights can now be inspected layer by layer
87
+
88
+ What the task does not yet support:
89
+
90
+ - strong semantic alignment
91
+ - trustworthy linguistic paraphrase quality
92
+ - meaningful claim that attention maps correspond to correct Sanskrit transformation
93
+
94
+ ## 6. Benefits
95
+
96
+ This task has practical value even with imperfect outputs:
97
+
98
+ - helps identify when the model stabilizes
99
+ - supports debugging of the denoising trajectory
100
+ - provides visual artifacts for discussing model internals
101
+ - can guide reduction of unnecessary inference steps in future work
102
+
103
+ ## 7. Limitations
104
+
105
+ There are two important limitations:
106
+
107
+ 1. The output quality is weak, so the interpretability evidence is about model behavior, not model correctness.
108
+ 2. Matplotlib on the current machine does not render Devanagari fonts well, so the generated figures contain font warnings and may not display labels cleanly.
109
+
110
+ ## 8. Conclusion
111
+
112
+ Task 2 is partially suitable for evaluation. It is strong as an interpretability and debugging report, but weak as proof of semantic paraphrase quality. For mentor review, it should be framed as evidence that the diffusion generation process can now be inspected and analyzed step by step.
analysis/reports/task3_concept_vectors_report.md ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Task 3 Report: Concept Vectors and PCA-Based Steering
2
+
3
+ ## 1. Objective
4
+
5
+ Task 3 explores whether decoder hidden states contain a measurable direction corresponding to paraphrase diversity. The idea is:
6
+
7
+ 1. collect hidden states from many validation samples
8
+ 2. fit PCA to the hidden-state space
9
+ 3. find a principal direction correlated with output diversity
10
+ 4. steer generation along that direction
11
+
12
+ This is an advanced representation-learning experiment. Its value for mentor evaluation lies in showing that the project is not limited to training and inference, but also investigates controllable generation.
13
+
14
+ ## 2. Implementation Approach
15
+
16
+ The implementation is in [analysis/concept_vectors.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/concept_vectors.py). Hidden states are captured from the decoder during cached inference and pooled across sequence positions.
17
+
18
+ ### PCA Fitting Snippet
19
+
20
+ ```python
21
+ def fit_pca(hidden_matrix, n_components=50):
22
+ from sklearn.decomposition import PCA
23
+ n_comp = min(n_components, hidden_matrix.shape[0] - 1, hidden_matrix.shape[1])
24
+ pca = PCA(n_components=n_comp)
25
+ pca.fit(hidden_matrix)
26
+ return pca
27
+ ```
28
+
29
+ ### Steering Snippet
30
+
31
+ ```python
32
+ if alpha != 0.0:
33
+ x = x + alpha * dir_tensor.unsqueeze(0).unsqueeze(0)
34
+
35
+ logits = inner.head(x)
36
+ ```
37
+
38
+ The steering mechanism adds a learned direction in hidden-state space before projection to logits.
39
+
40
+ ## 3. Experimental Setup
41
+
42
+ Task 3 was run from the shared analysis driver and generated:
43
+
44
+ - [analysis/outputs/task3_concept_space.png](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task3_concept_space.png)
45
+ - [analysis/outputs/task3_diversity_direction.npy](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task3_diversity_direction.npy)
46
+ - [analysis/outputs/task3_report.txt](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task3_report.txt)
47
+
48
+ The run used 500 validation examples for hidden-state extraction.
49
+
50
+ ## 4. Results
51
+
52
+ Observed summary:
53
+
54
+ - PCA components retained: `50`
55
+ - total explained variance: `96.1%`
56
+ - selected diversity principal component: `PC 1`
57
+ - absolute correlation with output length: `0.303`
58
+
59
+ On paper, these values suggest that hidden-state variation is structured and that at least one direction correlates with output-length changes. That is a positive sign from a representation-analysis standpoint.
60
+
61
+ However, the actual diversity spectrum outputs are not semantically convincing. The steered generations are highly repetitive and mostly malformed token sequences rather than clear paraphrases with controlled variation.
62
+
63
+ ## 5. Interpretation
64
+
65
+ This task should be presented carefully.
66
+
67
+ What is supported:
68
+
69
+ - hidden states are rich enough for PCA analysis
70
+ - the representation space is not random noise
71
+ - controllable steering infrastructure has been implemented successfully
72
+
73
+ What is not yet supported:
74
+
75
+ - interpretable semantic control
76
+ - high-quality paraphrase diversity
77
+ - evidence that the identified direction reflects useful linguistic variation
78
+
79
+ For mentor evaluation, this is best framed as a promising exploratory experiment rather than a finished result.
80
+
81
+ ## 6. Benefits
82
+
83
+ Benefits of the task include:
84
+
85
+ - opens a path toward controllable paraphrase generation
86
+ - demonstrates hidden-state instrumentation beyond standard inference
87
+ - provides a research direction for future work on style and diversity control
88
+ - connects model analysis with possible user-facing controllability
89
+
90
+ ## 7. Limitations
91
+
92
+ The main limitation is output quality. Even though the PCA statistics look reasonable, the steered generations are not linguistically strong enough to claim meaningful semantic control. This makes the current result more useful as a prototype than as a validated research finding.
93
+
94
+ ## 8. Conclusion
95
+
96
+ Task 3 is not yet strong enough as a final evaluation result, but it is valuable as research evidence of advanced model analysis. For mentor discussion, it should be described as an experimental controllability framework that has been implemented successfully but still requires better base model quality before the steering outputs become persuasive.
analysis/reports/task4_step_ablation_report.md ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Task 4 Report: Diffusion Step Ablation
2
+
3
+ ## 1. Objective
4
+
5
+ Task 4 studies how the number of diffusion steps affects meaning preservation, speed, and robustness. The hypothesis is that fewer denoising steps may improve speed, but too few steps may reduce output quality. This type of ablation is important for mentor evaluation because it tests a core design parameter of the D3PM model.
6
+
7
+ Unlike the earlier tasks, this one requires retraining separate checkpoints for each step count. This is not optional. A model trained at `T=128` cannot be evaluated fairly at `T=4` or `T=8` without retraining, because the timestep distribution seen during training changes fundamentally.
8
+
9
+ ## 2. Implementation Approach
10
+
11
+ The implementation is in [analysis/step_ablation.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/step_ablation.py). I patched the workflow so it is safe for this repository:
12
+
13
+ - it no longer overwrites `config.py`
14
+ - it uses environment variables for `DIFFUSION_STEPS`
15
+ - each training run writes directly to `ablation_results/T*`
16
+
17
+ ### Training Script Generation Snippet
18
+
19
+ ```python
20
+ f.write(
21
+ f"MODEL_TYPE=\"$MODEL_TYPE\" INCLUDE_NEG=\"$INCLUDE_NEG\" "
22
+ f"TRAIN_DEVICE=\"$TRAIN_DEVICE\" "
23
+ f"DIFFUSION_STEPS={T} INFERENCE_NUM_STEPS={T} "
24
+ f"TRAIN_OUTPUT_DIR=\"ablation_results/T{T}\" "
25
+ f"python train.py\n\n"
26
+ )
27
+ ```
28
+
29
+ This makes the ablation workflow reproducible without mutating repository files between runs.
30
+
31
+ ## 3. Current Workflow
32
+
33
+ Task 4 now supports the following sequence:
34
+
35
+ ```bash
36
+ uv run --active analysis/run_analysis.py --task 4 --phase generate_configs
37
+ bash ablation_configs/train_all.sh
38
+ uv run --active analysis/run_analysis.py --task 4 --phase analyze
39
+ ```
40
+
41
+ Generated script:
42
+
43
+ - [ablation_configs/train_all.sh](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/ablation_configs/train_all.sh)
44
+
45
+ This script trains:
46
+
47
+ - `T=4`
48
+ - `T=8`
49
+ - `T=16`
50
+ - `T=32`
51
+ - `T=64`
52
+
53
+ with outputs saved to `ablation_results/T4`, `T8`, `T16`, `T32`, and `T64`.
54
+
55
+ ## 4. Current Result Status
56
+
57
+ At the moment, no trained ablation checkpoints exist in `ablation_results/T*/best_model.pt`. Therefore, the analysis phase has no quantitative result yet. That means Task 4 currently has a correct implementation pipeline, but not a completed experiment.
58
+
59
+ This distinction matters for evaluation:
60
+
61
+ - the workflow is correct
62
+ - the experiment has not yet produced final numbers
63
+
64
+ ## 5. Evaluation Value
65
+
66
+ For mentor evaluation, Task 4 can still be included, but it should be presented as:
67
+
68
+ - a completed experimental setup
69
+ - a validated retraining workflow
70
+ - pending final quantitative results
71
+
72
+ This is still useful because ablation design is part of research rigor. It shows that the project is set up to test the effect of a critical modeling choice instead of assuming the default step count is optimal.
73
+
74
+ ## 6. Benefits
75
+
76
+ Once the checkpoints are trained, this task will answer:
77
+
78
+ - how much generation speed improves as diffusion steps decrease
79
+ - how meaning preservation changes with fewer steps
80
+ - where the best quality-speed tradeoff lies
81
+ - whether the current choice of diffusion steps is over- or under-provisioned
82
+
83
+ ## 7. Limitations
84
+
85
+ The limitation is straightforward: there are no ablation checkpoints yet, so there are no real results to defend. It should not be presented as a finished evaluation experiment at this stage.
86
+
87
+ ## 8. Conclusion
88
+
89
+ Task 4 is structurally correct and now safe to run in this repository. It is suitable for mentor evaluation as an experimental design and workflow contribution, but not yet as a result section. The next milestone is to train the five ablation checkpoints and run the analysis phase to generate the actual CER-speed comparison.
analysis/reports/task5_quality_guidance_report.md ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Task 5 Report: Quality Classifier and Guidance-Based Decoding
2
+
3
+ ## 1. Objective
4
+
5
+ Task 5 attempts to guide generation using a lightweight quality classifier trained on decoder hidden states. The idea is to predict a quality score from hidden states and then use the classifier gradient to bias inference toward higher-quality outputs.
6
+
7
+ This is an ambitious extension because it adds a second learned component on top of the main D3PM model without retraining the core paraphrase model itself.
8
+
9
+ ## 2. Implementation Approach
10
+
11
+ The implementation is in [analysis/quality_classifier.py](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/quality_classifier.py). It has three stages:
12
+
13
+ 1. collect `(hidden_state, quality_score)` pairs
14
+ 2. train a small MLP quality classifier
15
+ 3. use classifier gradients during decoding
16
+
17
+ ### Classifier Definition Snippet
18
+
19
+ ```python
20
+ class QualityClassifier(nn.Module):
21
+ def __init__(self, d_model: int):
22
+ super().__init__()
23
+ self.net = nn.Sequential(
24
+ nn.Linear(d_model, 128),
25
+ nn.ReLU(),
26
+ nn.Dropout(0.1),
27
+ nn.Linear(128, 64),
28
+ nn.ReLU(),
29
+ nn.Linear(64, 1),
30
+ nn.Sigmoid(),
31
+ )
32
+ ```
33
+
34
+ ### Guidance Snippet
35
+
36
+ ```python
37
+ hidden = x.detach().to(clf_device).requires_grad_(True)
38
+ hidden.retain_grad()
39
+ quality = classifier(hidden)
40
+ quality.sum().backward()
41
+ grad = hidden.grad.to(device)
42
+ logit_grad = grad @ inner.head.weight.T
43
+ logits = logits + guidance_scale * logit_grad
44
+ ```
45
+
46
+ This turns hidden-state quality prediction into a differentiable decoding signal.
47
+
48
+ ## 3. Current Status
49
+
50
+ Task 5 originally failed for two reasons:
51
+
52
+ - the gradient was taken from a non-leaf tensor, causing `hidden.grad` to be `None`
53
+ - the cached quality labels collapsed to all zeros, so the classifier had no meaningful learning signal
54
+
55
+ These implementation bugs were patched. However, the existing saved quality cache in [analysis/outputs/task5_quality_data.npz](/Users/bhsingh/Documents/Final_Paraphrase/Exclude_Negative/analysis/outputs/task5_quality_data.npz) still contains degenerate labels from the earlier failed run.
56
+
57
+ Observed cache statistics:
58
+
59
+ - count: `500`
60
+ - mean: `0.0`
61
+ - std: `0.0`
62
+ - min: `0.0`
63
+ - max: `0.0`
64
+
65
+ That means the current classifier result is not valid for evaluation.
66
+
67
+ ## 4. Why the Current Result Is Not Reliable
68
+
69
+ Because all quality labels are zero:
70
+
71
+ - the classifier is effectively trained on a constant target
72
+ - low validation loss is meaningless
73
+ - guidance behavior cannot be interpreted as quality-aware control
74
+
75
+ So although the code path now exists, the saved run should not be used in mentor evaluation as a finished result.
76
+
77
+ ## 5. What Was Fixed
78
+
79
+ Two concrete corrections were made:
80
+
81
+ - a bounded quality transform was introduced so very large CER values do not collapse everything to zero
82
+ - the Task 5 runner now refreshes cached quality data when it detects degenerate labels
83
+
84
+ This means Task 5 is closer to being experimentally sound, but it still needs to be rerun from scratch after the patch.
85
+
86
+ ## 6. Expected Benefits
87
+
88
+ If Task 5 works as intended after rerunning, it could provide:
89
+
90
+ - a lightweight mechanism for improving generation quality
91
+ - a controllable quality-diversity tradeoff
92
+ - a reusable framework for guidance without retraining the full D3PM model
93
+ - a more research-oriented extension beyond standard training and inference
94
+
95
+ ## 7. Limitations
96
+
97
+ At present, this task has one decisive limitation: the saved outputs are not valid evaluation artifacts. The infrastructure is promising, but the experimental evidence is not yet strong enough to defend.
98
+
99
+ ## 8. Conclusion
100
+
101
+ Task 5 should be presented only as a partially completed advanced experiment. The implementation framework is now in place and the core bugs have been addressed, but the current cached run is still invalid for evaluation. Before showing this task to a mentor as a result, the quality data and guidance sweep should be rerun after patching so that the classifier is trained on non-degenerate labels.
analysis/run_analysis.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analysis/run_analysis.py
3
+ =========================
4
+ Entry point for all 5 tasks.
5
+
6
+ Tasks:
7
+ Task 1 — KV Cache benchmark (no retraining)
8
+ Task 2 — Attention viz + drift (no retraining)
9
+ Task 3 — Concept vectors + PCA steer (no retraining)
10
+ Task 4 — Step ablation (REQUIRES retraining for each T)
11
+ Task 5 — Classifier-free guidance (trains small 10k-param classifier)
12
+
13
+ Usage:
14
+ python analysis/run_analysis.py --task 1
15
+ python analysis/run_analysis.py --task 2 --input "dharmo rakṣati rakṣitaḥ"
16
+ python analysis/run_analysis.py --task 3
17
+ python analysis/run_analysis.py --task 4 --phase generate_configs
18
+ python analysis/run_analysis.py --task 4 --phase analyze
19
+ python analysis/run_analysis.py --task 5
20
+ python analysis/run_analysis.py --task all --input "satyameva jayate"
21
+
22
+ Output files: analysis/outputs/
23
+ """
24
+
25
+ import copy
26
+ import torch
27
+ import os, sys, argparse, json
28
+ import numpy as np
29
+
30
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
31
+ from config import CONFIG
32
+ from inference import load_model
33
+ from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
34
+
35
+ OUTPUT_DIR = "analysis/outputs"
36
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
37
+
38
+
39
+ # ── Shared loader ─────────────────────────────────────────────────────
40
+
41
+ def infer_model_type_from_checkpoint(ckpt_path: str) -> str:
42
+ name = ckpt_path.lower()
43
+ if "ablation_results/t" in name or "d3pm_cross_attention" in name:
44
+ return "d3pm_cross_attention"
45
+ if "d3pm_encoder_decoder" in name:
46
+ return "d3pm_encoder_decoder"
47
+ if "baseline_cross_attention" in name:
48
+ return "baseline_cross_attention"
49
+ if "baseline_encoder_decoder" in name:
50
+ return "baseline_encoder_decoder"
51
+ return CONFIG["model_type"]
52
+
53
+
54
+ def infer_include_negative_from_checkpoint(ckpt_path: str) -> bool:
55
+ name = ckpt_path.lower()
56
+ if "_neg_true" in name:
57
+ return True
58
+ if "_neg_false" in name:
59
+ return False
60
+ if "ablation_results/t" in name:
61
+ return False
62
+ return CONFIG["data"]["include_negative_examples"]
63
+
64
+
65
+ def load_everything(cfg, device, ckpt_override=None):
66
+ model_name = cfg['model_type']
67
+ has_neg = cfg['data']['include_negative_examples']
68
+ candidates = [
69
+ f"results7/{model_name}_neg_{has_neg}/best_model.pt",
70
+ f"results/{model_name}_neg_{has_neg}/best_model.pt",
71
+ f"results7/{model_name}_neg_True/best_model.pt",
72
+ f"results/{model_name}_neg_True/best_model.pt",
73
+ f"results7/{model_name}_neg_False/best_model.pt",
74
+ f"results/{model_name}_neg_False/best_model.pt",
75
+ "ablation_results/T4/best_model.pt",
76
+ "ablation_results/T8/best_model.pt",
77
+ ]
78
+ ckpt = ckpt_override if ckpt_override else next((p for p in candidates if os.path.exists(p)), None)
79
+ if not os.path.exists(ckpt):
80
+ raise FileNotFoundError(f"No checkpoint found. Checked: {candidates}")
81
+ model, cfg = load_model(ckpt, cfg, device)
82
+ model.eval()
83
+ src_tok = SanskritSourceTokenizer(
84
+ vocab_size=cfg['model'].get('src_vocab_size', 500),
85
+ max_len=cfg['model']['max_seq_len'])
86
+ tgt_tok = SanskritTargetTokenizer(
87
+ vocab_size=cfg['model'].get('tgt_vocab_size', 500),
88
+ max_len=cfg['model']['max_seq_len'])
89
+ return model, src_tok, tgt_tok, cfg
90
+
91
+
92
+ def load_val_data(cfg, src_tok, tgt_tok, n=500):
93
+ """Load validation set as (src_tensors, ref_strings, input_strings)."""
94
+ from data.dataset import OptimizedSanskritDataset
95
+ from torch.utils.data import Subset
96
+ from sklearn.model_selection import train_test_split
97
+
98
+ dataset = OptimizedSanskritDataset(
99
+ 'train', max_len=cfg['model']['max_seq_len'],
100
+ cfg=cfg, src_tokenizer=src_tok, tgt_tokenizer=tgt_tok)
101
+ total = min(cfg['data']['dataset_size'], len(dataset))
102
+ _, val_idx = train_test_split(list(range(total)), train_size=0.8, random_state=42)
103
+ val_idx = val_idx[:n]
104
+
105
+ src_list, ref_list, inp_list = [], [], []
106
+ for i in val_idx:
107
+ item = dataset[i]
108
+ src_list.append(item['input_ids'].unsqueeze(0))
109
+ ref_list.append(item['target_text'])
110
+ inp_list.append(item['input_text'])
111
+ return src_list, ref_list, inp_list
112
+
113
+
114
+ # ── Task 1 ────────────────────────────────────────────────────────────
115
+
116
+ def run_task1(model, src_tok, device):
117
+ print("\n" + "="*65)
118
+ print(" TASK 1 — KV Cache Benchmark")
119
+ print("="*65)
120
+ if not hasattr(model.model, 'generate_cached'):
121
+ print(" SKIP: not D3PMCrossAttention.")
122
+ return
123
+ from analysis.kv_cache_benchmark import run_benchmark, print_summary
124
+ results = run_benchmark(model, src_tok, device, src_lens=[16, 32, 64])
125
+ print_summary(results)
126
+ path = os.path.join(OUTPUT_DIR, "task1_kv_cache.txt")
127
+ with open(path, "w") as f:
128
+ f.write("TASK 1 — KV CACHE BENCHMARK\n" + "="*40 + "\n\n")
129
+ f.write(f"{'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
130
+ f"{'speedup':>8} {'encoder%':>9}\n")
131
+ for src_len, r in results.items():
132
+ f.write(f"{src_len:>8} {r['standard_s']:>12.3f} {r['cached_s']:>10.3f} "
133
+ f"{r['speedup']:>7.2f}x {r['encoder_pct']:>8.1f}%\n")
134
+ print(f" Saved: {path}")
135
+
136
+
137
+ # ── Task 2 ────────────────────────────────────────────────────────────
138
+
139
+ def run_task2(model, src_tok, tgt_tok, device, input_text):
140
+ print("\n" + "="*65)
141
+ print(" TASK 2 — Attention Visualization + Semantic Drift")
142
+ print("="*65)
143
+ print(f" Input: {input_text}")
144
+ if not hasattr(model.model, 'encode_source'):
145
+ print(" SKIP: not D3PMCrossAttention.")
146
+ return
147
+
148
+ src_ids = src_tok.encode(input_text)
149
+ src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device)
150
+ src_chars = list(input_text.strip())
151
+
152
+ from analysis.attention_viz import (AttentionCapture, plot_attn_heatmap,
153
+ plot_attn_evolution, plot_all_layers)
154
+ from analysis.semantic_drift import (capture_intermediate_outputs,
155
+ compute_drift, compute_token_stability,
156
+ plot_drift_curve)
157
+
158
+ # Attention capture
159
+ print(" Capturing attention weights...")
160
+ capturer = AttentionCapture(model)
161
+ step_weights = capturer.capture(src_tensor, capture_every=10)
162
+
163
+ with torch.no_grad():
164
+ out_ids = model.generate_cached(src_tensor)
165
+ tgt_ids = [x for x in out_ids[0].tolist() if x > 4]
166
+ tgt_text = tgt_tok.decode(tgt_ids).strip()
167
+ tgt_chars = list(tgt_text)
168
+ print(f" Output: {tgt_text}")
169
+
170
+ first_t = max(step_weights.keys())
171
+ plot_attn_heatmap(step_weights, t_val=first_t, layer=0,
172
+ src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
173
+ save_path=os.path.join(OUTPUT_DIR, f"task2_attn_t{first_t}.png"),
174
+ title=f"Attention t={first_t} (noisy) Layer 0")
175
+ plot_attn_heatmap(step_weights, t_val=0, layer=0,
176
+ src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
177
+ save_path=os.path.join(OUTPUT_DIR, "task2_attn_t0.png"),
178
+ title="Attention t=0 (final) Layer 0")
179
+ plot_all_layers(step_weights, t_val=0,
180
+ src_tokens=src_chars[:20], tgt_tokens=tgt_chars[:20],
181
+ save_path=os.path.join(OUTPUT_DIR, "task2_all_layers_t0.png"))
182
+ if len(src_chars) > 0 and len(tgt_chars) > 0:
183
+ plot_attn_evolution(step_weights, src_token_idx=0, tgt_token_idx=0,
184
+ layer=0, src_token_str=src_chars[0], tgt_token_str=tgt_chars[0],
185
+ save_path=os.path.join(OUTPUT_DIR, "task2_attn_evolution.png"))
186
+
187
+ # Semantic drift
188
+ print(" Computing semantic drift...")
189
+ step_outputs, final_out = capture_intermediate_outputs(
190
+ model, src_tensor, tgt_tok, capture_every=5)
191
+ drift = compute_drift(step_outputs, final_out)
192
+ stab = compute_token_stability(step_outputs, final_out, tgt_tok)
193
+ plot_drift_curve(drift, src_text=input_text,
194
+ save_path=os.path.join(OUTPUT_DIR, "task2_semantic_drift.png"))
195
+
196
+ print(f" Lock-in timestep: t={drift['lock_in_t']}")
197
+ print(f" Mean position lock-in: t={stab['mean_lock_t']:.1f} ± {stab['std_lock_t']:.1f}")
198
+
199
+ report = os.path.join(OUTPUT_DIR, "task2_report.txt")
200
+ with open(report, "w", encoding="utf-8") as f:
201
+ f.write("TASK 2 — ATTENTION + DRIFT REPORT\n" + "="*50 + "\n\n")
202
+ f.write(f"Input : {input_text}\nOutput : {final_out}\n\n")
203
+ f.write(f"Lock-in t : {drift['lock_in_t']}\n")
204
+ f.write(f"Mean pos lock-in : {stab['mean_lock_t']:.1f} ± {stab['std_lock_t']:.1f}\n\n")
205
+ f.write("Step → Output → CER-to-final\n" + "-"*60 + "\n")
206
+ for tv, cer in zip(drift["t_vals"], drift["cer_to_final"]):
207
+ f.write(f" t={tv:4d} | {step_outputs.get(tv,'')[:40]:40s} | {cer:.4f}\n")
208
+ print(f" Report: {report}")
209
+
210
+
211
+ # ── Task 3 ────────────────────────────────────────────────────────────
212
+
213
+ def run_task3(model, src_tok, tgt_tok, device, src_list, ref_list):
214
+ print("\n" + "="*65)
215
+ print(" TASK 3 — Concept Vectors + PCA Steering")
216
+ print("="*65)
217
+ if not hasattr(model.model, 'encode_source'):
218
+ print(" SKIP: not D3PMCrossAttention.")
219
+ return
220
+
221
+ from analysis.concept_vectors import (collect_hidden_states, fit_pca,
222
+ find_diversity_direction, generate_diversity_spectrum, plot_pca_space)
223
+
224
+ # Collect hidden states from val set
225
+ n = min(500, len(src_list))
226
+ print(f" Collecting hidden states from {n} examples...")
227
+ hidden, _ = collect_hidden_states(
228
+ model, src_list[:n], t_capture=0, max_samples=n)
229
+
230
+ # Compute output lengths for diversity direction
231
+ lengths = []
232
+ for src in src_list[:n]:
233
+ with torch.no_grad():
234
+ out = model.generate_cached(src.to(device))
235
+ ids = [x for x in out[0].tolist() if x > 4]
236
+ lengths.append(len(tgt_tok.decode(ids)))
237
+
238
+ # Fit PCA + find diversity direction
239
+ pca = fit_pca(hidden, n_components=min(50, n-1))
240
+ direction, best_pc, corr = find_diversity_direction(hidden, lengths, pca)
241
+
242
+ # Plot concept space
243
+ plot_pca_space(hidden, lengths, pca, best_pc,
244
+ save_path=os.path.join(OUTPUT_DIR, "task3_concept_space.png"))
245
+
246
+ # Generate diversity spectrum for first example
247
+ print("\n Diversity spectrum for first example:")
248
+ src0 = src_list[0]
249
+ inp0 = src_tok.decode([x for x in src0[0].tolist() if x > 4])
250
+ print(f" Input: {inp0}")
251
+ spectrum = generate_diversity_spectrum(
252
+ model, src0.to(device), direction, tgt_tok,
253
+ alphas=[-2.0, -1.0, 0.0, 1.0, 2.0])
254
+
255
+ # Save diversity direction + results
256
+ np.save(os.path.join(OUTPUT_DIR, "task3_diversity_direction.npy"), direction)
257
+
258
+ report = os.path.join(OUTPUT_DIR, "task3_report.txt")
259
+ with open(report, "w", encoding="utf-8") as f:
260
+ f.write("TASK 3 — CONCEPT VECTORS + PCA STEERING\n" + "="*50 + "\n\n")
261
+ f.write(f"PCA: {pca.n_components_} components, "
262
+ f"{pca.explained_variance_ratio_.sum()*100:.1f}% variance\n")
263
+ f.write(f"Diversity PC: {best_pc} (|r|={corr:.3f} with output length)\n\n")
264
+ f.write("Diversity spectrum:\n")
265
+ for alpha, text in sorted(spectrum.items()):
266
+ f.write(f" alpha={alpha:+.1f} → {text}\n")
267
+ print(f" Report: {report}")
268
+
269
+
270
+ # ── Task 4 ────────────────────────────────────────────────────────────
271
+
272
+ def run_task4(phase, model, src_tok, tgt_tok, device, cfg,
273
+ src_list, ref_list):
274
+ print("\n" + "="*65)
275
+ print(f" TASK 4 — Step Ablation (phase={phase})")
276
+ print("="*65)
277
+
278
+ from analysis.step_ablation import (generate_ablation_configs,
279
+ run_ablation_analysis, plot_ablation_3d, run_adversarial_test)
280
+
281
+ if phase == "generate_configs":
282
+ print(" Generating ablation configs...")
283
+ generate_ablation_configs(output_dir="ablation_configs")
284
+ print("\n NEXT STEPS:")
285
+ print(" 1. bash ablation_configs/train_all.sh")
286
+ print(" 2. python analysis/run_analysis.py --task 4 --phase analyze")
287
+
288
+ elif phase == "analyze":
289
+ # Check which models exist
290
+ existing = [T for T in [4, 8, 16, 32, 64]
291
+ if os.path.exists(f"ablation_results/T{T}/best_model.pt")]
292
+ if not existing:
293
+ print(" No ablation models found at ablation_results/T*/best_model.pt")
294
+ print(" Run: python analysis/run_analysis.py --task 4 --phase generate_configs")
295
+ print(" Then: bash ablation_configs/train_all.sh")
296
+ return
297
+
298
+ print(f" Found models for T={existing}")
299
+ results = run_ablation_analysis(
300
+ ablation_dir="ablation_results", base_cfg=cfg,
301
+ src_list=src_list[:200], ref_list=ref_list[:200],
302
+ tgt_tokenizer=tgt_tok, device=device,
303
+ output_dir=OUTPUT_DIR)
304
+ plot_ablation_3d(results,
305
+ save_path=os.path.join(OUTPUT_DIR, "task4_ablation_3d.png"))
306
+
307
+ # Adversarial robustness always runs on existing model (no retraining)
308
+ print("\n Running adversarial robustness test...")
309
+ inp_texts = [src_tok.decode([x for x in s[0].tolist() if x > 4])
310
+ for s in src_list[:50]]
311
+ run_adversarial_test(
312
+ model, src_tok, tgt_tok,
313
+ test_inputs=inp_texts, test_refs=ref_list[:50],
314
+ device=device, output_dir=OUTPUT_DIR)
315
+
316
+
317
+ # ── Task 5 ────────────────────────────────────────────────────────────
318
+
319
+ def run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list):
320
+ print("\n" + "="*65)
321
+ print(" TASK 5 — Classifier-Free Guidance")
322
+ print("="*65)
323
+ if not hasattr(model.model, 'encode_source'):
324
+ print(" SKIP: not D3PMCrossAttention.")
325
+ return
326
+
327
+ from analysis.quality_classifier import (
328
+ QualityClassifier, collect_quality_data,
329
+ train_quality_classifier, sweep_guidance_scales)
330
+
331
+ clf_path = os.path.join(OUTPUT_DIR, "task5_quality_classifier.pt")
332
+ d_model = cfg['model']['d_model']
333
+
334
+ # Step 1: collect or load training data
335
+ data_path = os.path.join(OUTPUT_DIR, "task5_quality_data.npz")
336
+ if os.path.exists(data_path):
337
+ print(" Loading cached quality data...")
338
+ data = np.load(data_path)
339
+ hidden = data["hidden"]
340
+ quality = data["quality"]
341
+ else:
342
+ print(" Collecting quality data (this takes a few minutes)...")
343
+ n = min(2000, len(src_list))
344
+ hidden, quality = collect_quality_data(
345
+ model, src_list[:n], ref_list[:n], tgt_tok,
346
+ t_capture=0, max_samples=n)
347
+ np.savez(data_path, hidden=hidden, quality=quality)
348
+ print(f" Saved quality data: {data_path}")
349
+
350
+ # Step 2: train or load classifier
351
+ if os.path.exists(clf_path):
352
+ print(f" Loading cached classifier: {clf_path}")
353
+ clf = QualityClassifier(d_model)
354
+ clf.load_state_dict(torch.load(clf_path, map_location='cpu'))
355
+ clf.eval()
356
+ else:
357
+ print(" Training quality classifier...")
358
+ clf = train_quality_classifier(
359
+ hidden, quality, d_model=d_model,
360
+ epochs=30, batch_size=64, lr=1e-3,
361
+ save_path=clf_path)
362
+ clf.eval()
363
+
364
+ # Step 3: guidance scale sweep
365
+ print("\n Guidance scale sweep (λ ∈ {0.0, 0.5, 1.0, 1.5, 2.0, 3.0})...")
366
+ n_sweep = min(50, len(src_list))
367
+ results = sweep_guidance_scales(
368
+ model, clf, src_list[:n_sweep], ref_list[:n_sweep],
369
+ tgt_tok, scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
370
+ n_samples=n_sweep, device=device, output_dir=OUTPUT_DIR)
371
+
372
+ # Find optimal scale
373
+ best_scale = min(results, key=lambda s: results[s]["mean_cer"])
374
+ print(f"\n Optimal guidance scale: λ={best_scale:.1f} "
375
+ f"CER={results[best_scale]['mean_cer']:.4f}")
376
+
377
+ report = os.path.join(OUTPUT_DIR, "task5_report.txt")
378
+ with open(report, "w") as f:
379
+ f.write("TASK 5 — CLASSIFIER-FREE GUIDANCE\n" + "="*50 + "\n\n")
380
+ f.write(f"Classifier params: {sum(p.numel() for p in clf.parameters())}\n")
381
+ f.write(f"Training samples : {len(hidden)}\n\n")
382
+ f.write("Guidance scale sweep:\n")
383
+ f.write(f" {'λ':>6} {'CER':>8} {'diversity':>10}\n")
384
+ f.write(" " + "-"*28 + "\n")
385
+ for s in sorted(results.keys()):
386
+ r = results[s]
387
+ marker = " ← optimal" if s == best_scale else ""
388
+ f.write(f" {s:>6.1f} {r['mean_cer']:>8.4f} {r['diversity']:>10.3f}{marker}\n")
389
+ print(f" Report: {report}")
390
+
391
+
392
+ # ── Main ──────────────────────────────────────────────────────────────
393
+
394
+ def main():
395
+ global OUTPUT_DIR
396
+
397
+ parser = argparse.ArgumentParser()
398
+ parser.add_argument("--task",
399
+ choices=["1","2","3","4","5","all"], default="all")
400
+ parser.add_argument("--input",
401
+ default="dharmo rakṣati rakṣitaḥ",
402
+ help="IAST input text for Task 2")
403
+ parser.add_argument("--phase",
404
+ choices=["generate_configs", "analyze"], default="analyze",
405
+ help="Task 4 phase: generate_configs (before training) or analyze (after)")
406
+ parser.add_argument("--checkpoint", default=None,
407
+ help="Optional explicit checkpoint path")
408
+ parser.add_argument("--output_dir", default="analysis/outputs",
409
+ help="Output directory for reports/figures")
410
+ args = parser.parse_args()
411
+
412
+ OUTPUT_DIR = args.output_dir
413
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
414
+
415
+ cfg = copy.deepcopy(CONFIG)
416
+ if args.checkpoint:
417
+ cfg["model_type"] = infer_model_type_from_checkpoint(args.checkpoint)
418
+ cfg["data"]["include_negative_examples"] = infer_include_negative_from_checkpoint(args.checkpoint)
419
+ ckpt_name = os.path.basename(os.path.dirname(args.checkpoint))
420
+ if ckpt_name.startswith("T") and ckpt_name[1:].isdigit():
421
+ t_val = int(ckpt_name[1:])
422
+ cfg["model"]["diffusion_steps"] = t_val
423
+ cfg["inference"]["num_steps"] = t_val
424
+
425
+ requested = cfg["training"]["device"]
426
+ if requested == "mps" and not torch.backends.mps.is_available():
427
+ requested = "cpu"
428
+ elif requested == "cuda" and not torch.cuda.is_available():
429
+ requested = "cpu"
430
+ cfg["training"]["device"] = requested
431
+ device = torch.device(requested)
432
+
433
+ print("Loading model and tokenizers...")
434
+ model, src_tok, tgt_tok, cfg = load_everything(cfg, device, ckpt_override=args.checkpoint)
435
+
436
+ # Load val data for tasks that need it (Tasks 3, 4, 5)
437
+ needs_data = args.task in ("3", "4", "5", "all")
438
+ if needs_data:
439
+ print("Loading validation data...")
440
+ src_list, ref_list, inp_list = load_val_data(cfg, src_tok, tgt_tok, n=500)
441
+ else:
442
+ src_list, ref_list, inp_list = [], [], []
443
+
444
+ tasks = (["1","2","3","4","5"] if args.task == "all"
445
+ else [args.task])
446
+
447
+ for task in tasks:
448
+ if task == "1":
449
+ run_task1(model, src_tok, device)
450
+ elif task == "2":
451
+ run_task2(model, src_tok, tgt_tok, device, args.input)
452
+ elif task == "3":
453
+ run_task3(model, src_tok, tgt_tok, device, src_list, ref_list)
454
+ elif task == "4":
455
+ run_task4(args.phase, model, src_tok, tgt_tok, device, cfg,
456
+ src_list, ref_list)
457
+ elif task == "5":
458
+ run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list)
459
+
460
+ print(f"\n{'='*65}")
461
+ print(f" All outputs saved to: {OUTPUT_DIR}/")
462
+ print("="*65)
463
+
464
+
465
+ if __name__ == "__main__":
466
+ main()
analysis/run_tasks_except4_all_models.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Run Tasks 1,2,3,5 for every available checkpoint (excluding Task 4).
3
+
4
+ Usage:
5
+ python analysis/run_tasks_except4_all_models.py
6
+ python analysis/run_tasks_except4_all_models.py --input "dharmo rakṣati rakṣitaḥ"
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import argparse
12
+ import json
13
+ import os
14
+ import subprocess
15
+ import sys
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+
19
+
20
+ ROOT = Path(__file__).resolve().parents[1]
21
+ DEFAULT_OUT_ROOT = ROOT / "analysis" / "outputs_multi"
22
+
23
+
24
+ def discover_checkpoints() -> list[Path]:
25
+ roots = [ROOT / "results", ROOT / "results7", ROOT / "ablation_results"]
26
+ out: list[Path] = []
27
+ for base in roots:
28
+ if not base.exists():
29
+ continue
30
+ for ckpt in sorted(base.glob("*/best_model.pt")):
31
+ out.append(ckpt)
32
+ return out
33
+
34
+
35
+ def slug_for_checkpoint(ckpt: Path) -> str:
36
+ root = ckpt.parent.parent.name
37
+ exp = ckpt.parent.name
38
+ return f"{root}__{exp}"
39
+
40
+
41
+ def run_task(task: str, ckpt: Path, input_text: str, out_dir: Path) -> tuple[int, float]:
42
+ cmd = [
43
+ sys.executable,
44
+ str(ROOT / "analysis" / "run_analysis.py"),
45
+ "--task", task,
46
+ "--checkpoint", str(ckpt),
47
+ "--output_dir", str(out_dir),
48
+ ]
49
+ if task == "2":
50
+ cmd.extend(["--input", input_text])
51
+
52
+ start = datetime.now()
53
+ env = os.environ.copy()
54
+ env.setdefault("HF_HOME", "/tmp/hf_home")
55
+ env.setdefault("HF_DATASETS_CACHE", "/tmp/hf_datasets")
56
+ env.setdefault("HF_HUB_CACHE", "/tmp/hf_hub")
57
+ env.setdefault("TRANSFORMERS_CACHE", "/tmp/hf_transformers")
58
+ os.makedirs(env["HF_HOME"], exist_ok=True)
59
+ os.makedirs(env["HF_DATASETS_CACHE"], exist_ok=True)
60
+ os.makedirs(env["HF_HUB_CACHE"], exist_ok=True)
61
+ os.makedirs(env["TRANSFORMERS_CACHE"], exist_ok=True)
62
+
63
+ proc = subprocess.run(cmd, cwd=str(ROOT), env=env)
64
+ seconds = (datetime.now() - start).total_seconds()
65
+ return proc.returncode, seconds
66
+
67
+
68
+ def main() -> None:
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument("--input", default="dharmo rakṣati rakṣitaḥ")
71
+ parser.add_argument("--out_root", default=str(DEFAULT_OUT_ROOT))
72
+ args = parser.parse_args()
73
+
74
+ checkpoints = discover_checkpoints()
75
+ if not checkpoints:
76
+ raise FileNotFoundError("No checkpoints found under results/results7/ablation_results.")
77
+
78
+ out_root = Path(args.out_root)
79
+ out_root.mkdir(parents=True, exist_ok=True)
80
+
81
+ tasks = ["1", "2", "3", "5"]
82
+ summary = {
83
+ "timestamp": datetime.now().isoformat(timespec="seconds"),
84
+ "tasks": tasks,
85
+ "checkpoints": [],
86
+ }
87
+
88
+ for ckpt in checkpoints:
89
+ slug = slug_for_checkpoint(ckpt)
90
+ model_out = out_root / slug
91
+ model_out.mkdir(parents=True, exist_ok=True)
92
+ print(f"\n=== Checkpoint: {ckpt} ===")
93
+ model_item = {
94
+ "checkpoint": str(ckpt),
95
+ "output_dir": str(model_out),
96
+ "tasks": [],
97
+ }
98
+
99
+ for task in tasks:
100
+ task_out = model_out / f"task{task}"
101
+ task_out.mkdir(parents=True, exist_ok=True)
102
+ print(f"-> Running task {task} ...")
103
+ code, sec = run_task(task, ckpt, args.input, task_out)
104
+ item = {
105
+ "task": task,
106
+ "exit_code": code,
107
+ "seconds": round(sec, 2),
108
+ "output_dir": str(task_out),
109
+ }
110
+ model_item["tasks"].append(item)
111
+ status = "OK" if code == 0 else "FAILED"
112
+ print(f" {status} ({sec:.1f}s)")
113
+
114
+ summary["checkpoints"].append(model_item)
115
+
116
+ summary_path = out_root / "summary.json"
117
+ with summary_path.open("w", encoding="utf-8") as f:
118
+ json.dump(summary, f, ensure_ascii=False, indent=2)
119
+ print(f"\nSaved summary: {summary_path}")
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
analysis/semantic_drift.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # analysis/semantic_drift.py
3
+ # ===========================
4
+ # Task 2: Semantic drift metric — how much does the intermediate generation
5
+ # diverge from the final output as we walk through diffusion steps T → 0?
6
+ #
7
+ # Metric: CER between x0_estimate at each step vs the final x0 at t=0.
8
+ #
9
+ # A well-trained model should show:
10
+ # - High drift at t=T-1 (near-random initial estimate)
11
+ # - Rapid decrease in drift around t=T//2 (model finds the right structure)
12
+ # - Near-zero drift at t=10 (output is stable, only fine corrections remain)
13
+ #
14
+ # If drift stays high until t=5 then suddenly collapses → model is doing all
15
+ # its work in the last few steps → consider reducing T.
16
+ #
17
+ # Also measures:
18
+ # - Token stability: fraction of positions that don't change between steps
19
+ # - Lock-in time: first step where each position "commits" to its final token
20
+ #
21
+ # No retraining required. Uses generate_cached() with intermediate snapshots.
22
+ # """
23
+ #
24
+ # import torch
25
+ # import torch.nn.functional as F
26
+ # import numpy as np
27
+ # from typing import List, Dict, Optional, Tuple
28
+ #
29
+ #
30
+ # def compute_cer_between(pred: str, ref: str) -> float:
31
+ # """CER between two strings."""
32
+ # if not ref:
33
+ # return 1.0 if pred else 0.0
34
+ #
35
+ # def edit_distance(s1, s2):
36
+ # m, n = len(s1), len(s2)
37
+ # dp = list(range(n + 1))
38
+ # for i in range(1, m + 1):
39
+ # prev, dp[0] = dp[0], i
40
+ # for j in range(1, n + 1):
41
+ # temp = dp[j]
42
+ # dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
43
+ # prev = temp
44
+ # return dp[n]
45
+ #
46
+ # return edit_distance(pred, ref) / len(ref)
47
+ #
48
+ #
49
+ # @torch.no_grad()
50
+ # def capture_intermediate_outputs(
51
+ # model,
52
+ # src: torch.Tensor,
53
+ # tgt_tokenizer,
54
+ # capture_every: int = 5,
55
+ # temperature: float = 0.8,
56
+ # top_k: int = 40,
57
+ # ) -> Tuple[Dict[int, str], str]:
58
+ # """
59
+ # Run generation while recording the decoded x0_estimate at every
60
+ # `capture_every` diffusion steps.
61
+ #
62
+ # Args:
63
+ # model : SanskritModel (D3PMCrossAttention)
64
+ # src : [1, src_len] IAST token ids (single sample)
65
+ # tgt_tokenizer : SanskritTargetTokenizer for decoding intermediate outputs
66
+ # capture_every : record every N steps
67
+ # temperature : sampling temperature
68
+ # top_k : top-k filter
69
+ #
70
+ # Returns:
71
+ # step_outputs : dict mapping t_val → decoded Devanagari string at that step
72
+ # final_output : decoded string at t=0 (final result)
73
+ # """
74
+ # if src.dim() == 1:
75
+ # src = src.unsqueeze(0)
76
+ #
77
+ # inner = model.model
78
+ # T = inner.scheduler.num_timesteps
79
+ # device = src.device
80
+ #
81
+ # # Encode source once (KV cache)
82
+ # memory, src_pad_mask = inner.encode_source(src)
83
+ #
84
+ # B = src.shape[0]
85
+ # tgt_len = inner.max_seq_len
86
+ # mask_id = inner.mask_token_id
87
+ #
88
+ # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
89
+ # hint = None
90
+ #
91
+ # step_outputs: Dict[int, str] = {}
92
+ # inner.eval()
93
+ #
94
+ # for t_val in range(T - 1, -1, -1):
95
+ # t = torch.full((B,), t_val, dtype=torch.long, device=device)
96
+ # is_last = (t_val == 0)
97
+ #
98
+ # logits, _ = inner.forward_cached(
99
+ # memory, src_pad_mask, x0_est, t,
100
+ # x0_hint=hint, inference_mode=True,
101
+ # )
102
+ #
103
+ # logits = logits / max(temperature, 1e-8)
104
+ # if top_k > 0:
105
+ # V = logits.shape[-1]
106
+ # if top_k < V:
107
+ # topk_vals, _ = torch.topk(logits, top_k, dim=-1)
108
+ # threshold = topk_vals[..., -1].unsqueeze(-1)
109
+ # logits = logits.masked_fill(logits < threshold, float('-inf'))
110
+ #
111
+ # probs = F.softmax(logits, dim=-1)
112
+ # x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
113
+ # hint = x0_est
114
+ #
115
+ # # Capture at this step
116
+ # if (T - 1 - t_val) % capture_every == 0 or is_last:
117
+ # ids = [x for x in x0_est[0].tolist() if x > 4]
118
+ # text = tgt_tokenizer.decode(ids).strip()
119
+ # step_outputs[t_val] = text
120
+ #
121
+ # final_output = step_outputs.get(0, "")
122
+ # return step_outputs, final_output
123
+ #
124
+ #
125
+ # def _sample(probs):
126
+ # B, L, V = probs.shape
127
+ # flat = probs.view(B * L, V).clamp(min=1e-9)
128
+ # flat = flat / flat.sum(dim=-1, keepdim=True)
129
+ # return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
130
+ #
131
+ #
132
+ # def compute_drift(
133
+ # step_outputs: Dict[int, str],
134
+ # final_output: str,
135
+ # ) -> Dict[str, object]:
136
+ # """
137
+ # Compute drift metrics comparing each intermediate output to the final.
138
+ #
139
+ # Returns dict with:
140
+ # t_vals : list of captured timesteps (T-1 → 0)
141
+ # cer_to_final: CER between each step's output and the final output
142
+ # 0.0 = identical to final, 1.0 = completely different
143
+ # lock_in_t : first t_val where CER drops and stays below 0.1
144
+ # (step at which output "commits" to final form)
145
+ # """
146
+ # t_vals = sorted(step_outputs.keys(), reverse=True) # T-1 → 0
147
+ # cer_to_final = []
148
+ #
149
+ # for t_val in t_vals:
150
+ # cer = compute_cer_between(step_outputs[t_val], final_output)
151
+ # cer_to_final.append(cer)
152
+ #
153
+ # # Find lock-in: first step where CER stays below threshold for rest of run
154
+ # threshold = 0.1
155
+ # lock_in_t = 0 # default: never locked in early
156
+ # for i, (t_val, cer) in enumerate(zip(t_vals, cer_to_final)):
157
+ # if all(c <= threshold for c in cer_to_final[i:]):
158
+ # lock_in_t = t_val
159
+ # break
160
+ #
161
+ # return {
162
+ # "t_vals": t_vals,
163
+ # "cer_to_final": cer_to_final,
164
+ # "lock_in_t": lock_in_t,
165
+ # "final_output": final_output,
166
+ # }
167
+ #
168
+ #
169
+ # def compute_token_stability(
170
+ # step_outputs: Dict[int, str],
171
+ # final_output: str,
172
+ # tgt_tokenizer,
173
+ # ) -> Dict[str, object]:
174
+ # """
175
+ # Token-level stability: for each position, at which diffusion step
176
+ # does it first match its final token and stay matched?
177
+ #
178
+ # Returns:
179
+ # position_lock_times: list of t_val at which each position locks in
180
+ # mean_lock_t : average lock-in timestep across positions
181
+ # """
182
+ # T = max(step_outputs.keys())
183
+ # t_vals = sorted(step_outputs.keys(), reverse=True) # T-1 → 0
184
+ #
185
+ # # Encode all intermediate outputs and the final
186
+ # def encode(text):
187
+ # return tgt_tokenizer.encode(text)
188
+ #
189
+ # final_ids = encode(final_output)
190
+ # L = len(final_ids)
191
+ #
192
+ # # Build matrix: [n_steps, L]
193
+ # step_ids = []
194
+ # for t_val in t_vals:
195
+ # step_ids.append(encode(step_outputs.get(t_val, "")))
196
+ #
197
+ # # Pad all to same length
198
+ # max_len = max(len(s) for s in step_ids)
199
+ # step_ids = [s + [1] * (max_len - len(s)) for s in step_ids] # 1=PAD
200
+ # final_ids_padded = final_ids + [1] * (max_len - len(final_ids))
201
+ #
202
+ # step_arr = np.array(step_ids) # [n_steps, L]
203
+ # final_arr = np.array(final_ids_padded) # [L]
204
+ #
205
+ # # For each position: find first step index where it matches final
206
+ # # and stays matched for all subsequent steps
207
+ # position_lock_steps = []
208
+ # for pos in range(min(L, max_len)):
209
+ # col = step_arr[:, pos] # [n_steps]
210
+ # fin = final_arr[pos]
211
+ # locked_at = len(t_vals) - 1 # default: never locks early
212
+ # for i in range(len(t_vals)):
213
+ # if all(col[i:] == fin):
214
+ # locked_at = i
215
+ # break
216
+ # position_lock_steps.append(t_vals[locked_at] if locked_at < len(t_vals) else 0)
217
+ #
218
+ # return {
219
+ # "position_lock_times": position_lock_steps,
220
+ # "mean_lock_t": float(np.mean(position_lock_steps)),
221
+ # "std_lock_t": float(np.std(position_lock_steps)),
222
+ # }
223
+ #
224
+ #
225
+ # def plot_drift_curve(
226
+ # drift_result: Dict,
227
+ # src_text: str = "",
228
+ # save_path: Optional[str] = None,
229
+ # ):
230
+ # """
231
+ # Plot CER-to-final vs diffusion step.
232
+ # Shows where the model "commits" to the final output.
233
+ # """
234
+ # try:
235
+ # import matplotlib.pyplot as plt
236
+ # except ImportError:
237
+ # print("pip install matplotlib.")
238
+ # return
239
+ #
240
+ # t_vals = drift_result["t_vals"]
241
+ # cers = drift_result["cer_to_final"]
242
+ # lock_t = drift_result["lock_in_t"]
243
+ #
244
+ # fig, ax = plt.subplots(figsize=(12, 4))
245
+ # ax.plot(range(len(t_vals)), cers, linewidth=1.8, color='coral', label='CER to final')
246
+ # ax.fill_between(range(len(t_vals)), cers, alpha=0.15, color='coral')
247
+ #
248
+ # # Mark lock-in point
249
+ # if lock_t in t_vals:
250
+ # lock_idx = t_vals.index(lock_t)
251
+ # ax.axvline(lock_idx, color='steelblue', linestyle='--', linewidth=1.2,
252
+ # label=f"Lock-in at t={lock_t}")
253
+ #
254
+ # ax.axhline(0.1, color='gray', linestyle=':', linewidth=1, alpha=0.7)
255
+ #
256
+ # n = len(t_vals)
257
+ # tick_positions = list(range(0, n, max(1, n // 10)))
258
+ # ax.set_xticks(tick_positions)
259
+ # ax.set_xticklabels([str(t_vals[i]) for i in tick_positions], fontsize=8)
260
+ # ax.set_xlabel("Diffusion step t (T-1 → 0)", fontsize=11)
261
+ # ax.set_ylabel("CER vs final output", fontsize=11)
262
+ # ax.set_ylim(0, 1.05)
263
+ # ax.set_xlim(0, n - 1)
264
+ # ax.legend(fontsize=10)
265
+ #
266
+ # title = f"Semantic drift"
267
+ # if src_text:
268
+ # title += f" | src: {src_text[:50]}"
269
+ # ax.set_title(title, fontsize=11)
270
+ # plt.tight_layout()
271
+ #
272
+ # if save_path:
273
+ # import os
274
+ # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
275
+ # plt.savefig(save_path, dpi=150, bbox_inches='tight')
276
+ # print(f"Saved: {save_path}")
277
+ # else:
278
+ # plt.show()
279
+ # plt.close()
280
+ # ============================================================
281
+ # TASK 2: Source–Paraphrase Semantic Alignment Trajectory
282
+ # ============================================================
283
+
284
+ import torch
285
+ import torch.nn.functional as F
286
+ import numpy as np
287
+ import matplotlib.pyplot as plt
288
+ from typing import Dict, List, Tuple
289
+ from collections import defaultdict
290
+
291
+ # Optional (install if needed)
292
+ # pip install bert-score scikit-learn
293
+ from bert_score import score as bertscore
294
+ from sklearn.feature_extraction.text import TfidfVectorizer
295
+
296
+
297
+ # ============================================================
298
+ # ------------------ ATTENTION HOOK --------------------------
299
+ # ============================================================
300
+
301
+ def register_attention_hooks(model):
302
+ """
303
+ Registers forward hooks to capture cross-attention weights
304
+ from each decoder block.
305
+
306
+ Assumes each block has attribute `.cross_attn.attn_weights`
307
+ """
308
+ inner = model.model
309
+ attention_maps = []
310
+
311
+ def hook_fn(module, input, output):
312
+ if hasattr(module, "attn_weights"):
313
+ attention_maps.append(module.attn_weights.detach().cpu())
314
+
315
+ hooks = []
316
+ for block in inner.decoder_blocks:
317
+ if hasattr(block, "cross_attn"):
318
+ h = block.cross_attn.register_forward_hook(hook_fn)
319
+ hooks.append(h)
320
+
321
+ return hooks, attention_maps
322
+
323
+
324
+ # ============================================================
325
+ # ------------------ CAPTURE TRAJECTORY ----------------------
326
+ # ============================================================
327
+
328
+ @torch.no_grad()
329
+ def capture_alignment_trajectory(
330
+ model,
331
+ src_tensor: torch.Tensor,
332
+ src_text: str,
333
+ tgt_tokenizer,
334
+ steps_to_capture: List[int] = None,
335
+ ):
336
+ """
337
+ Capture:
338
+ - intermediate outputs
339
+ - cross-attention maps
340
+ - BERTScore vs source
341
+
342
+ Returns:
343
+ dict with outputs, attention, drift
344
+ """
345
+
346
+ inner = model.model
347
+ device = src_tensor.device
348
+ T = inner.scheduler.num_timesteps
349
+
350
+ if steps_to_capture is None:
351
+ steps_to_capture = list(range(T - 1, -1, -5)) + [0]
352
+
353
+ # Register hooks
354
+ hooks, attn_storage = register_attention_hooks(model)
355
+
356
+ memory, src_pad_mask = inner.encode_source(src_tensor)
357
+
358
+ B = src_tensor.shape[0]
359
+ tgt_len = inner.max_seq_len
360
+ mask_id = inner.mask_token_id
361
+
362
+ x0_est = torch.full((B, tgt_len), mask_id, device=device)
363
+ hint = None
364
+
365
+ outputs = {}
366
+ attention_per_step = {}
367
+
368
+ for t_val in range(T - 1, -1, -1):
369
+ t = torch.full((B,), t_val, device=device)
370
+
371
+ logits, _ = inner.forward_cached(
372
+ memory, src_pad_mask, x0_est, t,
373
+ x0_hint=hint, inference_mode=True
374
+ )
375
+
376
+ probs = F.softmax(logits, dim=-1)
377
+ x0_est = torch.argmax(probs, dim=-1)
378
+ hint = x0_est
379
+
380
+ if t_val in steps_to_capture:
381
+ ids = [x for x in x0_est[0].tolist() if x > 4]
382
+ text = tgt_tokenizer.decode(ids)
383
+
384
+ outputs[t_val] = text
385
+
386
+ # Collect attention maps (last layer only for simplicity)
387
+ if len(attn_storage) > 0:
388
+ attention_per_step[t_val] = attn_storage[-1].numpy()
389
+
390
+ # Remove hooks
391
+ for h in hooks:
392
+ h.remove()
393
+
394
+ # Compute BERTScore trajectory
395
+ bert_scores = compute_bert_alignment(src_text, outputs)
396
+
397
+ return {
398
+ "outputs": outputs,
399
+ "attention": attention_per_step,
400
+ "bert_scores": bert_scores,
401
+ }
402
+
403
+
404
+ # ============================================================
405
+ # ------------------ BERTScore -------------------------------
406
+ # ============================================================
407
+
408
+ def compute_bert_alignment(src_text: str, outputs: Dict[int, str]):
409
+ """
410
+ Compute BERTScore between source and each intermediate output
411
+ """
412
+ scores = {}
413
+
414
+ for t, text in outputs.items():
415
+ P, R, F1 = bertscore([text], [src_text], lang="hi", verbose=False)
416
+ scores[t] = float(F1.mean())
417
+
418
+ return scores
419
+
420
+
421
+ # ============================================================
422
+ # ------------------ SEMANTIC DRIFT --------------------------
423
+ # ============================================================
424
+
425
+ def compute_semantic_drift(bert_scores: Dict[int, float]):
426
+ """
427
+ Drift = drop from best alignment
428
+ """
429
+ max_score = max(bert_scores.values())
430
+ drift = {t: max_score - s for t, s in bert_scores.items()}
431
+ return drift
432
+
433
+
434
+ # ============================================================
435
+ # ------------------ ATTENTION STABILITY ---------------------
436
+ # ============================================================
437
+
438
+ def compute_attention_stability(attention_maps: Dict[int, np.ndarray]):
439
+ """
440
+ Measures if tokens attend consistently across steps.
441
+ """
442
+ steps = sorted(attention_maps.keys(), reverse=True)
443
+
444
+ stability_scores = []
445
+
446
+ for i in range(len(steps) - 1):
447
+ A = attention_maps[steps[i]]
448
+ B = attention_maps[steps[i+1]]
449
+
450
+ diff = np.abs(A - B).mean()
451
+ stability_scores.append(diff)
452
+
453
+ return np.mean(stability_scores)
454
+
455
+
456
+ # ============================================================
457
+ # ------------------ TF-IDF vs STABILITY ---------------------
458
+ # ============================================================
459
+
460
+ def compute_tfidf_attention_correlation(
461
+ src_texts: List[str],
462
+ attention_maps_list: List[Dict[int, np.ndarray]]
463
+ ):
464
+ """
465
+ Correlate TF-IDF importance with attention stability
466
+ """
467
+
468
+ vectorizer = TfidfVectorizer()
469
+ tfidf = vectorizer.fit_transform(src_texts).toarray()
470
+
471
+ word_importance = tfidf.mean(axis=0)
472
+
473
+ stability = []
474
+ for attn_maps in attention_maps_list:
475
+ stability.append(compute_attention_stability(attn_maps))
476
+
477
+ corr = np.corrcoef(word_importance[:len(stability)], stability)[0, 1]
478
+ return corr
479
+
480
+
481
+ # ============================================================
482
+ # ------------------ HEATMAP VISUALIZATION -------------------
483
+ # ============================================================
484
+
485
+ def plot_attention_heatmap(attn: np.ndarray, title="Attention"):
486
+ """
487
+ Plot cross-attention heatmap
488
+ attn: [tgt_len, src_len]
489
+ """
490
+ plt.figure(figsize=(6,5))
491
+ plt.imshow(attn, aspect='auto', cmap='viridis')
492
+ plt.colorbar()
493
+ plt.title(title)
494
+ plt.xlabel("Source tokens")
495
+ plt.ylabel("Target tokens")
496
+ plt.show()
497
+
498
+
499
+ def visualize_trajectory(attention_maps: Dict[int, np.ndarray]):
500
+ """
501
+ Show attention evolution over time
502
+ """
503
+ steps = sorted(attention_maps.keys(), reverse=True)
504
+
505
+ for t in steps[:5]: # show 5 steps
506
+ plot_attention_heatmap(attention_maps[t], title=f"Step t={t}")
507
+
508
+
509
+ # ============================================================
510
+ # ------------------ LOCKED vs FLEXIBLE ----------------------
511
+ # ============================================================
512
+
513
+ def analyze_token_behavior(attention_maps: Dict[int, np.ndarray]):
514
+ """
515
+ Detect whether tokens are locked or flexible
516
+ """
517
+ steps = sorted(attention_maps.keys(), reverse=True)
518
+
519
+ first = attention_maps[steps[0]]
520
+ last = attention_maps[steps[-1]]
521
+
522
+ diff = np.abs(first - last).mean(axis=1)
523
+
524
+ locked = np.where(diff < 0.05)[0]
525
+ flexible = np.where(diff >= 0.05)[0]
526
+
527
+ return {
528
+ "locked_tokens": locked.tolist(),
529
+ "flexible_tokens": flexible.tolist()
530
+ }
531
+
532
+
533
+ # ============================================================
534
+ # ------------------ MASTER FUNCTION -------------------------
535
+ # ============================================================
536
+
537
+ def run_task2_analysis(
538
+ model,
539
+ src_tensor,
540
+ src_text,
541
+ tgt_tokenizer
542
+ ):
543
+ result = capture_alignment_trajectory(
544
+ model, src_tensor, src_text, tgt_tokenizer
545
+ )
546
+
547
+ drift = compute_semantic_drift(result["bert_scores"])
548
+ stability = compute_attention_stability(result["attention"])
549
+ behavior = analyze_token_behavior(result["attention"])
550
+
551
+ print("\nBERTScore trajectory:")
552
+ print(result["bert_scores"])
553
+
554
+ print("\nSemantic drift:")
555
+ print(drift)
556
+
557
+ print(f"\nAttention stability: {stability:.4f}")
558
+
559
+ print("\nToken behavior:")
560
+ print(behavior)
561
+
562
+ visualize_trajectory(result["attention"])
563
+
564
+ return {
565
+ "trajectory": result,
566
+ "drift": drift,
567
+ "stability": stability,
568
+ "behavior": behavior
569
+ }
analysis/step_ablation.py ADDED
@@ -0,0 +1,582 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # analysis/step_ablation.py
3
+ # ==========================
4
+ # Task 4: Semantic Robustness — Ablation of Diffusion Steps vs Meaning Preservation
5
+ #
6
+ # Two-phase workflow (retraining IS required for different T values):
7
+ #
8
+ # PHASE 1 — Generate configs + train (run once per T value):
9
+ # python analysis/step_ablation.py --phase generate_configs
10
+ # # Creates configs: ablation_configs/T4.py, T8.py, T16.py, T32.py, T64.py
11
+ # # Then train each: MODEL_TYPE=d3pm_cross_attention python train.py (for each config)
12
+ #
13
+ # PHASE 2 — Analyze trained models (no retraining needed):
14
+ # python analysis/step_ablation.py --phase analyze
15
+ # # Loads each trained model, generates 200 paraphrases, computes CER
16
+ # # Produces 3D plot: X=steps, Y=generation_speed, Z=CER
17
+ #
18
+ # Why retraining is needed:
19
+ # A model trained with T=128 learns to denoise from x_t~Uniform[0,128].
20
+ # Running it with T=4 means the model only sees t∈{0,1,2,3} — which it
21
+ # was never trained on at those scales. Outputs are meaningless.
22
+ # You must train a separate model for each T value.
23
+ #
24
+ # Also implements adversarial robustness test (no retraining):
25
+ # Takes your existing T=128 model and tests whether corrupted IAST
26
+ # inputs (typos, character swaps) cause proportional output degradation.
27
+ # """
28
+ #
29
+ # import torch
30
+ # import torch.nn.functional as F
31
+ # import numpy as np
32
+ # import os
33
+ # import sys
34
+ # import time
35
+ # import json
36
+ # import copy
37
+ # from typing import List, Dict, Optional
38
+ #
39
+ # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
40
+ #
41
+ #
42
+ # # ── Phase 1: Config generation ────────────────────────────────────────
43
+ #
44
+ # T_VALUES = [4, 8, 16, 32, 64]
45
+ #
46
+ # def generate_ablation_configs(base_config_path: str = "config.py",
47
+ # output_dir: str = "ablation_configs"):
48
+ # """
49
+ # Generate one config file per T value.
50
+ # Each config is a copy of the base config with diffusion_steps changed.
51
+ #
52
+ # After running this, train each model:
53
+ # for T in 4 8 16 32 64; do
54
+ # cp ablation_configs/config_T${T}.py config.py
55
+ # python train.py
56
+ # mv results7/d3pm_cross_attention_neg_False \
57
+ # ablation_results/T${T}
58
+ # done
59
+ # """
60
+ # os.makedirs(output_dir, exist_ok=True)
61
+ #
62
+ # # Read base config
63
+ # with open(base_config_path, "r") as f:
64
+ # base_src = f.read()
65
+ #
66
+ # for T in T_VALUES:
67
+ # # Replace diffusion_steps and num_steps
68
+ # cfg_src = base_src
69
+ # cfg_src = cfg_src.replace(
70
+ # '"diffusion_steps": 128',
71
+ # f'"diffusion_steps": {T}'
72
+ # )
73
+ # cfg_src = cfg_src.replace(
74
+ # "'diffusion_steps': 128",
75
+ # f"'diffusion_steps': {T}"
76
+ # )
77
+ # cfg_src = cfg_src.replace(
78
+ # '"num_steps": 128',
79
+ # f'"num_steps": {T}'
80
+ # )
81
+ # cfg_src = cfg_src.replace(
82
+ # "'num_steps': 128",
83
+ # f"'num_steps': {T}"
84
+ # )
85
+ # out_path = os.path.join(output_dir, f"config_T{T}.py")
86
+ # with open(out_path, "w") as f:
87
+ # f.write(f"# Ablation config: T={T} diffusion steps\n")
88
+ # f.write(cfg_src)
89
+ # print(f" Wrote: {out_path}")
90
+ #
91
+ # # Write a shell script to train all
92
+ # shell_script = os.path.join(output_dir, "train_all.sh")
93
+ # with open(shell_script, "w") as f:
94
+ # f.write("#!/bin/bash\n")
95
+ # f.write("# Run this script to train all ablation models\n\n")
96
+ # for T in T_VALUES:
97
+ # f.write(f"echo '=== Training T={T} ==='\n")
98
+ # f.write(f"cp {output_dir}/config_T{T}.py config.py\n")
99
+ # f.write(f"python train.py\n")
100
+ # f.write(f"mkdir -p ablation_results/T{T}\n")
101
+ # f.write(f"cp -r results7/d3pm_cross_attention_neg_False/best_model.pt "
102
+ # f"ablation_results/T{T}/best_model.pt\n")
103
+ # f.write(f"cp -r results7/d3pm_cross_attention_neg_False/train.log "
104
+ # f"ablation_results/T{T}/train.log\n\n")
105
+ # os.chmod(shell_script, 0o755)
106
+ # print(f"\nTraining script: {shell_script}")
107
+ # print(f"Run: bash {shell_script}")
108
+ #
109
+ #
110
+ # # ── Phase 2: Analysis (after models are trained) ──────────────────────
111
+ #
112
+ # def compute_cer(pred: str, ref: str) -> float:
113
+ # if not ref:
114
+ # return 1.0
115
+ #
116
+ # def edit_distance(s1, s2):
117
+ # m, n = len(s1), len(s2)
118
+ # dp = list(range(n + 1))
119
+ # for i in range(1, m + 1):
120
+ # prev, dp[0] = dp[0], i
121
+ # for j in range(1, n + 1):
122
+ # temp = dp[j]
123
+ # dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
124
+ # prev = temp
125
+ # return dp[n]
126
+ #
127
+ # return edit_distance(pred, ref) / max(len(ref), 1)
128
+ #
129
+ #
130
+ # def evaluate_model(
131
+ # model,
132
+ # src_list: List[torch.Tensor],
133
+ # ref_list: List[str],
134
+ # tgt_tokenizer,
135
+ # n_samples: int = 200,
136
+ # temperature: float = 0.8,
137
+ # top_k: int = 40,
138
+ # ) -> Dict:
139
+ # """
140
+ # Generate n_samples outputs and compute CER + generation speed.
141
+ #
142
+ # Returns dict with:
143
+ # mean_cer : average CER over samples
144
+ # generation_s : total wall-clock seconds for all generations
145
+ # speed_per_sample: seconds per sample
146
+ # cer_list : per-sample CER values
147
+ # """
148
+ # device = next(model.parameters()).device
149
+ # n = min(n_samples, len(src_list))
150
+ # cer_list = []
151
+ #
152
+ # start = time.perf_counter()
153
+ # for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
154
+ # if src.dim() == 1:
155
+ # src = src.unsqueeze(0)
156
+ #
157
+ # with torch.no_grad():
158
+ # if hasattr(model.model, 'generate_cached'):
159
+ # out = model.model.generate_cached(
160
+ # src.to(device), temperature=temperature, top_k=top_k
161
+ # )
162
+ # else:
163
+ # out = model.generate(
164
+ # src.to(device), temperature=temperature, top_k=top_k
165
+ # )
166
+ #
167
+ # ids = [x for x in out[0].tolist() if x > 4]
168
+ # pred = tgt_tokenizer.decode(ids).strip()
169
+ # cer = compute_cer(pred, ref)
170
+ # cer_list.append(cer)
171
+ #
172
+ # elapsed = time.perf_counter() - start
173
+ #
174
+ # return {
175
+ # "mean_cer": float(np.mean(cer_list)),
176
+ # "std_cer": float(np.std(cer_list)),
177
+ # "generation_s": elapsed,
178
+ # "speed_per_sample": elapsed / max(n, 1),
179
+ # "cer_list": cer_list,
180
+ # "n_samples": n,
181
+ # }
182
+ #
183
+ #
184
+ # def run_ablation_analysis(
185
+ # ablation_dir: str = "ablation_results",
186
+ # base_cfg: dict = None,
187
+ # src_list: List[torch.Tensor] = None,
188
+ # ref_list: List[str] = None,
189
+ # tgt_tokenizer = None,
190
+ # device: torch.device = None,
191
+ # output_dir: str = "analysis/outputs",
192
+ # ) -> Dict:
193
+ # """
194
+ # Load each trained model and evaluate.
195
+ # Produces results dict and 3D plot.
196
+ #
197
+ # Expects ablation_results/T{N}/best_model.pt for each T in T_VALUES.
198
+ # """
199
+ # from inference import load_model
200
+ #
201
+ # results = {}
202
+ # for T in T_VALUES:
203
+ # ckpt = os.path.join(ablation_dir, f"T{T}", "best_model.pt")
204
+ # if not os.path.exists(ckpt):
205
+ # print(f" SKIP T={T}: no checkpoint at {ckpt}")
206
+ # continue
207
+ #
208
+ # print(f"\nEvaluating T={T}...")
209
+ # cfg_T = copy.deepcopy(base_cfg)
210
+ # cfg_T['model']['diffusion_steps'] = T
211
+ # cfg_T['inference']['num_steps'] = T
212
+ #
213
+ # model, cfg_T = load_model(ckpt, cfg_T, device)
214
+ # model.eval()
215
+ #
216
+ # metrics = evaluate_model(
217
+ # model, src_list, ref_list, tgt_tokenizer, n_samples=200
218
+ # )
219
+ # results[T] = metrics
220
+ # print(f" T={T} CER={metrics['mean_cer']:.4f} "
221
+ # f"speed={metrics['speed_per_sample']:.3f}s/sample")
222
+ #
223
+ # del model
224
+ #
225
+ # # Save results
226
+ # os.makedirs(output_dir, exist_ok=True)
227
+ # results_path = os.path.join(output_dir, "ablation_results.json")
228
+ # with open(results_path, "w") as f:
229
+ # json.dump({str(k): {kk: vv for kk, vv in v.items() if kk != 'cer_list'}
230
+ # for k, v in results.items()}, f, indent=2)
231
+ # print(f"\nResults saved: {results_path}")
232
+ #
233
+ # return results
234
+ #
235
+ #
236
+ # def plot_ablation_3d(
237
+ # results: Dict,
238
+ # save_path: Optional[str] = None,
239
+ # ):
240
+ # """
241
+ # 3D plot: X=diffusion_steps, Y=generation_speed(s/sample), Z=CER.
242
+ # Also produces a 2D summary plot.
243
+ # """
244
+ # try:
245
+ # import matplotlib.pyplot as plt
246
+ # from mpl_toolkits.mplot3d import Axes3D
247
+ # except ImportError:
248
+ # print("pip install matplotlib.")
249
+ # return
250
+ #
251
+ # T_list = sorted(results.keys())
252
+ # cers = [results[T]["mean_cer"] for T in T_list]
253
+ # speeds = [results[T]["speed_per_sample"] for T in T_list]
254
+ #
255
+ # # ── 3D plot ───────────────────────────────────────────────────────
256
+ # fig = plt.figure(figsize=(14, 5))
257
+ #
258
+ # ax3d = fig.add_subplot(121, projection='3d')
259
+ # ax3d.scatter(T_list, speeds, cers, c=cers, cmap='RdYlGn_r', s=80)
260
+ # for T, s, c in zip(T_list, speeds, cers):
261
+ # ax3d.text(T, s, c, f"T={T}", fontsize=8)
262
+ # ax3d.set_xlabel("Diffusion steps T", fontsize=9)
263
+ # ax3d.set_ylabel("Speed (s/sample)", fontsize=9)
264
+ # ax3d.set_zlabel("CER (↓ better)", fontsize=9)
265
+ # ax3d.set_title("T vs speed vs CER", fontsize=10)
266
+ #
267
+ # # ── 2D CER vs T (find the knee) ──────────────────────────────────
268
+ # ax2d = fig.add_subplot(122)
269
+ # ax2d.plot(T_list, cers, 'o-', linewidth=1.8, color='coral', markersize=7)
270
+ # for T, c in zip(T_list, cers):
271
+ # ax2d.annotate(f"{c:.3f}", (T, c), textcoords="offset points",
272
+ # xytext=(0, 8), fontsize=8, ha='center')
273
+ #
274
+ # # Find knee: largest CER drop per unit T (elbow method)
275
+ # if len(T_list) >= 3:
276
+ # drops = [cers[i] - cers[i+1] for i in range(len(cers)-1)]
277
+ # knee_i = int(np.argmax(drops))
278
+ # knee_T = T_list[knee_i + 1]
279
+ # ax2d.axvline(knee_T, color='steelblue', linestyle='--', linewidth=1.2,
280
+ # label=f"Knee at T={knee_T}")
281
+ # ax2d.legend(fontsize=9)
282
+ #
283
+ # ax2d.set_xlabel("Diffusion steps T", fontsize=10)
284
+ # ax2d.set_ylabel("CER (lower = better)", fontsize=10)
285
+ # ax2d.set_title("CER vs diffusion steps", fontsize=10)
286
+ # ax2d.set_ylim(0, max(cers) * 1.1)
287
+ #
288
+ # plt.tight_layout()
289
+ # if save_path:
290
+ # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
291
+ # plt.savefig(save_path, dpi=150, bbox_inches='tight')
292
+ # print(f"Saved: {save_path}")
293
+ # else:
294
+ # plt.show()
295
+ # plt.close()
296
+ #
297
+ #
298
+ # # ── Adversarial robustness test (no retraining needed) ───────────────
299
+ #
300
+ # def corrupt_iast(text: str, corruption_rate: float = 0.05) -> str:
301
+ # """
302
+ # Introduce random corruption into IAST text:
303
+ # - Character swap (adjacent chars swapped)
304
+ # - Character deletion
305
+ # - Random character insertion
306
+ #
307
+ # Models rate as 5% to 20% corruption to test robustness.
308
+ # """
309
+ # import random
310
+ # chars = list(text)
311
+ # n_corrupt = max(1, int(len(chars) * corruption_rate))
312
+ #
313
+ # for _ in range(n_corrupt):
314
+ # op = random.choice(['swap', 'delete', 'insert'])
315
+ # pos = random.randint(0, len(chars) - 1)
316
+ #
317
+ # if op == 'swap' and pos < len(chars) - 1:
318
+ # chars[pos], chars[pos+1] = chars[pos+1], chars[pos]
319
+ # elif op == 'delete' and len(chars) > 1:
320
+ # chars.pop(pos)
321
+ # elif op == 'insert':
322
+ # chars.insert(pos, random.choice('abcdeimnostu'))
323
+ #
324
+ # return "".join(chars)
325
+ #
326
+ #
327
+ # @torch.no_grad()
328
+ # def run_adversarial_test(
329
+ # model,
330
+ # src_tokenizer,
331
+ # tgt_tokenizer,
332
+ # test_inputs: List[str],
333
+ # test_refs: List[str],
334
+ # corruption_rates: List[float] = [0.0, 0.05, 0.10, 0.15, 0.20],
335
+ # device: torch.device = None,
336
+ # output_dir: str = "analysis/outputs",
337
+ # ) -> Dict:
338
+ # """
339
+ # Test if CER degrades proportionally with IAST corruption.
340
+ # Uses existing trained model — no retraining.
341
+ # """
342
+ # device = device or next(model.parameters()).device
343
+ # results = {}
344
+ #
345
+ # print("\nAdversarial robustness test...")
346
+ # for rate in corruption_rates:
347
+ # cer_list = []
348
+ # for text, ref in zip(test_inputs, test_refs):
349
+ # corrupted = corrupt_iast(text, rate)
350
+ # ids = src_tokenizer.encode(corrupted)
351
+ # src = torch.tensor([ids], dtype=torch.long, device=device)
352
+ #
353
+ # if hasattr(model.model, 'generate_cached'):
354
+ # out = model.model.generate_cached(src)
355
+ # else:
356
+ # out = model.generate(src)
357
+ #
358
+ # pred_ids = [x for x in out[0].tolist() if x > 4]
359
+ # pred = tgt_tokenizer.decode(pred_ids).strip()
360
+ # cer_list.append(compute_cer(pred, ref))
361
+ #
362
+ # mean_cer = float(np.mean(cer_list))
363
+ # results[rate] = mean_cer
364
+ # print(f" corruption={rate*100:.0f}% → CER={mean_cer:.4f}")
365
+ #
366
+ # # Save + plot
367
+ # os.makedirs(output_dir, exist_ok=True)
368
+ # try:
369
+ # import matplotlib.pyplot as plt
370
+ # fig, ax = plt.subplots(figsize=(8, 4))
371
+ # rates = [r * 100 for r in corruption_rates]
372
+ # cers = [results[r] for r in corruption_rates]
373
+ # ax.plot(rates, cers, 'o-', linewidth=1.8, color='steelblue', markersize=7)
374
+ # ax.set_xlabel("IAST corruption rate (%)", fontsize=11)
375
+ # ax.set_ylabel("CER", fontsize=11)
376
+ # ax.set_title("Model robustness to IAST input corruption", fontsize=11)
377
+ # ax.set_ylim(0, max(cers) * 1.2)
378
+ # plt.tight_layout()
379
+ # plt.savefig(os.path.join(output_dir, "adversarial_robustness.png"),
380
+ # dpi=150, bbox_inches='tight')
381
+ # plt.close()
382
+ # print(f" Saved: {output_dir}/adversarial_robustness.png")
383
+ # except ImportError:
384
+ # pass
385
+ #
386
+ # with open(os.path.join(output_dir, "adversarial_results.json"), "w") as f:
387
+ # json.dump({str(k): v for k, v in results.items()}, f, indent=2)
388
+ #
389
+ # return results
390
+ """
391
+ analysis/task4_pipeline.py
392
+ ================================
393
+ Correct Task 4 Pipeline:
394
+
395
+ PHASE 1 → Evaluate all models
396
+ PHASE 2 → Analyze + detect optimal T
397
+
398
+ NO early decision making.
399
+ """
400
+
401
+ import torch
402
+ import numpy as np
403
+ import time
404
+ import os
405
+ import json
406
+ from typing import Dict, List
407
+
408
+
409
+ # ───────────────────────��─────────────────────
410
+ # Load Metrics
411
+ # ─────────────────────────────────────────────
412
+
413
+ def load_metrics():
414
+ from bert_score import score as bert_score
415
+ from sentence_transformers import SentenceTransformer, util
416
+ from nltk.translate.bleu_score import sentence_bleu
417
+
418
+ st_model = SentenceTransformer('all-MiniLM-L6-v2')
419
+ return bert_score, st_model, util, sentence_bleu
420
+
421
+
422
+ # ─────────────────────────────────────────────
423
+ # PHASE 1 — Evaluate ALL models
424
+ # ─────────────────────────────────────────────
425
+
426
+ def evaluate_all_models(models: Dict[int, object],
427
+ src_list,
428
+ ref_list,
429
+ tgt_tokenizer,
430
+ n_samples=200):
431
+
432
+ bert_score_fn, st_model, util, bleu_fn = load_metrics()
433
+
434
+ results = {}
435
+
436
+ print("\n=== PHASE 1: Evaluating ALL models ===")
437
+
438
+ for T, model in sorted(models.items()):
439
+ print(f"\nEvaluating T={T}...")
440
+
441
+ device = next(model.parameters()).device
442
+ preds, refs = [], []
443
+
444
+ start = time.perf_counter()
445
+
446
+ for src, ref in zip(src_list[:n_samples], ref_list[:n_samples]):
447
+ if src.dim() == 1:
448
+ src = src.unsqueeze(0)
449
+
450
+ with torch.no_grad():
451
+ out = model.model.generate_cached(src.to(device))
452
+
453
+ ids = [x for x in out[0].tolist() if x > 4]
454
+ pred = tgt_tokenizer.decode(ids).strip()
455
+
456
+ preds.append(pred)
457
+ refs.append(ref)
458
+
459
+ elapsed = time.perf_counter() - start
460
+
461
+ # BERTScore
462
+ P, R, F1 = bert_score_fn(preds, refs, lang="hi", verbose=False)
463
+ bert_f1 = float(F1.mean())
464
+
465
+ # Sentence similarity
466
+ emb_p = st_model.encode(preds, convert_to_tensor=True)
467
+ emb_r = st_model.encode(refs, convert_to_tensor=True)
468
+ sim = util.cos_sim(emb_p, emb_r).diagonal().mean().item()
469
+
470
+ # BLEU
471
+ bleu_scores = [
472
+ bleu_fn([r.split()], p.split())
473
+ for p, r in zip(preds, refs)
474
+ ]
475
+
476
+ results[T] = {
477
+ "bertscore_f1": bert_f1,
478
+ "semantic_sim": sim,
479
+ "bleu": float(np.mean(bleu_scores)),
480
+ "speed_per_sample": elapsed / n_samples
481
+ }
482
+
483
+ print(f" BERTScore: {bert_f1:.4f}")
484
+ print(f" Sim: {sim:.4f}")
485
+ print(f" BLEU: {results[T]['bleu']:.4f}")
486
+ print(f" Speed: {results[T]['speed_per_sample']:.4f}s")
487
+
488
+ # Save raw results
489
+ os.makedirs("analysis/outputs", exist_ok=True)
490
+ with open("analysis/outputs/task4_raw_results.json", "w") as f:
491
+ json.dump(results, f, indent=2)
492
+
493
+ return results
494
+
495
+
496
+ # ─────────────────────────────────────────────
497
+ # PHASE 2 — Analyze results (Knee Detection)
498
+ # ─────────────────────────────────────────────
499
+
500
+ def analyze_results(results: Dict):
501
+ print("\n=== PHASE 2: Analysis ===")
502
+
503
+ T_list = sorted(results.keys())
504
+ scores = [results[T]["bertscore_f1"] for T in T_list]
505
+
506
+ gains = [scores[i+1] - scores[i] for i in range(len(scores)-1)]
507
+
508
+ print("\nMarginal Gains:")
509
+ for i, g in enumerate(gains):
510
+ print(f" T{T_list[i]} → T{T_list[i+1]}: +{g:.4f}")
511
+
512
+ # Knee detection
513
+ threshold = 0.02
514
+ knee_T = T_list[-1]
515
+
516
+ for i, g in enumerate(gains):
517
+ if g < threshold:
518
+ knee_T = T_list[i+1]
519
+ break
520
+
521
+ print(f"\n✅ Optimal T (knee detected): {knee_T}")
522
+
523
+ return knee_T, gains
524
+
525
+
526
+ # ─────────────────────────────────────────────
527
+ # 3D Plot (BERTScore)
528
+ # ─────────────────────────────────────────────
529
+
530
+ def plot_3d(results):
531
+ import matplotlib.pyplot as plt
532
+ from mpl_toolkits.mplot3d import Axes3D
533
+
534
+ T_list = sorted(results.keys())
535
+
536
+ X = T_list
537
+ Y = [results[T]["speed_per_sample"] for T in T_list]
538
+ Z = [results[T]["bertscore_f1"] for T in T_list]
539
+
540
+ fig = plt.figure(figsize=(10, 6))
541
+ ax = fig.add_subplot(111, projection='3d')
542
+
543
+ ax.scatter(X, Y, Z)
544
+
545
+ for x, y, z in zip(X, Y, Z):
546
+ ax.text(x, y, z, f"T={x}", fontsize=8)
547
+
548
+ ax.set_xlabel("Diffusion Steps")
549
+ ax.set_ylabel("Speed")
550
+ ax.set_zlabel("BERTScore")
551
+
552
+ plt.title("3D Tradeoff: Steps vs Speed vs Quality")
553
+
554
+ os.makedirs("analysis/outputs", exist_ok=True)
555
+ plt.savefig("analysis/outputs/task4_3d.png")
556
+ plt.close()
557
+
558
+ print("Saved 3D plot")
559
+
560
+
561
+ # ────────────���────────────────────────────────
562
+ # FINAL RUNNER
563
+ # ─────────────────────────────────────────────
564
+
565
+ def run_task4(models, src_list, ref_list, tgt_tokenizer):
566
+
567
+ # Phase 1: Evaluate all
568
+ results = evaluate_all_models(
569
+ models, src_list, ref_list, tgt_tokenizer
570
+ )
571
+
572
+ # Phase 2: Analyze
573
+ knee_T, gains = analyze_results(results)
574
+
575
+ # Plot
576
+ plot_3d(results)
577
+
578
+ # Save report
579
+ with open("analysis/outputs/task4_report.txt", "w") as f:
580
+ f.write(f"Optimal diffusion steps = {knee_T}\n")
581
+
582
+ return knee_T
app.py CHANGED
@@ -1,235 +1,547 @@
1
- """
2
- Hugging Face Space app for Sanskrit D3PM project.
3
-
4
- Deploy on Spaces with:
5
- app_file = app_hf_space.py
6
-
7
- Optional environment variables:
8
- HF_CHECKPOINT_REPO : model repo id (e.g. "username/sanskrit-d3pm")
9
- HF_CHECKPOINT_FILE : checkpoint path in repo (default: "best_model.pt")
10
- HF_CHECKPOINT_LABEL : UI label for remote checkpoint
11
- """
12
-
13
- from __future__ import annotations
14
-
15
  import copy
 
16
  import os
17
- from typing import Dict, Tuple
 
 
18
 
19
  import gradio as gr
20
  import torch
 
21
 
22
  from config import CONFIG
23
  from inference import _build_tokenizers, _resolve_device, load_model, run_inference
24
 
25
 
26
- def _clean_output(text: str, max_repeat: int = 2) -> str:
27
- text = " ".join(text.split())
28
- if not text:
29
- return text
30
- toks = text.split()
31
- out = []
32
- prev = None
33
- run = 0
34
- for t in toks:
35
- if t == prev:
36
- run += 1
37
- else:
38
- prev = t
39
- run = 1
40
- if run <= max_repeat:
41
- out.append(t)
42
- s = " ".join(out)
43
- s = s.replace(" ।", "।").replace(" ॥", "॥")
44
- return " ".join(s.split())
45
 
46
 
47
- def _discover_local_checkpoints() -> Dict[str, str]:
48
- found = {}
49
  for root in ("ablation_results", "results7", "results"):
50
  if not os.path.isdir(root):
51
  continue
52
- for exp in sorted(os.listdir(root)):
53
- ckpt = os.path.join(root, exp, "best_model.pt")
54
- if os.path.exists(ckpt):
55
- found[f"{exp} [{root}]"] = ckpt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  return found
57
 
58
 
59
- def _discover_remote_checkpoint() -> Dict[str, str]:
60
- repo = os.getenv("HF_CHECKPOINT_REPO", "").strip()
61
- if not repo:
62
- return {}
63
-
64
- filename = os.getenv("HF_CHECKPOINT_FILE", "best_model.pt").strip()
65
- label = os.getenv("HF_CHECKPOINT_LABEL", f"remote:{repo}")
66
 
67
- try:
68
- from huggingface_hub import hf_hub_download
69
 
70
- ckpt_path = hf_hub_download(repo_id=repo, filename=filename)
71
- return {label: ckpt_path}
72
- except Exception as e:
73
- print(f"[WARN] remote checkpoint download failed: {e}")
74
- return {}
 
 
 
75
 
76
 
77
- def _infer_model_type(path: str) -> str:
78
- p = path.lower()
79
- if "d3pm_encoder_decoder" in p:
 
 
 
80
  return "d3pm_encoder_decoder"
81
- if "baseline_cross_attention" in p:
82
  return "baseline_cross_attention"
83
- if "baseline_encoder_decoder" in p:
84
  return "baseline_encoder_decoder"
85
- return "d3pm_cross_attention"
86
 
87
 
88
- def _infer_neg(path: str) -> bool:
89
- p = path.lower()
90
- if "_neg_true" in p:
 
91
  return True
92
- if "_neg_false" in p:
93
  return False
94
  return CONFIG["data"]["include_negative_examples"]
95
 
96
 
97
- class RuntimeStore:
98
- def __init__(self):
99
- self.loaded: Dict[str, Dict] = {}
100
-
101
- def get(self, ckpt_label: str, ckpt_path: str) -> Dict:
102
- if ckpt_label in self.loaded:
103
- return self.loaded[ckpt_label]
104
-
105
- cfg = copy.deepcopy(CONFIG)
106
- cfg["model_type"] = _infer_model_type(ckpt_path)
107
- cfg["data"]["include_negative_examples"] = _infer_neg(ckpt_path)
108
- device = _resolve_device(cfg)
109
-
110
- model, cfg = load_model(ckpt_path, cfg, device)
111
- src_tok, tgt_tok = _build_tokenizers(cfg)
112
-
113
- bundle = {
114
- "label": ckpt_label,
115
- "path": ckpt_path,
116
- "cfg": cfg,
117
- "device": str(device),
118
- "model": model,
119
- "src_tok": src_tok,
120
- "tgt_tok": tgt_tok,
121
- }
122
- self.loaded[ckpt_label] = bundle
123
- return bundle
124
-
125
-
126
- RUNTIME = RuntimeStore()
127
- CHECKPOINTS = {}
128
- CHECKPOINTS.update(_discover_local_checkpoints())
129
- CHECKPOINTS.update(_discover_remote_checkpoint())
130
-
131
- if not CHECKPOINTS:
132
- CHECKPOINTS = {"No checkpoint found": ""}
133
-
134
-
135
- def load_checkpoint_ui(label: str) -> Tuple[Dict, str]:
136
- if label not in CHECKPOINTS or not CHECKPOINTS[label]:
137
- raise gr.Error("No valid checkpoint found. Upload/provide best_model.pt first.")
138
- bundle = RUNTIME.get(label, CHECKPOINTS[label])
139
- info = (
140
- f"Loaded `{label}`\n"
141
- f"- path: `{bundle['path']}`\n"
142
- f"- model_type: `{bundle['cfg']['model_type']}`\n"
143
- f"- device: `{bundle['device']}`\n"
144
- f"- max_seq_len: `{bundle['cfg']['model']['max_seq_len']}`"
145
- )
146
- return bundle, info
147
-
148
-
149
- def generate_ui(
150
- bundle: Dict,
151
- text: str,
152
- temperature: float,
153
- top_k: int,
154
- repetition_penalty: float,
155
- diversity_penalty: float,
156
- num_steps: int,
157
- clean_output: bool,
158
- ) -> str:
159
- if not bundle:
160
- raise gr.Error("Load a checkpoint first.")
161
- if not text.strip():
162
- raise gr.Error("Enter input text.")
163
-
164
- cfg = copy.deepcopy(bundle["cfg"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  cfg["inference"]["temperature"] = float(temperature)
166
  cfg["inference"]["top_k"] = int(top_k)
167
  cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
168
  cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
169
  cfg["inference"]["num_steps"] = int(num_steps)
170
 
171
- src_tok = bundle["src_tok"]
172
- tgt_tok = bundle["tgt_tok"]
173
- device = torch.device(bundle["device"])
174
- ids = torch.tensor([src_tok.encode(text.strip())], dtype=torch.long, device=device)
175
-
176
- out = run_inference(bundle["model"], ids, cfg)
177
- token_ids = [x for x in out[0].tolist() if x > 4]
178
- pred = tgt_tok.decode(token_ids).strip()
179
- if clean_output:
180
- pred = _clean_output(pred)
181
- return pred if pred else "(empty output)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
 
184
- with gr.Blocks(title="Sanskrit D3PM Space") as demo:
185
  model_state = gr.State(None)
 
186
  gr.Markdown(
187
  """
188
- ## Sanskrit D3PM Paraphrase (IAST → Devanagari)
189
- Load a trained checkpoint and generate output from Roman/IAST Sanskrit input.
 
 
 
 
190
  """
191
  )
192
 
193
- checkpoint = gr.Dropdown(
194
- choices=list(CHECKPOINTS.keys()),
195
- value=list(CHECKPOINTS.keys())[0],
196
- label="Checkpoint",
197
- )
198
- load_btn = gr.Button("Load Model", variant="primary")
199
- load_info = gr.Markdown("Select a checkpoint and click **Load Model**.")
200
-
201
- text_in = gr.Textbox(label="Input (Roman / IAST)", lines=3, value="dharmo rakṣati rakṣitaḥ")
202
- text_out = gr.Textbox(label="Output (Devanagari)", lines=6)
203
-
204
  with gr.Row():
205
- temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
206
- top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
207
- repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
208
- diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
209
- num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
210
- clean_output = gr.Checkbox(value=True, label="Clean Output")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
- generate_btn = gr.Button("Generate", variant="primary")
 
 
 
 
213
 
214
- load_btn.click(load_checkpoint_ui, inputs=[checkpoint], outputs=[model_state, load_info])
215
  generate_btn.click(
216
- generate_ui,
217
  inputs=[
218
- model_state, text_in, temperature, top_k, repetition_penalty,
219
- diversity_penalty, num_steps, clean_output
 
 
 
 
 
 
220
  ],
221
- outputs=[text_out],
222
  )
223
- text_in.submit(
224
- generate_ui,
225
  inputs=[
226
- model_state, text_in, temperature, top_k, repetition_penalty,
227
- diversity_penalty, num_steps, clean_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  ],
229
- outputs=[text_out],
230
  )
231
 
232
 
233
  if __name__ == "__main__":
234
- port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
235
  demo.launch(server_name="0.0.0.0", server_port=port, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import copy
2
+ import json
3
  import os
4
+ import subprocess
5
+ import sys
6
+ from datetime import datetime
7
 
8
  import gradio as gr
9
  import torch
10
+ from huggingface_hub import hf_hub_download, list_repo_files
11
 
12
  from config import CONFIG
13
  from inference import _build_tokenizers, _resolve_device, load_model, run_inference
14
 
15
 
16
+ RESULTS_DIR = "generated_results"
17
+ DEFAULT_ANALYSIS_OUT = "analysis/outputs"
18
+ os.makedirs(RESULTS_DIR, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
 
21
+ def discover_checkpoints():
22
+ found = []
23
  for root in ("ablation_results", "results7", "results"):
24
  if not os.path.isdir(root):
25
  continue
26
+ for entry in sorted(os.listdir(root)):
27
+ ckpt = os.path.join(root, entry, "best_model.pt")
28
+ if not os.path.exists(ckpt):
29
+ continue
30
+ found.append(
31
+ {
32
+ "label": f"{entry} [{root}]",
33
+ "path": ckpt,
34
+ "experiment": entry,
35
+ "root": root,
36
+ }
37
+ )
38
+ repo = os.getenv("HF_CHECKPOINT_REPO", "").strip()
39
+ if repo:
40
+ branch = os.getenv("HF_CHECKPOINT_REVISION", "main").strip() or "main"
41
+ try:
42
+ for fname in list_repo_files(repo_id=repo, repo_type="model", revision=branch):
43
+ if not fname.endswith("/best_model.pt") and fname != "best_model.pt":
44
+ continue
45
+ local_path = hf_hub_download(repo_id=repo, filename=fname, revision=branch, repo_type="model")
46
+ parent = os.path.basename(os.path.dirname(fname)) if "/" in fname else "remote"
47
+ root = os.path.dirname(fname).split("/")[0] if "/" in fname else "remote"
48
+ found.append(
49
+ {
50
+ "label": f"{parent} [hf:{repo}]",
51
+ "path": local_path,
52
+ "experiment": parent,
53
+ "root": root,
54
+ }
55
+ )
56
+ except Exception as e:
57
+ print(f"[WARN] Could not discover remote checkpoints from {repo}: {e}")
58
  return found
59
 
60
 
61
+ def checkpoint_map():
62
+ return {item["label"]: item for item in discover_checkpoints()}
 
 
 
 
 
63
 
 
 
64
 
65
+ def default_checkpoint_label():
66
+ cps = discover_checkpoints()
67
+ if not cps:
68
+ return None
69
+ for item in cps:
70
+ if item["path"].endswith("ablation_results/T4/best_model.pt"):
71
+ return item["label"]
72
+ return cps[0]["label"]
73
 
74
 
75
+ def infer_model_type(experiment_name: str, root: str = "") -> str:
76
+ if root == "ablation_results":
77
+ return "d3pm_cross_attention"
78
+ if experiment_name.startswith("d3pm_cross_attention"):
79
+ return "d3pm_cross_attention"
80
+ if experiment_name.startswith("d3pm_encoder_decoder"):
81
  return "d3pm_encoder_decoder"
82
+ if experiment_name.startswith("baseline_cross_attention"):
83
  return "baseline_cross_attention"
84
+ if experiment_name.startswith("baseline_encoder_decoder"):
85
  return "baseline_encoder_decoder"
86
+ return CONFIG["model_type"]
87
 
88
 
89
+ def infer_include_negative(experiment_name: str, root: str = "") -> bool:
90
+ if root == "ablation_results":
91
+ return False
92
+ if "_neg_True" in experiment_name:
93
  return True
94
+ if "_neg_False" in experiment_name:
95
  return False
96
  return CONFIG["data"]["include_negative_examples"]
97
 
98
 
99
+ def build_runtime_cfg(ckpt_path: str):
100
+ experiment = os.path.basename(os.path.dirname(ckpt_path)) or "remote"
101
+ root = os.path.basename(os.path.dirname(os.path.dirname(ckpt_path))) or "remote"
102
+ cfg = copy.deepcopy(CONFIG)
103
+ cfg["model_type"] = infer_model_type(experiment, root=root)
104
+ cfg["data"]["include_negative_examples"] = infer_include_negative(experiment, root=root)
105
+
106
+ if root == "ablation_results" and experiment.startswith("T") and experiment[1:].isdigit():
107
+ t_val = int(experiment[1:])
108
+ cfg["model"]["diffusion_steps"] = t_val
109
+ cfg["inference"]["num_steps"] = t_val
110
+
111
+ device = _resolve_device(cfg)
112
+ return cfg, device, experiment
113
+
114
+
115
+ def load_selected_model(checkpoint_label):
116
+ mapping = checkpoint_map()
117
+ if checkpoint_label not in mapping:
118
+ raise gr.Error("Selected checkpoint not found. Click refresh.")
119
+
120
+ ckpt_path = mapping[checkpoint_label]["path"]
121
+ cfg, device, experiment = build_runtime_cfg(ckpt_path)
122
+ model, cfg = load_model(ckpt_path, cfg, device)
123
+ src_tok, tgt_tok = _build_tokenizers(cfg)
124
+
125
+ bundle = {
126
+ "ckpt_path": ckpt_path,
127
+ "experiment": experiment,
128
+ "device": str(device),
129
+ "cfg": cfg,
130
+ "model": model,
131
+ "src_tok": src_tok,
132
+ "tgt_tok": tgt_tok,
133
+ }
134
+ model_info = {
135
+ "checkpoint": ckpt_path,
136
+ "experiment": experiment,
137
+ "model_type": cfg["model_type"],
138
+ "include_negatives": cfg["data"]["include_negative_examples"],
139
+ "device": str(device),
140
+ "max_seq_len": cfg["model"]["max_seq_len"],
141
+ "diffusion_steps": cfg["model"]["diffusion_steps"],
142
+ "inference_steps": cfg["inference"]["num_steps"],
143
+ "d_model": cfg["model"]["d_model"],
144
+ "n_layers": cfg["model"]["n_layers"],
145
+ "n_heads": cfg["model"]["n_heads"],
146
+ }
147
+ status = f"Loaded `{experiment}` on `{device}` (`{cfg['model_type']}`)"
148
+ suggested_out = os.path.join("analysis", "outputs_ui", experiment)
149
+ return bundle, status, model_info, cfg["inference"]["num_steps"], suggested_out
150
+
151
+
152
+ def apply_preset(preset_name):
153
+ presets = {
154
+ "Manual": (0.70, 40, 1.20, 0.0),
155
+ "Literal": (0.60, 20, 1.25, 0.0),
156
+ "Balanced": (0.70, 40, 1.20, 0.0),
157
+ "Creative": (0.90, 80, 1.05, 0.2),
158
+ }
159
+ return presets.get(preset_name, presets["Balanced"])
160
+
161
+
162
+ def clean_generated_text(text: str, max_consecutive: int = 2) -> str:
163
+ text = " ".join(text.split())
164
+ if not text:
165
+ return text
166
+ tokens = text.split()
167
+ cleaned = []
168
+ prev = None
169
+ run = 0
170
+ for tok in tokens:
171
+ if tok == prev:
172
+ run += 1
173
+ else:
174
+ prev = tok
175
+ run = 1
176
+ if run <= max_consecutive:
177
+ cleaned.append(tok)
178
+ out = " ".join(cleaned).replace(" ।", "।").replace(" ॥", "॥")
179
+ return " ".join(out.split())
180
+
181
+
182
+ def save_generation(experiment, record):
183
+ ts = datetime.now().strftime("%Y%m%d")
184
+ path = os.path.join(RESULTS_DIR, f"{experiment}_ui_{ts}.json")
185
+ existing = []
186
+ if os.path.exists(path):
187
+ with open(path, "r", encoding="utf-8") as f:
188
+ existing = json.load(f)
189
+ existing.append(record)
190
+ with open(path, "w", encoding="utf-8") as f:
191
+ json.dump(existing, f, ensure_ascii=False, indent=2)
192
+ return path
193
+
194
+
195
+ def generate_from_ui(
196
+ model_bundle,
197
+ input_text,
198
+ temperature,
199
+ top_k,
200
+ repetition_penalty,
201
+ diversity_penalty,
202
+ num_steps,
203
+ clean_output,
204
+ ):
205
+ if not model_bundle:
206
+ raise gr.Error("Load a model first.")
207
+ if not input_text.strip():
208
+ raise gr.Error("Enter input text first.")
209
+
210
+ cfg = copy.deepcopy(model_bundle["cfg"])
211
  cfg["inference"]["temperature"] = float(temperature)
212
  cfg["inference"]["top_k"] = int(top_k)
213
  cfg["inference"]["repetition_penalty"] = float(repetition_penalty)
214
  cfg["inference"]["diversity_penalty"] = float(diversity_penalty)
215
  cfg["inference"]["num_steps"] = int(num_steps)
216
 
217
+ src_tok = model_bundle["src_tok"]
218
+ tgt_tok = model_bundle["tgt_tok"]
219
+ device = torch.device(model_bundle["device"])
220
+
221
+ input_ids = torch.tensor([src_tok.encode(input_text.strip())], dtype=torch.long, device=device)
222
+ out = run_inference(model_bundle["model"], input_ids, cfg)
223
+
224
+ # Align decode with validation style: strip only special ids.
225
+ pad_id = 1
226
+ mask_id = cfg["diffusion"]["mask_token_id"]
227
+ decoded_ids = [x for x in out[0].tolist() if x not in (pad_id, mask_id)]
228
+ raw_output_text = tgt_tok.decode(decoded_ids).strip()
229
+ output_text = clean_generated_text(raw_output_text) if clean_output else raw_output_text
230
+ if not output_text:
231
+ output_text = "(empty output)"
232
+
233
+ record = {
234
+ "timestamp": datetime.now().isoformat(timespec="seconds"),
235
+ "experiment": model_bundle["experiment"],
236
+ "checkpoint": model_bundle["ckpt_path"],
237
+ "input_text": input_text,
238
+ "raw_output_text": raw_output_text,
239
+ "output_text": output_text,
240
+ "temperature": float(temperature),
241
+ "top_k": int(top_k),
242
+ "repetition_penalty": float(repetition_penalty),
243
+ "diversity_penalty": float(diversity_penalty),
244
+ "num_steps": int(num_steps),
245
+ "clean_output": bool(clean_output),
246
+ }
247
+ log_path = save_generation(model_bundle["experiment"], record)
248
+ status = f"Inference done. Saved: `{log_path}`"
249
+ return output_text, status, record
250
+
251
+
252
+ def _run_analysis_cmd(task, ckpt_path, output_dir, input_text="dharmo rakṣati rakṣitaḥ", phase="analyze"):
253
+ os.makedirs(output_dir, exist_ok=True)
254
+ cmd = [
255
+ sys.executable,
256
+ "analysis/run_analysis.py",
257
+ "--task",
258
+ str(task),
259
+ "--checkpoint",
260
+ ckpt_path,
261
+ "--output_dir",
262
+ output_dir,
263
+ ]
264
+ if str(task) == "2" or str(task) == "all":
265
+ cmd.extend(["--input", input_text])
266
+ if str(task) == "4":
267
+ cmd.extend(["--phase", phase])
268
+
269
+ env = os.environ.copy()
270
+ env.setdefault("HF_HOME", "/tmp/hf_home")
271
+ env.setdefault("HF_DATASETS_CACHE", "/tmp/hf_datasets")
272
+ env.setdefault("HF_HUB_CACHE", "/tmp/hf_hub")
273
+
274
+ proc = subprocess.run(cmd, capture_output=True, text=True, env=env)
275
+ log = f"$ {' '.join(cmd)}\n\n{proc.stdout}\n{proc.stderr}"
276
+ return proc.returncode, log
277
+
278
+
279
+ def run_single_task(model_bundle, task, output_dir, input_text, task4_phase):
280
+ if not model_bundle:
281
+ raise gr.Error("Load a model first.")
282
+ code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
283
+ status = f"Task {task} {'completed' if code == 0 else 'failed'} (exit={code})."
284
+ return status, log
285
+
286
+
287
+ def run_all_tasks(model_bundle, output_dir, input_text, task4_phase):
288
+ if not model_bundle:
289
+ raise gr.Error("Load a model first.")
290
+ logs = []
291
+ failures = 0
292
+ for task in ["1", "2", "3", "4", "5"]:
293
+ code, log = _run_analysis_cmd(task, model_bundle["ckpt_path"], output_dir, input_text, task4_phase)
294
+ logs.append(f"\n\n{'='*22} TASK {task} {'='*22}\n{log}")
295
+ if code != 0:
296
+ failures += 1
297
+ status = f"Run-all finished with {failures} failed task(s)." if failures else "All 5 tasks completed."
298
+ return status, "".join(logs)
299
+
300
+
301
+ def _read_text(path):
302
+ if not os.path.exists(path):
303
+ return "Not found."
304
+ with open(path, "r", encoding="utf-8", errors="ignore") as f:
305
+ return f.read()
306
+
307
+
308
+ def _img_or_none(path):
309
+ return path if os.path.exists(path) else None
310
+
311
+
312
+ def refresh_task_outputs(output_dir):
313
+ task1_txt = _read_text(os.path.join(output_dir, "task1_kv_cache.txt"))
314
+ task2_txt = _read_text(os.path.join(output_dir, "task2_report.txt"))
315
+ task3_txt = _read_text(os.path.join(output_dir, "task3_report.txt"))
316
+ task5_txt = _read_text(os.path.join(output_dir, "task5_report.txt"))
317
+
318
+ task2_drift = _img_or_none(os.path.join(output_dir, "task2_semantic_drift.png"))
319
+ task2_attn = _img_or_none(os.path.join(output_dir, "task2_attn_t0.png"))
320
+ task3_space = _img_or_none(os.path.join(output_dir, "task3_concept_space.png"))
321
+ task4_plot = _img_or_none(os.path.join(output_dir, "task4_ablation_3d.png"))
322
+ return task1_txt, task2_txt, task2_drift, task2_attn, task3_txt, task3_space, task5_txt, task4_plot
323
+
324
+
325
+ CUSTOM_CSS = """
326
+ :root {
327
+ --bg1: #f5fbff;
328
+ --bg2: #f2f7ef;
329
+ --card: #ffffff;
330
+ --line: #d9e6f2;
331
+ --ink: #163048;
332
+ }
333
+ .gradio-container {
334
+ background: linear-gradient(130deg, var(--bg1), var(--bg2));
335
+ color: var(--ink);
336
+ }
337
+ #hero {
338
+ background: radial-gradient(110% 130% at 0% 0%, #d7ebff 0%, #ecf6ff 55%, #f8fbff 100%);
339
+ border: 1px solid #cfe0f1;
340
+ border-radius: 16px;
341
+ padding: 18px 20px;
342
+ }
343
+ .panel {
344
+ background: var(--card);
345
+ border: 1px solid var(--line);
346
+ border-radius: 14px;
347
+ }
348
+ """
349
 
350
 
351
+ with gr.Blocks(title="Sanskrit Diffusion Client Demo", css=CUSTOM_CSS) as demo:
352
  model_state = gr.State(None)
353
+
354
  gr.Markdown(
355
  """
356
+ <div id="hero">
357
+ <h1 style="margin:0;">Sanskrit Diffusion Client Demo</h1>
358
+ <p style="margin:.5rem 0 0 0;">
359
+ Select any trained model, run all 5 analysis tasks or individual tasks, then test inference with user-controlled parameters.
360
+ </p>
361
+ </div>
362
  """
363
  )
364
 
 
 
 
 
 
 
 
 
 
 
 
365
  with gr.Row():
366
+ with gr.Column(scale=2, elem_classes=["panel"]):
367
+ checkpoint_dropdown = gr.Dropdown(
368
+ label="Model Checkpoint",
369
+ choices=list(checkpoint_map().keys()),
370
+ value=default_checkpoint_label(),
371
+ interactive=True,
372
+ )
373
+ with gr.Column(scale=1, elem_classes=["panel"]):
374
+ refresh_btn = gr.Button("Refresh Models")
375
+ load_btn = gr.Button("Load Selected Model", variant="primary")
376
+
377
+ load_status = gr.Markdown("Select a model and load.")
378
+ model_info = gr.JSON(label="Loaded Model Details")
379
+
380
+ with gr.Tabs():
381
+ with gr.Tab("1) Task Runner"):
382
+ with gr.Row():
383
+ with gr.Column(scale=2):
384
+ analysis_output_dir = gr.Textbox(
385
+ label="Analysis Output Directory",
386
+ value=DEFAULT_ANALYSIS_OUT,
387
+ )
388
+ analysis_input = gr.Textbox(
389
+ label="Task 2 Input Text",
390
+ value="dharmo rakṣati rakṣitaḥ",
391
+ lines=2,
392
+ )
393
+ with gr.Column(scale=1):
394
+ task4_phase = gr.Dropdown(
395
+ choices=["analyze", "generate_configs"],
396
+ value="analyze",
397
+ label="Task 4 Phase",
398
+ )
399
+ run_all_btn = gr.Button("Run All 5 Tasks", variant="primary")
400
+
401
+ with gr.Row():
402
+ task_choice = gr.Dropdown(
403
+ choices=["1", "2", "3", "4", "5"],
404
+ value="1",
405
+ label="Single Task",
406
+ )
407
+ run_single_btn = gr.Button("Run Selected Task")
408
+ refresh_outputs_btn = gr.Button("Refresh Output Viewer")
409
+
410
+ task_run_status = gr.Markdown("")
411
+ task_run_log = gr.Textbox(label="Task Execution Log", lines=18, interactive=False)
412
+
413
+ with gr.Accordion("Task Outputs Viewer", open=True):
414
+ task1_box = gr.Textbox(label="Task 1 Report", lines=10, interactive=False)
415
+ task2_box = gr.Textbox(label="Task 2 Report", lines=10, interactive=False)
416
+ with gr.Row():
417
+ task2_drift_img = gr.Image(label="Task2 Drift", type="filepath")
418
+ task2_attn_img = gr.Image(label="Task2 Attention", type="filepath")
419
+ task3_box = gr.Textbox(label="Task 3 Report", lines=10, interactive=False)
420
+ task3_img = gr.Image(label="Task3 Concept Space", type="filepath")
421
+ task5_box = gr.Textbox(label="Task 5 Report", lines=10, interactive=False)
422
+ task4_img = gr.Image(label="Task4 3D Ablation Plot", type="filepath")
423
+
424
+ with gr.Tab("2) Inference Playground"):
425
+ with gr.Row():
426
+ with gr.Column(scale=2):
427
+ input_text = gr.Textbox(
428
+ label="Input (Roman / IAST)",
429
+ lines=4,
430
+ value="dharmo rakṣati rakṣitaḥ",
431
+ )
432
+ output_text = gr.Textbox(
433
+ label="Output (Devanagari)",
434
+ lines=7,
435
+ interactive=False,
436
+ )
437
+ run_status = gr.Markdown("")
438
+ run_record = gr.JSON(label="Inference Metadata")
439
+ with gr.Column(scale=1, elem_classes=["panel"]):
440
+ preset = gr.Radio(["Manual", "Literal", "Balanced", "Creative"], value="Balanced", label="Preset")
441
+ temperature = gr.Slider(0.4, 1.2, value=0.70, step=0.05, label="Temperature")
442
+ top_k = gr.Slider(5, 100, value=40, step=1, label="Top-K")
443
+ repetition_penalty = gr.Slider(1.0, 3.0, value=1.20, step=0.05, label="Repetition Penalty")
444
+ diversity_penalty = gr.Slider(0.0, 1.0, value=0.0, step=0.05, label="Diversity Penalty")
445
+ num_steps = gr.Slider(1, 128, value=64, step=1, label="Inference Steps")
446
+ clean_output = gr.Checkbox(value=True, label="Clean Output")
447
+ generate_btn = gr.Button("Generate", variant="primary")
448
+
449
+ gr.Examples(
450
+ examples=[
451
+ ["dharmo rakṣati rakṣitaḥ"],
452
+ ["satyameva jayate"],
453
+ ["yadā mano nivarteta viṣayebhyaḥ svabhāvataḥ"],
454
+ ],
455
+ inputs=[input_text],
456
+ )
457
+
458
+ def refresh_checkpoints():
459
+ choices = list(checkpoint_map().keys())
460
+ value = default_checkpoint_label() if choices else None
461
+ return gr.Dropdown(choices=choices, value=value)
462
+
463
+ refresh_btn.click(fn=refresh_checkpoints, outputs=[checkpoint_dropdown])
464
+ load_btn.click(
465
+ fn=load_selected_model,
466
+ inputs=[checkpoint_dropdown],
467
+ outputs=[model_state, load_status, model_info, num_steps, analysis_output_dir],
468
+ )
469
 
470
+ preset.change(
471
+ fn=apply_preset,
472
+ inputs=[preset],
473
+ outputs=[temperature, top_k, repetition_penalty, diversity_penalty],
474
+ )
475
 
 
476
  generate_btn.click(
477
+ fn=generate_from_ui,
478
  inputs=[
479
+ model_state,
480
+ input_text,
481
+ temperature,
482
+ top_k,
483
+ repetition_penalty,
484
+ diversity_penalty,
485
+ num_steps,
486
+ clean_output,
487
  ],
488
+ outputs=[output_text, run_status, run_record],
489
  )
490
+ input_text.submit(
491
+ fn=generate_from_ui,
492
  inputs=[
493
+ model_state,
494
+ input_text,
495
+ temperature,
496
+ top_k,
497
+ repetition_penalty,
498
+ diversity_penalty,
499
+ num_steps,
500
+ clean_output,
501
+ ],
502
+ outputs=[output_text, run_status, run_record],
503
+ )
504
+
505
+ run_single_btn.click(
506
+ fn=run_single_task,
507
+ inputs=[model_state, task_choice, analysis_output_dir, analysis_input, task4_phase],
508
+ outputs=[task_run_status, task_run_log],
509
+ )
510
+ run_all_btn.click(
511
+ fn=run_all_tasks,
512
+ inputs=[model_state, analysis_output_dir, analysis_input, task4_phase],
513
+ outputs=[task_run_status, task_run_log],
514
+ )
515
+ refresh_outputs_btn.click(
516
+ fn=refresh_task_outputs,
517
+ inputs=[analysis_output_dir],
518
+ outputs=[
519
+ task1_box,
520
+ task2_box,
521
+ task2_drift_img,
522
+ task2_attn_img,
523
+ task3_box,
524
+ task3_img,
525
+ task5_box,
526
+ task4_img,
527
+ ],
528
+ )
529
+ demo.load(
530
+ fn=refresh_task_outputs,
531
+ inputs=[analysis_output_dir],
532
+ outputs=[
533
+ task1_box,
534
+ task2_box,
535
+ task2_drift_img,
536
+ task2_attn_img,
537
+ task3_box,
538
+ task3_img,
539
+ task5_box,
540
+ task4_img,
541
  ],
 
542
  )
543
 
544
 
545
  if __name__ == "__main__":
546
+ port = int(os.environ["GRADIO_SERVER_PORT"]) if "GRADIO_SERVER_PORT" in os.environ else None
547
  demo.launch(server_name="0.0.0.0", server_port=port, share=False)
data/__init__.py ADDED
File without changes
data/dataset.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ dataset.py — Cross-Script Translation Fix
3
+ ==========================================
4
+ INPUT : quote_text (Roman/IAST transliteration of Sanskrit)
5
+ TARGET : quote_devanagari (Devanagari script)
6
+
7
+ This is the CORRECT task: the model learns to transliterate / translate
8
+ Roman Sanskrit → Devanagari, which is a meaningful, learnable mapping
9
+ (far better than devanagari→devanagari reconstruction which teaches nothing).
10
+
11
+ KEY CHANGES from original:
12
+ 1. _input_field = 'quote_text' (was 'quote_devanagari')
13
+ 2. _target_field = 'quote_devanagari' (unchanged)
14
+ 3. Separate source/target tokenizers — Roman and Devanagari have
15
+ completely different character sets; a shared BPE vocab forces the
16
+ model to learn both scripts in one embedding table, which wastes
17
+ capacity and confuses the attention mechanism.
18
+ 4. Negative example generation fixed — reversal now operates on
19
+ DEVANAGARI target only (not accidentally on Roman source).
20
+ 5. curriculum_sort uses target length (Devanagari) for difficulty proxy.
21
+ """
22
+
23
+ from datasets import load_dataset
24
+ from torch.utils.data import Dataset
25
+ import torch
26
+ import torch.nn.functional as F
27
+ import random
28
+
29
+
30
+ class OptimizedSanskritDataset(Dataset):
31
+ def __init__(self, split='train', tokenizer=None, max_len=80, cfg=None,
32
+ src_tokenizer=None, tgt_tokenizer=None):
33
+ """
34
+ Args:
35
+ tokenizer : shared tokenizer (legacy — used if src/tgt not provided)
36
+ src_tokenizer : tokenizer for quote_text (Roman script)
37
+ tgt_tokenizer : tokenizer for quote_devanagari (Devanagari script)
38
+ If None, falls back to shared `tokenizer`.
39
+ """
40
+ from config import CONFIG
41
+ self.cfg = cfg or CONFIG
42
+ self.max_len = max_len
43
+ self.pad_id = 1
44
+ self.mask_id = self.cfg['diffusion']['mask_token_id']
45
+ self.include_negatives = self.cfg['data']['include_negative_examples']
46
+
47
+ # ── Tokenizer setup ───────────────────────────────────────────
48
+ # Support both legacy (shared) and new (separate src/tgt) tokenizers
49
+ self.src_tokenizer = src_tokenizer or tokenizer
50
+ self.tgt_tokenizer = tgt_tokenizer or tokenizer
51
+
52
+ if self.src_tokenizer is None:
53
+ raise ValueError("Provide at least one tokenizer.")
54
+
55
+ print(f"📥 Loading '{split}' split …")
56
+ raw = load_dataset("paws/sanskrit-verses-gretil", split=split)
57
+ cols = raw.column_names
58
+
59
+ # ── Field selection ───────────────────────────────────────────
60
+ if 'quote_text' in cols and 'quote_devanagari' in cols:
61
+ # CORRECT setup: Roman input → Devanagari output
62
+ self._input_field = 'quote_text'
63
+ self._target_field = 'quote_devanagari'
64
+ print(" Format: quote_text (Roman) → quote_devanagari (Devanagari) ✓")
65
+ elif 'sentence1' in cols and 'sentence2' in cols:
66
+ # PAWS paraphrase pairs fallback
67
+ self._input_field = 'sentence1'
68
+ self._target_field = 'sentence2'
69
+ print(" Format: PAWS sentence pairs ✓")
70
+ else:
71
+ # Last resort: same field both sides
72
+ self._input_field = 'quote_devanagari'
73
+ self._target_field = 'quote_devanagari'
74
+ print(" ⚠️ Format: Devanagari→Devanagari (suboptimal — no quote_text found)")
75
+
76
+ # ── Filter empty rows ─────────────────────────────────────────
77
+ # Some rows have empty quote_text — skip them
78
+ raw = raw.filter(
79
+ lambda ex: (
80
+ bool(ex[self._input_field].strip()) and
81
+ bool(ex[self._target_field].strip())
82
+ )
83
+ )
84
+ print(f" After empty-filter: {len(raw)} samples")
85
+
86
+ self.dataset = raw
87
+
88
+ if split == 'train':
89
+ self.dataset = self._curriculum_sort()
90
+
91
+ print(f"✅ {len(self.dataset)} samples loaded.")
92
+
93
+ # ── Encoding ──────────────────────────────────────────────────────
94
+
95
+ def _encode_src(self, text):
96
+ """Encode source (Roman) text."""
97
+ ids = self.src_tokenizer.encode(text)[:self.max_len]
98
+ t = torch.tensor(ids, dtype=torch.long)
99
+ t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id)
100
+ return t
101
+
102
+ def _encode_tgt(self, text):
103
+ """Encode target (Devanagari) text."""
104
+ ids = self.tgt_tokenizer.encode(text)[:self.max_len]
105
+ t = torch.tensor(ids, dtype=torch.long)
106
+ t = F.pad(t, (0, max(0, self.max_len - len(t))), value=self.pad_id)
107
+ return t
108
+
109
+ # ── Curriculum ───��────────────────────────────────────────────────
110
+
111
+ def _curriculum_sort(self):
112
+ """Short, common Devanagari targets first → long, rare targets last."""
113
+ scores = []
114
+ for s in self.dataset:
115
+ text = s[self._target_field]
116
+ length = len(text.split())
117
+ rarity_score = len(set(text)) / max(1, len(text))
118
+ scores.append(length * (1 - rarity_score))
119
+ order = sorted(range(len(self.dataset)), key=lambda i: scores[i])
120
+ return self.dataset.select(order)
121
+
122
+ # ── Item ──────────────────────────────────────────────────────────
123
+
124
+ def __len__(self):
125
+ return len(self.dataset)
126
+
127
+ def __getitem__(self, idx):
128
+ sample = self.dataset[idx]
129
+
130
+ src_text = sample[self._input_field].strip()
131
+ tgt_text = sample[self._target_field].strip()
132
+
133
+ input_ids = self._encode_src(src_text) # Roman encoded with src_tokenizer
134
+ target_ids = self._encode_tgt(tgt_text) # Devanagari encoded with tgt_tokenizer
135
+
136
+ out = {
137
+ 'input_ids': input_ids,
138
+ 'target_ids': target_ids,
139
+ 'input_text': src_text,
140
+ 'target_text': tgt_text,
141
+ }
142
+
143
+ if self.include_negatives:
144
+ neg_ids = target_ids.clone()
145
+ # Reverse a random chunk of the DEVANAGARI target
146
+ non_pad = (neg_ids != self.pad_id).sum().item()
147
+ if non_pad > 4:
148
+ i1, i2 = sorted(random.sample(range(non_pad), 2))
149
+ neg_ids[i1:i2] = torch.flip(neg_ids[i1:i2], dims=[0])
150
+ out['negative_target_ids'] = neg_ids
151
+
152
+ return out
requirements.txt CHANGED
@@ -4,3 +4,9 @@ numpy>=1.24
4
  tqdm>=4.66
5
  huggingface_hub==0.25.2
6
  tokenizers>=0.15
 
 
 
 
 
 
 
4
  tqdm>=4.66
5
  huggingface_hub==0.25.2
6
  tokenizers>=0.15
7
+ datasets>=2.20
8
+ scikit-learn>=1.4
9
+ matplotlib>=3.8
10
+ bert-score>=0.3.13
11
+ sentence-transformers>=3.0
12
+ nltk>=3.8