DevaFlow-space / analysis /quality_classifier.py
bhsinghgrid's picture
Upgrade UI: model selection + tasks 1-5 + analysis modules
29e5bf8 verified
# """
# analysis/quality_classifier.py
# ================================
# Task 5: Classifier-Free Guidance for Paraphrase Quality Control
#
# Two steps β€” only Step 2 requires training a SMALL model (not the main D3PM):
#
# STEP 1 β€” Collect training data (no training):
# Run existing model on val set, record (hidden_state, CER) pairs.
# Hidden states come from model.model._last_hidden after forward_cached().
# CER score = quality label (lower CER = higher quality).
#
# STEP 2 β€” Train quality classifier:
# Small 2-layer MLP: d_model β†’ 64 β†’ 1
# Input: pooled decoder hidden state [B, d_model]
# Output: predicted quality score in [0, 1] (1 = high quality)
# Loss: MSE against normalized CER labels
# Training time: ~5-10 minutes on CPU for 10k examples
#
# STEP 3 β€” Guided inference (no retraining):
# At each diffusion step, use classifier gradient to shift logits:
# guided_logits = logits + Ξ» * βˆ‚(quality_score)/βˆ‚(logits)
# Higher Ξ» β†’ model biased toward high-quality outputs
# Ξ»=0 β†’ standard generation (no guidance)
#
# Key: main D3PM model is FROZEN throughout. Only the 10k-param classifier trains.
# """
#
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# import numpy as np
# import os
# import json
# from typing import List, Dict, Optional, Tuple
#
#
# # ── Quality classifier architecture ──────────────────────────────────
#
# class QualityClassifier(nn.Module):
# """
# Lightweight MLP that predicts transliteration quality from decoder
# hidden states.
#
# Architecture:
# d_model β†’ 128 β†’ 64 β†’ 1 β†’ Sigmoid
#
# Input: mean-pooled decoder hidden state [B, d_model]
# Output: quality score [B, 1] ∈ [0, 1] (1 = high quality)
#
# ~10k parameters. Trains in minutes on CPU.
# """
# def __init__(self, d_model: int):
# super().__init__()
# self.net = nn.Sequential(
# nn.Linear(d_model, 128),
# nn.ReLU(),
# nn.Dropout(0.1),
# nn.Linear(128, 64),
# nn.ReLU(),
# nn.Linear(64, 1),
# nn.Sigmoid(),
# )
# self.d_model = d_model
#
# def forward(self, hidden: torch.Tensor) -> torch.Tensor:
# """
# Args:
# hidden : [B, tgt_len, d_model] OR [B, d_model] (already pooled)
#
# Returns:
# score : [B, 1] quality score in [0, 1]
# """
# if hidden.dim() == 3:
# # Pool over sequence length
# hidden = hidden.mean(dim=1) # [B, d_model]
# return self.net(hidden) # [B, 1]
#
#
# # ── Training data collection ──────────────────────────────────────────
#
# @torch.no_grad()
# def collect_quality_data(
# model,
# src_list: List[torch.Tensor],
# ref_list: List[str],
# tgt_tokenizer,
# t_capture: int = 0,
# temperature: float = 0.8,
# top_k: int = 40,
# max_samples: int = 5000,
# ) -> Tuple[np.ndarray, np.ndarray]:
# """
# Collect (hidden_state, quality_score) pairs for classifier training.
#
# For each sample:
# 1. Run generate_cached() on src
# 2. Capture decoder hidden state at t=t_capture
# 3. Compute CER between output and reference
# 4. Quality = 1 - CER (normalize to [0,1])
#
# Args:
# model : SanskritModel
# src_list : list of [1, src_len] tensors
# ref_list : list of reference Devanagari strings
# tgt_tokenizer : SanskritTargetTokenizer
# t_capture : which step to capture hidden states (0 = final)
# max_samples : cap number of training examples
#
# Returns:
# hidden_matrix : np.ndarray [N, d_model]
# quality_scores: np.ndarray [N] values in [0, 1]
# """
# inner = model.model
# T = inner.scheduler.num_timesteps
# device = next(inner.parameters()).device
#
# hidden_list = []
# quality_list = []
# n = min(len(src_list), max_samples)
#
# def cer(pred, ref):
# if not ref:
# return 1.0
# def ed(s1, s2):
# m, n = len(s1), len(s2)
# dp = list(range(n + 1))
# for i in range(1, m + 1):
# prev, dp[0] = dp[0], i
# for j in range(1, n + 1):
# temp = dp[j]
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
# prev = temp
# return dp[n]
# return ed(pred, ref) / max(len(ref), 1)
#
# print(f"Collecting quality data from {n} examples...")
# for i, (src, ref) in enumerate(zip(src_list[:n], ref_list[:n])):
# if i % 200 == 0:
# print(f" {i}/{n}")
#
# if src.dim() == 1:
# src = src.unsqueeze(0)
# src = src.to(device)
#
# B = src.shape[0]
# tgt_len = inner.max_seq_len
# mask_id = inner.mask_token_id
#
# memory, src_pad_mask = inner.encode_source(src)
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
# hint = None
# h_cap = None
#
# for t_val in range(T - 1, -1, -1):
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
# is_last = (t_val == 0)
#
# logits, _ = inner.forward_cached(
# memory, src_pad_mask, x0_est, t,
# x0_hint=hint, inference_mode=True,
# )
#
# if t_val == t_capture and hasattr(inner, '_last_hidden'):
# h_cap = inner._last_hidden[0].mean(dim=0).detach().cpu() # [d_model]
#
# logits = logits / max(temperature, 1e-8)
# if top_k > 0:
# V = logits.shape[-1]
# if top_k < V:
# vals, _ = torch.topk(logits, top_k, dim=-1)
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
#
# probs = F.softmax(logits, dim=-1)
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample(probs)
# hint = x0_est
#
# if h_cap is None:
# continue
#
# ids = [x for x in x0_est[0].tolist() if x > 4]
# pred = tgt_tokenizer.decode(ids).strip()
# q = max(0.0, 1.0 - cer(pred, ref)) # quality = 1 - CER
#
# hidden_list.append(h_cap.numpy())
# quality_list.append(q)
#
# print(f"Collected {len(hidden_list)} quality examples.")
# print(f"Quality stats: mean={np.mean(quality_list):.3f} "
# f"min={np.min(quality_list):.3f} max={np.max(quality_list):.3f}")
#
# return np.stack(hidden_list), np.array(quality_list, dtype=np.float32)
#
#
# def _sample(probs):
# B, L, V = probs.shape
# flat = probs.view(B * L, V).clamp(min=1e-9)
# flat = flat / flat.sum(dim=-1, keepdim=True)
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
#
#
# # ── Training ──────────────────────────────────────────────────────────
#
# def train_quality_classifier(
# hidden_matrix: np.ndarray,
# quality_scores: np.ndarray,
# d_model: int,
# epochs: int = 30,
# batch_size: int = 64,
# lr: float = 1e-3,
# val_frac: float = 0.1,
# save_path: Optional[str] = None,
# ) -> QualityClassifier:
# """
# Train QualityClassifier on collected (hidden, quality) pairs.
#
# Args:
# hidden_matrix : [N, d_model] from collect_quality_data()
# quality_scores : [N] quality labels in [0, 1]
# d_model : hidden dimension
# epochs : training epochs
# save_path : if given, save trained classifier weights here
#
# Returns:
# trained QualityClassifier
# """
# device = torch.device("cpu") # classifier is tiny, CPU is fine
#
# X = torch.tensor(hidden_matrix, dtype=torch.float32)
# y = torch.tensor(quality_scores, dtype=torch.float32).unsqueeze(-1)
#
# N = len(X)
# n_val = max(1, int(N * val_frac))
# idx = torch.randperm(N)
# val_idx = idx[:n_val]
# train_idx = idx[n_val:]
#
# X_train, y_train = X[train_idx], y[train_idx]
# X_val, y_val = X[val_idx], y[val_idx]
#
# clf = QualityClassifier(d_model).to(device)
# optimizer = torch.optim.Adam(clf.parameters(), lr=lr)
#
# print(f"\nTraining QualityClassifier: {sum(p.numel() for p in clf.parameters())} params")
# print(f"Train: {len(X_train)} Val: {len(X_val)}")
#
# best_val_loss = float('inf')
# best_state = None
#
# for epoch in range(epochs):
# clf.train()
# perm = torch.randperm(len(X_train))
# train_loss = 0.0
# n_batches = 0
#
# for start in range(0, len(X_train), batch_size):
# batch_idx = perm[start:start + batch_size]
# xb, yb = X_train[batch_idx], y_train[batch_idx]
# pred = clf(xb)
# loss = F.mse_loss(pred, yb)
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()
# train_loss += loss.item()
# n_batches += 1
#
# clf.eval()
# with torch.no_grad():
# val_pred = clf(X_val)
# val_loss = F.mse_loss(val_pred, y_val).item()
#
# if epoch % 5 == 0 or epoch == epochs - 1:
# print(f" Ep {epoch+1:3d} train={train_loss/n_batches:.4f} val={val_loss:.4f}")
#
# if val_loss < best_val_loss:
# best_val_loss = val_loss
# best_state = {k: v.clone() for k, v in clf.state_dict().items()}
#
# if best_state:
# clf.load_state_dict(best_state)
# print(f" Best val loss: {best_val_loss:.4f}")
#
# if save_path:
# os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
# torch.save(clf.state_dict(), save_path)
# print(f" Classifier saved: {save_path}")
#
# return clf
#
#
# # ── Guided inference ──────────────────────────────────────────────────
#
# def generate_guided(
# model,
# src: torch.Tensor,
# classifier: QualityClassifier,
# guidance_scale: float = 1.0,
# temperature: float = 0.8,
# top_k: int = 40,
# ) -> torch.Tensor:
# """
# Classifier-guided generation.
#
# At each diffusion step:
# 1. Run forward_cached() β†’ logits, hidden states
# 2. Compute classifier gradient: βˆ‚(quality_score) / βˆ‚(hidden)
# 3. Project gradient back to logit space (approximate)
# 4. guided_logits = logits + Ξ» * gradient_signal
# 5. Sample from guided_logits
#
# guidance_scale Ξ»:
# 0.0 β†’ no guidance (standard generation)
# 0.5 β†’ weak guidance
# 1.0 β†’ moderate guidance (recommended starting point)
# 2.0 β†’ strong guidance (may reduce diversity)
# 3.0 β†’ very strong (may collapse to repetitive output)
#
# Args:
# model : SanskritModel (frozen)
# src : [1, src_len] IAST token ids
# classifier : trained QualityClassifier
# guidance_scale : Ξ» β€” guidance strength
#
# Returns:
# x0_est : [1, tgt_len] generated token ids
# """
# inner = model.model
# T = inner.scheduler.num_timesteps
# device = next(inner.parameters()).device
# clf_device = next(classifier.parameters()).device
#
# if src.dim() == 1:
# src = src.unsqueeze(0)
# src = src.to(device)
#
# B = src.shape[0]
# tgt_len = inner.max_seq_len
# mask_id = inner.mask_token_id
#
# memory, src_pad_mask = inner.encode_source(src)
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
# hint = None
#
# inner.eval()
# classifier.eval()
#
# for t_val in range(T - 1, -1, -1):
# t = torch.full((B,), t_val, dtype=torch.long, device=device)
# is_last = (t_val == 0)
#
# if guidance_scale > 0.0:
# # Need gradients for classifier guidance
# with torch.enable_grad():
# # Run forward_cached and get hidden states
# PAD = 1
# if t_val > 0:
# _, x_t_ids = inner.forward_process.q_sample(x0_est, t)
# else:
# x_t_ids = x0_est
#
# x = inner.tgt_embed(x_t_ids)
# t_norm = t.float() / T
# t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
# x = x + t_emb.unsqueeze(1)
#
# if hint is not None:
# hint_emb = inner.tgt_embed(hint)
# gate = inner.hint_gate(x)
# x = x + gate * hint_emb
#
# for block in inner.decoder_blocks:
# x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
#
# # hidden: [B, tgt_len, d_model] β€” detach from graph for clf
# hidden = x.detach().requires_grad_(True).to(clf_device)
#
# # Classifier quality score
# quality = classifier(hidden) # [B, 1]
# quality.sum().backward()
#
# # Gradient of quality w.r.t. hidden: [B, tgt_len, d_model]
# grad = hidden.grad.to(device) # [B, tgt_len, d_model]
#
# # Project gradient to logit space via output head weight
# # logit_grad β‰ˆ grad @ head.weight [B, tgt_len, tgt_vocab]
# logit_grad = grad @ inner.head.weight.T
#
# # Compute standard logits (no gradient needed)
# with torch.no_grad():
# logits = inner.head(x)
#
# # Apply guidance
# logits = logits + guidance_scale * logit_grad
#
# else:
# with torch.no_grad():
# logits, _ = inner.forward_cached(
# memory, src_pad_mask, x0_est, t,
# x0_hint=hint, inference_mode=True,
# )
#
# with torch.no_grad():
# logits = logits / max(temperature, 1e-8)
# if top_k > 0:
# V = logits.shape[-1]
# if top_k < V:
# vals, _ = torch.topk(logits, top_k, dim=-1)
# logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
#
# probs = F.softmax(logits, dim=-1)
# x0_est = torch.argmax(probs, dim=-1) if is_last else _sample_no_grad(probs)
# hint = x0_est
#
# return x0_est
#
#
# def _sample_no_grad(probs):
# B, L, V = probs.shape
# flat = probs.view(B * L, V).clamp(min=1e-9)
# flat = flat / flat.sum(dim=-1, keepdim=True)
# return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
#
#
# # ── Guidance scale sweep ──────────────────────────────────────────────
#
# def sweep_guidance_scales(
# model,
# classifier: QualityClassifier,
# src_list: List[torch.Tensor],
# ref_list: List[str],
# tgt_tokenizer,
# scales: List[float] = [0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
# n_samples: int = 50,
# device: torch.device = None,
# output_dir: str = "analysis/outputs",
# ) -> Dict:
# """
# Evaluate CER at each guidance scale.
# Produces quality-diversity tradeoff plot.
# """
# def cer(pred, ref):
# if not ref:
# return 1.0
# def ed(s1, s2):
# m, n = len(s1), len(s2)
# dp = list(range(n + 1))
# for i in range(1, m + 1):
# prev, dp[0] = dp[0], i
# for j in range(1, n + 1):
# temp = dp[j]
# dp[j] = prev if s1[i-1] == s2[j-1] else 1 + min(prev, dp[j], dp[j-1])
# prev = temp
# return dp[n]
# return ed(pred, ref) / max(len(ref), 1)
#
# device = device or next(model.parameters()).device
# results = {}
# n = min(n_samples, len(src_list))
#
# print("\nGuidance scale sweep...")
# for scale in scales:
# cer_list = []
# output_set = []
# for src, ref in zip(src_list[:n], ref_list[:n]):
# if src.dim() == 1:
# src = src.unsqueeze(0)
# out = generate_guided(model, src.to(device), classifier,
# guidance_scale=scale)
# ids = [x for x in out[0].tolist() if x > 4]
# pred = tgt_tokenizer.decode(ids).strip()
# cer_list.append(cer(pred, ref))
# output_set.append(pred)
#
# mean_cer = float(np.mean(cer_list))
#
# # Self-diversity: unique outputs / total (proxy for diversity)
# unique_frac = len(set(output_set)) / max(len(output_set), 1)
#
# results[scale] = {"mean_cer": mean_cer, "diversity": unique_frac}
# print(f" Ξ»={scale:.1f} CER={mean_cer:.4f} diversity={unique_frac:.3f}")
#
# # Plot
# os.makedirs(output_dir, exist_ok=True)
# try:
# import matplotlib.pyplot as plt
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
#
# sc_list = sorted(results.keys())
# cers = [results[s]["mean_cer"] for s in sc_list]
# diversities = [results[s]["diversity"] for s in sc_list]
#
# ax1.plot(sc_list, cers, 'o-', color='coral', linewidth=1.8, markersize=7)
# ax1.set_xlabel("Guidance scale Ξ»", fontsize=10)
# ax1.set_ylabel("CER (↓ better)", fontsize=10)
# ax1.set_title("Quality vs guidance scale", fontsize=10)
#
# ax2.plot(sc_list, diversities, 'o-', color='steelblue', linewidth=1.8, markersize=7)
# ax2.set_xlabel("Guidance scale Ξ»", fontsize=10)
# ax2.set_ylabel("Output diversity (unique fraction)", fontsize=10)
# ax2.set_title("Diversity vs guidance scale", fontsize=10)
#
# plt.suptitle("Quality-Diversity Tradeoff (Guidance Scale Sweep)", fontsize=11)
# plt.tight_layout()
# path = os.path.join(output_dir, "guidance_scale_sweep.png")
# plt.savefig(path, dpi=150, bbox_inches='tight')
# plt.close()
# print(f" Saved: {path}")
# except ImportError:
# pass
#
# with open(os.path.join(output_dir, "guidance_results.json"), "w") as f:
# json.dump({str(k): v for k, v in results.items()}, f, indent=2)
#
# return results
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Dict
# ============================================================
# 1. QUALITY CLASSIFIER
# ============================================================
class QualityClassifier(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(d_model, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid(),
)
def forward(self, hidden):
if hidden.dim() == 3:
hidden = hidden.mean(dim=1)
return self.net(hidden)
# ============================================================
# 2. GUIDED GENERATION (CORRECTED)
# ============================================================
@torch.no_grad()
def generate_guided(
model,
src: torch.Tensor,
classifier: QualityClassifier,
guidance_scale: float = 1.0,
temperature: float = 0.8,
top_k: int = 40,
):
inner = model.model
T = inner.scheduler.num_timesteps
device = next(inner.parameters()).device
if src.dim() == 1:
src = src.unsqueeze(0)
src = src.to(device)
B = src.shape[0]
tgt_len = inner.max_seq_len
mask_id = inner.mask_token_id
# KV CACHE
memory, src_pad_mask = inner.encode_source(src)
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
hint = None
inner.eval()
classifier.eval()
for t_val in range(T - 1, -1, -1):
t = torch.full((B,), t_val, dtype=torch.long, device=device)
is_last = (t_val == 0)
if guidance_scale > 0:
# ENABLE GRAD FOR GUIDANCE
with torch.enable_grad():
if t_val > 0:
_, x_t_ids = inner.forward_process.q_sample(x0_est, t)
else:
x_t_ids = x0_est
x = inner.tgt_embed(x_t_ids)
# time embedding
t_norm = t.float() / T
t_emb = inner.time_mlp(t_norm.unsqueeze(-1))
x = x + t_emb.unsqueeze(1)
# hint conditioning
if hint is not None:
hint_emb = inner.tgt_embed(hint)
gate = inner.hint_gate(x)
x = x + gate * hint_emb
# decoder forward
for block in inner.decoder_blocks:
x = block(x, memory, tgt_pad_mask=None, src_pad_mask=src_pad_mask)
# IMPORTANT: NO DETACH HERE
hidden = x.requires_grad_(True)
# classifier forward
quality = classifier(hidden) # [B,1]
# compute gradient
quality.sum().backward()
grad = hidden.grad # [B, L, d_model]
# ===== FIX 1: Normalize gradient =====
grad_norm = grad.norm(dim=-1, keepdim=True) + 1e-6
grad = grad / grad_norm
# ===== FIX 2: Project to logit space =====
logit_grad = torch.matmul(grad, inner.head.weight.T)
# ===== FIX 3: Clip gradient =====
logit_grad = torch.clamp(logit_grad, -5.0, 5.0)
# compute logits (no grad)
with torch.no_grad():
logits = inner.head(x)
# apply guidance
logits = logits + guidance_scale * logit_grad
else:
with torch.no_grad():
logits, _ = inner.forward_cached(
memory, src_pad_mask, x0_est, t,
x0_hint=hint,
inference_mode=True,
)
# ===== Sampling =====
logits = logits / max(temperature, 1e-8)
if top_k > 0:
V = logits.shape[-1]
if top_k < V:
vals, _ = torch.topk(logits, top_k, dim=-1)
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
probs = F.softmax(logits, dim=-1)
if is_last:
x0_est = torch.argmax(probs, dim=-1)
else:
x0_est = _sample(probs)
hint = x0_est
return x0_est
def _sample(probs):
B, L, V = probs.shape
flat = probs.view(B * L, V).clamp(min=1e-9)
flat = flat / flat.sum(dim=-1, keepdim=True)
return torch.multinomial(flat, 1).squeeze(-1).view(B, L)
# ============================================================
# 3. GUIDANCE SWEEP (EVALUATION)
# ============================================================
def sweep_guidance(
model,
classifier,
src_list,
ref_list,
tgt_tokenizer,
scales=[0.0, 0.5, 1.0, 1.5, 2.0, 3.0],
n_samples=50,
):
def cer(pred, ref):
if not ref:
return 1.0
dp = list(range(len(ref) + 1))
for i in range(1, len(pred) + 1):
prev, dp[0] = dp[0], i
for j in range(1, len(ref) + 1):
temp = dp[j]
dp[j] = prev if pred[i-1] == ref[j-1] else 1 + min(prev, dp[j], dp[j-1])
prev = temp
return dp[-1] / max(len(ref), 1)
results = {}
for scale in scales:
cer_list = []
outputs = []
for src, ref in zip(src_list[:n_samples], ref_list[:n_samples]):
if src.dim() == 1:
src = src.unsqueeze(0)
out = generate_guided(model, src, classifier, scale)
ids = [x for x in out[0].tolist() if x > 4]
pred = tgt_tokenizer.decode(ids).strip()
cer_list.append(cer(pred, ref))
outputs.append(pred)
results[scale] = {
"CER": float(np.mean(cer_list)),
"diversity": len(set(outputs)) / len(outputs)
}
print(f"Ξ»={scale:.1f} | CER={results[scale]['CER']:.4f} | diversity={results[scale]['diversity']:.3f}")
return results