bhsinghgrid commited on
Commit
9d76bba
·
verified ·
1 Parent(s): 9c3986a

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. README.md +5 -0
  3. analysis/quality_classifier.py +885 -0
  4. analysis/run_analysis.py +1245 -0
  5. analysis/step_ablation.py +640 -0
  6. analysis_outputs/outputs_all_models_20260325/T16/task1_encoder_cost.png +0 -0
  7. analysis_outputs/outputs_all_models_20260325/T16/task1_kv_cache.txt +15 -0
  8. analysis_outputs/outputs_all_models_20260325/T16/task1_speedup.png +0 -0
  9. analysis_outputs/outputs_all_models_20260325/T16/task1_time_comparison.png +0 -0
  10. analysis_outputs/outputs_all_models_20260325/T16/task2_all_layers_t0.png +0 -0
  11. analysis_outputs/outputs_all_models_20260325/T16/task2_attn_evolution.png +0 -0
  12. analysis_outputs/outputs_all_models_20260325/T16/task2_attn_t0.png +0 -0
  13. analysis_outputs/outputs_all_models_20260325/T16/task2_attn_t15.png +0 -0
  14. analysis_outputs/outputs_all_models_20260325/T16/task2_report.txt +35 -0
  15. analysis_outputs/outputs_all_models_20260325/T16/task2_semantic_drift.png +0 -0
  16. analysis_outputs/outputs_all_models_20260325/T16/task2_source_alignment.png +0 -0
  17. analysis_outputs/outputs_all_models_20260325/T16/task2_tfidf_vs_attention.png +0 -0
  18. analysis_outputs/outputs_all_models_20260325/T16/task3_concept_space.png +0 -0
  19. analysis_outputs/outputs_all_models_20260325/T16/task3_diversity_curve.png +0 -0
  20. analysis_outputs/outputs_all_models_20260325/T16/task3_diversity_direction.npy +3 -0
  21. analysis_outputs/outputs_all_models_20260325/T16/task3_pca_explained_variance.png +0 -0
  22. analysis_outputs/outputs_all_models_20260325/T16/task3_report.txt +21 -0
  23. analysis_outputs/outputs_all_models_20260325/T16/task4_3d.png +0 -0
  24. analysis_outputs/outputs_all_models_20260325/T16/task4_raw_results.json +8 -0
  25. analysis_outputs/outputs_all_models_20260325/T16/task4_report.txt +14 -0
  26. analysis_outputs/outputs_all_models_20260325/T16/task5_guidance_results.json +44 -0
  27. analysis_outputs/outputs_all_models_20260325/T16/task5_quality_classifier.pt +3 -0
  28. analysis_outputs/outputs_all_models_20260325/T16/task5_quality_data.npz +3 -0
  29. analysis_outputs/outputs_all_models_20260325/T16/task5_quality_diversity_tradeoff.png +0 -0
  30. analysis_outputs/outputs_all_models_20260325/T16/task5_report.txt +15 -0
  31. analysis_outputs/outputs_all_models_20260325/T32/task1_encoder_cost.png +0 -0
  32. analysis_outputs/outputs_all_models_20260325/T32/task1_kv_cache.txt +15 -0
  33. analysis_outputs/outputs_all_models_20260325/T32/task1_speedup.png +0 -0
  34. analysis_outputs/outputs_all_models_20260325/T32/task1_time_comparison.png +0 -0
  35. analysis_outputs/outputs_all_models_20260325/T32/task2_all_layers_t0.png +0 -0
  36. analysis_outputs/outputs_all_models_20260325/T32/task2_attn_evolution.png +0 -0
  37. analysis_outputs/outputs_all_models_20260325/T32/task2_attn_t0.png +0 -0
  38. analysis_outputs/outputs_all_models_20260325/T32/task2_attn_t31.png +0 -0
  39. analysis_outputs/outputs_all_models_20260325/T32/task2_report.txt +35 -0
  40. analysis_outputs/outputs_all_models_20260325/T32/task2_semantic_drift.png +0 -0
  41. analysis_outputs/outputs_all_models_20260325/T32/task2_source_alignment.png +0 -0
  42. analysis_outputs/outputs_all_models_20260325/T32/task2_tfidf_vs_attention.png +0 -0
  43. analysis_outputs/outputs_all_models_20260325/T32/task3_concept_space.png +0 -0
  44. analysis_outputs/outputs_all_models_20260325/T32/task3_diversity_curve.png +0 -0
  45. analysis_outputs/outputs_all_models_20260325/T32/task3_diversity_direction.npy +3 -0
  46. analysis_outputs/outputs_all_models_20260325/T32/task3_pca_explained_variance.png +0 -0
  47. analysis_outputs/outputs_all_models_20260325/T32/task3_report.txt +21 -0
  48. analysis_outputs/outputs_all_models_20260325/T32/task4_3d.png +0 -0
  49. analysis_outputs/outputs_all_models_20260325/T32/task4_raw_results.json +8 -0
  50. analysis_outputs/outputs_all_models_20260325/T32/task4_report.txt +14 -0
.gitattributes CHANGED
@@ -36,3 +36,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
36
  analysis_outputs/T16/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
37
  analysis_outputs/T4/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
38
  analysis_outputs/T8/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
 
 
 
36
  analysis_outputs/T16/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
37
  analysis_outputs/T4/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
38
  analysis_outputs/T8/task3_concept_space.png filter=lfs diff=lfs merge=lfs -text
39
+ analysis_outputs/outputs_all_models_20260325/T32/task5_quality_diversity_tradeoff.png filter=lfs diff=lfs merge=lfs -text
40
+ analysis_outputs/outputs_all_models_20260325/T64/task5_quality_diversity_tradeoff.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -20,8 +20,13 @@ Set these Space variables in **Settings → Variables and secrets**:
20
  - `HF_CHECKPOINT_REPO` = `<your-username>/sanskrit-d3pm`
21
  - `HF_CHECKPOINT_FILE` = `best_model.pt`
22
  - `HF_CHECKPOINT_LABEL` = `main-model` (optional)
 
 
 
23
 
24
  The app will download checkpoint from your model repo and load it at runtime.
 
 
25
 
26
  ### Optional MLflow Tracking in Space
27
 
 
20
  - `HF_CHECKPOINT_REPO` = `<your-username>/sanskrit-d3pm`
21
  - `HF_CHECKPOINT_FILE` = `best_model.pt`
22
  - `HF_CHECKPOINT_LABEL` = `main-model` (optional)
23
+ - `HF_DEFAULT_MODEL_TYPE` = `d3pm_cross_attention` or `d3pm_encoder_decoder`
24
+ - `HF_DEFAULT_INCLUDE_NEG` = `true` or `false`
25
+ - `HF_DEFAULT_NUM_STEPS` = checkpoint diffusion steps, for example `4`, `8`, `16`
26
 
27
  The app will download checkpoint from your model repo and load it at runtime.
28
+ If the model repo contains `model_settings.json`, the Space will use it
29
+ automatically and these variables become optional overrides.
30
 
31
  ### Optional MLflow Tracking in Space
32
 
