| """ |
| analysis/quality_classifier.py |
| ================================ |
| Task 5: Classifier-Free Guidance for Paraphrase Quality Control |
| |
| Two steps β only Step 2 requires training a SMALL model (not the main D3PM): |
| |
| STEP 1 β Collect training data (no training): |
| Run existing model on val set, record (hidden_state, CER) pairs. |
| Hidden states come from model.model._last_hidden after forward_cached(). |
| CER score = quality label (lower CER = higher quality). |
| |
| STEP 2 β Train quality classifier: |
| Small 2-layer MLP: d_model β 64 β 1 |
| Input: pooled decoder hidden state [B, d_model] |
| Output: predicted quality score in [0, 1] (1 = high quality) |
| Loss: MSE against normalized CER labels |
| Training time: ~5-10 minutes on CPU for 10k examples |
| |
| STEP 3 β Guided inference (no retraining): |
| At each diffusion step, use classifier gradient to shift logits: |
| guided_logits = logits + Ξ» * β(quality_score)/β(logits) |
| Higher Ξ» β model biased toward high-quality outputs |
| Ξ»=0 β standard generation (no guidance) |
| |
| Key: main D3PM model is FROZEN throughout. Only the 10k-param classifier trains. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import os |
| import json |
| from typing import List, Dict, Optional, Tuple |
|
|
|
|
| |
|
|
| class QualityClassifier(nn.Module): |
| """ |
| Lightweight MLP that predicts transliteration quality from decoder |
| hidden states. |
| |
| Architecture: |
| d_model β 128 β 64 β 1 β Sigmoid |
| |
| Input: mean-pooled decoder hidden state [B, d_model] |
| Output: quality score [B, 1] β [0, 1] (1 = high quality) |
| |
| ~10k parameters. Trains in minutes on CPU. |
| """ |
| def __init__(self, d_model: int): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(d_model, 128), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(128, 64), |
| nn.ReLU(), |
| nn.Linear(64, 1), |
| nn.Sigmoid(), |
| ) |
| self.d_model = d_model |
|
|
| def forward(self, hidden: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| hidden : [B, tgt_len, d_model] OR [B, d_model] (already pooled) |
| |
| Returns: |
| score : [B, 1] quality score in [0, 1] |
| """ |
| if hidden.dim() == 3: |
| |
| hidden = hidden.mean(dim=1) |
| return self.net(hidden) |
|
|
|
|
| |
|
|
| @torch.no_grad() |
| def collect_quality_data( |
| model, |
| src_list: List[torch.Tensor], |
| ref_list: List[str], |
| tgt_tokenizer, |
| t_capture: int = 0, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| max_samples: int = 5000, |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| """ |
| Collect (hidden_state, quality_score) pairs for classifier training. |
| |
| For each sample: |
| 1. Run generate_cached() on src |
| 2. Capture decoder hidden state at t=t_capture |
| 3. Compute CER between output and reference |
| 4. Quality = 1 - CER (normalize to [0,1]) |
| |
| Args: |
| model : SanskritModel |
| src_list : list of [1, src_len] tensors |
| ref_list : list of reference Devanagari strings |
| tgt_tokenizer : SanskritTargetTokenizer |
| t_capture : which step to capture hidden states (0 = final) |
| max_samples : cap number of training examples |
| |
| Returns: |
| hidden_matrix : np.ndarray [N, d_model] |
| quality_scores: np.ndarray [N] values in [0, 1] |
| """ |
| inner = model.model |
| T = inner.scheduler.num_timesteps |
| device = next(inner.parameters()).device |
|
|
| hidden_list = [] |
| quality_list = [] |
| n = min(len(src_list), max_samples) |
|
|
| def cer(pred, ref): |
| if not ref: |
| return 1.0 |
| def ed(s1, s2): |
| m, n = len(s1), len(s2) |
| dp = list(range(n + 1)) |
| for i in range(1, m + 1): |
| prev, dp[0] = dp[0], i |
| for j in range(1, n + 1): |
| temp = dp[j] |
| dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1]) |
| prev = temp |
| return dp[n] |
| return ed(pred, ref) / max(len(ref), 1) |
|
|
| print(f"Collecting quality data from {n} examples...") |
| for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])): |
| if i % 200 == 0: |
| print(f" {i}/{n}") |
|
|
| if src.dim() == 1: |
| src = src.unsqueeze(0) |
| src = src.to(device) |
|
|
| B = src.shape[0] |
| tgt_len = inner.max_seq_len |
| mask_id = inner.mask_token_id |
|
|
| memory, src_pad_mask = inner.encode_source(src) |
| x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) |
| hint = None |
| h_cap = None |
|
|
| for t_val in range(T - 1, -1, -1): |
| t = torch.full((B,), t_val, dtype=torch.long, device=device) |
| is_last = (t_val == 0) |
|
|
| logits, _ = inner.forward_cached( |
| memory, src_pad_mask, x0_est, t, |
| x0_hint=hint, inference_mode=True, |
| ) |
|
|
| if t_val == t_capture and hasattr(inner, '_last_hidden'): |
| h_cap = inner._last_hidden[0].mean(dim=0).detach().cpu() |
|
|
| logits = logits / max(temperature, 1e-8) |
| if top_k > 0: |
| V = logits.shape[-1] |
| if top_k < V: |
| vals, _ = torch.topk(logits, top_k, dim=-1) |
| logits = logits.masked_fill(logits < vals[..., -1:], float('-inf')) |
|
|
| probs = F.softmax(logits, dim=-1) |
| x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs) |
| hint = x0_est |
|
|
| if h_cap is None: |
| continue |
|
|
| ids = [x for x in x0_est[0].tolist() if x > 4] |
| pred = tgt_tokenizer.decode(ids).strip() |
| q = max(0.0, 1.0 - cer(pred, ref)) |
|
|
| hidden_list.append(h_cap.numpy()) |
| quality_list.append(q) |
|
|
| print(f"Collected {len(hidden_list)} quality examples.") |
| print(f"Quality stats: mean={np.mean(quality_list):.3f} " |
| f"min={np.min(quality_list):.3f} max={np.max(quality_list):.3f}") |
|
|
| return np.stack(hidden_list), np.array(quality_list, dtype=np.float32) |
|
|
|
|
| def _sample(probs): |
| B, L, V = probs.shape |
| flat = probs.view(B * L, V).clamp(min=1e-9) |
| flat = flat / flat.sum(dim=-1, keepdim=True) |
| return torch.multinomial(flat, 1).squeeze(-1).view(B, L) |
|
|
|
|
| |
|
|
| def train_quality_classifier( |
| hidden_matrix: np.ndarray, |
| quality_scores: np.ndarray, |
| d_model: int, |
| epochs: int = 30, |
| batch_size: int = 64, |
| lr: float = 1e-3, |
| val_frac: float = 0.1, |
| save_path: Optional[str] = None, |
| ) -> QualityClassifier: |
| """ |
| Train QualityClassifier on collected (hidden, quality) pairs. |
| |
| Args: |
| hidden_matrix : [N, d_model] from collect_quality_data() |
| quality_scores : [N] quality labels in [0, 1] |
| d_model : hidden dimension |
| epochs : training epochs |
| save_path : if given, save trained classifier weights here |
| |
| Returns: |
| trained QualityClassifier |
| """ |
| device = torch.device("cpu") |
|
|
| X = torch.tensor(hidden_matrix, dtype=torch.float32) |
| y = torch.tensor(quality_scores, dtype=torch.float32).unsqueeze(-1) |
|
|
| N = len(X) |
| n_val = max(1, int(N * val_frac)) |
| idx = torch.randperm(N) |
| val_idx = idx[:n_val] |
| train_idx = idx[n_val:] |
|
|
| X_train, y_train = X[train_idx], y[train_idx] |
| X_val, y_val = X[val_idx], y[val_idx] |
|
|
| clf = QualityClassifier(d_model).to(device) |
| optimizer = torch.optim.Adam(clf.parameters(), lr=lr) |
|
|
| print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params") |
| print(f"Train: {len(X_train)} Val: {len(X_val)}") |
|
|
| best_val_loss = float('inf') |
| best_state = None |
|
|
| for epoch in range(epochs): |
| clf.train() |
| perm = torch.randperm(len(X_train)) |
| train_loss = 0.0 |
| n_batches = 0 |
|
|
| for start in range(0, len(X_train), batch_size): |
| batch_idx = perm[start:start + batch_size] |
| xb, yb = X_train[batch_idx], y_train[batch_idx] |
| pred = clf(xb) |
| loss = F.mse_loss(pred, yb) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| train_loss += loss.item() |
| n_batches += 1 |
|
|
| clf.eval() |
| with torch.no_grad(): |
| val_pred = clf(X_val) |
| val_loss = F.mse_loss(val_pred, y_val).item() |
|
|
| if epoch % 5 == 0 or epoch == epochs - 1: |
| print(f" Ep {epoch+1:3d} train={train_loss/n_batches:.4f} val={val_loss:.4f}") |
|
|
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| best_state = {k: v.clone() for k, v in clf.state_dict().items()} |
|
|
| if best_state: |
| clf.load_state_dict(best_state) |
| print(f" Best val loss: {best_val_loss:.4f}") |
|
|
| if save_path: |
| os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) |
| torch.save(clf.state_dict(), save_path) |
| print(f" Classifier saved: {save_path}") |
|
|
| return clf |
|
|
|
|
| |
|
|
| def generate_guided( |
| model, |
| src: torch.Tensor, |
| classifier: QualityClassifier, |
| guidance_scale: float = 1.0, |
| temperature: float = 0.8, |
| top_k: int = 40, |
| ) -> torch.Tensor: |
| """ |
| Classifier-guided generation. |
| |
| At each diffusion step: |
| 1. Run forward_cached() β logits, hidden states |
| 2. Compute classifier gradient: β(quality_score) / β(hidden) |
| 3. Project gradient back to logit space (approximate) |
| 4. guided_logits = logits + Ξ» * gradient_signal |
| 5. Sample from guided_logits |
| |
| guidance_scale Ξ»: |
| 0.0 β no guidance (standard generation) |
| 0.5 β weak guidance |
| 1.0 β moderate guidance (recommended starting point) |
| 2.0 β strong guidance (may reduce diversity) |
| 3.0 β very strong (may collapse to repetitive output) |
| |
| Args: |
| model : SanskritModel (frozen) |
| src : [1, src_len] IAST token ids |
| classifier : trained QualityClassifier |
| guidance_scale : Ξ» β guidance strength |
| |
| Returns: |
| x0_est : [1, tgt_len] generated token ids |
| """ |
| inner = model.model |
| T = inner.scheduler.num_timesteps |
| device = next(inner.parameters()).device |
| clf_device = next(classifier.parameters()).device |
|
|
| if src.dim() == 1: |
| src = src.unsqueeze(0) |
| src = src.to(device) |
|
|
| B = src.shape[0] |
| tgt_len = inner.max_seq_len |
| mask_id = inner.mask_token_id |
|
|
| memory, src_pad_mask = inner.encode_source(src) |
| x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) |
| hint = None |
|
|
| inner.eval() |
| classifier.eval() |
|
|
| for t_val in range(T - 1, -1, -1): |
| t = torch.full((B,), t_val, dtype=torch.long, device=device) |
| is_last = (t_val == 0) |
|
|
| if guidance_scale > 0.0: |
| |
| with torch.enable_grad(): |
| |
| PAD = 1 |
| if t_val > 0: |
| _, x_t_ids = inner.forward_process.q_sample(x0_est, t) |
| else: |
| x_t_ids = x0_est |
|
|
| x = inner.tgt_embed(x_t_ids) |
| t_norm = t.float() / T |
| t_emb = inner.time_mlp(t_norm.unsqueeze(-1)) |
| x = x + t_emb.unsqueeze(1) |
|
|
| if hint is not None: |
| hint_emb = inner.tgt_embed(hint) |
| gate = inner.hint_gate(x) |
| x = x + gate * hint_emb |
|
|
| for block in inner.decoder_blocks: |
| x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask) |
|
|
| |
| hidden = x.detach().requires_grad_(True).to(clf_device) |
|
|
| |
| quality = classifier(hidden) |
| quality.sum().backward() |
|
|
| |
| grad = hidden.grad.to(device) |
|
|
| |
| |
| logit_grad = grad @ inner.head.weight.T |
|
|
| |
| with torch.no_grad(): |
| logits = inner.head(x) |
|
|
| |
| logits = logits + guidance_scale * logit_grad |
|
|
| else: |
| with torch.no_grad(): |
| logits, _ = inner.forward_cached( |
| memory, src_pad_mask, x0_est, t, |
| x0_hint=hint, inference_mode=True, |
| ) |
|
|
| with torch.no_grad(): |
| logits = logits / max(temperature, 1e-8) |
| if top_k > 0: |
| V = logits.shape[-1] |
| if top_k < V: |
| vals, _ = torch.topk(logits, top_k, dim=-1) |
| logits = logits.masked_fill(logits < vals[..., -1:], float('-inf')) |
|
|
| probs = F.softmax(logits, dim=-1) |
| x0_est = torch.argmax(probs, dim=-1) if is_last else _sample_no_grad(probs) |
| hint = x0_est |
|
|
| return x0_est |
|
|
|
|
| def _sample_no_grad(probs): |
| B, L, V = probs.shape |
| flat = probs.view(B * L, V).clamp(min=1e-9) |
| flat = flat / flat.sum(dim=-1, keepdim=True) |
| return torch.multinomial(flat, 1).squeeze(-1).view(B, L) |
|
|
|
|
| |
|
|
| def sweep_guidance_scales( |
| model, |
| classifier: QualityClassifier, |
| src_list: List[torch.Tensor], |
| ref_list: List[str], |
| tgt_tokenizer, |
| scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0], |
| n_samples: int = 50, |
| device: torch.device = None, |
| output_dir: str = "analysis/outputs", |
| ) -> Dict: |
| """ |
| Evaluate CER at each guidance scale. |
| Produces quality-diversity tradeoff plot. |
| """ |
| def cer(pred, ref): |
| if not ref: |
| return 1.0 |
| def ed(s1, s2): |
| m, n = len(s1), len(s2) |
| dp = list(range(n + 1)) |
| for i in range(1, m + 1): |
| prev, dp[0] = dp[0], i |
| for j in range(1, n + 1): |
| temp = dp[j] |
| dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1]) |
| prev = temp |
| return dp[n] |
| return ed(pred, ref) / max(len(ref), 1) |
|
|
| device = device or next(model.parameters()).device |
| results = {} |
| n = min(n_samples, len(src_list)) |
|
|
| print("\nGuidance scale sweep...") |
| for scale in scales: |
| cer_list = [] |
| output_set = [] |
| for src, ref in zip(src_list[:n], ref_list[:n]): |
| if src.dim() == 1: |
| src = src.unsqueeze(0) |
| out = generate_guided(model, src.to(device), classifier, |
| guidance_scale=scale) |
| ids = [x for x in out[0].tolist() if x > 4] |
| pred = tgt_tokenizer.decode(ids).strip() |
| cer_list.append(cer(pred, ref)) |
| output_set.append(pred) |
|
|
| mean_cer = float(np.mean(cer_list)) |
|
|
| |
| unique_frac = len(set(output_set)) / max(len(output_set), 1) |
|
|
| results[scale] = {"mean_cer": mean_cer, "diversity": unique_frac} |
| print(f" Ξ»={scale:.1f} CER={mean_cer:.4f} diversity={unique_frac:.3f}") |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
| try: |
| import matplotlib.pyplot as plt |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) |
|
|
| sc_list = sorted(results.keys()) |
| cers = [results[s]["mean_cer"] for s in sc_list] |
| diversities = [results[s]["diversity"] for s in sc_list] |
|
|
| ax1.plot(sc_list, cers, 'o-', color='coral', linewidth=1.8, markersize=7) |
| ax1.set_xlabel("Guidance scale Ξ»", fontsize=10) |
| ax1.set_ylabel("CER (β better)", fontsize=10) |
| ax1.set_title("Quality vs guidance scale", fontsize=10) |
|
|
| ax2.plot(sc_list, diversities, 'o-', color='steelblue', linewidth=1.8, markersize=7) |
| ax2.set_xlabel("Guidance scale Ξ»", fontsize=10) |
| ax2.set_ylabel("Output diversity (unique fraction)", fontsize=10) |
| ax2.set_title("Diversity vs guidance scale", fontsize=10) |
|
|
| plt.suptitle("Quality-Diversity Tradeoff (Guidance Scale Sweep)", fontsize=11) |
| plt.tight_layout() |
| path = os.path.join(output_dir, "guidance_scale_sweep.png") |
| plt.savefig(path, dpi=150, bbox_inches='tight') |
| plt.close() |
| print(f" Saved: {path}") |
| except ImportError: |
| pass |
|
|
| with open(os.path.join(output_dir, "guidance_results.json"), "w") as f: |
| json.dump({str(k): v for k, v in results.items()}, f, indent=2) |
|
|
| return results |
|
|