Spaces:
Sleeping
Sleeping
| # """ | |
| # analysis/quality_classifier.py | |
| # ================================ | |
| # Task 5: Classifier-Free Guidance for Paraphrase Quality Control | |
| # | |
| # Two steps β only Step 2 requires training a SMALL model (not the main D3PM): | |
| # | |
| # STEP 1 β Collect training data (no training): | |
| # Run existing model on val set, record (hidden_state, CER) pairs. | |
| # Hidden states come from model.model._last_hidden after forward_cached(). | |
| # CER score = quality label (lower CER = higher quality). | |
| # | |
| # STEP 2 β Train quality classifier: | |
| # Small 2-layer MLP: d_model β 64 β 1 | |
| # Input: pooled decoder hidden state [B, d_model] | |
| # Output: predicted quality score in [0, 1] (1 = high quality) | |
| # Loss: MSE against normalized CER labels | |
| # Training time: ~5-10 minutes on CPU for 10k examples | |
| # | |
| # STEP 3 β Guided inference (no retraining): | |
| # At each diffusion step, use classifier gradient to shift logits: | |
| # guided_logits = logits + Ξ» * β(quality_score)/β(logits) | |
| # Higher Ξ» β model biased toward high-quality outputs | |
| # Ξ»=0 β standard generation (no guidance) | |
| # | |
| # Key: main D3PM model is FROZEN throughout. Only the 10k-param classifier trains. | |
| # """ | |
| # | |
| # import torch | |
| # import torch.nn as nn | |
| # import torch.nn.functional as F | |
| # import numpy as np | |
| # import os | |
| # import json | |
| # from typing import List, Dict, Optional, Tuple | |
| # | |
| # | |
| # # ββ Quality classifier architecture ββββββββββββββββββββββββββββββββββ | |
| # | |
| # class QualityClassifier(nn.Module): | |
| # """ | |
| # Lightweight MLP that predicts transliteration quality from decoder | |
| # hidden states. | |
| # | |
| # Architecture: | |
| # d_model β 128 β 64 β 1 β Sigmoid | |
| # | |
| # Input: mean-pooled decoder hidden state [B, d_model] | |
| # Output: quality score [B, 1] β [0, 1] (1 = high quality) | |
| # | |
| # ~10k parameters. Trains in minutes on CPU. | |
| # """ | |
| # def __init__(self, d_model: int): | |
| # super().__init__() | |
| # self.net = nn.Sequential( | |
| # nn.Linear(d_model, 128), | |
| # nn.ReLU(), | |
| # nn.Dropout(0.1), | |
| # nn.Linear(128, 64), | |
| # nn.ReLU(), | |
| # nn.Linear(64, 1), | |
| # nn.Sigmoid(), | |
| # ) | |
| # self.d_model = d_model | |
| # | |
| # def forward(self, hidden: torch.Tensor) -> torch.Tensor: | |
| # """ | |
| # Args: | |
| # hidden : [B, tgt_len, d_model] OR [B, d_model] (already pooled) | |
| # | |
| # Returns: | |
| # score : [B, 1] quality score in [0, 1] | |
| # """ | |
| # if hidden.dim() == 3: | |
| # # Pool over sequence length | |
| # hidden = hidden.mean(dim=1) # [B, d_model] | |
| # return self.net(hidden) # [B, 1] | |
| # | |
| # | |
| # # ββ Training data collection ββββββββββββββββββββββββββββββββββββββββββ | |
| # | |
| # @torch.no_grad() | |
| # def collect_quality_data( | |
| # model, | |
| # src_list: List[torch.Tensor], | |
| # ref_list: List[str], | |
| # tgt_tokenizer, | |
| # t_capture: int = 0, | |
| # temperature: float = 0.8, | |
| # top_k: int = 40, | |
| # max_samples: int = 5000, | |
| # ) -> Tuple[np.ndarray, np.ndarray]: | |
| # """ | |
| # Collect (hidden_state, quality_score) pairs for classifier training. | |
| # | |
| # For each sample: | |
| # 1. Run generate_cached() on src | |
| # 2. Capture decoder hidden state at t=t_capture | |
| # 3. Compute CER between output and reference | |
| # 4. Quality = 1 - CER (normalize to [0,1]) | |
| # | |
| # Args: | |
| # model : SanskritModel | |
| # src_list : list of [1, src_len] tensors | |
| # ref_list : list of reference Devanagari strings | |
| # tgt_tokenizer : SanskritTargetTokenizer | |
| # t_capture : which step to capture hidden states (0 = final) | |
| # max_samples : cap number of training examples | |
| # | |
| # Returns: | |
| # hidden_matrix : np.ndarray [N, d_model] | |
| # quality_scores: np.ndarray [N] values in [0, 1] | |
| # """ | |
| # inner = model.model | |
| # T = inner.scheduler.num_timesteps | |
| # device = next(inner.parameters()).device | |
| # | |
| # hidden_list = [] | |
| # quality_list = [] | |
| # n = min(len(src_list), max_samples) | |
| # | |
| # def cer(pred, ref): | |
| # if not ref: | |
| # return 1.0 | |
| # def ed(s1, s2): | |
| # m, n = len(s1), len(s2) | |
| # dp = list(range(n + 1)) | |
| # for i in range(1, m + 1): | |
| # prev, dp[0] = dp[0], i | |
| # for j in range(1, n + 1): | |
| # temp = dp[j] | |
| # dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1]) | |
| # prev = temp | |
| # return dp[n] | |
| # return ed(pred, ref) / max(len(ref), 1) | |
| # | |
| # print(f"Collecting quality data from {n} examples...") | |
| # for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])): | |
| # if i % 200 == 0: | |
| # print(f" {i}/{n}") | |
| # | |
| # if src.dim() == 1: | |
| # src = src.unsqueeze(0) | |
| # src = src.to(device) | |
| # | |
| # B = src.shape[0] | |
| # tgt_len = inner.max_seq_len | |
| # mask_id = inner.mask_token_id | |
| # | |
| # memory, src_pad_mask = inner.encode_source(src) | |
| # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) | |
| # hint = None | |
| # h_cap = None | |
| # | |
| # for t_val in range(T - 1, -1, -1): | |
| # t = torch.full((B,), t_val, dtype=torch.long, device=device) | |
| # is_last = (t_val == 0) | |
| # | |
| # logits, _ = inner.forward_cached( | |
| # memory, src_pad_mask, x0_est, t, | |
| # x0_hint=hint, inference_mode=True, | |
| # ) | |
| # | |
| # if t_val == t_capture and hasattr(inner, '_last_hidden'): | |
| # h_cap = inner._last_hidden[0].mean(dim=0).detach().cpu() # [d_model] | |
| # | |
| # logits = logits / max(temperature, 1e-8) | |
| # if top_k > 0: | |
| # V = logits.shape[-1] | |
| # if top_k < V: | |
| # vals, _ = torch.topk(logits, top_k, dim=-1) | |
| # logits = logits.masked_fill(logits < vals[..., -1:], float('-inf')) | |
| # | |
| # probs = F.softmax(logits, dim=-1) | |
| # x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs) | |
| # hint = x0_est | |
| # | |
| # if h_cap is None: | |
| # continue | |
| # | |
| # ids = [x for x in x0_est[0].tolist() if x > 4] | |
| # pred = tgt_tokenizer.decode(ids).strip() | |
| # q = max(0.0, 1.0 - cer(pred, ref)) # quality = 1 - CER | |
| # | |
| # hidden_list.append(h_cap.numpy()) | |
| # quality_list.append(q) | |
| # | |
| # print(f"Collected {len(hidden_list)} quality examples.") | |
| # print(f"Quality stats: mean={np.mean(quality_list):.3f} " | |
| # f"min={np.min(quality_list):.3f} max={np.max(quality_list):.3f}") | |
| # | |
| # return np.stack(hidden_list), np.array(quality_list, dtype=np.float32) | |
| # | |
| # | |
| # def _sample(probs): | |
| # B, L, V = probs.shape | |
| # flat = probs.view(B * L, V).clamp(min=1e-9) | |
| # flat = flat / flat.sum(dim=-1, keepdim=True) | |
| # return torch.multinomial(flat, 1).squeeze(-1).view(B, L) | |
| # | |
| # | |
| # # ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # | |
| # def train_quality_classifier( | |
| # hidden_matrix: np.ndarray, | |
| # quality_scores: np.ndarray, | |
| # d_model: int, | |
| # epochs: int = 30, | |
| # batch_size: int = 64, | |
| # lr: float = 1e-3, | |
| # val_frac: float = 0.1, | |
| # save_path: Optional[str] = None, | |
| # ) -> QualityClassifier: | |
| # """ | |
| # Train QualityClassifier on collected (hidden, quality) pairs. | |
| # | |
| # Args: | |
| # hidden_matrix : [N, d_model] from collect_quality_data() | |
| # quality_scores : [N] quality labels in [0, 1] | |
| # d_model : hidden dimension | |
| # epochs : training epochs | |
| # save_path : if given, save trained classifier weights here | |
| # | |
| # Returns: | |
| # trained QualityClassifier | |
| # """ | |
| # device = torch.device("cpu") # classifier is tiny, CPU is fine | |
| # | |
| # X = torch.tensor(hidden_matrix, dtype=torch.float32) | |
| # y = torch.tensor(quality_scores, dtype=torch.float32).unsqueeze(-1) | |
| # | |
| # N = len(X) | |
| # n_val = max(1, int(N * val_frac)) | |
| # idx = torch.randperm(N) | |
| # val_idx = idx[:n_val] | |
| # train_idx = idx[n_val:] | |
| # | |
| # X_train, y_train = X[train_idx], y[train_idx] | |
| # X_val, y_val = X[val_idx], y[val_idx] | |
| # | |
| # clf = QualityClassifier(d_model).to(device) | |
| # optimizer = torch.optim.Adam(clf.parameters(), lr=lr) | |
| # | |
| # print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params") | |
| # print(f"Train: {len(X_train)} Val: {len(X_val)}") | |
| # | |
| # best_val_loss = float('inf') | |
| # best_state = None | |
| # | |
| # for epoch in range(epochs): | |
| # clf.train() | |
| # perm = torch.randperm(len(X_train)) | |
| # train_loss = 0.0 | |
| # n_batches = 0 | |
| # | |
| # for start in range(0, len(X_train), batch_size): | |
| # batch_idx = perm[start:start + batch_size] | |
| # xb, yb = X_train[batch_idx], y_train[batch_idx] | |
| # pred = clf(xb) | |
| # loss = F.mse_loss(pred, yb) | |
| # optimizer.zero_grad() | |
| # loss.backward() | |
| # optimizer.step() | |
| # train_loss += loss.item() | |
| # n_batches += 1 | |
| # | |
| # clf.eval() | |
| # with torch.no_grad(): | |
| # val_pred = clf(X_val) | |
| # val_loss = F.mse_loss(val_pred, y_val).item() | |
| # | |
| # if epoch % 5 == 0 or epoch == epochs - 1: | |
| # print(f" Ep {epoch+1:3d} train={train_loss/n_batches:.4f} val={val_loss:.4f}") | |
| # | |
| # if val_loss < best_val_loss: | |
| # best_val_loss = val_loss | |
| # best_state = {k: v.clone() for k, v in clf.state_dict().items()} | |
| # | |
| # if best_state: | |
| # clf.load_state_dict(best_state) | |
| # print(f" Best val loss: {best_val_loss:.4f}") | |
| # | |
| # if save_path: | |
| # os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True) | |
| # torch.save(clf.state_dict(), save_path) | |
| # print(f" Classifier saved: {save_path}") | |
| # | |
| # return clf | |
| # | |
| # | |
| # # ββ Guided inference ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # | |
| # def generate_guided( | |
| # model, | |
| # src: torch.Tensor, | |
| # classifier: QualityClassifier, | |
| # guidance_scale: float = 1.0, | |
| # temperature: float = 0.8, | |
| # top_k: int = 40, | |
| # ) -> torch.Tensor: | |
| # """ | |
| # Classifier-guided generation. | |
| # | |
| # At each diffusion step: | |
| # 1. Run forward_cached() β logits, hidden states | |
| # 2. Compute classifier gradient: β(quality_score) / β(hidden) | |
| # 3. Project gradient back to logit space (approximate) | |
| # 4. guided_logits = logits + Ξ» * gradient_signal | |
| # 5. Sample from guided_logits | |
| # | |
| # guidance_scale Ξ»: | |
| # 0.0 β no guidance (standard generation) | |
| # 0.5 β weak guidance | |
| # 1.0 β moderate guidance (recommended starting point) | |
| # 2.0 β strong guidance (may reduce diversity) | |
| # 3.0 β very strong (may collapse to repetitive output) | |
| # | |
| # Args: | |
| # model : SanskritModel (frozen) | |
| # src : [1, src_len] IAST token ids | |
| # classifier : trained QualityClassifier | |
| # guidance_scale : Ξ» β guidance strength | |
| # | |
| # Returns: | |
| # x0_est : [1, tgt_len] generated token ids | |
| # """ | |
| # inner = model.model | |
| # T = inner.scheduler.num_timesteps | |
| # device = next(inner.parameters()).device | |
| # clf_device = next(classifier.parameters()).device | |
| # | |
| # if src.dim() == 1: | |
| # src = src.unsqueeze(0) | |
| # src = src.to(device) | |
| # | |
| # B = src.shape[0] | |
| # tgt_len = inner.max_seq_len | |
| # mask_id = inner.mask_token_id | |
| # | |
| # memory, src_pad_mask = inner.encode_source(src) | |
| # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) | |
| # hint = None | |
| # | |
| # inner.eval() | |
| # classifier.eval() | |
| # | |
| # for t_val in range(T - 1, -1, -1): | |
| # t = torch.full((B,), t_val, dtype=torch.long, device=device) | |
| # is_last = (t_val == 0) | |
| # | |
| # if guidance_scale > 0.0: | |
| # # Need gradients for classifier guidance | |
| # with torch.enable_grad(): | |
| # # Run forward_cached and get hidden states | |
| # PAD = 1 | |
| # if t_val > 0: | |
| # _, x_t_ids = inner.forward_process.q_sample(x0_est, t) | |
| # else: | |
| # x_t_ids = x0_est | |
| # | |
| # x = inner.tgt_embed(x_t_ids) | |
| # t_norm = t.float() / T | |
| # t_emb = inner.time_mlp(t_norm.unsqueeze(-1)) | |
| # x = x + t_emb.unsqueeze(1) | |
| # | |
| # if hint is not None: | |
| # hint_emb = inner.tgt_embed(hint) | |
| # gate = inner.hint_gate(x) | |
| # x = x + gate * hint_emb | |
| # | |
| # for block in inner.decoder_blocks: | |
| # x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask) | |
| # | |
| # # hidden: [B, tgt_len, d_model] β detach from graph for clf | |
| # hidden = x.detach().requires_grad_(True).to(clf_device) | |
| # | |
| # # Classifier quality score | |
| # quality = classifier(hidden) # [B, 1] | |
| # quality.sum().backward() | |
| # | |
| # # Gradient of quality w.r.t. hidden: [B, tgt_len, d_model] | |
| # grad = hidden.grad.to(device) # [B, tgt_len, d_model] | |
| # | |
| # # Project gradient to logit space via output head weight | |
| # # logit_grad β grad @ head.weight [B, tgt_len, tgt_vocab] | |
| # logit_grad = grad @ inner.head.weight.T | |
| # | |
| # # Compute standard logits (no gradient needed) | |
| # with torch.no_grad(): | |
| # logits = inner.head(x) | |
| # | |
| # # Apply guidance | |
| # logits = logits + guidance_scale * logit_grad | |
| # | |
| # else: | |
| # with torch.no_grad(): | |
| # logits, _ = inner.forward_cached( | |
| # memory, src_pad_mask, x0_est, t, | |
| # x0_hint=hint, inference_mode=True, | |
| # ) | |
| # | |
| # with torch.no_grad(): | |
| # logits = logits / max(temperature, 1e-8) | |
| # if top_k > 0: | |
| # V = logits.shape[-1] | |
| # if top_k < V: | |
| # vals, _ = torch.topk(logits, top_k, dim=-1) | |
| # logits = logits.masked_fill(logits < vals[..., -1:], float('-inf')) | |
| # | |
| # probs = F.softmax(logits, dim=-1) | |
| # x0_est = torch.argmax(probs, dim=-1) if is_last else _sample_no_grad(probs) | |
| # hint = x0_est | |
| # | |
| # return x0_est | |
| # | |
| # | |
| # def _sample_no_grad(probs): | |
| # B, L, V = probs.shape | |
| # flat = probs.view(B * L, V).clamp(min=1e-9) | |
| # flat = flat / flat.sum(dim=-1, keepdim=True) | |
| # return torch.multinomial(flat, 1).squeeze(-1).view(B, L) | |
| # | |
| # | |
| # # ββ Guidance scale sweep ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # | |
| # def sweep_guidance_scales( | |
| # model, | |
| # classifier: QualityClassifier, | |
| # src_list: List[torch.Tensor], | |
| # ref_list: List[str], | |
| # tgt_tokenizer, | |
| # scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0], | |
| # n_samples: int = 50, | |
| # device: torch.device = None, | |
| # output_dir: str = "analysis/outputs", | |
| # ) -> Dict: | |
| # """ | |
| # Evaluate CER at each guidance scale. | |
| # Produces quality-diversity tradeoff plot. | |
| # """ | |
| # def cer(pred, ref): | |
| # if not ref: | |
| # return 1.0 | |
| # def ed(s1, s2): | |
| # m, n = len(s1), len(s2) | |
| # dp = list(range(n + 1)) | |
| # for i in range(1, m + 1): | |
| # prev, dp[0] = dp[0], i | |
| # for j in range(1, n + 1): | |
| # temp = dp[j] | |
| # dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1]) | |
| # prev = temp | |
| # return dp[n] | |
| # return ed(pred, ref) / max(len(ref), 1) | |
| # | |
| # device = device or next(model.parameters()).device | |
| # results = {} | |
| # n = min(n_samples, len(src_list)) | |
| # | |
| # print("\nGuidance scale sweep...") | |
| # for scale in scales: | |
| # cer_list = [] | |
| # output_set = [] | |
| # for src, ref in zip(src_list[:n], ref_list[:n]): | |
| # if src.dim() == 1: | |
| # src = src.unsqueeze(0) | |
| # out = generate_guided(model, src.to(device), classifier, | |
| # guidance_scale=scale) | |
| # ids = [x for x in out[0].tolist() if x > 4] | |
| # pred = tgt_tokenizer.decode(ids).strip() | |
| # cer_list.append(cer(pred, ref)) | |
| # output_set.append(pred) | |
| # | |
| # mean_cer = float(np.mean(cer_list)) | |
| # | |
| # # Self-diversity: unique outputs / total (proxy for diversity) | |
| # unique_frac = len(set(output_set)) / max(len(output_set), 1) | |
| # | |
| # results[scale] = {"mean_cer": mean_cer, "diversity": unique_frac} | |
| # print(f" Ξ»={scale:.1f} CER={mean_cer:.4f} diversity={unique_frac:.3f}") | |
| # | |
| # # Plot | |
| # os.makedirs(output_dir, exist_ok=True) | |
| # try: | |
| # import matplotlib.pyplot as plt | |
| # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4)) | |
| # | |
| # sc_list = sorted(results.keys()) | |
| # cers = [results[s]["mean_cer"] for s in sc_list] | |
| # diversities = [results[s]["diversity"] for s in sc_list] | |
| # | |
| # ax1.plot(sc_list, cers, 'o-', color='coral', linewidth=1.8, markersize=7) | |
| # ax1.set_xlabel("Guidance scale Ξ»", fontsize=10) | |
| # ax1.set_ylabel("CER (β better)", fontsize=10) | |
| # ax1.set_title("Quality vs guidance scale", fontsize=10) | |
| # | |
| # ax2.plot(sc_list, diversities, 'o-', color='steelblue', linewidth=1.8, markersize=7) | |
| # ax2.set_xlabel("Guidance scale Ξ»", fontsize=10) | |
| # ax2.set_ylabel("Output diversity (unique fraction)", fontsize=10) | |
| # ax2.set_title("Diversity vs guidance scale", fontsize=10) | |
| # | |
| # plt.suptitle("Quality-Diversity Tradeoff (Guidance Scale Sweep)", fontsize=11) | |
| # plt.tight_layout() | |
| # path = os.path.join(output_dir, "guidance_scale_sweep.png") | |
| # plt.savefig(path, dpi=150, bbox_inches='tight') | |
| # plt.close() | |
| # print(f" Saved: {path}") | |
| # except ImportError: | |
| # pass | |
| # | |
| # with open(os.path.join(output_dir, "guidance_results.json"), "w") as f: | |
| # json.dump({str(k): v for k, v in results.items()}, f, indent=2) | |
| # | |
| # return results | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import List, Dict | |
| # ============================================================ | |
| # 1. QUALITY CLASSIFIER | |
| # ============================================================ | |
| class QualityClassifier(nn.Module): | |
| def __init__(self, d_model: int): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(d_model, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(128, 64), | |
| nn.ReLU(), | |
| nn.Linear(64, 1), | |
| nn.Sigmoid(), | |
| ) | |
| def forward(self, hidden): | |
| if hidden.dim() == 3: | |
| hidden = hidden.mean(dim=1) | |
| return self.net(hidden) | |
| # ============================================================ | |
| # 2. GUIDED GENERATION (CORRECTED) | |
| # ============================================================ | |
| def generate_guided( | |
| model, | |
| src: torch.Tensor, | |
| classifier: QualityClassifier, | |
| guidance_scale: float = 1.0, | |
| temperature: float = 0.8, | |
| top_k: int = 40, | |
| ): | |
| inner = model.model | |
| T = inner.scheduler.num_timesteps | |
| device = next(inner.parameters()).device | |
| if src.dim() == 1: | |
| src = src.unsqueeze(0) | |
| src = src.to(device) | |
| B = src.shape[0] | |
| tgt_len = inner.max_seq_len | |
| mask_id = inner.mask_token_id | |
| # KV CACHE | |
| memory, src_pad_mask = inner.encode_source(src) | |
| x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) | |
| hint = None | |
| inner.eval() | |
| classifier.eval() | |
| for t_val in range(T - 1, -1, -1): | |
| t = torch.full((B,), t_val, dtype=torch.long, device=device) | |
| is_last = (t_val == 0) | |
| if guidance_scale > 0: | |
| # ENABLE GRAD FOR GUIDANCE | |
| with torch.enable_grad(): | |
| if t_val > 0: | |
| _, x_t_ids = inner.forward_process.q_sample(x0_est, t) | |
| else: | |
| x_t_ids = x0_est | |
| x = inner.tgt_embed(x_t_ids) | |
| # time embedding | |
| t_norm = t.float() / T | |
| t_emb = inner.time_mlp(t_norm.unsqueeze(-1)) | |
| x = x + t_emb.unsqueeze(1) | |
| # hint conditioning | |
| if hint is not None: | |
| hint_emb = inner.tgt_embed(hint) | |
| gate = inner.hint_gate(x) | |
| x = x + gate * hint_emb | |
| # decoder forward | |
| for block in inner.decoder_blocks: | |
| x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask) | |
| # IMPORTANT: NO DETACH HERE | |
| hidden = x.requires_grad_(True) | |
| # classifier forward | |
| quality = classifier(hidden) # [B,1] | |
| # compute gradient | |
| quality.sum().backward() | |
| grad = hidden.grad # [B, L, d_model] | |
| # ===== FIX 1: Normalize gradient ===== | |
| grad_norm = grad.norm(dim=-1, keepdim=True) + 1e-6 | |
| grad = grad / grad_norm | |
| # ===== FIX 2: Project to logit space ===== | |
| logit_grad = torch.matmul(grad, inner.head.weight.T) | |
| # ===== FIX 3: Clip gradient ===== | |
| logit_grad = torch.clamp(logit_grad, -5.0, 5.0) | |
| # compute logits (no grad) | |
| with torch.no_grad(): | |
| logits = inner.head(x) | |
| # apply guidance | |
| logits = logits + guidance_scale * logit_grad | |
| else: | |
| with torch.no_grad(): | |
| logits, _ = inner.forward_cached( | |
| memory, src_pad_mask, x0_est, t, | |
| x0_hint=hint, | |
| inference_mode=True, | |
| ) | |
| # ===== Sampling ===== | |
| logits = logits / max(temperature, 1e-8) | |
| if top_k > 0: | |
| V = logits.shape[-1] | |
| if top_k < V: | |
| vals, _ = torch.topk(logits, top_k, dim=-1) | |
| logits = logits.masked_fill(logits < vals[..., -1:], float('-inf')) | |
| probs = F.softmax(logits, dim=-1) | |
| if is_last: | |
| x0_est = torch.argmax(probs, dim=-1) | |
| else: | |
| x0_est = _sample(probs) | |
| hint = x0_est | |
| return x0_est | |
| def _sample(probs): | |
| B, L, V = probs.shape | |
| flat = probs.view(B * L, V).clamp(min=1e-9) | |
| flat = flat / flat.sum(dim=-1, keepdim=True) | |
| return torch.multinomial(flat, 1).squeeze(-1).view(B, L) | |
| # ============================================================ | |
| # 3. GUIDANCE SWEEP (EVALUATION) | |
| # ============================================================ | |
| def sweep_guidance( | |
| model, | |
| classifier, | |
| src_list, | |
| ref_list, | |
| tgt_tokenizer, | |
| scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0], | |
| n_samples=50, | |
| ): | |
| def cer(pred, ref): | |
| if not ref: | |
| return 1.0 | |
| dp = list(range(len(ref) + 1)) | |
| for i in range(1, len(pred) + 1): | |
| prev, dp[0] = dp[0], i | |
| for j in range(1, len(ref) + 1): | |
| temp = dp[j] | |
| dp[j] = prev if pred[i-1] == ref[j-1] else 1 + min(prev, dp[j], dp[j-1]) | |
| prev = temp | |
| return dp[-1] / max(len(ref), 1) | |
| results = {} | |
| for scale in scales: | |
| cer_list = [] | |
| outputs = [] | |
| for src, ref in zip(src_list[:n_samples], ref_list[:n_samples]): | |
| if src.dim() == 1: | |
| src = src.unsqueeze(0) | |
| out = generate_guided(model, src, classifier, scale) | |
| ids = [x for x in out[0].tolist() if x > 4] | |
| pred = tgt_tokenizer.decode(ids).strip() | |
| cer_list.append(cer(pred, ref)) | |
| outputs.append(pred) | |
| results[scale] = { | |
| "CER": float(np.mean(cer_list)), | |
| "diversity": len(set(outputs)) / len(outputs) | |
| } | |
| print(f"Ξ»={scale:.1f} | CER={results[scale]['CER']:.4f} | diversity={results[scale]['diversity']:.3f}") | |
| return results |