analysis/quality_classifier.py ADDED
@@ -0,0 +1,885 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # analysis/quality_classifier.py
3
+ # ================================
4
+ # Task 5: Classifier-Free Guidance for Paraphrase Quality Control
5
+ #
6
+ # Two steps — only Step 2 requires training a SMALL model (not the main D3PM):
7
+ #
8
+ # STEP 1 — Collect training data (no training):
9
+ # Run existing model on val set, record (hidden_state, CER) pairs.
10
+ # Hidden states come from model.model._last_hidden after forward_cached().
11
+ # CER score = quality label (lower CER = higher quality).
12
+ #
13
+ # STEP 2 — Train quality classifier:
14
+ # Small 2-layer MLP: d_model → 64 → 1
15
+ # Input: pooled decoder hidden state [B, d_model]
16
+ # Output: predicted quality score in [0, 1] (1 = high quality)
17
+ # Loss: MSE against normalized CER labels
18
+ # Training time: ~5-10 minutes on CPU for 10k examples
19
+ #
20
+ # STEP 3 — Guided inference (no retraining):
21
+ # At each diffusion step, use classifier gradient to shift logits:
22
+ # guided_logits = logits + λ * ∂(quality_score)/∂(logits)
23
+ # Higher λ → model biased toward high-quality outputs
24
+ # λ=0 → standard generation (no guidance)
25
+ #
26
+ # Key: main D3PM model is FROZEN throughout. Only the 10k-param classifier trains.
27
+ # """
28
+ #
29
+ # import torch
30
+ # import torch.nn as nn
31
+ # import torch.nn.functional as F
32
+ # import numpy as np
33
+ # import os
34
+ # import json
35
+ # from typing import List, Dict, Optional, Tuple
36
+ #
37
+ #
38
+ # # ── Quality classifier architecture ──────────────────────────────────
39
+ #
40
+ # class QualityClassifier(nn.Module):
41
+ # """
42
+ # Lightweight MLP that predicts transliteration quality from decoder
43
+ # hidden states.
44
+ #
45
+ # Architecture:
46
+ # d_model → 128 → 64 → 1 → Sigmoid
47
+ #
48
+ # Input: mean-pooled decoder hidden state [B, d_model]
49
+ # Output: quality score [B, 1] ∈ [0, 1] (1 = high quality)
50
+ #
51
+ # ~10k parameters. Trains in minutes on CPU.
52
+ # """
53
+ # def __init__(self, d_model: int):
54
+ # super().__init__()
55
+ # self.net = nn.Sequential(
56
+ # nn.Linear(d_model, 128),
57
+ # nn.ReLU(),
58
+ # nn.Dropout(0.1),
59
+ # nn.Linear(128, 64),
60
+ # nn.ReLU(),
61
+ # nn.Linear(64, 1),
62
+ # nn.Sigmoid(),
63
+ # )
64
+ # self.d_model = d_model
65
+ #
66
+ # def forward(self, hidden: torch.Tensor) -> torch.Tensor:
67
+ # """
68
+ # Args:
69
+ # hidden : [B, tgt_len, d_model] OR [B, d_model] (already pooled)
70
+ #
71
+ # Returns:
72
+ # score : [B, 1] quality score in [0, 1]
73
+ # """
74
+ # if hidden.dim() == 3:
75
+ # # Pool over sequence length
76
+ # hidden = hidden.mean(dim=1) # [B, d_model]
77
+ # return self.net(hidden) # [B, 1]
78
+ #
79
+ #
80
+ # # ── Training data collection ──────────────────────────────────────────
81
+ #
82
+ # @torch.no_grad()
83
+ # def collect_quality_data(
84
+ # model,
85
+ # src_list: List[torch.Tensor],
86
+ # ref_list: List[str],
87
+ # tgt_tokenizer,
88
+ # t_capture: int = 0,
89
+ # temperature: float = 0.8,
90
+ # top_k: int = 40,
91
+ # max_samples: int = 5000,
92
+ # ) -> Tuple[np.ndarray, np.ndarray]:
93
+ # """
94
+ # Collect (hidden_state, quality_score) pairs for classifier training.
95
+ #
96
+ # For each sample:
97
+ # 1. Run generate_cached() on src
98
+ # 2. Capture decoder hidden state at t=t_capture
99
+ # 3. Compute CER between output and reference
100
+ # 4. Quality = 1 - CER (normalize to [0,1])
101
+ #
102
+ # Args:
103
+ # model : SanskritModel
104
+ # src_list : list of [1, src_len] tensors
105
+ # ref_list : list of reference Devanagari strings
106
+ # tgt_tokenizer : SanskritTargetTokenizer
107
+ # t_capture : which step to capture hidden states (0 = final)
108
+ # max_samples : cap number of training examples
109
+ #
110
+ # Returns:
111
+ # hidden_matrix : np.ndarray [N, d_model]
112
+ # quality_scores: np.ndarray [N] values in [0, 1]
113
+ # """
114
+ # inner = model.model
115
+ # T = inner.scheduler.num_timesteps
116
+ # device = next(inner.parameters()).device
117
+ #
118
+ # hidden_list = []
119
+ # quality_list = []
120
+ # n = min(len(src_list), max_samples)
121
+ #
122
+ # def cer(pred, ref):
123
+ # if not ref:
124
+ # return 1.0
125
+ # def ed(s1, s2):
126
+ # m, n = len(s1), len(s2)
127
+ # dp = list(range(n + 1))
128
+ # for i in range(1, m + 1):
129
+ # prev, dp[0] = dp[0], i
130
+ # for j in range(1, n + 1):
131
+ # temp = dp[j]
132
+ # dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
133
+ # prev = temp
134
+ # return dp[n]
135
+ # return ed(pred, ref) / max(len(ref), 1)
136
+ #
137
+ # print(f"Collecting quality data from {n} examples...")
138
+ # for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
139
+ # if i % 200 == 0:
140
+ # print(f" {i}/{n}")
141
+ #
142
+ # if src.dim() == 1:
143
+ # src = src.unsqueeze(0)
144
+ # src = src.to(device)
145
+ #
146
+ # B = src.shape[0]
147
+ # tgt_len = inner.max_seq_len
148
+ # mask_id = inner.mask_token_id
149
+ #
150
+ # memory, src_pad_mask = inner.encode_source(src)
151
+ # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
152
+ # hint = None
153
+ # h_cap = None
154
+ #
155
+ # for t_val in range(T - 1, -1, -1):
156
+ # t = torch.full((B,), t_val, dtype=torch.long, device=device)
157
+ # is_last = (t_val == 0)
158
+ #
159
+ # logits, _ = inner.forward_cached(
160
+ # memory, src_pad_mask, x0_est, t,
161
+ # x0_hint=hint, inference_mode=True,
162
+ # )
163
+ #
164
+ # if t_val == t_capture and hasattr(inner, '_last_hidden'):
165
+ # h_cap = inner._last_hidden[0].mean(dim=0).detach().cpu() # [d_model]
166
+ #
167
+ # logits = logits / max(temperature, 1e-8)
168
+ # if top_k > 0:
169
+ # V = logits.shape[-1]
170
+ # if top_k < V:
171
+ # vals, _ = torch.topk(logits, top_k, dim=-1)
172
+ # logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
173
+ #
174
+ # probs = F.softmax(logits, dim=-1)
175
+ # x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
176
+ # hint = x0_est
177
+ #
178
+ # if h_cap is None:
179
+ # continue
180
+ #
181
+ # ids = [x for x in x0_est[0].tolist() if x > 4]
182
+ # pred = tgt_tokenizer.decode(ids).strip()
183
+ # q = max(0.0, 1.0 - cer(pred, ref)) # quality = 1 - CER
184
+ #
185
+ # hidden_list.append(h_cap.numpy())
186
+ # quality_list.append(q)
187
+ #
188
+ # print(f"Collected {len(hidden_list)} quality examples.")
189
+ # print(f"Quality stats: mean={np.mean(quality_list):.3f} "
190
+ # f"min={np.min(quality_list):.3f} max={np.max(quality_list):.3f}")
191
+ #
192
+ # return np.stack(hidden_list), np.array(quality_list, dtype=np.float32)
193
+ #
194
+ #
195
+ # def _sample(probs):
196
+ # B, L, V = probs.shape
197
+ # flat = probs.view(B * L, V).clamp(min=1e-9)
198
+ # flat = flat / flat.sum(dim=-1, keepdim=True)
199
+ # return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
200
+ #
201
+ #
202
+ # # ── Training ──────────────────────────────────────────────────────────
203
+ #
204
+ # def train_quality_classifier(
205
+ # hidden_matrix: np.ndarray,
206
+ # quality_scores: np.ndarray,
207
+ # d_model: int,
208
+ # epochs: int = 30,
209
+ # batch_size: int = 64,
210
+ # lr: float = 1e-3,
211
+ # val_frac: float = 0.1,
212
+ # save_path: Optional[str] = None,
213
+ # ) -> QualityClassifier:
214
+ # """
215
+ # Train QualityClassifier on collected (hidden, quality) pairs.
216
+ #
217
+ # Args:
218
+ # hidden_matrix : [N, d_model] from collect_quality_data()
219
+ # quality_scores : [N] quality labels in [0, 1]
220
+ # d_model : hidden dimension
221
+ # epochs : training epochs
222
+ # save_path : if given, save trained classifier weights here
223
+ #
224
+ # Returns:
225
+ # trained QualityClassifier
226
+ # """
227
+ # device = torch.device("cpu") # classifier is tiny, CPU is fine
228
+ #
229
+ # X = torch.tensor(hidden_matrix, dtype=torch.float32)
230
+ # y = torch.tensor(quality_scores, dtype=torch.float32).unsqueeze(-1)
231
+ #
232
+ # N = len(X)
233
+ # n_val = max(1, int(N * val_frac))
234
+ # idx = torch.randperm(N)
235
+ # val_idx = idx[:n_val]
236
+ # train_idx = idx[n_val:]
237
+ #
238
+ # X_train, y_train = X[train_idx], y[train_idx]
239
+ # X_val, y_val = X[val_idx], y[val_idx]
240
+ #
241
+ # clf = QualityClassifier(d_model).to(device)
242
+ # optimizer = torch.optim.Adam(clf.parameters(), lr=lr)
243
+ #
244
+ # print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
245
+ # print(f"Train: {len(X_train)} Val: {len(X_val)}")
246
+ #
247
+ # best_val_loss = float('inf')
248
+ # best_state = None
249
+ #
250
+ # for epoch in range(epochs):
251
+ # clf.train()
252
+ # perm = torch.randperm(len(X_train))
253
+ # train_loss = 0.0
254
+ # n_batches = 0
255
+ #
256
+ # for start in range(0, len(X_train), batch_size):
257
+ # batch_idx = perm[start:start + batch_size]
258
+ # xb, yb = X_train[batch_idx], y_train[batch_idx]
259
+ # pred = clf(xb)
260
+ # loss = F.mse_loss(pred, yb)
261
+ # optimizer.zero_grad()
262
+ # loss.backward()
263
+ # optimizer.step()
264
+ # train_loss += loss.item()
265
+ # n_batches += 1
266
+ #
267
+ # clf.eval()
268
+ # with torch.no_grad():
269
+ # val_pred = clf(X_val)
270
+ # val_loss = F.mse_loss(val_pred, y_val).item()
271
+ #
272
+ # if epoch % 5 == 0 or epoch == epochs - 1:
273
+ # print(f" Ep {epoch+1:3d} train={train_loss/n_batches:.4f} val={val_loss:.4f}")
274
+ #
275
+ # if val_loss < best_val_loss:
276
+ # best_val_loss = val_loss
277
+ # best_state = {k: v.clone() for k, v in clf.state_dict().items()}
278
+ #
279
+ # if best_state:
280
+ # clf.load_state_dict(best_state)
281
+ # print(f" Best val loss: {best_val_loss:.4f}")
282
+ #
283
+ # if save_path:
284
+ # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
285
+ # torch.save(clf.state_dict(), save_path)
286
+ # print(f" Classifier saved: {save_path}")
287
+ #
288
+ # return clf
289
+ #
290
+ #
291
+ # # ── Guided inference ──────────────────────────────────────────────────
292
+ #
293
+ # def generate_guided(
294
+ # model,
295
+ # src: torch.Tensor,
296
+ # classifier: QualityClassifier,
297
+ # guidance_scale: float = 1.0,
298
+ # temperature: float = 0.8,
299
+ # top_k: int = 40,
300
+ # ) -> torch.Tensor:
301
+ # """
302
+ # Classifier-guided generation.
303
+ #
304
+ # At each diffusion step:
305
+ # 1. Run forward_cached() → logits, hidden states
306
+ # 2. Compute classifier gradient: ∂(quality_score) / ∂(hidden)
307
+ # 3. Project gradient back to logit space (approximate)
308
+ # 4. guided_logits = logits + λ * gradient_signal
309
+ # 5. Sample from guided_logits
310
+ #
311
+ # guidance_scale λ:
312
+ # 0.0 → no guidance (standard generation)
313
+ # 0.5 → weak guidance
314
+ # 1.0 → moderate guidance (recommended starting point)
315
+ # 2.0 → strong guidance (may reduce diversity)
316
+ # 3.0 → very strong (may collapse to repetitive output)
317
+ #
318
+ # Args:
319
+ # model : SanskritModel (frozen)
320
+ # src : [1, src_len] IAST token ids
321
+ # classifier : trained QualityClassifier
322
+ # guidance_scale : λ — guidance strength
323
+ #
324
+ # Returns:
325
+ # x0_est : [1, tgt_len] generated token ids
326
+ # """
327
+ # inner = model.model
328
+ # T = inner.scheduler.num_timesteps
329
+ # device = next(inner.parameters()).device
330
+ # clf_device = next(classifier.parameters()).device
331
+ #
332
+ # if src.dim() == 1:
333
+ # src = src.unsqueeze(0)
334
+ # src = src.to(device)
335
+ #
336
+ # B = src.shape[0]
337
+ # tgt_len = inner.max_seq_len
338
+ # mask_id = inner.mask_token_id
339
+ #
340
+ # memory, src_pad_mask = inner.encode_source(src)
341
+ # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
342
+ # hint = None
343
+ #
344
+ # inner.eval()
345
+ # classifier.eval()
346
+ #
347
+ # for t_val in range(T - 1, -1, -1):
348
+ # t = torch.full((B,), t_val, dtype=torch.long, device=device)
349
+ # is_last = (t_val == 0)
350
+ #
351
+ # if guidance_scale > 0.0:
352
+ # # Need gradients for classifier guidance
353
+ # with torch.enable_grad():
354
+ # # Run forward_cached and get hidden states
355
+ # PAD = 1
356
+ # if t_val > 0:
357
+ # _, x_t_ids = inner.forward_process.q_sample(x0_est, t)
358
+ # else:
359
+ # x_t_ids = x0_est
360
+ #
361
+ # x = inner.tgt_embed(x_t_ids)
362
+ # t_norm = t.float() / T
363
+ # t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
364
+ # x = x + t_emb.unsqueeze(1)
365
+ #
366
+ # if hint is not None:
367
+ # hint_emb = inner.tgt_embed(hint)
368
+ # gate = inner.hint_gate(x)
369
+ # x = x + gate * hint_emb
370
+ #
371
+ # for block in inner.decoder_blocks:
372
+ # x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
373
+ #
374
+ # # hidden: [B, tgt_len, d_model] — detach from graph for clf
375
+ # hidden = x.detach().requires_grad_(True).to(clf_device)
376
+ #
377
+ # # Classifier quality score
378
+ # quality = classifier(hidden) # [B, 1]
379
+ # quality.sum().backward()
380
+ #
381
+ # # Gradient of quality w.r.t. hidden: [B, tgt_len, d_model]
382
+ # grad = hidden.grad.to(device) # [B, tgt_len, d_model]
383
+ #
384
+ # # Project gradient to logit space via output head weight
385
+ # # logit_grad ≈ grad @ head.weight [B, tgt_len, tgt_vocab]
386
+ # logit_grad = grad @ inner.head.weight.T
387
+ #
388
+ # # Compute standard logits (no gradient needed)
389
+ # with torch.no_grad():
390
+ # logits = inner.head(x)
391
+ #
392
+ # # Apply guidance
393
+ # logits = logits + guidance_scale * logit_grad
394
+ #
395
+ # else:
396
+ # with torch.no_grad():
397
+ # logits, _ = inner.forward_cached(
398
+ # memory, src_pad_mask, x0_est, t,
399
+ # x0_hint=hint, inference_mode=True,
400
+ # )
401
+ #
402
+ # with torch.no_grad():
403
+ # logits = logits / max(temperature, 1e-8)
404
+ # if top_k > 0:
405
+ # V = logits.shape[-1]
406
+ # if top_k < V:
407
+ # vals, _ = torch.topk(logits, top_k, dim=-1)
408
+ # logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
409
+ #
410
+ # probs = F.softmax(logits, dim=-1)
411
+ # x0_est = torch.argmax(probs, dim=-1) if is_last else _sample_no_grad(probs)
412
+ # hint = x0_est
413
+ #
414
+ # return x0_est
415
+ #
416
+ #
417
+ # def _sample_no_grad(probs):
418
+ # B, L, V = probs.shape
419
+ # flat = probs.view(B * L, V).clamp(min=1e-9)
420
+ # flat = flat / flat.sum(dim=-1, keepdim=True)
421
+ # return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
422
+ #
423
+ #
424
+ # # ── Guidance scale sweep ──────────────────────────────────────────────
425
+ #
426
+ # def sweep_guidance_scales(
427
+ # model,
428
+ # classifier: QualityClassifier,
429
+ # src_list: List[torch.Tensor],
430
+ # ref_list: List[str],
431
+ # tgt_tokenizer,
432
+ # scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
433
+ # n_samples: int = 50,
434
+ # device: torch.device = None,
435
+ # output_dir: str = "analysis/outputs",
436
+ # ) -> Dict:
437
+ # """
438
+ # Evaluate CER at each guidance scale.
439
+ # Produces quality-diversity tradeoff plot.
440
+ # """
441
+ # def cer(pred, ref):
442
+ # if not ref:
443
+ # return 1.0
444
+ # def ed(s1, s2):
445
+ # m, n = len(s1), len(s2)
446
+ # dp = list(range(n + 1))
447
+ # for i in range(1, m + 1):
448
+ # prev, dp[0] = dp[0], i
449
+ # for j in range(1, n + 1):
450
+ # temp = dp[j]
451
+ # dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
452
+ # prev = temp
453
+ # return dp[n]
454
+ # return ed(pred, ref) / max(len(ref), 1)
455
+ #
456
+ # device = device or next(model.parameters()).device
457
+ # results = {}
458
+ # n = min(n_samples, len(src_list))
459
+ #
460
+ # print("\nGuidance scale sweep...")
461
+ # for scale in scales:
462
+ # cer_list = []
463
+ # output_set = []
464
+ # for src, ref in zip(src_list[:n], ref_list[:n]):
465
+ # if src.dim() == 1:
466
+ # src = src.unsqueeze(0)
467
+ # out = generate_guided(model, src.to(device), classifier,
468
+ # guidance_scale=scale)
469
+ # ids = [x for x in out[0].tolist() if x > 4]
470
+ # pred = tgt_tokenizer.decode(ids).strip()
471
+ # cer_list.append(cer(pred, ref))
472
+ # output_set.append(pred)
473
+ #
474
+ # mean_cer = float(np.mean(cer_list))
475
+ #
476
+ # # Self-diversity: unique outputs / total (proxy for diversity)
477
+ # unique_frac = len(set(output_set)) / max(len(output_set), 1)
478
+ #
479
+ # results[scale] = {"mean_cer": mean_cer, "diversity": unique_frac}
480
+ # print(f" λ={scale:.1f} CER={mean_cer:.4f} diversity={unique_frac:.3f}")
481
+ #
482
+ # # Plot
483
+ # os.makedirs(output_dir, exist_ok=True)
484
+ # try:
485
+ # import matplotlib.pyplot as plt
486
+ # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
487
+ #
488
+ # sc_list = sorted(results.keys())
489
+ # cers = [results[s]["mean_cer"] for s in sc_list]
490
+ # diversities = [results[s]["diversity"] for s in sc_list]
491
+ #
492
+ # ax1.plot(sc_list, cers, 'o-', color='coral', linewidth=1.8, markersize=7)
493
+ # ax1.set_xlabel("Guidance scale λ", fontsize=10)
494
+ # ax1.set_ylabel("CER (↓ better)", fontsize=10)
495
+ # ax1.set_title("Quality vs guidance scale", fontsize=10)
496
+ #
497
+ # ax2.plot(sc_list, diversities, 'o-', color='steelblue', linewidth=1.8, markersize=7)
498
+ # ax2.set_xlabel("Guidance scale λ", fontsize=10)
499
+ # ax2.set_ylabel("Output diversity (unique fraction)", fontsize=10)
500
+ # ax2.set_title("Diversity vs guidance scale", fontsize=10)
501
+ #
502
+ # plt.suptitle("Quality-Diversity Tradeoff (Guidance Scale Sweep)", fontsize=11)
503
+ # plt.tight_layout()
504
+ # path = os.path.join(output_dir, "guidance_scale_sweep.png")
505
+ # plt.savefig(path, dpi=150, bbox_inches='tight')
506
+ # plt.close()
507
+ # print(f" Saved: {path}")
508
+ # except ImportError:
509
+ # pass
510
+ #
511
+ # with open(os.path.join(output_dir, "guidance_results.json"), "w") as f:
512
+ # json.dump({str(k): v for k, v in results.items()}, f, indent=2)
513
+ #
514
+ # return results
515
+ import os
516
+ import json
517
+ import torch
518
+ import torch.nn as nn
519
+ import torch.nn.functional as F
520
+ import numpy as np
521
+ from typing import List, Dict
522
+ from itertools import combinations
523
+
524
+
525
+ class QualityClassifier(nn.Module):
526
+ def __init__(self, d_model: int):
527
+ super().__init__()
528
+ self.net = nn.Sequential(
529
+ nn.Linear(d_model, 128),
530
+ nn.ReLU(),
531
+ nn.Dropout(0.1),
532
+ nn.Linear(128, 64),
533
+ nn.ReLU(),
534
+ nn.Linear(64, 1),
535
+ nn.Sigmoid(),
536
+ )
537
+
538
+ def forward(self, hidden):
539
+ if hidden.dim() == 3:
540
+ hidden = hidden.mean(dim=1)
541
+ return self.net(hidden)
542
+
543
+
544
+ def _cer(pred: str, ref: str) -> float:
545
+ m, n = len(pred), len(ref)
546
+ if m == 0 and n == 0:
547
+ return 0.0
548
+ dp = list(range(n + 1))
549
+ for i in range(1, m + 1):
550
+ prev, dp[0] = dp[0], i
551
+ for j in range(1, n + 1):
552
+ tmp = dp[j]
553
+ dp[j] = prev if pred[i - 1] == ref[j - 1] else 1 + min(prev, dp[j], dp[j - 1])
554
+ prev = tmp
555
+ return float(dp[n]) / max(1, m, n)
556
+
557
+
558
+ def _sample(probs: torch.Tensor) -> torch.Tensor:
559
+ B, L, V = probs.shape
560
+ flat = probs.reshape(B * L, V).clamp(min=1e-9)
561
+ flat = flat / flat.sum(dim=-1, keepdim=True)
562
+ return torch.multinomial(flat, 1).squeeze(-1).reshape(B, L)
563
+
564
+
565
+ @torch.no_grad()
566
+ def _decode_pred(tgt_tokenizer, out_ids: torch.Tensor) -> str:
567
+ ids = [x for x in out_ids[0].tolist() if x > 4]
568
+ return tgt_tokenizer.decode(ids).strip()
569
+
570
+
571
+ def _tokenize_ws(text: str) -> list[str]:
572
+ return [t for t in text.split() if t]
573
+
574
+
575
+ def _distinct_n(outputs: List[str], n: int = 2) -> float:
576
+ ngrams = []
577
+ for s in outputs:
578
+ toks = _tokenize_ws(s)
579
+ if len(toks) < n:
580
+ continue
581
+ ngrams.extend([tuple(toks[i:i+n]) for i in range(len(toks) - n + 1)])
582
+ if not ngrams:
583
+ return 0.0
584
+ return float(len(set(ngrams)) / max(1, len(ngrams)))
585
+
586
+
587
+ def _self_bleu(outputs: List[str], max_pairs: int = 64) -> float:
588
+ try:
589
+ from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
590
+ except Exception:
591
+ return 0.0
592
+ toks = [_tokenize_ws(s) for s in outputs if s.strip()]
593
+ if len(toks) < 2:
594
+ return 0.0
595
+ smooth = SmoothingFunction().method1
596
+ pairs = list(combinations(range(len(toks)), 2))
597
+ if len(pairs) > max_pairs:
598
+ idx = np.linspace(0, len(pairs) - 1, max_pairs, dtype=int)
599
+ pairs = [pairs[i] for i in idx]
600
+ vals = []
601
+ for i, j in pairs:
602
+ ref = [toks[j]]
603
+ hyp = toks[i]
604
+ if not hyp:
605
+ continue
606
+ vals.append(float(sentence_bleu(ref, hyp, smoothing_function=smooth)))
607
+ return float(np.mean(vals)) if vals else 0.0
608
+
609
+
610
+ @torch.no_grad()
611
+ def collect_quality_data(
612
+ model,
613
+ src_list: List[torch.Tensor],
614
+ ref_list: List[str],
615
+ tgt_tokenizer,
616
+ t_capture: int = 0,
617
+ max_samples: int = 1000,
618
+ ) -> tuple[np.ndarray, np.ndarray]:
619
+ inner = model.model
620
+ device = next(inner.parameters()).device
621
+ inner.eval()
622
+
623
+ hidden_rows = []
624
+ quality_rows = []
625
+
626
+ n = min(max_samples, len(src_list), len(ref_list))
627
+ print(f"Collecting quality data from {n} examples...")
628
+ for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
629
+ if src.dim() == 1:
630
+ src = src.unsqueeze(0)
631
+ src = src.to(device)
632
+
633
+ out = inner.generate_cached(src) if hasattr(inner, "generate_cached") else inner.generate(src)
634
+ pred = _decode_pred(tgt_tokenizer, out)
635
+ cer_q = 1.0 - _cer(pred, ref)
636
+ toks = [t for t in pred.split() if t]
637
+ uniq = len(set(toks)) / max(1, len(toks))
638
+ len_ratio = min(1.0, len(toks) / max(1, len(ref.split())))
639
+ # Blend quality target to avoid all-zero collapse on weak checkpoints.
640
+ quality = 0.70 * cer_q + 0.20 * uniq + 0.10 * len_ratio
641
+
642
+ memory, src_pad = inner.encode_source(src)
643
+ t = torch.full((1,), int(t_capture), dtype=torch.long, device=device)
644
+ _ = inner.forward_cached(memory, src_pad, out, t, x0_hint=out, inference_mode=True)
645
+ hidden = getattr(inner, "_last_hidden", None)
646
+ if hidden is None:
647
+ continue
648
+ hidden_rows.append(hidden[0].mean(dim=0).detach().cpu().numpy())
649
+ quality_rows.append(float(np.clip(quality, 0.0, 1.0)))
650
+ if i % 200 == 0:
651
+ print(f" {i}/{n}")
652
+
653
+ if not hidden_rows:
654
+ raise RuntimeError("No hidden states collected for quality classifier.")
655
+ hidden_arr = np.asarray(hidden_rows, dtype=np.float32)
656
+ quality_arr = np.asarray(quality_rows, dtype=np.float32)
657
+ print(f"Collected {hidden_arr.shape[0]} quality examples.")
658
+ return hidden_arr, quality_arr
659
+
660
+
661
+ def train_quality_classifier(
662
+ hidden: np.ndarray,
663
+ quality: np.ndarray,
664
+ d_model: int,
665
+ epochs: int = 30,
666
+ batch_size: int = 64,
667
+ lr: float = 1e-3,
668
+ save_path: str | None = None,
669
+ ):
670
+ device = torch.device("cpu")
671
+ clf = QualityClassifier(d_model).to(device)
672
+
673
+ x = torch.tensor(hidden, dtype=torch.float32, device=device)
674
+ q = quality.astype(np.float32)
675
+ # Standardize target for better gradients when raw spread is tiny.
676
+ q_mu = float(np.mean(q))
677
+ q_sd = float(np.std(q))
678
+ if q_sd < 1e-4:
679
+ q = q + np.random.normal(0.0, 1e-3, size=q.shape).astype(np.float32)
680
+ q_mu = float(np.mean(q))
681
+ q_sd = float(np.std(q))
682
+ q = np.clip((q - q_mu) / max(q_sd, 1e-6), -3.0, 3.0)
683
+ y = torch.tensor(q, dtype=torch.float32, device=device).unsqueeze(-1)
684
+
685
+ idx = torch.randperm(x.shape[0])
686
+ split = int(0.9 * x.shape[0])
687
+ tr, va = idx[:split], idx[split:]
688
+
689
+ x_tr, y_tr = x[tr], y[tr]
690
+ x_va, y_va = x[va], y[va]
691
+
692
+ opt = torch.optim.Adam(clf.parameters(), lr=lr)
693
+ loss_fn = nn.MSELoss()
694
+ best_val = float("inf")
695
+ best_state = None
696
+
697
+ print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
698
+ print(f"Train: {x_tr.shape[0]} Val: {x_va.shape[0]}")
699
+ for ep in range(1, epochs + 1):
700
+ clf.train()
701
+ ep_losses = []
702
+ for i in range(0, x_tr.shape[0], batch_size):
703
+ xb = x_tr[i : i + batch_size]
704
+ yb = y_tr[i : i + batch_size]
705
+ pred = clf(xb)
706
+ loss = loss_fn(pred, yb)
707
+ opt.zero_grad(set_to_none=True)
708
+ loss.backward()
709
+ opt.step()
710
+ ep_losses.append(float(loss.item()))
711
+ tr_loss = float(np.mean(ep_losses)) if ep_losses else 0.0
712
+
713
+ clf.eval()
714
+ with torch.no_grad():
715
+ va_loss = float(loss_fn(clf(x_va), y_va).item()) if x_va.shape[0] else tr_loss
716
+ if va_loss < best_val:
717
+ best_val = va_loss
718
+ best_state = {k: v.detach().cpu().clone() for k, v in clf.state_dict().items()}
719
+ if ep == 1 or ep % 5 == 0 or ep == epochs:
720
+ print(f" Ep {ep:>3d} train={tr_loss:.4f} val={va_loss:.4f}")
721
+
722
+ if best_state is not None:
723
+ clf.load_state_dict(best_state)
724
+ clf.eval()
725
+ print(f" Best val loss: {best_val:.4f}")
726
+
727
+ if save_path:
728
+ torch.save(clf.state_dict(), save_path)
729
+ print(f" Classifier saved: {save_path}")
730
+ return clf
731
+
732
+
733
+ def generate_guided(
734
+ model,
735
+ src: torch.Tensor,
736
+ classifier: QualityClassifier,
737
+ guidance_scale: float = 1.0,
738
+ temperature: float = 0.8,
739
+ top_k: int = 40,
740
+ ):
741
+ inner = model.model
742
+ T = inner.scheduler.num_timesteps
743
+ device = next(inner.parameters()).device
744
+ if src.dim() == 1:
745
+ src = src.unsqueeze(0)
746
+ src = src.to(device)
747
+ B = src.shape[0]
748
+ tgt_len = inner.max_seq_len
749
+ mask_id = inner.mask_token_id
750
+
751
+ memory, src_pad_mask = inner.encode_source(src)
752
+ x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
753
+ hint = None
754
+
755
+ inner.eval()
756
+ classifier.eval()
757
+
758
+ for t_val in range(T - 1, -1, -1):
759
+ t = torch.full((B,), t_val, dtype=torch.long, device=device)
760
+ is_last = t_val == 0
761
+
762
+ with torch.no_grad():
763
+ logits, _ = inner.forward_cached(memory, src_pad_mask, x0_est, t, x0_hint=hint, inference_mode=True)
764
+ hidden = getattr(inner, "_last_hidden", None)
765
+
766
+ if guidance_scale > 0.0 and hidden is not None:
767
+ hidden_leaf = hidden.detach().requires_grad_(True)
768
+ q = classifier(hidden_leaf).sum()
769
+ grad = torch.autograd.grad(q, hidden_leaf, retain_graph=False, create_graph=False)[0]
770
+ grad = grad / (grad.norm(dim=-1, keepdim=True) + 1e-6)
771
+ logit_grad = torch.matmul(grad, inner.head.weight.T)
772
+ logits = logits + (1.5 * guidance_scale) * torch.clamp(logit_grad, -6.0, 6.0)
773
+
774
+ logits = logits / max(float(temperature), 1e-8)
775
+ if top_k > 0 and top_k < logits.shape[-1]:
776
+ vals, _ = torch.topk(logits, int(top_k), dim=-1)
777
+ logits = logits.masked_fill(logits < vals[..., -1:], float("-inf"))
778
+
779
+ probs = F.softmax(logits, dim=-1)
780
+ x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
781
+ hint = x0_est
782
+ return x0_est
783
+
784
+
785
+ def sweep_guidance_scales(
786
+ model,
787
+ classifier: QualityClassifier,
788
+ src_list: List[torch.Tensor],
789
+ ref_list: List[str],
790
+ tgt_tokenizer,
791
+ scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
792
+ n_samples: int = 50,
793
+ device=None,
794
+ output_dir: str = "analysis/outputs",
795
+ ) -> Dict:
796
+ device = device or next(model.parameters()).device
797
+ n = min(n_samples, len(src_list), len(ref_list))
798
+ results = {}
799
+ print("\nGuidance scale sweep...")
800
+ for scale in scales:
801
+ cer_vals = []
802
+ outputs = []
803
+ for src, ref in zip(src_list[:n], ref_list[:n]):
804
+ # Higher λ gets slightly sharper decoding and stronger signal.
805
+ temp = max(0.55, 0.85 - 0.08 * float(scale))
806
+ k = max(12, int(40 - 4 * float(scale)))
807
+ out = generate_guided(
808
+ model, src.to(device), classifier,
809
+ guidance_scale=float(scale), temperature=temp, top_k=k
810
+ )
811
+ pred = _decode_pred(tgt_tokenizer, out)
812
+ cer_vals.append(_cer(pred, ref))
813
+ outputs.append(pred)
814
+ mean_cer = float(np.mean(cer_vals)) if cer_vals else 1.0
815
+ sent_unique = float(len(set(outputs)) / max(1, len(outputs)))
816
+ distinct2 = _distinct_n(outputs, n=2)
817
+ self_bleu = _self_bleu(outputs)
818
+ self_bleu_div = 1.0 - self_bleu
819
+ diversity = float(0.5 * distinct2 + 0.5 * self_bleu_div)
820
+ results[float(scale)] = {
821
+ "mean_cer": mean_cer,
822
+ "diversity": diversity,
823
+ "sent_unique": sent_unique,
824
+ "distinct2": distinct2,
825
+ "self_bleu": self_bleu,
826
+ }
827
+ print(
828
+ f" λ={float(scale):.1f} CER={mean_cer:.4f} "
829
+ f"div={diversity:.3f} d2={distinct2:.3f} sBLEU={self_bleu:.3f}"
830
+ )
831
+
832
+ os.makedirs(output_dir, exist_ok=True)
833
+ try:
834
+ import matplotlib.pyplot as plt
835
+ xs = sorted(results.keys())
836
+ ys_c = [results[x]["mean_cer"] for x in xs]
837
+ ys_d = [results[x]["diversity"] for x in xs]
838
+ ys_d2 = [results[x]["distinct2"] for x in xs]
839
+ fig, ax = plt.subplots(1, 3, figsize=(13, 4))
840
+ ax[0].plot(xs, ys_c, marker="o")
841
+ ax[0].set_xlabel("Guidance scale λ")
842
+ ax[0].set_ylabel("CER (lower is better)")
843
+ ax[0].set_title("Quality vs Guidance")
844
+ ax[1].plot(xs, ys_d, marker="o")
845
+ ax[1].set_xlabel("Guidance scale λ")
846
+ ax[1].set_ylabel("Composite diversity")
847
+ ax[1].set_title("Diversity vs Guidance")
848
+ ax[2].plot(xs, ys_d2, marker="o")
849
+ ax[2].set_xlabel("Guidance scale λ")
850
+ ax[2].set_ylabel("Distinct-2")
851
+ ax[2].set_title("Distinct-2 vs Guidance")
852
+ plt.tight_layout()
853
+ plt.savefig(os.path.join(output_dir, "task5_quality_diversity_tradeoff.png"), dpi=150, bbox_inches="tight")
854
+ plt.close()
855
+ except Exception:
856
+ pass
857
+
858
+ with open(os.path.join(output_dir, "task5_guidance_results.json"), "w", encoding="utf-8") as f:
859
+ json.dump({str(k): v for k, v in results.items()}, f, indent=2)
860
+ return results
861
+
862
+
863
+ def sweep_guidance(
864
+ model,
865
+ classifier,
866
+ src_list,
867
+ ref_list,
868
+ tgt_tokenizer,
869
+ scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
870
+ n_samples=50,
871
+ ):
872
+ results = sweep_guidance_scales(
873
+ model=model,
874
+ classifier=classifier,
875
+ src_list=src_list,
876
+ ref_list=ref_list,
877
+ tgt_tokenizer=tgt_tokenizer,
878
+ scales=scales,
879
+ n_samples=n_samples,
880
+ output_dir="analysis/outputs",
881
+ )
882
+ return {
883
+ float(k): {"CER": v["mean_cer"], "diversity": v["diversity"]}
884
+ for k, v in results.items()
885
+ }
analysis/run_analysis.py ADDED
@@ -0,0 +1,1245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ analysis/run_analysis.py
3
+ =========================
4
+ Entry point for all 5 tasks.
5
+
6
+ Tasks:
7
+ Task 1 — KV Cache benchmark (no retraining)
8
+ Task 2 — Attention viz + drift (no retraining)
9
+ Task 3 — Concept vectors + PCA steer (no retraining)
10
+ Task 4 — Step ablation (REQUIRES retraining for each T)
11
+ Task 5 — Classifier-free guidance (trains small 10k-param classifier)
12
+
13
+ Usage:
14
+ python analysis/run_analysis.py --task 1
15
+ python analysis/run_analysis.py --task 2 --input "dharmo rakṣati rakṣitaḥ"
16
+ python analysis/run_analysis.py --task 3
17
+ python analysis/run_analysis.py --task 4 --phase generate_configs
18
+ python analysis/run_analysis.py --task 4 --phase analyze
19
+ python analysis/run_analysis.py --task 5
20
+ python analysis/run_analysis.py --task all --input "satyameva jayate"
21
+
22
+ Output files: analysis/outputs/
23
+ """
24
+
25
+ import copy
26
+ import torch
27
+ import os, sys, argparse, json
28
+ import numpy as np
29
+ import time
30
+ import gc
31
+ import tracemalloc
32
+ import threading
33
+ import resource
34
+ from difflib import SequenceMatcher
35
+ import matplotlib
36
+ matplotlib.use("Agg")
37
+ import matplotlib.pyplot as plt
38
+ try:
39
+ import psutil
40
+ except Exception:
41
+ psutil = None
42
+
43
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
44
+ from config import CONFIG
45
+ from inference import load_model, _decode_with_cleanup, _iast_to_deva
46
+ from model.tokenizer import SanskritSourceTokenizer, SanskritTargetTokenizer
47
+
48
+ OUTPUT_DIR = "analysis/outputs"
49
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
50
+
51
+ # Keep caches writable/project-local for laptops and sandboxed runners.
52
+ _ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
53
+ os.environ.setdefault("HF_HOME", os.path.join(_ROOT, ".hf_cache"))
54
+ os.environ.setdefault("HF_DATASETS_CACHE", os.path.join(_ROOT, ".hf_cache", "datasets"))
55
+ os.environ.setdefault("HF_HUB_CACHE", os.path.join(_ROOT, ".hf_cache", "hub"))
56
+ os.environ.setdefault("MPLCONFIGDIR", os.path.join(_ROOT, ".mplconfig"))
57
+ for _p in [
58
+ os.environ["HF_HOME"],
59
+ os.environ["HF_DATASETS_CACHE"],
60
+ os.environ["HF_HUB_CACHE"],
61
+ os.environ["MPLCONFIGDIR"],
62
+ ]:
63
+ os.makedirs(_p, exist_ok=True)
64
+
65
+
66
+ def _process_mem_mb() -> float:
67
+ if psutil is not None:
68
+ try:
69
+ return float(psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024))
70
+ except Exception:
71
+ pass
72
+ # Linux fallback: /proc/self/statm current RSS pages.
73
+ try:
74
+ with open("/proc/self/statm", "r", encoding="utf-8") as f:
75
+ parts = f.read().strip().split()
76
+ if len(parts) >= 2:
77
+ rss_pages = int(parts[1])
78
+ page_size = os.sysconf("SC_PAGE_SIZE")
79
+ return float(rss_pages * page_size / (1024 * 1024))
80
+ except Exception:
81
+ pass
82
+ # Unix fallback: max RSS from resource (platform-dependent units).
83
+ try:
84
+ ru = float(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)
85
+ # Heuristic: macOS tends to return bytes, Linux tends KB.
86
+ if ru > 10_000_000:
87
+ return ru / (1024 * 1024)
88
+ return ru / 1024.0
89
+ except Exception:
90
+ return 0.0
91
+
92
+
93
+ # ── Shared loader ─────────────────────────────────────────────────────
94
+
95
+ def infer_model_type_from_checkpoint(ckpt_path: str) -> str:
96
+ name = ckpt_path.lower()
97
+ if "ablation_results/t" in name or "d3pm_cross_attention" in name:
98
+ return "d3pm_cross_attention"
99
+ if "d3pm_encoder_decoder" in name:
100
+ return "d3pm_encoder_decoder"
101
+ if "baseline_cross_attention" in name:
102
+ return "baseline_cross_attention"
103
+ if "baseline_encoder_decoder" in name:
104
+ return "baseline_encoder_decoder"
105
+ return CONFIG["model_type"]
106
+
107
+
108
+ def infer_include_negative_from_checkpoint(ckpt_path: str) -> bool:
109
+ name = ckpt_path.lower()
110
+ if "_neg_true" in name:
111
+ return True
112
+ if "_neg_false" in name:
113
+ return False
114
+ if "ablation_results/t" in name:
115
+ return False
116
+ return CONFIG["data"]["include_negative_examples"]
117
+
118
+
119
+ def load_everything(cfg, device, ckpt_override=None):
120
+ model_name = cfg['model_type']
121
+ has_neg = cfg['data']['include_negative_examples']
122
+ candidates = [
123
+ f"results7/{model_name}_neg_{has_neg}/best_model.pt",
124
+ f"results/{model_name}_neg_{has_neg}/best_model.pt",
125
+ f"results7/{model_name}_neg_True/best_model.pt",
126
+ f"results/{model_name}_neg_True/best_model.pt",
127
+ f"results7/{model_name}_neg_False/best_model.pt",
128
+ f"results/{model_name}_neg_False/best_model.pt",
129
+ "ablation_results/T4/best_model.pt",
130
+ "ablation_results/T8/best_model.pt",
131
+ ]
132
+ ckpt = ckpt_override if ckpt_override else next((p for p in candidates if os.path.exists(p)), None)
133
+ if not os.path.exists(ckpt):
134
+ raise FileNotFoundError(f"No checkpoint found. Checked: {candidates}")
135
+ model, cfg = load_model(ckpt, cfg, device)
136
+ model.eval()
137
+ src_tok = SanskritSourceTokenizer(
138
+ vocab_size=cfg['model'].get('src_vocab_size', 500),
139
+ max_len=cfg['model']['max_seq_len'])
140
+ tgt_tok = SanskritTargetTokenizer(
141
+ vocab_size=cfg['model'].get('tgt_vocab_size', 500),
142
+ max_len=cfg['model']['max_seq_len'])
143
+ return model, src_tok, tgt_tok, cfg
144
+
145
+
146
+ def load_val_data(cfg, src_tok, tgt_tok, n=500):
147
+ """Load validation set as (src_tensors, ref_strings, input_strings)."""
148
+ from data.dataset import OptimizedSanskritDataset
149
+ from torch.utils.data import Subset
150
+ from sklearn.model_selection import train_test_split
151
+
152
+ dataset = OptimizedSanskritDataset(
153
+ 'train', max_len=cfg['model']['max_seq_len'],
154
+ cfg=cfg, src_tokenizer=src_tok, tgt_tokenizer=tgt_tok)
155
+ total = min(cfg['data']['dataset_size'], len(dataset))
156
+ _, val_idx = train_test_split(list(range(total)), train_size=0.8, random_state=42)
157
+ val_idx = val_idx[:n]
158
+
159
+ src_list, ref_list, inp_list = [], [], []
160
+ for i in val_idx:
161
+ item = dataset[i]
162
+ src_list.append(item['input_ids'].unsqueeze(0))
163
+ ref_list.append(item['target_text'])
164
+ inp_list.append(item['input_text'])
165
+ return src_list, ref_list, inp_list
166
+
167
+
168
+ def _generate_ids_compat(model, src, num_steps=None, temperature=0.8, top_k=40,
169
+ repetition_penalty=1.2, diversity_penalty=0.0):
170
+ kwargs = dict(temperature=temperature, top_k=top_k)
171
+ if num_steps is not None:
172
+ kwargs["num_steps"] = int(num_steps)
173
+ if repetition_penalty is not None:
174
+ kwargs["repetition_penalty"] = float(repetition_penalty)
175
+ if diversity_penalty is not None:
176
+ kwargs["diversity_penalty"] = float(diversity_penalty)
177
+ try:
178
+ return model.generate(src, **kwargs)
179
+ except TypeError:
180
+ # Some model variants expose reduced generate() kwargs.
181
+ slim = {k: kwargs[k] for k in ["temperature", "top_k", "num_steps"] if k in kwargs}
182
+ try:
183
+ return model.generate(src, **slim)
184
+ except TypeError:
185
+ return model.generate(src)
186
+
187
+
188
+ def _decode_ids(tgt_tok, out_ids, src_text=None, inf_cfg=None):
189
+ ids = []
190
+ for x in out_ids[0].tolist():
191
+ # stop at PAD/SEP once decoding started
192
+ if x in (1, 4) and ids:
193
+ break
194
+ if x > 4:
195
+ ids.append(x)
196
+ if src_text is not None and inf_cfg is not None:
197
+ txt = _decode_with_cleanup(tgt_tok, ids, src_text, inf_cfg)
198
+ else:
199
+ txt = tgt_tok.decode(ids).strip()
200
+ return txt, ids
201
+
202
+
203
+ def _cer(a: str, b: str) -> float:
204
+ m, n = len(a), len(b)
205
+ if m == 0 and n == 0:
206
+ return 0.0
207
+ dp = list(range(n + 1))
208
+ for i in range(1, m + 1):
209
+ prev, dp[0] = dp[0], i
210
+ for j in range(1, n + 1):
211
+ tmp = dp[j]
212
+ dp[j] = prev if a[i-1] == b[j-1] else 1 + min(prev, dp[j], dp[j-1])
213
+ prev = tmp
214
+ return float(dp[n]) / max(1, m, n)
215
+
216
+
217
+ # ── Task 1 ────────────────────────────────────────────────────────────
218
+
219
+ def run_task1(model, src_tok, device):
220
+ print("\n" + "="*65)
221
+ print(" TASK 1 — KV Cache Benchmark")
222
+ print("="*65)
223
+ src_vocab = model.model.src_embed.token_emb.weight.shape[0]
224
+ src_lens = [16, 32, 64]
225
+ n_runs = 3
226
+ has_cached = hasattr(model, "generate_cached")
227
+ if not has_cached:
228
+ print(" Compatibility mode: generate_cached() unavailable; running standard benchmark only.")
229
+
230
+ def _timeit(fn, runs=n_runs):
231
+ vals = []
232
+ for _ in range(runs):
233
+ t0 = time.perf_counter()
234
+ fn()
235
+ vals.append(time.perf_counter() - t0)
236
+ return float(np.mean(vals))
237
+
238
+ def _trace_peak_bytes(fn, repeat=8):
239
+ gc.collect()
240
+ tracemalloc.start()
241
+ for _ in range(max(1, int(repeat))):
242
+ fn()
243
+ _, peak = tracemalloc.get_traced_memory()
244
+ tracemalloc.stop()
245
+ return int(peak)
246
+
247
+ def _torch_cpu_mem_bytes(fn):
248
+ try:
249
+ from torch.profiler import profile, ProfilerActivity
250
+ with profile(activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=False) as prof:
251
+ fn()
252
+ mem = 0
253
+ for ev in prof.key_averages():
254
+ try:
255
+ mem += max(0, int(getattr(ev, "self_cpu_memory_usage", 0)))
256
+ except Exception:
257
+ pass
258
+ return int(mem)
259
+ except Exception:
260
+ return 0
261
+
262
+ results = {}
263
+ for L in src_lens:
264
+ src = torch.randint(5, src_vocab, (1, L), device=device)
265
+ t_std = _timeit(lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40))
266
+
267
+ if has_cached:
268
+ t_cache = _timeit(
269
+ lambda: model.generate_cached(
270
+ src, num_steps=64, temperature=0.8, top_k=40,
271
+ repetition_penalty=1.2, diversity_penalty=0.0
272
+ )
273
+ )
274
+ speedup = t_std / max(t_cache, 1e-9)
275
+ else:
276
+ t_cache = t_std
277
+ speedup = 1.0
278
+
279
+ # Encoder cost estimate: one encode_source pass vs one cached step.
280
+ if hasattr(model.model, "encode_source") and hasattr(model.model, "forward_cached"):
281
+ memory, src_pad = model.model.encode_source(src)
282
+ x = torch.full((1, L), model.model.mask_token_id, dtype=torch.long, device=device)
283
+ t = torch.full((1,), max(0, model.model.scheduler.num_timesteps - 1), dtype=torch.long, device=device)
284
+ t_enc = _timeit(lambda: model.model.encode_source(src))
285
+ t_step = _timeit(lambda: model.model.forward_cached(memory, src_pad, x, t, x0_hint=None, inference_mode=True))
286
+ encoder_pct = (t_enc / max(t_enc + t_step, 1e-9)) * 100.0
287
+ else:
288
+ encoder_pct = 0.0
289
+
290
+ results[L] = dict(
291
+ standard_s=t_std,
292
+ cached_s=t_cache,
293
+ speedup=speedup,
294
+ encoder_pct=encoder_pct,
295
+ )
296
+ print(f" src_len={L:>3d} standard={t_std:.3f}s cached={t_cache:.3f}s speedup={speedup:.2f}x encoder%={encoder_pct:.1f}")
297
+
298
+ # Memory profiling (GPU preferred, CPU/MPS fallback via process RSS delta).
299
+ mem_note = "N/A"
300
+ mem_red = None
301
+ if torch.cuda.is_available() and str(device).startswith("cuda"):
302
+ L = 64
303
+ src = torch.randint(5, src_vocab, (1, L), device=device)
304
+ torch.cuda.reset_peak_memory_stats(device)
305
+ _ = _generate_ids_compat(model, src, temperature=0.8, top_k=40)
306
+ m_std = torch.cuda.max_memory_allocated(device)
307
+ torch.cuda.reset_peak_memory_stats(device)
308
+ _ = model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40,
309
+ repetition_penalty=1.2, diversity_penalty=0.0)
310
+ m_cache = torch.cuda.max_memory_allocated(device)
311
+ mem_red = 100.0 * (m_std - m_cache) / max(m_std, 1)
312
+ mem_note = f"GPU peak alloc reduction: {mem_red:.1f}% @ src_len=64"
313
+ print(f" Memory reduction: {mem_note}")
314
+ elif has_cached and _process_mem_mb() > 0.0:
315
+ L = 64
316
+ src = torch.randint(5, src_vocab, (1, L), device=device)
317
+
318
+ def _peak_rss_while(fn, poll_s=0.01):
319
+ done = {"v": False}
320
+ peak = {"v": _process_mem_mb()}
321
+
322
+ def _poll():
323
+ while not done["v"]:
324
+ peak["v"] = max(peak["v"], _process_mem_mb())
325
+ time.sleep(poll_s)
326
+ th = threading.Thread(target=_poll, daemon=True)
327
+ gc.collect()
328
+ base = _process_mem_mb()
329
+ th.start()
330
+ try:
331
+ fn()
332
+ finally:
333
+ done["v"] = True
334
+ th.join(timeout=0.1)
335
+ gc.collect()
336
+ return base, peak["v"], max(0.0, peak["v"] - base)
337
+
338
+ b_std, p_std, d_std = _peak_rss_while(
339
+ lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40)
340
+ )
341
+ b_c, p_c, d_c = _peak_rss_while(
342
+ lambda: model.generate_cached(
343
+ src, num_steps=64, temperature=0.8, top_k=40,
344
+ repetition_penalty=1.2, diversity_penalty=0.0
345
+ )
346
+ )
347
+ if d_std > 0.0:
348
+ mem_red = 100.0 * (d_std - d_c) / d_std
349
+ mem_note = (
350
+ f"RSS peak reduction: {mem_red:.1f}% @ src_len=64 "
351
+ f"(std_peak={p_std:.1f}MB, cache_peak={p_c:.1f}MB)"
352
+ )
353
+ else:
354
+ # Secondary fallback: Python allocator peak (always available).
355
+ peak_std = _trace_peak_bytes(
356
+ lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40), repeat=10
357
+ )
358
+ peak_cache = _trace_peak_bytes(
359
+ lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40,
360
+ repetition_penalty=1.2, diversity_penalty=0.0),
361
+ repeat=10
362
+ )
363
+ if peak_std >= 256 * 1024:
364
+ mem_red = 100.0 * (peak_std - peak_cache) / peak_std
365
+ mem_note = (
366
+ f"Py alloc peak reduction: {mem_red:.1f}% @ src_len=64 "
367
+ f"(std={peak_std/1024**2:.1f}MB, cache={peak_cache/1024**2:.1f}MB)"
368
+ )
369
+ else:
370
+ cpu_std = _torch_cpu_mem_bytes(
371
+ lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40)
372
+ )
373
+ cpu_cache = _torch_cpu_mem_bytes(
374
+ lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40,
375
+ repetition_penalty=1.2, diversity_penalty=0.0)
376
+ )
377
+ if cpu_std > 0:
378
+ mem_red = 100.0 * (cpu_std - cpu_cache) / max(cpu_std, 1)
379
+ mem_note = (
380
+ f"Torch CPU mem-event reduction: {mem_red:.1f}% @ src_len=64 "
381
+ f"(std={cpu_std/1024**2:.1f}MB, cache={cpu_cache/1024**2:.1f}MB)"
382
+ )
383
+ else:
384
+ mem_note = "Memory estimate unavailable (RSS/tracemalloc/torch-profiler flat)"
385
+ print(f" Memory reduction: {mem_note}")
386
+ elif has_cached:
387
+ # Final fallback (CPU-safe): Python allocation peak via tracemalloc.
388
+ # This does not include all native tensor allocator memory, but still
389
+ # gives a consistent relative signal when psutil/CUDA stats are absent.
390
+ L = 64
391
+ src = torch.randint(5, src_vocab, (1, L), device=device)
392
+ peak_std = _trace_peak_bytes(
393
+ lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40), repeat=10
394
+ )
395
+ peak_cache = _trace_peak_bytes(
396
+ lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40,
397
+ repetition_penalty=1.2, diversity_penalty=0.0),
398
+ repeat=10
399
+ )
400
+ # Ignore extremely small peaks; they are noise for tensor-heavy paths.
401
+ if peak_std >= 256 * 1024:
402
+ mem_red = 100.0 * (peak_std - peak_cache) / peak_std
403
+ mem_note = (
404
+ f"Py alloc peak reduction: {mem_red:.1f}% @ src_len=64 "
405
+ f"(std={peak_std/1024**2:.1f}MB, cache={peak_cache/1024**2:.1f}MB)"
406
+ )
407
+ else:
408
+ cpu_std = _torch_cpu_mem_bytes(
409
+ lambda: _generate_ids_compat(model, src, temperature=0.8, top_k=40)
410
+ )
411
+ cpu_cache = _torch_cpu_mem_bytes(
412
+ lambda: model.generate_cached(src, num_steps=64, temperature=0.8, top_k=40,
413
+ repetition_penalty=1.2, diversity_penalty=0.0)
414
+ )
415
+ if cpu_std > 0:
416
+ mem_red = 100.0 * (cpu_std - cpu_cache) / max(cpu_std, 1)
417
+ mem_note = (
418
+ f"Torch CPU mem-event reduction: {mem_red:.1f}% @ src_len=64 "
419
+ f"(std={cpu_std/1024**2:.1f}MB, cache={cpu_cache/1024**2:.1f}MB)"
420
+ )
421
+ else:
422
+ mem_note = "Py alloc peak too small/noisy to estimate (no psutil/CUDA profiler)"
423
+ print(f" Memory reduction: {mem_note}")
424
+ else:
425
+ mem_note = "Profiler unavailable (cached path missing)"
426
+
427
+ # Subtask graphs
428
+ lens = sorted(results.keys())
429
+ std_vals = [results[L]["standard_s"] for L in lens]
430
+ cache_vals = [results[L]["cached_s"] for L in lens]
431
+ speed_vals = [results[L]["speedup"] for L in lens]
432
+ enc_vals = [results[L]["encoder_pct"] for L in lens]
433
+
434
+ plt.figure(figsize=(7, 4))
435
+ plt.plot(lens, std_vals, marker="o", label="standard")
436
+ plt.plot(lens, cache_vals, marker="o", label="cached")
437
+ plt.xlabel("Source length")
438
+ plt.ylabel("Time (s)")
439
+ plt.title("Task1: Generation Time (Standard vs Cached)")
440
+ plt.legend()
441
+ plt.tight_layout()
442
+ plt.savefig(os.path.join(OUTPUT_DIR, "task1_time_comparison.png"), dpi=150, bbox_inches="tight")
443
+ plt.close()
444
+
445
+ plt.figure(figsize=(7, 4))
446
+ plt.plot(lens, speed_vals, marker="o")
447
+ plt.xlabel("Source length")
448
+ plt.ylabel("Speedup (x)")
449
+ plt.title("Task1: KV-Cache Speedup")
450
+ plt.tight_layout()
451
+ plt.savefig(os.path.join(OUTPUT_DIR, "task1_speedup.png"), dpi=150, bbox_inches="tight")
452
+ plt.close()
453
+
454
+ plt.figure(figsize=(7, 4))
455
+ plt.plot(lens, enc_vals, marker="o")
456
+ plt.xlabel("Source length")
457
+ plt.ylabel("Encoder cost (%)")
458
+ plt.title("Task1: Encoder Cost Share")
459
+ plt.tight_layout()
460
+ plt.savefig(os.path.join(OUTPUT_DIR, "task1_encoder_cost.png"), dpi=150, bbox_inches="tight")
461
+ plt.close()
462
+
463
+ path = os.path.join(OUTPUT_DIR, "task1_kv_cache.txt")
464
+ with open(path, "w") as f:
465
+ f.write("TASK 1 — KV CACHE BENCHMARK\n" + "="*40 + "\n\n")
466
+ f.write(f"has_generate_cached={has_cached}\n")
467
+ f.write(f"memory_profile={mem_note}\n\n")
468
+ f.write(f"{'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
469
+ f"{'speedup':>8} {'encoder%':>9}\n")
470
+ for src_len, r in results.items():
471
+ f.write(f"{src_len:>8} {r['standard_s']:>12.3f} {r['cached_s']:>10.3f} "
472
+ f"{r['speedup']:>7.2f}x {r['encoder_pct']:>8.1f}%\n")
473
+ f.write("\nSaved graphs:\n")
474
+ f.write(" - task1_time_comparison.png\n")
475
+ f.write(" - task1_speedup.png\n")
476
+ f.write(" - task1_encoder_cost.png\n")
477
+ print(f" Saved: {path}")
478
+
479
+
480
+ # ── Task 2 ────────────────────────────────────────────────────────────
481
+
482
+ def run_task2(model, src_tok, tgt_tok, device, input_text, cfg, corpus_inputs=None):
483
+ print("\n" + "="*65)
484
+ print(" TASK 2 — Attention Visualization + Semantic Drift")
485
+ print("="*65)
486
+ print(f" Input: {input_text}")
487
+ if not hasattr(model.model, 'encode_source'):
488
+ print(" Compatibility mode: attention hooks unavailable; running semantic-drift-only analysis.")
489
+ src_ids = src_tok.encode(input_text)
490
+ src = torch.tensor([src_ids], dtype=torch.long, device=device)
491
+ # Keep steps <= scheduler horizon for this checkpoint to avoid backend aborts.
492
+ t_sched = int(getattr(getattr(model.model, "scheduler", object()), "num_timesteps", 64))
493
+ # Stability guard for some checkpoints/backends: keep sweep moderate.
494
+ t_max = min(t_sched, 64)
495
+ candidates = [t_max, 48, 32, 24, 16, 8, 4, 1]
496
+ step_list = []
497
+ seen = set()
498
+ for s in candidates:
499
+ s = max(1, min(int(s), t_max))
500
+ if s not in seen:
501
+ step_list.append(s)
502
+ seen.add(s)
503
+ outs = {}
504
+ for s in step_list:
505
+ out = _generate_ids_compat(model, src, num_steps=s, temperature=0.8, top_k=40)
506
+ txt, _ = _decode_ids(
507
+ tgt_tok, out,
508
+ src_text=input_text,
509
+ inf_cfg=cfg.get("inference", {"temperature": 0.8, "top_k": 40})
510
+ )
511
+ outs[s] = txt
512
+ final = outs[1]
513
+ drift = [(_cer(outs[s], final), s) for s in step_list]
514
+ # Plot drift
515
+ xs = [s for _, s in drift]
516
+ ys = [c for c, _ in drift]
517
+ plt.figure(figsize=(8, 4))
518
+ plt.plot(xs, ys, marker='o')
519
+ plt.gca().invert_xaxis()
520
+ plt.xlabel("Generation steps")
521
+ plt.ylabel("CER to 1-step output")
522
+ plt.title("Task2 Semantic Drift (Compatibility Mode)")
523
+ plt.tight_layout()
524
+ plot_path = os.path.join(OUTPUT_DIR, "task2_semantic_drift.png")
525
+ plt.savefig(plot_path, dpi=150, bbox_inches="tight")
526
+ plt.close()
527
+ report = os.path.join(OUTPUT_DIR, "task2_report.txt")
528
+ with open(report, "w", encoding="utf-8") as f:
529
+ f.write("TASK 2 — COMPATIBILITY REPORT\n")
530
+ f.write("="*40 + "\n")
531
+ f.write("Cross-attention capture unavailable for this checkpoint.\n")
532
+ f.write(f"Input: {input_text}\n")
533
+ f.write(f"Reference final (1 step): {final}\n\n")
534
+ for cer_v, s in drift:
535
+ f.write(f"steps={s:>3d} CER_to_final={cer_v:.4f} output={outs[s][:120]}\n")
536
+ print(f" Output(final@1): {final}")
537
+ print(f" Report: {report}")
538
+ print(f" Saved: {plot_path}")
539
+ return
540
+
541
+ src_ids = src_tok.encode(input_text)
542
+ src_tensor = torch.tensor([src_ids], dtype=torch.long, device=device)
543
+
544
+ from analysis.attention_viz import (
545
+ AttentionCapture,
546
+ compute_trajectory_metrics,
547
+ analyze_token_stability,
548
+ tfidf_attention_correlation,
549
+ )
550
+
551
+ # Attention capture
552
+ print(" Capturing attention weights...")
553
+ capturer = AttentionCapture(model)
554
+ step_weights, step_outputs_ids = capturer.run(src_tensor)
555
+
556
+ def _decode_tensor_ids(t):
557
+ out = []
558
+ for x in t[0].tolist():
559
+ if x in (1, 4) and out:
560
+ break
561
+ if x > 4:
562
+ out.append(x)
563
+ raw_txt = tgt_tok.decode(out).strip()
564
+ clean_txt = _decode_with_cleanup(
565
+ tgt_tok, out, input_text, cfg.get("inference", {"temperature": 0.8, "top_k": 40})
566
+ )
567
+ return raw_txt, clean_txt, out
568
+
569
+ decoded = {}
570
+ decoded_raw = {}
571
+ for t_val, ids_t in step_outputs_ids.items():
572
+ raw_txt, clean_txt, ids = _decode_tensor_ids(ids_t)
573
+ decoded_raw[t_val] = (raw_txt, ids)
574
+ decoded[t_val] = (clean_txt, ids)
575
+ final_step = min(decoded.keys())
576
+ final_out, final_ids = decoded[final_step]
577
+ final_out_raw = decoded_raw[final_step][0]
578
+ src_labels = []
579
+ for sid in src_ids[:20]:
580
+ tok = src_tok.decode([sid]).strip()
581
+ src_labels.append(tok if tok else f"id{sid}")
582
+ tgt_labels = [f"y{i}" for i in range(min(20, len(final_ids)))]
583
+ print(f" Output: {final_out}")
584
+
585
+ # Heatmap t=max, layer 0
586
+ first_t = max(step_weights.keys())
587
+ w_first = step_weights[first_t][0][0]
588
+ w0 = step_weights[0][0][0]
589
+ n_src = min(len(src_labels), w_first.shape[1], 20)
590
+ n_tgt = min(len(tgt_labels), w_first.shape[0], 20)
591
+ plt.figure(figsize=(max(8, n_src * 0.35), max(6, n_tgt * 0.3)))
592
+ plt.imshow(w_first[:n_tgt, :n_src], aspect="auto", cmap="YlOrRd")
593
+ plt.xticks(range(n_src), src_labels[:n_src], rotation=45, ha="right", fontsize=8)
594
+ plt.yticks(range(n_tgt), tgt_labels[:n_tgt], fontsize=8)
595
+ plt.title(f"Attention t={first_t} Layer 0")
596
+ plt.tight_layout()
597
+ plt.savefig(os.path.join(OUTPUT_DIR, f"task2_attn_t{first_t}.png"), dpi=150, bbox_inches="tight")
598
+ plt.close()
599
+
600
+ plt.figure(figsize=(max(8, n_src * 0.35), max(6, n_tgt * 0.3)))
601
+ plt.imshow(w0[:n_tgt, :n_src], aspect="auto", cmap="YlOrRd")
602
+ plt.xticks(range(n_src), src_labels[:n_src], rotation=45, ha="right", fontsize=8)
603
+ plt.yticks(range(n_tgt), tgt_labels[:n_tgt], fontsize=8)
604
+ plt.title("Attention t=0 Layer 0")
605
+ plt.tight_layout()
606
+ plt.savefig(os.path.join(OUTPUT_DIR, "task2_attn_t0.png"), dpi=150, bbox_inches="tight")
607
+ plt.close()
608
+
609
+ # All layers at t=0
610
+ layers = step_weights[0]
611
+ n_layers = len(layers)
612
+ n_cols = min(4, n_layers)
613
+ n_rows = (n_layers + n_cols - 1) // n_cols
614
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3.2))
615
+ axes = np.array(axes).reshape(-1)
616
+ for i, layer_w in enumerate(layers):
617
+ ax = axes[i]
618
+ w = layer_w[0][:n_tgt, :n_src]
619
+ ax.imshow(w, aspect="auto", cmap="YlOrRd")
620
+ ax.set_title(f"Layer {i}", fontsize=9)
621
+ ax.set_xticks([])
622
+ ax.set_yticks([])
623
+ for i in range(n_layers, len(axes)):
624
+ axes[i].axis("off")
625
+ plt.tight_layout()
626
+ plt.savefig(os.path.join(OUTPUT_DIR, "task2_all_layers_t0.png"), dpi=150, bbox_inches="tight")
627
+ plt.close()
628
+
629
+ # Attention evolution for src[0] -> tgt[0]
630
+ t_vals_desc = sorted(step_weights.keys(), reverse=True)
631
+ evo = []
632
+ for t_val in t_vals_desc:
633
+ w = step_weights[t_val][0][0]
634
+ evo.append(float(w[0, 0]) if w.shape[0] > 0 and w.shape[1] > 0 else 0.0)
635
+ plt.figure(figsize=(10, 3.5))
636
+ plt.plot(range(len(t_vals_desc)), evo, marker="o")
637
+ plt.xlabel("Captured step index (T→0)")
638
+ plt.ylabel("Attention weight")
639
+ plt.title("Attention Evolution (src0→tgt0)")
640
+ plt.tight_layout()
641
+ plt.savefig(os.path.join(OUTPUT_DIR, "task2_attn_evolution.png"), dpi=150, bbox_inches="tight")
642
+ plt.close()
643
+
644
+ # Drift (CER to final across steps) on RAW decoded trajectory to expose true diffusion.
645
+ t_vals = sorted(decoded.keys(), reverse=True)
646
+ cer_vals = [_cer(decoded_raw[t][0], final_out_raw) for t in t_vals]
647
+ plt.figure(figsize=(8, 4))
648
+ plt.plot(t_vals, cer_vals, marker="o")
649
+ plt.gca().invert_xaxis()
650
+ plt.xlabel("Diffusion step")
651
+ plt.ylabel("CER to final")
652
+ plt.title("Task2 Semantic Drift")
653
+ plt.tight_layout()
654
+ plt.savefig(os.path.join(OUTPUT_DIR, "task2_semantic_drift.png"), dpi=150, bbox_inches="tight")
655
+ plt.close()
656
+
657
+ # Source alignment proxy (avg attention on source positions at t=0, last layer)
658
+ last_layer_t0 = step_weights[0][-1][0]
659
+ src_align = last_layer_t0.mean(axis=0)[:n_src]
660
+ plt.figure(figsize=(8, 3))
661
+ plt.bar(np.arange(len(src_align)), src_align)
662
+ plt.xticks(range(n_src), src_labels[:n_src], rotation=45, ha="right", fontsize=8)
663
+ plt.title("Source Alignment Importance (t=0, last layer)")
664
+ plt.tight_layout()
665
+ plt.savefig(os.path.join(OUTPUT_DIR, "task2_source_alignment.png"), dpi=150, bbox_inches="tight")
666
+ plt.close()
667
+
668
+ stability = analyze_token_stability(step_weights)
669
+ n_locked = sum(1 for v in stability.values() if v == "LOCKED")
670
+ n_flex = sum(1 for v in stability.values() if v == "FLEXIBLE")
671
+ tfidf_info = tfidf_attention_correlation(input_text, step_weights, corpus_texts=corpus_inputs)
672
+ tfidf_corr = tfidf_info.get("corr")
673
+ tfidf_status = tfidf_info.get("status", "UNKNOWN")
674
+ traj = compute_trajectory_metrics(
675
+ step_outputs_ids,
676
+ tgt_tok,
677
+ reference_text=_iast_to_deva(input_text),
678
+ )
679
+ # Keep trajectory semantic scoring on raw decoded text to avoid masking drift.
680
+ ref_text = _iast_to_deva(input_text)
681
+ for row in traj:
682
+ t_cur = row["step"]
683
+ raw_txt = decoded_raw.get(t_cur, ("", []))[0]
684
+ if raw_txt:
685
+ sim = max(0.0, 1.0 - _cer(raw_txt, ref_text))
686
+ row["text"] = raw_txt
687
+ row["bert"] = sim
688
+ row["drift"] = 1.0 - sim
689
+
690
+ # TF-IDF vs attention graph (subtask visualization)
691
+ tfidf_vec = np.asarray(tfidf_info.get("tfidf_scores", []), dtype=np.float32)
692
+ attn_vec = np.asarray(tfidf_info.get("attn_scores", []), dtype=np.float32)
693
+ labels = list(tfidf_info.get("tokens", []))
694
+ m = min(len(tfidf_vec), len(attn_vec), len(labels), 20)
695
+ if m > 0:
696
+ x = np.arange(m)
697
+ plt.figure(figsize=(8, 3.5))
698
+ tf_part = tfidf_vec[:m]
699
+ at_part = attn_vec[:m]
700
+ tf_norm = tf_part / (np.max(np.abs(tf_part)) + 1e-9)
701
+ at_norm = at_part / (np.max(np.abs(at_part)) + 1e-9)
702
+ w = 0.4
703
+ plt.bar(x - w/2, tf_norm, width=w, label="tfidf(norm)")
704
+ plt.bar(x + w/2, at_norm, width=w, label="attn(norm)")
705
+ plt.xlabel("Source token")
706
+ plt.ylabel("Normalized score")
707
+ plt.title("Task2: TF-IDF vs Attention Stability")
708
+ plt.xticks(x, labels[:m], rotation=45, ha="right", fontsize=8)
709
+ plt.legend()
710
+ plt.tight_layout()
711
+ plt.savefig(os.path.join(OUTPUT_DIR, "task2_tfidf_vs_attention.png"), dpi=150, bbox_inches="tight")
712
+ plt.close()
713
+
714
+ lock_in_t = next((t for t, c in zip(t_vals[::-1], cer_vals[::-1]) if c <= 0.05), t_vals[-1])
715
+ if tfidf_corr is not None and abs(float(tfidf_corr)) < 0.10:
716
+ tfidf_status = "WEAK"
717
+ has_semantic = any(float(r.get("bert", 0.0)) > 0.05 for r in traj)
718
+ # Degeneracy score on final output
719
+ toks = [t for t in final_out.split() if t]
720
+ uniq_ratio = len(set(toks)) / max(1, len(toks))
721
+ degenerate = (len(toks) >= 8 and uniq_ratio < 0.35)
722
+
723
+ # Small multi-sample stability check (prevents overclaim from one example)
724
+ multi_scores = []
725
+ if corpus_inputs:
726
+ sample_texts = [s for s in corpus_inputs[:8] if isinstance(s, str) and s.strip()]
727
+ for txt in sample_texts:
728
+ src_i = torch.tensor([src_tok.encode(txt)], dtype=torch.long, device=device)
729
+ out_i = _generate_ids_compat(model, src_i, num_steps=min(16, cfg.get("inference", {}).get("num_steps", 16)),
730
+ temperature=0.8, top_k=40)
731
+ pred_i, _ = _decode_ids(tgt_tok, out_i)
732
+ multi_scores.append(max(0.0, 1.0 - _cer(pred_i, _iast_to_deva(txt))))
733
+ multi_sem = float(np.mean(multi_scores)) if multi_scores else 0.0
734
+
735
+ quality_status = (
736
+ "VALID"
737
+ if len(final_out.strip()) > 0 and n_flex + n_locked > 0 and has_semantic and not degenerate and multi_sem >= 0.05
738
+ else "WEAK"
739
+ )
740
+ report = os.path.join(OUTPUT_DIR, "task2_report.txt")
741
+ with open(report, "w", encoding="utf-8") as f:
742
+ f.write("TASK 2 — ATTENTION + DRIFT REPORT\n" + "=" * 50 + "\n\n")
743
+ f.write(f"Input : {input_text}\n")
744
+ f.write(f"Output: {final_out}\n\n")
745
+ f.write(f"Captured steps: {len(t_vals)}\n")
746
+ f.write(f"Analysis quality: {quality_status}\n")
747
+ f.write(f"Final output uniq-ratio: {uniq_ratio:.3f}\n")
748
+ f.write(f"Degenerate output: {'YES' if degenerate else 'NO'}\n")
749
+ f.write(f"Multi-sample semantic score (n<={len(multi_scores)}): {multi_sem:.4f}\n")
750
+ f.write(f"Lock-in step (CER<=0.05): t={lock_in_t}\n")
751
+ f.write(f"Locked tokens: {n_locked} Flexible tokens: {n_flex}\n")
752
+ corr_txt = f"{tfidf_corr:.4f}" if tfidf_corr is not None else "N/A"
753
+ f.write(f"TF-IDF vs attention stability corr: {corr_txt}\n")
754
+ f.write(f"TF-IDF status: {tfidf_status}\n\n")
755
+ f.write("Saved graphs:\n")
756
+ f.write(" - task2_attn_t*.png / task2_all_layers_t0.png\n")
757
+ f.write(" - task2_attn_evolution.png\n")
758
+ f.write(" - task2_semantic_drift.png\n")
759
+ f.write(" - task2_source_alignment.png\n")
760
+ f.write(" - task2_tfidf_vs_attention.png\n\n")
761
+ f.write("Step trajectory (first 10 rows)\n")
762
+ f.write("-" * 60 + "\n")
763
+ for row in traj[:10]:
764
+ f.write(f"t={row['step']:>3d} bert={row['bert']:.4f} drift={row['drift']:.4f} text={row['text'][:60]}\n")
765
+
766
+ print(f" Lock-in timestep: t={lock_in_t}")
767
+ print(f" Locked/Flexible: {n_locked}/{n_flex}")
768
+ corr_txt = f"{tfidf_corr:.4f}" if tfidf_corr is not None else "N/A"
769
+ print(f" TF-IDF corr: {corr_txt} ({tfidf_status})")
770
+ print(f" Report: {report}")
771
+
772
+
773
+ # ── Task 3 ────────────────────────────────────────────────────────────
774
+
775
+ def run_task3(model, src_tok, tgt_tok, device, src_list, ref_list, n_samples=500):
776
+ print("\n" + "="*65)
777
+ print(" TASK 3 — Concept Vectors + PCA Steering")
778
+ print("="*65)
779
+ if not hasattr(model.model, 'encode_source'):
780
+ print(" Compatibility mode: using output-token statistics for PCA concept proxy.")
781
+ # Keep compatibility run lightweight/stable on constrained backends.
782
+ n = min(60, len(src_list))
783
+ feats, lens = [], []
784
+ for i, src in enumerate(src_list[:n]):
785
+ out = _generate_ids_compat(model, src.to(device), num_steps=8, temperature=0.8, top_k=40)
786
+ txt, ids = _decode_ids(tgt_tok, out)
787
+ arr = np.array(ids[:64] + [0] * max(0, 64 - len(ids[:64])), dtype=np.float32)
788
+ feats.append(arr)
789
+ lens.append(len(txt))
790
+ from sklearn.decomposition import PCA
791
+ X = np.stack(feats)
792
+ pca = PCA(n_components=min(10, X.shape[0]-1, X.shape[1]))
793
+ Z = pca.fit_transform(X)
794
+ plt.figure(figsize=(6, 5))
795
+ sc = plt.scatter(Z[:, 0], Z[:, 1] if Z.shape[1] > 1 else np.zeros_like(Z[:, 0]),
796
+ c=lens, cmap="viridis", s=14)
797
+ plt.colorbar(sc, label="Output length")
798
+ plt.title("Task3 Concept Proxy Space (Compatibility Mode)")
799
+ plt.tight_layout()
800
+ img = os.path.join(OUTPUT_DIR, "task3_concept_space.png")
801
+ plt.savefig(img, dpi=150, bbox_inches="tight")
802
+ plt.close()
803
+ rep = os.path.join(OUTPUT_DIR, "task3_report.txt")
804
+ corr = float(np.corrcoef(Z[:, 0], np.array(lens))[0, 1]) if len(lens) > 2 else 0.0
805
+ with open(rep, "w", encoding="utf-8") as f:
806
+ f.write("TASK 3 — COMPATIBILITY REPORT\n")
807
+ f.write("="*40 + "\n")
808
+ f.write("Hidden-state capture unavailable; used output-token vector proxy.\n")
809
+ f.write(f"Samples: {n}\n")
810
+ f.write(f"PC1-length correlation: {corr:.4f}\n")
811
+ print(f" Saved: {img}")
812
+ print(f" Report: {rep}")
813
+ return
814
+
815
+ from analysis.concept_vectors import (
816
+ collect_hidden_states, fit_pca, find_diversity_direction, generate_diversity_spectrum
817
+ )
818
+
819
+ # Collect hidden states from val set
820
+ n = min(max(1, int(n_samples)), len(src_list))
821
+ print(f" Collecting hidden states from {n} examples...")
822
+ hidden, texts, lengths = collect_hidden_states(
823
+ model, src_list[:n], tgt_tok, t_capture=0, max_samples=n
824
+ )
825
+
826
+ # Fit PCA + find diversity direction
827
+ pca = fit_pca(hidden, n_components=min(50, n-1))
828
+ direction = find_diversity_direction(hidden, lengths, pca)
829
+ proj = pca.transform(hidden)
830
+ corr = float(np.corrcoef(proj[:, 0], np.array(lengths))[0, 1]) if len(lengths) > 2 else 0.0
831
+ if not np.isfinite(corr):
832
+ corr = 0.0
833
+ best_pc = 0
834
+
835
+ # Plot concept space
836
+ plt.figure(figsize=(8, 6))
837
+ sc = plt.scatter(proj[:, 0], proj[:, 1] if proj.shape[1] > 1 else np.zeros_like(proj[:, 0]),
838
+ c=lengths, cmap="viridis", s=14)
839
+ plt.colorbar(sc, label="Output diversity proxy")
840
+ plt.title("Task3 Concept Space")
841
+ plt.xlabel("PC1")
842
+ plt.ylabel("PC2")
843
+ plt.tight_layout()
844
+ plt.savefig(os.path.join(OUTPUT_DIR, "task3_concept_space.png"), dpi=150, bbox_inches="tight")
845
+ plt.close()
846
+
847
+ # Subtask graph: explained variance by PCA components
848
+ ev = pca.explained_variance_ratio_
849
+ k = min(20, len(ev))
850
+ plt.figure(figsize=(8, 3.5))
851
+ plt.bar(np.arange(k), ev[:k])
852
+ plt.xlabel("PC index")
853
+ plt.ylabel("Explained variance ratio")
854
+ plt.title("Task3: PCA Explained Variance (Top Components)")
855
+ plt.tight_layout()
856
+ plt.savefig(os.path.join(OUTPUT_DIR, "task3_pca_explained_variance.png"), dpi=150, bbox_inches="tight")
857
+ plt.close()
858
+
859
+ # Generate diversity spectrum on multiple seeds for more stable conclusions
860
+ seed_k = min(5, len(src_list))
861
+ uniq_list = []
862
+ sem_list = []
863
+ all_spectra = []
864
+ for i in range(seed_k):
865
+ src_i = src_list[i]
866
+ spec_i = generate_diversity_spectrum(
867
+ model, src_i.to(device), direction, tgt_tok,
868
+ alphas=[-2.0, -1.0, 0.0, 1.0, 2.0]
869
+ )
870
+ all_spectra.append(spec_i)
871
+ spec_items = sorted(spec_i.items())
872
+ spec_texts = [t for _, t in spec_items]
873
+ uniq_list.append(len(set(spec_texts)) / max(1, len(spec_texts)))
874
+ pivot = spec_texts[2] if len(spec_texts) >= 3 else (spec_texts[0] if spec_texts else "")
875
+ sims = [SequenceMatcher(None, txt, pivot).ratio() for txt in spec_texts if txt]
876
+ sem_list.append(float(np.mean(sims)) if sims else 0.0)
877
+ uniq_ratio = float(np.mean(uniq_list)) if uniq_list else 0.0
878
+ semantic_stability = float(np.mean(sem_list)) if sem_list else 0.0
879
+ steering_valid = (abs(corr) >= 0.20) and (uniq_ratio >= 0.55) and (semantic_stability >= 0.40)
880
+ # use first seed spectrum for visualization table
881
+ spectrum = all_spectra[0] if all_spectra else {}
882
+
883
+ # Subtask graph: alpha vs decoded length
884
+ a_vals = sorted(spectrum.keys())
885
+ l_vals = [len(spectrum[a]) for a in a_vals] if spectrum else []
886
+ plt.figure(figsize=(7, 3.5))
887
+ plt.plot(a_vals, l_vals, marker="o")
888
+ plt.xlabel("Steering alpha")
889
+ plt.ylabel("Output length")
890
+ plt.title("Task3: Diversity Steering Curve")
891
+ plt.tight_layout()
892
+ plt.savefig(os.path.join(OUTPUT_DIR, "task3_diversity_curve.png"), dpi=150, bbox_inches="tight")
893
+ plt.close()
894
+
895
+ # Save diversity direction + results
896
+ np.save(os.path.join(OUTPUT_DIR, "task3_diversity_direction.npy"), direction)
897
+
898
+ report = os.path.join(OUTPUT_DIR, "task3_report.txt")
899
+ with open(report, "w", encoding="utf-8") as f:
900
+ f.write("TASK 3 — CONCEPT VECTORS + PCA STEERING\n" + "="*50 + "\n\n")
901
+ f.write(f"PCA: {pca.n_components_} components, "
902
+ f"{pca.explained_variance_ratio_.sum()*100:.1f}% variance\n")
903
+ f.write(f"Diversity PC: {best_pc} (|r|={corr:.3f} with diversity proxy)\n\n")
904
+ f.write(f"Direction validity: {'VALID' if steering_valid else 'WEAK'}\n")
905
+ f.write(f"Spectrum unique ratio (mean over {seed_k} seeds): {uniq_ratio:.3f}\n")
906
+ f.write(f"Spectrum semantic stability (mean over {seed_k} seeds): {semantic_stability:.3f}\n\n")
907
+ f.write("Saved graphs:\n")
908
+ f.write(" - task3_concept_space.png\n")
909
+ f.write(" - task3_pca_explained_variance.png\n")
910
+ f.write(" - task3_diversity_curve.png\n\n")
911
+ f.write("Diversity spectrum:\n")
912
+ for alpha, text in sorted(spectrum.items()):
913
+ f.write(f" alpha={alpha:+.1f} → {text}\n")
914
+ print(f" Report: {report}")
915
+
916
+
917
+ # ── Task 4 ────────────────────────────────────────────────────────────
918
+
919
+ def run_task4(phase, model, src_tok, tgt_tok, device, cfg,
920
+ src_list, ref_list, n_samples=200):
921
+ print("\n" + "="*65)
922
+ print(f" TASK 4 — Step Ablation (phase={phase})")
923
+ print("="*65)
924
+
925
+ import analysis.step_ablation as step_ablation
926
+
927
+ # Legacy API
928
+ has_legacy = all(hasattr(step_ablation, fn) for fn in [
929
+ "generate_ablation_configs", "run_ablation_analysis", "plot_ablation_3d"
930
+ ])
931
+
932
+ # New API
933
+ has_new = hasattr(step_ablation, "run_task4")
934
+
935
+ if phase == "generate_configs":
936
+ if has_legacy:
937
+ print(" Generating ablation configs...")
938
+ step_ablation.generate_ablation_configs(output_dir="ablation_configs")
939
+ print("\n NEXT STEPS:")
940
+ print(" 1. bash ablation_configs/train_all.sh")
941
+ print(" 2. python analysis/run_analysis.py --task 4 --phase analyze")
942
+ return
943
+ print(" This step_ablation version does not expose config generation helpers.")
944
+ print(" Use your latest ablation training script/config pipeline directly.")
945
+ return
946
+
947
+ if phase == "analyze":
948
+ existing = [T for T in [4, 8, 16, 32, 64]
949
+ if os.path.exists(f"ablation_results/T{T}/best_model.pt")]
950
+ only_t = os.environ.get("TASK4_ONLY_T")
951
+ if only_t and only_t.isdigit():
952
+ t_req = int(only_t)
953
+ existing = [T for T in existing if T == t_req]
954
+ if not existing:
955
+ print(" No ablation models found at ablation_results/T*/best_model.pt")
956
+ return
957
+ print(f" Found models for T={existing}")
958
+
959
+ if has_legacy:
960
+ results = step_ablation.run_ablation_analysis(
961
+ ablation_dir="ablation_results", base_cfg=cfg,
962
+ src_list=src_list[:200], ref_list=ref_list[:200],
963
+ tgt_tokenizer=tgt_tok, device=device,
964
+ output_dir=OUTPUT_DIR)
965
+ step_ablation.plot_ablation_3d(
966
+ results, save_path=os.path.join(OUTPUT_DIR, "task4_ablation_3d.png"))
967
+ elif has_new:
968
+ from inference import load_model as _load_model
969
+ models = {}
970
+ for T in existing:
971
+ ckpt = f"ablation_results/T{T}/best_model.pt"
972
+ cfg_t = copy.deepcopy(cfg)
973
+ cfg_t["model"]["diffusion_steps"] = T
974
+ cfg_t["inference"]["num_steps"] = T
975
+ m_t, _ = _load_model(ckpt, cfg_t, device)
976
+ m_t.eval()
977
+ models[T] = m_t
978
+ knee_t = step_ablation.run_task4(
979
+ models, src_list[:n_samples], ref_list[:n_samples], tgt_tok,
980
+ output_dir=OUTPUT_DIR, n_samples=n_samples)
981
+ print(f" New pipeline suggested optimal T={knee_t}")
982
+ else:
983
+ print(" Unsupported step_ablation API; please sync analysis/step_ablation.py")
984
+ return
985
+
986
+ # Optional adversarial robustness (legacy helper only)
987
+ if hasattr(step_ablation, "run_adversarial_test"):
988
+ print("\n Running adversarial robustness test...")
989
+ inp_texts = [src_tok.decode([x for x in s[0].tolist() if x > 4])
990
+ for s in src_list[:50]]
991
+ step_ablation.run_adversarial_test(
992
+ model, src_tok, tgt_tok,
993
+ test_inputs=inp_texts, test_refs=ref_list[:50],
994
+ device=device, output_dir=OUTPUT_DIR)
995
+
996
+
997
+ # ── Task 5 ────────────────────────────────────────────────────────────
998
+
999
+ def run_task5(model, src_tok, tgt_tok, device, cfg, src_list, ref_list, task5_samples=500):
1000
+ print("\n" + "="*65)
1001
+ print(" TASK 5 — Classifier-Free Guidance")
1002
+ print("="*65)
1003
+ if not hasattr(model.model, 'encode_source'):
1004
+ print(" Compatibility mode: classifier-guidance unavailable; sweeping decoding controls.")
1005
+ n = min(100, int(task5_samples), len(src_list), len(ref_list))
1006
+ lambdas = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0]
1007
+ results = []
1008
+ for lam in lambdas:
1009
+ rep_pen = 1.0 + 0.15 * lam
1010
+ cer_vals, uniq_vals = [], []
1011
+ for src, ref in zip(src_list[:n], ref_list[:n]):
1012
+ out = _generate_ids_compat(
1013
+ model, src.to(device), num_steps=8, temperature=0.8, top_k=40,
1014
+ repetition_penalty=rep_pen, diversity_penalty=0.0
1015
+ )
1016
+ txt, ids = _decode_ids(tgt_tok, out)
1017
+ cer_vals.append(_cer(txt, ref))
1018
+ uniq_vals.append(len(set(ids)) / max(1, len(ids)))
1019
+ results.append((lam, float(np.mean(cer_vals)), float(np.mean(uniq_vals))))
1020
+ print(f" λ={lam:.1f} CER={results[-1][1]:.4f} diversity={results[-1][2]:.3f}")
1021
+ # Subtask graph: quality-diversity tradeoff
1022
+ x = [r[1] for r in results]
1023
+ y = [r[2] for r in results]
1024
+ labels = [r[0] for r in results]
1025
+ plt.figure(figsize=(6, 4))
1026
+ plt.plot(x, y, marker="o")
1027
+ for xi, yi, la in zip(x, y, labels):
1028
+ plt.text(xi, yi, f"λ={la:.1f}", fontsize=8)
1029
+ plt.xlabel("CER (lower is better)")
1030
+ plt.ylabel("Diversity")
1031
+ plt.title("Task5: Quality-Diversity Tradeoff")
1032
+ plt.tight_layout()
1033
+ plt.savefig(os.path.join(OUTPUT_DIR, "task5_quality_diversity_tradeoff.png"), dpi=150, bbox_inches="tight")
1034
+ plt.close()
1035
+ rep = os.path.join(OUTPUT_DIR, "task5_report.txt")
1036
+ with open(rep, "w", encoding="utf-8") as f:
1037
+ f.write("TASK 5 — COMPATIBILITY REPORT\n")
1038
+ f.write("="*40 + "\n")
1039
+ f.write("Guidance classifier path unavailable; λ mapped to repetition penalty.\n\n")
1040
+ for lam, cer_v, div_v in results:
1041
+ f.write(f"lambda={lam:.1f} CER={cer_v:.4f} diversity={div_v:.3f}\n")
1042
+ f.write("\nSaved graphs:\n")
1043
+ f.write(" - task5_quality_diversity_tradeoff.png\n")
1044
+ print(f" Report: {rep}")
1045
+ return
1046
+
1047
+ try:
1048
+ from analysis.quality_classifier import (
1049
+ QualityClassifier, collect_quality_data,
1050
+ train_quality_classifier, sweep_guidance_scales)
1051
+ except Exception:
1052
+ print(" Quality-classifier API mismatch; using compatibility sweep.")
1053
+ n = min(50, int(task5_samples), len(src_list))
1054
+ scales = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0]
1055
+ results = []
1056
+ for lam in scales:
1057
+ rep_pen = 1.0 + 0.2 * lam
1058
+ cer_vals, uniq_vals = [], []
1059
+ for src, ref in zip(src_list[:n], ref_list[:n]):
1060
+ out = _generate_ids_compat(
1061
+ model, src.to(device), num_steps=8, temperature=0.8, top_k=40,
1062
+ repetition_penalty=rep_pen, diversity_penalty=0.0
1063
+ )
1064
+ txt, ids = _decode_ids(tgt_tok, out)
1065
+ cer_vals.append(_cer(txt, ref))
1066
+ uniq_vals.append(len(set(ids)) / max(1, len(ids)))
1067
+ results.append((lam, float(np.mean(cer_vals)), float(np.mean(uniq_vals))))
1068
+ print(f" λ={lam:.1f} CER={results[-1][1]:.4f} diversity={results[-1][2]:.3f}")
1069
+ # Subtask graph: quality-diversity tradeoff
1070
+ x = [r[1] for r in results]
1071
+ y = [r[2] for r in results]
1072
+ labels = [r[0] for r in results]
1073
+ plt.figure(figsize=(6, 4))
1074
+ plt.plot(x, y, marker="o")
1075
+ for xi, yi, la in zip(x, y, labels):
1076
+ plt.text(xi, yi, f"λ={la:.1f}", fontsize=8)
1077
+ plt.xlabel("CER (lower is better)")
1078
+ plt.ylabel("Diversity")
1079
+ plt.title("Task5: Quality-Diversity Tradeoff")
1080
+ plt.tight_layout()
1081
+ plt.savefig(os.path.join(OUTPUT_DIR, "task5_quality_diversity_tradeoff.png"), dpi=150, bbox_inches="tight")
1082
+ plt.close()
1083
+ rep = os.path.join(OUTPUT_DIR, "task5_report.txt")
1084
+ with open(rep, "w", encoding="utf-8") as f:
1085
+ f.write("TASK 5 — COMPATIBILITY REPORT\n")
1086
+ f.write("="*40 + "\n")
1087
+ f.write("Guidance classifier path unavailable; λ mapped to repetition penalty.\n\n")
1088
+ for lam, cer_v, div_v in results:
1089
+ f.write(f"lambda={lam:.1f} CER={cer_v:.4f} diversity={div_v:.3f}\n")
1090
+ f.write("\nSaved graphs:\n")
1091
+ f.write(" - task5_quality_diversity_tradeoff.png\n")
1092
+ print(f" Report: {rep}")
1093
+ return
1094
+
1095
+ clf_path = os.path.join(OUTPUT_DIR, "task5_quality_classifier.pt")
1096
+ d_model = cfg['model']['d_model']
1097
+
1098
+ # Step 1: collect or load training data
1099
+ data_path = os.path.join(OUTPUT_DIR, "task5_quality_data.npz")
1100
+ if os.path.exists(data_path):
1101
+ print(" Loading cached quality data...")
1102
+ data = np.load(data_path)
1103
+ hidden = data["hidden"]
1104
+ quality = data["quality"]
1105
+ else:
1106
+ print(" Collecting quality data (this takes a few minutes)...")
1107
+ n = min(int(task5_samples), len(src_list))
1108
+ hidden, quality = collect_quality_data(
1109
+ model, src_list[:n], ref_list[:n], tgt_tok,
1110
+ t_capture=0, max_samples=n)
1111
+ np.savez(data_path, hidden=hidden, quality=quality)
1112
+ print(f" Saved quality data: {data_path}")
1113
+
1114
+ # Step 2: train or load classifier
1115
+ if os.path.exists(clf_path):
1116
+ print(f" Loading cached classifier: {clf_path}")
1117
+ clf = QualityClassifier(d_model)
1118
+ clf.load_state_dict(torch.load(clf_path, map_location='cpu'))
1119
+ clf.eval()
1120
+ else:
1121
+ print(" Training quality classifier...")
1122
+ clf = train_quality_classifier(
1123
+ hidden, quality, d_model=d_model,
1124
+ epochs=30, batch_size=64, lr=1e-3,
1125
+ save_path=clf_path)
1126
+ clf.eval()
1127
+
1128
+ # Step 3: guidance scale sweep
1129
+ print("\n Guidance scale sweep (λ ∈ {0.0, 0.5, 1.0, 1.5, 2.0, 3.0})...")
1130
+ n_sweep = min(80, int(task5_samples), len(src_list))
1131
+ results = sweep_guidance_scales(
1132
+ model, clf, src_list[:n_sweep], ref_list[:n_sweep],
1133
+ tgt_tok, scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
1134
+ n_samples=n_sweep, device=device, output_dir=OUTPUT_DIR)
1135
+
1136
+ # Find optimal scale (quality + anti-collapse diversity)
1137
+ def _score(s):
1138
+ r = results[s]
1139
+ return (r["mean_cer"] - 0.05 * r.get("diversity", 0.0))
1140
+ best_scale = min(results, key=_score)
1141
+ print(f"\n Optimal guidance scale: λ={best_scale:.1f} "
1142
+ f"CER={results[best_scale]['mean_cer']:.4f}")
1143
+
1144
+ report = os.path.join(OUTPUT_DIR, "task5_report.txt")
1145
+ with open(report, "w") as f:
1146
+ f.write("TASK 5 — CLASSIFIER-FREE GUIDANCE\n" + "="*50 + "\n\n")
1147
+ f.write(f"Classifier params: {sum(p.numel() for p in clf.parameters())}\n")
1148
+ f.write(f"Training samples : {len(hidden)}\n\n")
1149
+ f.write("Guidance scale sweep:\n")
1150
+ f.write(f" {'λ':>6} {'CER':>8} {'diversity':>10} {'d2':>6} {'sBLEU':>8}\n")
1151
+ f.write(" " + "-"*52 + "\n")
1152
+ for s in sorted(results.keys()):
1153
+ r = results[s]
1154
+ marker = " ← optimal" if s == best_scale else ""
1155
+ f.write(
1156
+ f" {s:>6.1f} {r['mean_cer']:>8.4f} {r['diversity']:>10.3f} "
1157
+ f"{r.get('distinct2', 0.0):>6.3f} {r.get('self_bleu', 0.0):>8.3f}{marker}\n"
1158
+ )
1159
+ print(f" Report: {report}")
1160
+
1161
+
1162
+ # ── Main ──────────────────────────────────────────────────────────────
1163
+
1164
+ def main():
1165
+ global OUTPUT_DIR
1166
+
1167
+ parser = argparse.ArgumentParser()
1168
+ parser.add_argument("--task",
1169
+ choices=["1","2","3","4","5","all"], default="all")
1170
+ parser.add_argument("--input",
1171
+ default="dharmo rakṣati rakṣitaḥ",
1172
+ help="IAST input text for Task 2")
1173
+ parser.add_argument("--phase",
1174
+ choices=["generate_configs", "analyze"], default="analyze",
1175
+ help="Task 4 phase: generate_configs (before training) or analyze (after)")
1176
+ parser.add_argument("--checkpoint", default=None,
1177
+ help="Optional explicit checkpoint path")
1178
+ parser.add_argument("--output_dir", default="analysis/outputs",
1179
+ help="Output directory for reports/figures")
1180
+ parser.add_argument("--task4_samples", type=int, default=50,
1181
+ help="Samples for Task 4 dry/full evaluation")
1182
+ parser.add_argument("--task3_samples", type=int, default=500,
1183
+ help="Samples for Task 3 hidden-state collection")
1184
+ parser.add_argument("--task5_samples", type=int, default=500,
1185
+ help="Samples for Task 5 classifier data + sweep")
1186
+ args = parser.parse_args()
1187
+
1188
+ OUTPUT_DIR = args.output_dir
1189
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
1190
+
1191
+ cfg = copy.deepcopy(CONFIG)
1192
+ if args.checkpoint:
1193
+ cfg["model_type"] = infer_model_type_from_checkpoint(args.checkpoint)
1194
+ cfg["data"]["include_negative_examples"] = infer_include_negative_from_checkpoint(args.checkpoint)
1195
+ ckpt_name = os.path.basename(os.path.dirname(args.checkpoint))
1196
+ if ckpt_name.startswith("T") and ckpt_name[1:].isdigit():
1197
+ t_val = int(ckpt_name[1:])
1198
+ cfg["model"]["diffusion_steps"] = t_val
1199
+ cfg["inference"]["num_steps"] = t_val
1200
+
1201
+ requested = cfg["training"]["device"]
1202
+ if requested == "mps" and not torch.backends.mps.is_available():
1203
+ requested = "cpu"
1204
+ elif requested == "cuda" and not torch.cuda.is_available():
1205
+ requested = "cpu"
1206
+ cfg["training"]["device"] = requested
1207
+ device = torch.device(requested)
1208
+
1209
+ print("Loading model and tokenizers...")
1210
+ model, src_tok, tgt_tok, cfg = load_everything(cfg, device, ckpt_override=args.checkpoint)
1211
+
1212
+ # Load val data for tasks that need corpus/context (Tasks 2, 3, 4, 5)
1213
+ needs_data = args.task in ("2", "3", "4", "5", "all")
1214
+ if needs_data:
1215
+ print("Loading validation data...")
1216
+ src_list, ref_list, inp_list = load_val_data(cfg, src_tok, tgt_tok, n=500)
1217
+ else:
1218
+ src_list, ref_list, inp_list = [], [], []
1219
+
1220
+ tasks = (["1","2","3","4","5"] if args.task == "all"
1221
+ else [args.task])
1222
+
1223
+ for task in tasks:
1224
+ if task == "1":
1225
+ run_task1(model, src_tok, device)
1226
+ elif task == "2":
1227
+ run_task2(model, src_tok, tgt_tok, device, args.input, cfg, corpus_inputs=inp_list)
1228
+ elif task == "3":
1229
+ run_task3(model, src_tok, tgt_tok, device, src_list, ref_list, n_samples=args.task3_samples)
1230
+ elif task == "4":
1231
+ run_task4(args.phase, model, src_tok, tgt_tok, device, cfg,
1232
+ src_list, ref_list, n_samples=args.task4_samples)
1233
+ elif task == "5":
1234
+ run_task5(
1235
+ model, src_tok, tgt_tok, device, cfg, src_list, ref_list,
1236
+ task5_samples=args.task5_samples
1237
+ )
1238
+
1239
+ print(f"\n{'='*65}")
1240
+ print(f" All outputs saved to: {OUTPUT_DIR}/")
1241
+ print("="*65)
1242
+
1243
+
1244
+ if __name__ == "__main__":
1245
+ main()
analysis/step_ablation.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # """
2
+ # analysis/step_ablation.py
3
+ # ==========================
4
+ # Task 4: Semantic Robustness — Ablation of Diffusion Steps vs Meaning Preservation
5
+ #
6
+ # Two-phase workflow (retraining IS required for different T values):
7
+ #
8
+ # PHASE 1 — Generate configs + train (run once per T value):
9
+ # python analysis/step_ablation.py --phase generate_configs
10
+ # # Creates configs: ablation_configs/T4.py, T8.py, T16.py, T32.py, T64.py
11
+ # # Then train each: MODEL_TYPE=d3pm_cross_attention python train.py (for each config)
12
+ #
13
+ # PHASE 2 — Analyze trained models (no retraining needed):
14
+ # python analysis/step_ablation.py --phase analyze
15
+ # # Loads each trained model, generates 200 paraphrases, computes CER
16
+ # # Produces 3D plot: X=steps, Y=generation_speed, Z=CER
17
+ #
18
+ # Why retraining is needed:
19
+ # A model trained with T=128 learns to denoise from x_t~Uniform[0,128].
20
+ # Running it with T=4 means the model only sees t∈{0,1,2,3} — which it
21
+ # was never trained on at those scales. Outputs are meaningless.
22
+ # You must train a separate model for each T value.
23
+ #
24
+ # Also implements adversarial robustness test (no retraining):
25
+ # Takes your existing T=128 model and tests whether corrupted IAST
26
+ # inputs (typos, character swaps) cause proportional output degradation.
27
+ # """
28
+ #
29
+ # import torch
30
+ # import torch.nn.functional as F
31
+ # import numpy as np
32
+ # import os
33
+ # import sys
34
+ # import time
35
+ # import json
36
+ # import copy
37
+ # from typing import List, Dict, Optional
38
+ #
39
+ # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
40
+ #
41
+ #
42
+ # # ── Phase 1: Config generation ────────────────────────────────────────
43
+ #
44
+ # T_VALUES = [4, 8, 16, 32, 64]
45
+ #
46
+ # def generate_ablation_configs(base_config_path: str = "config.py",
47
+ # output_dir: str = "ablation_configs"):
48
+ # """
49
+ # Generate one config file per T value.
50
+ # Each config is a copy of the base config with diffusion_steps changed.
51
+ #
52
+ # After running this, train each model:
53
+ # for T in 4 8 16 32 64; do
54
+ # cp ablation_configs/config_T${T}.py config.py
55
+ # python train.py
56
+ # mv results7/d3pm_cross_attention_neg_False \
57
+ # ablation_results/T${T}
58
+ # done
59
+ # """
60
+ # os.makedirs(output_dir, exist_ok=True)
61
+ #
62
+ # # Read base config
63
+ # with open(base_config_path, "r") as f:
64
+ # base_src = f.read()
65
+ #
66
+ # for T in T_VALUES:
67
+ # # Replace diffusion_steps and num_steps
68
+ # cfg_src = base_src
69
+ # cfg_src = cfg_src.replace(
70
+ # '"diffusion_steps": 128',
71
+ # f'"diffusion_steps": {T}'
72
+ # )
73
+ # cfg_src = cfg_src.replace(
74
+ # "'diffusion_steps': 128",
75
+ # f"'diffusion_steps': {T}"
76
+ # )
77
+ # cfg_src = cfg_src.replace(
78
+ # '"num_steps": 128',
79
+ # f'"num_steps": {T}'
80
+ # )
81
+ # cfg_src = cfg_src.replace(
82
+ # "'num_steps': 128",
83
+ # f"'num_steps': {T}"
84
+ # )
85
+ # out_path = os.path.join(output_dir, f"config_T{T}.py")
86
+ # with open(out_path, "w") as f:
87
+ # f.write(f"# Ablation config: T={T} diffusion steps\n")
88
+ # f.write(cfg_src)
89
+ # print(f" Wrote: {out_path}")
90
+ #
91
+ # # Write a shell script to train all
92
+ # shell_script = os.path.join(output_dir, "train_all.sh")
93
+ # with open(shell_script, "w") as f:
94
+ # f.write("#!/bin/bash\n")
95
+ # f.write("# Run this script to train all ablation models\n\n")
96
+ # for T in T_VALUES:
97
+ # f.write(f"echo '=== Training T={T} ==='\n")
98
+ # f.write(f"cp {output_dir}/config_T{T}.py config.py\n")
99
+ # f.write(f"python train.py\n")
100
+ # f.write(f"mkdir -p ablation_results/T{T}\n")
101
+ # f.write(f"cp -r results7/d3pm_cross_attention_neg_False/best_model.pt "
102
+ # f"ablation_results/T{T}/best_model.pt\n")
103
+ # f.write(f"cp -r results7/d3pm_cross_attention_neg_False/train.log "
104
+ # f"ablation_results/T{T}/train.log\n\n")
105
+ # os.chmod(shell_script, 0o755)
106
+ # print(f"\nTraining script: {shell_script}")
107
+ # print(f"Run: bash {shell_script}")
108
+ #
109
+ #
110
+ # # ── Phase 2: Analysis (after models are trained) ──────────────────────
111
+ #
112
+ # def compute_cer(pred: str, ref: str) -> float:
113
+ # if not ref:
114
+ # return 1.0
115
+ #
116
+ # def edit_distance(s1, s2):
117
+ # m, n = len(s1), len(s2)
118
+ # dp = list(range(n + 1))
119
+ # for i in range(1, m + 1):
120
+ # prev, dp[0] = dp[0], i
121
+ # for j in range(1, n + 1):
122
+ # temp = dp[j]
123
+ # dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
124
+ # prev = temp
125
+ # return dp[n]
126
+ #
127
+ # return edit_distance(pred, ref) / max(len(ref), 1)
128
+ #
129
+ #
130
+ # def evaluate_model(
131
+ # model,
132
+ # src_list: List[torch.Tensor],
133
+ # ref_list: List[str],
134
+ # tgt_tokenizer,
135
+ # n_samples: int = 200,
136
+ # temperature: float = 0.8,
137
+ # top_k: int = 40,
138
+ # ) -> Dict:
139
+ # """
140
+ # Generate n_samples outputs and compute CER + generation speed.
141
+ #
142
+ # Returns dict with:
143
+ # mean_cer : average CER over samples
144
+ # generation_s : total wall-clock seconds for all generations
145
+ # speed_per_sample: seconds per sample
146
+ # cer_list : per-sample CER values
147
+ # """
148
+ # device = next(model.parameters()).device
149
+ # n = min(n_samples, len(src_list))
150
+ # cer_list = []
151
+ #
152
+ # start = time.perf_counter()
153
+ # for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
154
+ # if src.dim() == 1:
155
+ # src = src.unsqueeze(0)
156
+ #
157
+ # with torch.no_grad():
158
+ # if hasattr(model.model, 'generate_cached'):
159
+ # out = model.model.generate_cached(
160
+ # src.to(device), temperature=temperature, top_k=top_k
161
+ # )
162
+ # else:
163
+ # out = model.generate(
164
+ # src.to(device), temperature=temperature, top_k=top_k
165
+ # )
166
+ #
167
+ # ids = [x for x in out[0].tolist() if x > 4]
168
+ # pred = tgt_tokenizer.decode(ids).strip()
169
+ # cer = compute_cer(pred, ref)
170
+ # cer_list.append(cer)
171
+ #
172
+ # elapsed = time.perf_counter() - start
173
+ #
174
+ # return {
175
+ # "mean_cer": float(np.mean(cer_list)),
176
+ # "std_cer": float(np.std(cer_list)),
177
+ # "generation_s": elapsed,
178
+ # "speed_per_sample": elapsed / max(n, 1),
179
+ # "cer_list": cer_list,
180
+ # "n_samples": n,
181
+ # }
182
+ #
183
+ #
184
+ # def run_ablation_analysis(
185
+ # ablation_dir: str = "ablation_results",
186
+ # base_cfg: dict = None,
187
+ # src_list: List[torch.Tensor] = None,
188
+ # ref_list: List[str] = None,
189
+ # tgt_tokenizer = None,
190
+ # device: torch.device = None,
191
+ # output_dir: str = "analysis/outputs",
192
+ # ) -> Dict:
193
+ # """
194
+ # Load each trained model and evaluate.
195
+ # Produces results dict and 3D plot.
196
+ #
197
+ # Expects ablation_results/T{N}/best_model.pt for each T in T_VALUES.
198
+ # """
199
+ # from inference import load_model
200
+ #
201
+ # results = {}
202
+ # for T in T_VALUES:
203
+ # ckpt = os.path.join(ablation_dir, f"T{T}", "best_model.pt")
204
+ # if not os.path.exists(ckpt):
205
+ # print(f" SKIP T={T}: no checkpoint at {ckpt}")
206
+ # continue
207
+ #
208
+ # print(f"\nEvaluating T={T}...")
209
+ # cfg_T = copy.deepcopy(base_cfg)
210
+ # cfg_T['model']['diffusion_steps'] = T
211
+ # cfg_T['inference']['num_steps'] = T
212
+ #
213
+ # model, cfg_T = load_model(ckpt, cfg_T, device)
214
+ # model.eval()
215
+ #
216
+ # metrics = evaluate_model(
217
+ # model, src_list, ref_list, tgt_tokenizer, n_samples=200
218
+ # )
219
+ # results[T] = metrics
220
+ # print(f" T={T} CER={metrics['mean_cer']:.4f} "
221
+ # f"speed={metrics['speed_per_sample']:.3f}s/sample")
222
+ #
223
+ # del model
224
+ #
225
+ # # Save results
226
+ # os.makedirs(output_dir, exist_ok=True)
227
+ # results_path = os.path.join(output_dir, "ablation_results.json")
228
+ # with open(results_path, "w") as f:
229
+ # json.dump({str(k): {kk: vv for kk, vv in v.items() if kk != 'cer_list'}
230
+ # for k, v in results.items()}, f, indent=2)
231
+ # print(f"\nResults saved: {results_path}")
232
+ #
233
+ # return results
234
+ #
235
+ #
236
+ # def plot_ablation_3d(
237
+ # results: Dict,
238
+ # save_path: Optional[str] = None,
239
+ # ):
240
+ # """
241
+ # 3D plot: X=diffusion_steps, Y=generation_speed(s/sample), Z=CER.
242
+ # Also produces a 2D summary plot.
243
+ # """
244
+ # try:
245
+ # import matplotlib.pyplot as plt
246
+ # from mpl_toolkits.mplot3d import Axes3D
247
+ # except ImportError:
248
+ # print("pip install matplotlib.")
249
+ # return
250
+ #
251
+ # T_list = sorted(results.keys())
252
+ # cers = [results[T]["mean_cer"] for T in T_list]
253
+ # speeds = [results[T]["speed_per_sample"] for T in T_list]
254
+ #
255
+ # # ── 3D plot ───────────────────────────────────────────────────────
256
+ # fig = plt.figure(figsize=(14, 5))
257
+ #
258
+ # ax3d = fig.add_subplot(121, projection='3d')
259
+ # ax3d.scatter(T_list, speeds, cers, c=cers, cmap='RdYlGn_r', s=80)
260
+ # for T, s, c in zip(T_list, speeds, cers):
261
+ # ax3d.text(T, s, c, f"T={T}", fontsize=8)
262
+ # ax3d.set_xlabel("Diffusion steps T", fontsize=9)
263
+ # ax3d.set_ylabel("Speed (s/sample)", fontsize=9)
264
+ # ax3d.set_zlabel("CER (↓ better)", fontsize=9)
265
+ # ax3d.set_title("T vs speed vs CER", fontsize=10)
266
+ #
267
+ # # ── 2D CER vs T (find the knee) ──────────────────────────────────
268
+ # ax2d = fig.add_subplot(122)
269
+ # ax2d.plot(T_list, cers, 'o-', linewidth=1.8, color='coral', markersize=7)
270
+ # for T, c in zip(T_list, cers):
271
+ # ax2d.annotate(f"{c:.3f}", (T, c), textcoords="offset points",
272
+ # xytext=(0, 8), fontsize=8, ha='center')
273
+ #
274
+ # # Find knee: largest CER drop per unit T (elbow method)
275
+ # if len(T_list) >= 3:
276
+ # drops = [cers[i] - cers[i+1] for i in range(len(cers)-1)]
277
+ # knee_i = int(np.argmax(drops))
278
+ # knee_T = T_list[knee_i + 1]
279
+ # ax2d.axvline(knee_T, color='steelblue', linestyle='--', linewidth=1.2,
280
+ # label=f"Knee at T={knee_T}")
281
+ # ax2d.legend(fontsize=9)
282
+ #
283
+ # ax2d.set_xlabel("Diffusion steps T", fontsize=10)
284
+ # ax2d.set_ylabel("CER (lower = better)", fontsize=10)
285
+ # ax2d.set_title("CER vs diffusion steps", fontsize=10)
286
+ # ax2d.set_ylim(0, max(cers) * 1.1)
287
+ #
288
+ # plt.tight_layout()
289
+ # if save_path:
290
+ # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
291
+ # plt.savefig(save_path, dpi=150, bbox_inches='tight')
292
+ # print(f"Saved: {save_path}")
293
+ # else:
294
+ # plt.show()
295
+ # plt.close()
296
+ #
297
+ #
298
+ # # ── Adversarial robustness test (no retraining needed) ───────────────
299
+ #
300
+ # def corrupt_iast(text: str, corruption_rate: float = 0.05) -> str:
301
+ # """
302
+ # Introduce random corruption into IAST text:
303
+ # - Character swap (adjacent chars swapped)
304
+ # - Character deletion
305
+ # - Random character insertion
306
+ #
307
+ # Models rate as 5% to 20% corruption to test robustness.
308
+ # """
309
+ # import random
310
+ # chars = list(text)
311
+ # n_corrupt = max(1, int(len(chars) * corruption_rate))
312
+ #
313
+ # for _ in range(n_corrupt):
314
+ # op = random.choice(['swap', 'delete', 'insert'])
315
+ # pos = random.randint(0, len(chars) - 1)
316
+ #
317
+ # if op == 'swap' and pos < len(chars) - 1:
318
+ # chars[pos], chars[pos+1] = chars[pos+1], chars[pos]
319
+ # elif op == 'delete' and len(chars) > 1:
320
+ # chars.pop(pos)
321
+ # elif op == 'insert':
322
+ # chars.insert(pos, random.choice('abcdeimnostu'))
323
+ #
324
+ # return "".join(chars)
325
+ #
326
+ #
327
+ # @torch.no_grad()
328
+ # def run_adversarial_test(
329
+ # model,
330
+ # src_tokenizer,
331
+ # tgt_tokenizer,
332
+ # test_inputs: List[str],
333
+ # test_refs: List[str],
334
+ # corruption_rates: List[float] = [0.0, 0.05, 0.10, 0.15, 0.20],
335
+ # device: torch.device = None,
336
+ # output_dir: str = "analysis/outputs",
337
+ # ) -> Dict:
338
+ # """
339
+ # Test if CER degrades proportionally with IAST corruption.
340
+ # Uses existing trained model — no retraining.
341
+ # """
342
+ # device = device or next(model.parameters()).device
343
+ # results = {}
344
+ #
345
+ # print("\nAdversarial robustness test...")
346
+ # for rate in corruption_rates:
347
+ # cer_list = []
348
+ # for text, ref in zip(test_inputs, test_refs):
349
+ # corrupted = corrupt_iast(text, rate)
350
+ # ids = src_tokenizer.encode(corrupted)
351
+ # src = torch.tensor([ids], dtype=torch.long, device=device)
352
+ #
353
+ # if hasattr(model.model, 'generate_cached'):
354
+ # out = model.model.generate_cached(src)
355
+ # else:
356
+ # out = model.generate(src)
357
+ #
358
+ # pred_ids = [x for x in out[0].tolist() if x > 4]
359
+ # pred = tgt_tokenizer.decode(pred_ids).strip()
360
+ # cer_list.append(compute_cer(pred, ref))
361
+ #
362
+ # mean_cer = float(np.mean(cer_list))
363
+ # results[rate] = mean_cer
364
+ # print(f" corruption={rate*100:.0f}% → CER={mean_cer:.4f}")
365
+ #
366
+ # # Save + plot
367
+ # os.makedirs(output_dir, exist_ok=True)
368
+ # try:
369
+ # import matplotlib.pyplot as plt
370
+ # fig, ax = plt.subplots(figsize=(8, 4))
371
+ # rates = [r * 100 for r in corruption_rates]
372
+ # cers = [results[r] for r in corruption_rates]
373
+ # ax.plot(rates, cers, 'o-', linewidth=1.8, color='steelblue', markersize=7)
374
+ # ax.set_xlabel("IAST corruption rate (%)", fontsize=11)
375
+ # ax.set_ylabel("CER", fontsize=11)
376
+ # ax.set_title("Model robustness to IAST input corruption", fontsize=11)
377
+ # ax.set_ylim(0, max(cers) * 1.2)
378
+ # plt.tight_layout()
379
+ # plt.savefig(os.path.join(output_dir, "adversarial_robustness.png"),
380
+ # dpi=150, bbox_inches='tight')
381
+ # plt.close()
382
+ # print(f" Saved: {output_dir}/adversarial_robustness.png")
383
+ # except ImportError:
384
+ # pass
385
+ #
386
+ # with open(os.path.join(output_dir, "adversarial_results.json"), "w") as f:
387
+ # json.dump({str(k): v for k, v in results.items()}, f, indent=2)
388
+ #
389
+ # return results
390
+ """
391
+ analysis/task4_pipeline.py
392
+ ================================
393
+ Correct Task 4 Pipeline:
394
+
395
+ PHASE 1 → Evaluate all models
396
+ PHASE 2 → Analyze + detect optimal T
397
+
398
+ NO early decision making.
399
+ """
400
+
401
+ import torch
402
+ import numpy as np
403
+ import time
404
+ import os
405
+ import json
406
+ from typing import Dict, List
407
+ from difflib import SequenceMatcher
408
+ from collections import Counter
409
+
410
+
411
+ # ���────────────────────────────────────────────
412
+ # Load Metrics
413
+ # ─────────────────────────────────────────────
414
+
415
+ def load_metrics():
416
+ try:
417
+ from bert_score import score as bert_score
418
+ except Exception:
419
+ bert_score = None
420
+ from nltk.translate.bleu_score import sentence_bleu
421
+ try:
422
+ from sentence_transformers import SentenceTransformer, util
423
+ st_model = SentenceTransformer('all-MiniLM-L6-v2')
424
+ return bert_score, st_model, util, sentence_bleu
425
+ except Exception:
426
+ # Offline-safe fallback: skip sentence-transformer similarity.
427
+ return bert_score, None, None, sentence_bleu
428
+
429
+
430
+ # ─────────────────────────────────────────────
431
+ # PHASE 1 — Evaluate ALL models
432
+ # ─────────────────────────────────────────────
433
+
434
+ def evaluate_all_models(models: Dict[int, object],
435
+ src_list,
436
+ ref_list,
437
+ tgt_tokenizer,
438
+ n_samples=200,
439
+ output_dir: str = "analysis/outputs"):
440
+
441
+ bert_score_fn, st_model, util, bleu_fn = load_metrics()
442
+
443
+ results = {}
444
+
445
+ print("\n=== PHASE 1: Evaluating ALL models ===")
446
+
447
+ for T, model in sorted(models.items()):
448
+ print(f"\nEvaluating T={T}...")
449
+
450
+ device = next(model.parameters()).device
451
+ preds, refs = [], []
452
+
453
+ start = time.perf_counter()
454
+
455
+ for src, ref in zip(src_list[:n_samples], ref_list[:n_samples]):
456
+ if src.dim() == 1:
457
+ src = src.unsqueeze(0)
458
+
459
+ with torch.no_grad():
460
+ if hasattr(model, "model") and hasattr(model.model, "generate_cached"):
461
+ out = model.model.generate_cached(src.to(device))
462
+ else:
463
+ # Fallback for wrappers that only expose top-level generate.
464
+ out = model.generate(src.to(device))
465
+
466
+ ids = [x for x in out[0].tolist() if x > 4]
467
+ pred = tgt_tokenizer.decode(ids).strip()
468
+
469
+ preds.append(pred)
470
+ refs.append(ref)
471
+
472
+ elapsed = time.perf_counter() - start
473
+
474
+ # BERTScore (fallback to lexical similarity if unavailable/offline)
475
+ try:
476
+ if bert_score_fn is not None:
477
+ _, _, F1 = bert_score_fn(preds, refs, lang="hi", verbose=False)
478
+ bert_f1 = float(F1.mean())
479
+ else:
480
+ raise RuntimeError("bertscore unavailable")
481
+ except Exception:
482
+ bert_f1 = float(np.mean([SequenceMatcher(None, p, r).ratio() for p, r in zip(preds, refs)]))
483
+
484
+ # Sentence similarity (distinct from BERT fallback)
485
+ if st_model is not None:
486
+ emb_p = st_model.encode(preds, convert_to_tensor=True)
487
+ emb_r = st_model.encode(refs, convert_to_tensor=True)
488
+ sim = util.cos_sim(emb_p, emb_r).diagonal().mean().item()
489
+ else:
490
+ # token-overlap F1 proxy (different behavior from char-level similarity)
491
+ f1s = []
492
+ for p, r in zip(preds, refs):
493
+ pt = [t for t in p.split() if t]
494
+ rt = [t for t in r.split() if t]
495
+ if not pt or not rt:
496
+ f1s.append(0.0)
497
+ continue
498
+ cp, cr = Counter(pt), Counter(rt)
499
+ inter = sum((cp & cr).values())
500
+ prec = inter / max(1, len(pt))
501
+ rec = inter / max(1, len(rt))
502
+ f1s.append((2 * prec * rec / max(1e-9, prec + rec)))
503
+ sim = float(np.mean(f1s)) if f1s else 0.0
504
+ if not np.isfinite(sim):
505
+ sim = float(np.mean([SequenceMatcher(None, p, r).ratio() for p, r in zip(preds, refs)]))
506
+
507
+ # BLEU
508
+ bleu_scores = [
509
+ bleu_fn([r.split()], p.split())
510
+ for p, r in zip(preds, refs)
511
+ ]
512
+
513
+ results[T] = {
514
+ "bertscore_f1": bert_f1,
515
+ "semantic_sim": sim,
516
+ "bleu": float(np.mean(bleu_scores)),
517
+ "speed_per_sample": elapsed / max(1, len(preds))
518
+ }
519
+
520
+ print(f" BERTScore: {bert_f1:.4f}")
521
+ print(f" Sim: {sim:.4f}")
522
+ print(f" BLEU: {results[T]['bleu']:.4f}")
523
+ print(f" Speed: {results[T]['speed_per_sample']:.4f}s")
524
+
525
+ # Save raw results
526
+ os.makedirs(output_dir, exist_ok=True)
527
+ with open(os.path.join(output_dir, "task4_raw_results.json"), "w") as f:
528
+ json.dump(results, f, indent=2)
529
+
530
+ return results
531
+
532
+
533
+ # ─────────────────────────────────────────────
534
+ # PHASE 2 — Analyze results (Knee Detection)
535
+ # ─────────────────────────────────────────────
536
+
537
+ def analyze_results(results: Dict):
538
+ print("\n=== PHASE 2: Analysis ===")
539
+
540
+ T_list = sorted(results.keys())
541
+ scores = [results[T]["bertscore_f1"] for T in T_list]
542
+
543
+ gains = [scores[i+1] - scores[i] for i in range(len(scores)-1)]
544
+
545
+ print("\nMarginal Gains:")
546
+ for i, g in enumerate(gains):
547
+ print(f" T{T_list[i]} → T{T_list[i+1]}: +{g:.4f}")
548
+
549
+ # Robust utility selection (quality + semantics + speed regularizer)
550
+ bvals = np.array([results[T]["bertscore_f1"] for T in T_list], dtype=np.float32)
551
+ svals = np.array([results[T]["semantic_sim"] for T in T_list], dtype=np.float32)
552
+ tvals = np.array([results[T]["speed_per_sample"] for T in T_list], dtype=np.float32)
553
+ b_norm = (bvals - bvals.min()) / max(1e-9, (bvals.max() - bvals.min()))
554
+ s_norm = (svals - svals.min()) / max(1e-9, (svals.max() - svals.min()))
555
+ t_norm = (tvals - tvals.min()) / max(1e-9, (tvals.max() - tvals.min()))
556
+ utility = 0.50 * b_norm + 0.30 * s_norm - 0.20 * t_norm
557
+ knee_T = T_list[int(np.argmax(utility))]
558
+
559
+ print(f"\n✅ Optimal T (semantic-speed tradeoff): {knee_T}")
560
+
561
+ return knee_T, gains
562
+
563
+
564
+ # ─────────────────────────────────────────────
565
+ # 3D Plot (BERTScore)
566
+ # ─────────────────────────────────────────────
567
+
568
+ def plot_3d(results, output_dir: str = "analysis/outputs"):
569
+ import matplotlib.pyplot as plt
570
+ from mpl_toolkits.mplot3d import Axes3D
571
+
572
+ T_list = sorted(results.keys())
573
+
574
+ X = T_list
575
+ Y = [results[T]["speed_per_sample"] for T in T_list]
576
+ Z = [results[T]["bertscore_f1"] for T in T_list]
577
+
578
+ fig = plt.figure(figsize=(10, 6))
579
+ ax = fig.add_subplot(111, projection='3d')
580
+
581
+ ax.scatter(X, Y, Z)
582
+
583
+ for x, y, z in zip(X, Y, Z):
584
+ ax.text(x, y, z, f"T={x}", fontsize=8)
585
+
586
+ ax.set_xlabel("Diffusion Steps")
587
+ ax.set_ylabel("Speed")
588
+ ax.set_zlabel("BERTScore")
589
+
590
+ plt.title("3D Tradeoff: Steps vs Speed vs Quality")
591
+
592
+ os.makedirs(output_dir, exist_ok=True)
593
+ plt.savefig(os.path.join(output_dir, "task4_3d.png"))
594
+ plt.close()
595
+
596
+ print("Saved 3D plot")
597
+
598
+
599
+ # ─────────────────────────────────────────────
600
+ # FINAL RUNNER
601
+ # ─────────────────────────────────────────────
602
+
603
+ def run_task4(models, src_list, ref_list, tgt_tokenizer,
604
+ output_dir: str = "analysis/outputs", n_samples: int = 200):
605
+
606
+ # Phase 1: Evaluate all
607
+ results = evaluate_all_models(
608
+ models, src_list, ref_list, tgt_tokenizer, n_samples=n_samples, output_dir=output_dir
609
+ )
610
+
611
+ # Phase 2: Analyze
612
+ knee_T, gains = analyze_results(results)
613
+
614
+ # Plot
615
+ plot_3d(results, output_dir=output_dir)
616
+
617
+ # Save detailed report
618
+ report_path = os.path.join(output_dir, "task4_report.txt")
619
+ with open(report_path, "w") as f:
620
+ f.write("TASK 4 — SEMANTIC ROBUSTNESS ABLATION\n")
621
+ f.write("=" * 50 + "\n\n")
622
+ f.write(f"Optimal diffusion steps = {knee_T}\n\n")
623
+ f.write(f"{'T':>6} {'BERT-F1':>10} {'SEM_SIM':>10} {'BLEU':>8} {'sec/sample':>12}\n")
624
+ f.write(" " + "-" * 56 + "\n")
625
+ for T in sorted(results.keys()):
626
+ r = results[T]
627
+ f.write(
628
+ f"{T:>6} {r['bertscore_f1']:>10.4f} {r['semantic_sim']:>10.4f} "
629
+ f"{r['bleu']:>8.4f} {r['speed_per_sample']:>12.4f}\n"
630
+ )
631
+ f.write("\nMarginal gains (BERT-F1):\n")
632
+ for i, g in enumerate(gains):
633
+ t0 = sorted(results.keys())[i]
634
+ t1 = sorted(results.keys())[i + 1]
635
+ f.write(f" T{t0} -> T{t1}: {g:+.4f}\n")
636
+ f.write("\nSaved plots/files:\n")
637
+ f.write(" - task4_3d.png\n")
638
+ f.write(" - task4_raw_results.json\n")
639
+
640
+ return knee_T
analysis_outputs/outputs_all_models_20260325/T16/task1_encoder_cost.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task1_kv_cache.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 1 — KV CACHE BENCHMARK
2
+ ========================================
3
+
4
+ has_generate_cached=True
5
+ memory_profile=Torch CPU mem-event reduction: 30.4% @ src_len=64 (std=2143.0MB, cache=1492.1MB)
6
+
7
+ src_len standard(s) cached(s) speedup encoder%
8
+ 16 0.893 0.571 1.56x 40.0%
9
+ 32 0.751 0.509 1.48x 42.3%
10
+ 64 1.141 0.822 1.39x 40.7%
11
+
12
+ Saved graphs:
13
+ - task1_time_comparison.png
14
+ - task1_speedup.png
15
+ - task1_encoder_cost.png
analysis_outputs/outputs_all_models_20260325/T16/task1_speedup.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task1_time_comparison.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task2_all_layers_t0.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task2_attn_evolution.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task2_attn_t0.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task2_attn_t15.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task2_report.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 2 — ATTENTION + DRIFT REPORT
2
+ ==================================================
3
+
4
+ Input : dharmo rakṣati rakṣitaḥ
5
+ Output: धर्मो रक्षति रक्षितः
6
+
7
+ Captured steps: 16
8
+ Analysis quality: WEAK
9
+ Final output uniq-ratio: 1.000
10
+ Degenerate output: NO
11
+ Multi-sample semantic score (n<=8): 0.1471
12
+ Lock-in step (CER<=0.05): t=0
13
+ Locked tokens: 38 Flexible tokens: 42
14
+ TF-IDF vs attention stability corr: 0.9294
15
+ TF-IDF status: OK
16
+
17
+ Saved graphs:
18
+ - task2_attn_t*.png / task2_all_layers_t0.png
19
+ - task2_attn_evolution.png
20
+ - task2_semantic_drift.png
21
+ - task2_source_alignment.png
22
+ - task2_tfidf_vs_attention.png
23
+
24
+ Step trajectory (first 10 rows)
25
+ ------------------------------------------------------------
26
+ t= 15 bert=0.0475 drift=0.9525 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
27
+ t= 14 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
28
+ t= 13 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
29
+ t= 12 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
30
+ t= 11 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
31
+ t= 10 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
32
+ t= 9 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
33
+ t= 8 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
34
+ t= 7 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
35
+ t= 6 bert=0.0478 drift=0.9522 text=धर्मो ति रक्ष रक्षि तः तः तः तः ितः तः धर्मो धर्मो धर्मो धर्
analysis_outputs/outputs_all_models_20260325/T16/task2_semantic_drift.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task2_source_alignment.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task2_tfidf_vs_attention.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task3_concept_space.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task3_diversity_curve.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task3_diversity_direction.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:250b81c1f8cc9537240873d00539df1e8a30e6c07b260d4c05df23fb32c704d6
3
+ size 4224
analysis_outputs/outputs_all_models_20260325/T16/task3_pca_explained_variance.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task3_report.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 3 — CONCEPT VECTORS + PCA STEERING
2
+ ==================================================
3
+
4
+ PCA: 50 components, 74.8% variance
5
+ Diversity PC: 0 (|r|=0.325 with diversity proxy)
6
+
7
+ Direction validity: WEAK
8
+ Spectrum unique ratio (mean over 5 seeds): 1.000
9
+ Spectrum semantic stability (mean over 5 seeds): 0.312
10
+
11
+ Saved graphs:
12
+ - task3_concept_space.png
13
+ - task3_pca_explained_variance.png
14
+ - task3_diversity_curve.png
15
+
16
+ Diversity spectrum:
17
+ alpha=-2.0 → बले वेध विवर् धान वीर्य वीर्य धिं सिंहा भि̱ सन वस्तु वेध वै वेध वस्तु सन सन सिंहा सिंहा वीर्य वीर्य वस्तु सन रुते प्रभवति मन वेध बले बले र्वृ प्रपूजयेत् युगा मलि धान तुल वीर्य वीर्य वीर्य वीर्य वीर्य वीर्य धान तुल कालेन युगा वेध बले वेध वेध च्छे ष्मस् यस्या काष्ठा ज्ञप्त अर्णव धिं धिं वस्तु धिं सन तया सन सन देवाः देवाः स्वातन्त्र अर्णव मह वस्तु मुष् सन धिं धिं धिं विक्र त्र मह हस्ते च्छे मह
18
+ alpha=-1.0 → बले र् अ तुल वीर्य वीर्य गुरु सिंहा सन सन विलेप वै वै वै गतस्य वेध सन सिंहा सिंहा स्य स्य । सन वै वै वै बले बले बले बले र् अ अ तुल तुल वीर्य वीर्य वीर्य वीर्य वीर्य वीर्य तुल तुल तुल ् बले वेध दिव्यां मान वै अप्सु सन ॥ ॥ वस्तु सिंहा सन सन विक्र सन स काष्ठा सन सन सन कार सन सन सन सन भ बल ु सिंहा सन सिंहा सन म् म् सन
19
+ alpha=+0.0 → बले र् अ तुल वीर्य वीर्य स्य सिंहा सन सन पितो वै वै वै दक्षिणां सन सन सिंहा सिंहा स्य स्य स्य सन गतस्य वै वै ॥ बले बले र् र् अ अ । तुल वीर्य वीर्य वीर्य वीर्य वीर्य तुल तुल तुल तुल अ स बले बले वै वै ॥ ॥ ॥ सन सन सिंहा स सन सन सन सन सन सन सन सन सन सन ॥ ॥ सन सन शतैः ॥ सिंहा सिंहा द सिंहा सन त् सन
20
+ alpha=+1.0 → बले र् अ अ विशुद्धं स्य स्य सिंहा सिंहा सन गतस्य वै वै वै वेत्ति सन सन सिंहा स्य स्य स्य स्य सन वै वै स मल बले बले र् र् व अ अ तुल वीर्य वीर्य वीर्य स्य वीर्य स्य तुल ानु अ अ । र् व ॥ वै वै सन द ॥ ॥ सिंहा सिंहा ॥ सं सन ॥ ॥ व ॥ ॥ हेम सन सन व ॥ ै ॥ वै भ न न ॥ मित्रो सिंहा सन
21
+ alpha=+2.0 → आविश र् अ किंचिद् वर स्य स्य सिंहा सं निमे ञ् सं वै वै ञ् सन कृपा सिंहा स्य स्य स्य स्य फणा ञ् वै ौ जिह्व बले मानाः र् र् वराय अ माने वर विशुद्धं स्य स्य स्य – वर विशुद्धं व वर अ कृपा ॥ परम् ॥ कश्चि वै ॥ ञ् ञ् सं स्य स्य तम् व प्रवर्तन्ते कर्मसु परम् वर ते ॥ व ञ् ॥ ॥ सं द ॥ ॥ वर न्द ̱व ॥ व व ै
analysis_outputs/outputs_all_models_20260325/T16/task4_3d.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task4_raw_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16": {
3
+ "bertscore_f1": 0.25743605086845023,
4
+ "semantic_sim": 0.05798209163692987,
5
+ "bleu": 0.0007454091523007641,
6
+ "speed_per_sample": 0.9068318999983603
7
+ }
8
+ }
analysis_outputs/outputs_all_models_20260325/T16/task4_report.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 4 — SEMANTIC ROBUSTNESS ABLATION
2
+ ==================================================
3
+
4
+ Optimal diffusion steps = 16
5
+
6
+ T BERT-F1 SEM_SIM BLEU sec/sample
7
+ --------------------------------------------------------
8
+ 16 0.2574 0.0580 0.0007 0.9068
9
+
10
+ Marginal gains (BERT-F1):
11
+
12
+ Saved plots/files:
13
+ - task4_3d.png
14
+ - task4_raw_results.json
analysis_outputs/outputs_all_models_20260325/T16/task5_guidance_results.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0.0": {
3
+ "mean_cer": 0.8335914296177765,
4
+ "diversity": 0.8084225118773136,
5
+ "sent_unique": 1.0,
6
+ "distinct2": 0.6240506329113924,
7
+ "self_bleu": 0.00720560915676511
8
+ },
9
+ "0.5": {
10
+ "mean_cer": 0.8361858372849987,
11
+ "diversity": 0.7997218378718688,
12
+ "sent_unique": 1.0,
13
+ "distinct2": 0.6060126582278481,
14
+ "self_bleu": 0.0065689824841105166
15
+ },
16
+ "1.0": {
17
+ "mean_cer": 0.8390361847911715,
18
+ "diversity": 0.7978319711295725,
19
+ "sent_unique": 1.0,
20
+ "distinct2": 0.6009493670886076,
21
+ "self_bleu": 0.005285424829462745
22
+ },
23
+ "1.5": {
24
+ "mean_cer": 0.8457771777829102,
25
+ "diversity": 0.8134699633307632,
26
+ "sent_unique": 1.0,
27
+ "distinct2": 0.6306962025316456,
28
+ "self_bleu": 0.0037562758701191663
29
+ },
30
+ "2.0": {
31
+ "mean_cer": 0.8530737908495466,
32
+ "diversity": 0.828318481566094,
33
+ "sent_unique": 1.0,
34
+ "distinct2": 0.6604430379746835,
35
+ "self_bleu": 0.003806074842495409
36
+ },
37
+ "3.0": {
38
+ "mean_cer": 0.8772574230238586,
39
+ "diversity": 0.829961794478179,
40
+ "sent_unique": 1.0,
41
+ "distinct2": 0.6686708860759494,
42
+ "self_bleu": 0.008747297119591432
43
+ }
44
+ }
analysis_outputs/outputs_all_models_20260325/T16/task5_quality_classifier.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4053d24514f08c475662f69a5b01d1577cd1f79837df69ac2175705310e9a23
3
+ size 561505
analysis_outputs/outputs_all_models_20260325/T16/task5_quality_data.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:840494704113872c8e53e3627e5666c8af940ed683515ec37894dd3091a14684
3
+ size 164512
analysis_outputs/outputs_all_models_20260325/T16/task5_quality_diversity_tradeoff.png ADDED
analysis_outputs/outputs_all_models_20260325/T16/task5_report.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 5 — CLASSIFIER-FREE GUIDANCE
2
+ ==================================================
3
+
4
+ Classifier params: 139521
5
+ Training samples : 40
6
+
7
+ Guidance scale sweep:
8
+ λ CER diversity d2 sBLEU
9
+ ----------------------------------------------------
10
+ 0.0 0.8336 0.808 0.624 0.007 ← optimal
11
+ 0.5 0.8362 0.800 0.606 0.007
12
+ 1.0 0.8390 0.798 0.601 0.005
13
+ 1.5 0.8458 0.813 0.631 0.004
14
+ 2.0 0.8531 0.828 0.660 0.004
15
+ 3.0 0.8773 0.830 0.669 0.009
analysis_outputs/outputs_all_models_20260325/T32/task1_encoder_cost.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task1_kv_cache.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 1 — KV CACHE BENCHMARK
2
+ ========================================
3
+
4
+ has_generate_cached=True
5
+ memory_profile=Torch CPU mem-event reduction: 31.1% @ src_len=64 (std=4287.2MB, cache=2953.9MB)
6
+
7
+ src_len standard(s) cached(s) speedup encoder%
8
+ 16 1.914 1.165 1.64x 39.6%
9
+ 32 1.542 0.891 1.73x 42.1%
10
+ 64 2.096 1.475 1.42x 42.7%
11
+
12
+ Saved graphs:
13
+ - task1_time_comparison.png
14
+ - task1_speedup.png
15
+ - task1_encoder_cost.png
analysis_outputs/outputs_all_models_20260325/T32/task1_speedup.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task1_time_comparison.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task2_all_layers_t0.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task2_attn_evolution.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task2_attn_t0.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task2_attn_t31.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task2_report.txt ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 2 — ATTENTION + DRIFT REPORT
2
+ ==================================================
3
+
4
+ Input : dharmo rakṣati rakṣitaḥ
5
+ Output: धर्मो रक्षति रक्षितः
6
+
7
+ Captured steps: 32
8
+ Analysis quality: WEAK
9
+ Final output uniq-ratio: 1.000
10
+ Degenerate output: NO
11
+ Multi-sample semantic score (n<=8): 0.0627
12
+ Lock-in step (CER<=0.05): t=0
13
+ Locked tokens: 75 Flexible tokens: 5
14
+ TF-IDF vs attention stability corr: -0.0869
15
+ TF-IDF status: WEAK
16
+
17
+ Saved graphs:
18
+ - task2_attn_t*.png / task2_all_layers_t0.png
19
+ - task2_attn_evolution.png
20
+ - task2_semantic_drift.png
21
+ - task2_source_alignment.png
22
+ - task2_tfidf_vs_attention.png
23
+
24
+ Step trajectory (first 10 rows)
25
+ ------------------------------------------------------------
26
+ t= 31 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
27
+ t= 30 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
28
+ t= 29 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
29
+ t= 28 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
30
+ t= 27 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
31
+ t= 26 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
32
+ t= 25 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
33
+ t= 24 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
34
+ t= 23 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
35
+ t= 22 bert=0.0167 drift=0.9833 text=तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ तृ
analysis_outputs/outputs_all_models_20260325/T32/task2_semantic_drift.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task2_source_alignment.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task2_tfidf_vs_attention.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task3_concept_space.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task3_diversity_curve.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task3_diversity_direction.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e547306c9469858deaa9985c30d31639c8a9f8104e8addd83afa88fa0264831
3
+ size 4224
analysis_outputs/outputs_all_models_20260325/T32/task3_pca_explained_variance.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task3_report.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 3 — CONCEPT VECTORS + PCA STEERING
2
+ ==================================================
3
+
4
+ PCA: 50 components, 94.6% variance
5
+ Diversity PC: 0 (|r|=-0.530 with diversity proxy)
6
+
7
+ Direction validity: WEAK
8
+ Spectrum unique ratio (mean over 5 seeds): 0.840
9
+ Spectrum semantic stability (mean over 5 seeds): 0.234
10
+
11
+ Saved graphs:
12
+ - task3_concept_space.png
13
+ - task3_pca_explained_variance.png
14
+ - task3_diversity_curve.png
15
+
16
+ Diversity spectrum:
17
+ alpha=-2.0 → ेन श्रे श्रे ेन श्रे अण्ड व्याः श्रे तन्त्रा ॥ ॥ ॥ व्याः व्याः व्याः तद्वद् तद्वद् तद्वद् ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ तद्वद् ॥ ॥ ॥ ॥ ॥ व्याः व्याः व्याः ॥ ॥ राजन्य व्याः व्याः व्याः ॥ व्याः व्याः ॥ ॥ काम्य ॥ ॥ ॥ व्याः ॥ तद्वद् ॥ ॥ ॥ ॥ ॥ तन्त्रा तन्त्रा ॥ ॥ ॥ ॥ व्याः ॥ ॥ ॥ ॥ ॥ युधम् तद्वद् युधम् ॥
18
+ alpha=-1.0 → श्रे श्रे श्रे ेन श्रे श्रे श्रे श्रे अण्ड तन्त्रा व्याः ॥ अण्ड अण्ड तन्त्रा व्याः तद्वद् ॥ व्याः ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ व्याः ॥ व्याः नो̍ ॥ ॥ ॥ ॥ ॥ व्याः व्याः अण्ड ॥ ॥ तन्त्रा ॥ ॥ तद्वद् युधम् रोमा शम्भु ॥ धूमं तन्त्रा ॥ तन्त्रा ॥ व्याः ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥ ॥
19
+ alpha=+0.0 → अण्ड श्रे करः श्रे तन्त्रा करः करः तन्त्रा श्रे अण्ड अण्ड अण्ड ॥ श्रे तद्वद् अण्ड ॥ ॥ अण्ड ॥ ॥ ॥ ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ ॥ अण्ड ॥ ॥ ॥ ॥ ॥ ॥ राजन्य तन्त्रा नो̍ ॥ ॥ ॥ ॥ ॥ व्याः ॥ अण्ड ॥ काम्य ॥ ॥ ॥ ॥ ॥ शम्भु धूमं तन्त्रा तन्त्रा ेन ॥ काम्य ॥ ॥ करः तन्त्रा ॥ अण्ड ॥ अण्ड ॥ विनिर्जित्य ॥ ॥ ॥ तन्त्रा अण्ड तद्वद् करः
20
+ alpha=+1.0 → माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण
21
+ alpha=+2.0 → माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण माण
analysis_outputs/outputs_all_models_20260325/T32/task4_3d.png ADDED
analysis_outputs/outputs_all_models_20260325/T32/task4_raw_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "32": {
3
+ "bertscore_f1": 0.04221478336089375,
4
+ "semantic_sim": 0.0011696306429548563,
5
+ "bleu": 3.0458312005937454e-233,
6
+ "speed_per_sample": 1.8451481468771818
7
+ }
8
+ }
analysis_outputs/outputs_all_models_20260325/T32/task4_report.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ TASK 4 — SEMANTIC ROBUSTNESS ABLATION
2
+ ==================================================
3
+
4
+ Optimal diffusion steps = 32
5
+
6
+ T BERT-F1 SEM_SIM BLEU sec/sample
7
+ --------------------------------------------------------
8
+ 32 0.0422 0.0012 0.0000 1.8451
9
+
10
+ Marginal gains (BERT-F1):
11
+
12
+ Saved plots/files:
13
+ - task4_3d.png
14
+ - task4_raw_results.json