import argparse import json import torch import torch.nn as nn import torch.nn.functional as F from torch.optim import AdamW try: from dataset import create_dataloader from tokenizer import MAX_SEQ_LENGTH except ImportError: MAX_SEQ_LENGTH = 64 create_dataloader = None import math import copy # ───────────────────────────────────────────────────────────── # Timestep Embedding — tells the model what diffusion stage it's at # ───────────────────────────────────────────────────────────── class TimestepEmbedding(nn.Module): """ Maps scalar timestep t ∈ [0, 1] to a d_model-dimensional vector. Uses sinusoidal encoding → 2-layer MLP (like DiT / DDPM). This is THE critical missing piece from the old architecture. Without this, the model treats step 1 (90% masked) the same as step 40 (fully revealed), making iterative refinement impossible. """ def __init__(self, d_model, max_period=10000): super().__init__() self.d_model = d_model self.max_period = max_period # MLP: sinusoidal features → hidden → output # The MLP learns to transform raw sinusoidal features into # useful scale/shift parameters for AdaLN self.mlp = nn.Sequential( nn.Linear(d_model, d_model * 4), nn.GELU(), nn.Linear(d_model * 4, d_model * 4), nn.GELU(), ) def forward(self, t): """ Args: t: [batch_size] tensor of timestep values in [0, 1] Returns: [batch_size, d_model * 4] conditioning vector """ # Sinusoidal encoding of scalar timestep half = self.d_model // 2 freqs = torch.exp( -math.log(self.max_period) * torch.arange(half, device=t.device, dtype=torch.float32) / half ) # t: [B] → [B, 1] * [1, half] → [B, half] args = t.unsqueeze(-1).float() * freqs.unsqueeze(0) embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # [B, d_model] return self.mlp(embedding) # [B, d_model * 4] # ───────────────────────────────────────────────────────────── # AdaLN Transformer Layer — timestep-conditioned normalization # ───────────────────────────────────────────────────────────── class AdaLNTransformerLayer(nn.Module): """ Transformer encoder layer with Adaptive Layer Normalization (AdaLN). Instead of standard LayerNorm, we use: AdaLN(x, t) = γ_t * LayerNorm(x) + β_t where γ_t and β_t are predicted from the timestep embedding. This is how DiT (Diffusion Transformer) conditions on timestep, and it's the standard approach in MDLM papers. """ def __init__(self, d_model, nhead, dim_feedforward, dropout=0.2): super().__init__() # Self-attention self.self_attn = nn.MultiheadAttention( embed_dim=d_model, num_heads=nhead, dropout=dropout, batch_first=True ) # Feedforward self.ff = nn.Sequential( nn.Linear(d_model, dim_feedforward), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim_feedforward, d_model), nn.Dropout(dropout), ) # Layer norms (will be modulated by timestep) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) # Dropout for residual connections self.dropout = nn.Dropout(dropout) # AdaLN modulation: timestep_emb → (scale1, shift1, scale2, shift2) # 4 vectors of d_model each = 4 * d_model from the conditioning self.adaLN_modulation = nn.Sequential( nn.GELU(), nn.Linear(d_model * 4, d_model * 4), ) # Zero-init the modulation output so model starts as standard transformer # (1 + 0) * LayerNorm(x) + 0 = LayerNorm(x) at initialization # This is critical for stable early training (from DiT paper) nn.init.zeros_(self.adaLN_modulation[-1].weight) nn.init.zeros_(self.adaLN_modulation[-1].bias) def forward(self, x, t_emb, src_key_padding_mask=None): """ Args: x: [B, seq_len, d_model] t_emb: [B, d_model * 4] from TimestepEmbedding src_key_padding_mask: [B, seq_len] True for padding positions """ # Predict modulation parameters from timestep mod = self.adaLN_modulation(t_emb) # [B, d_model * 4] scale1, shift1, scale2, shift2 = mod.chunk(4, dim=-1) # each [B, d_model] # Unsqueeze for broadcasting: [B, d_model] → [B, 1, d_model] scale1 = scale1.unsqueeze(1) shift1 = shift1.unsqueeze(1) scale2 = scale2.unsqueeze(1) shift2 = shift2.unsqueeze(1) # Pre-norm self-attention with AdaLN normed = self.norm1(x) normed = (1 + scale1) * normed + shift1 # AdaLN modulation attn_out, _ = self.self_attn( normed, normed, normed, key_padding_mask=src_key_padding_mask ) x = x + self.dropout(attn_out) # Pre-norm feedforward with AdaLN normed = self.norm2(x) normed = (1 + scale2) * normed + shift2 # AdaLN modulation x = x + self.ff(normed) return x # ───────────────────────────────────────────────────────────── # Positional Encoding (unchanged) # ───────────────────────────────────────────────────────────── class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe.unsqueeze(0)) def forward(self, x): return x + self.pe[:, :x.size(1), :] # ───────────────────────────────────────────────────────────── # Main Model — Timestep-Conditioned Masked Diffusion # ───────────────────────────────────────────────────────────── class MaskedDiffusionModel(nn.Module): """ Key differences from old model: 1. Accepts timestep t as input → AdaLN conditioning 2. Smaller: d_model=192, nhead=6, num_layers=4 (~4M params) 3. Higher dropout (0.2) throughout 4. Final LayerNorm before output projection """ def __init__(self, vocab_size, d_model=192, nhead=6, num_layers=4, max_seq_len=MAX_SEQ_LENGTH, dropout=0.2): super().__init__() self.d_model = d_model # Token embedding + positional encoding self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0) self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len) self.embed_dropout = nn.Dropout(dropout) # Timestep conditioning self.time_embed = TimestepEmbedding(d_model) # AdaLN transformer layers self.layers = nn.ModuleList([ AdaLNTransformerLayer( d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, dropout=dropout, ) for _ in range(num_layers) ]) # Final norm and output projection self.final_norm = nn.LayerNorm(d_model) self.fc_out = nn.Linear(d_model, vocab_size) # Initialize weights self._init_weights() def _init_weights(self): """Xavier uniform init for better gradient flow.""" for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) def forward(self, x, t, src_key_padding_mask=None): """ Args: x: [B, seq_len] token IDs t: [B] masking rates in [0, 1] — THE NEW CRITICAL INPUT src_key_padding_mask: [B, seq_len] True for padding """ # Embed tokens + positions h = self.embedding(x) h = self.pos_encoder(h) h = self.embed_dropout(h) # Get timestep conditioning vector t_emb = self.time_embed(t) # [B, d_model * 4] # Pass through AdaLN transformer layers for layer in self.layers: h = layer(h, t_emb, src_key_padding_mask=src_key_padding_mask) # Final norm + project to vocab h = self.final_norm(h) return self.fc_out(h) # ───────────────────────────────────────────────────────────── # EMA (Exponential Moving Average) for stable inference # ───────────────────────────────────────────────────────────── class EMA: """ Maintains an exponential moving average of model parameters. During training, weights fluctuate. EMA smooths them out: θ_ema = decay * θ_ema + (1 - decay) * θ_model Use EMA weights for validation and inference — they generalize better. This is standard practice in all diffusion model papers. """ def __init__(self, model, decay=0.999): self.decay = decay self.shadow = copy.deepcopy(model.state_dict()) def update(self, model): with torch.no_grad(): model_state = model.state_dict() for key in self.shadow: self.shadow[key] = ( self.decay * self.shadow[key] + (1 - self.decay) * model_state[key] ) def apply(self, model): """Load EMA weights into model (for validation/inference).""" model.load_state_dict(self.shadow) def state_dict(self): return self.shadow def parse_args(): parser = argparse.ArgumentParser(description='Train the masked diffusion dialogue model.') parser.add_argument('--tokenized-file', type=str, default='tokenized_data.json', help='Tokenized JSON dataset file') parser.add_argument('--batch-size', type=int, default=32, help='Training batch size') parser.add_argument('--epochs', type=int, default=150, help='Number of training epochs') parser.add_argument('--base-lr', type=float, default=1e-4, help='Base learning rate') parser.add_argument('--warmup-epochs', type=int, default=5, help='Number of warmup epochs') parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing value') parser.add_argument('--dropout', type=float, default=0.2, help='Dropout probability') parser.add_argument('--weight-decay', type=float, default=0.01, help='AdamW weight decay') parser.add_argument('--patience', type=int, default=20, help='Early stopping patience') parser.add_argument('--val-split', type=float, default=0.1, help='Validation split fraction') return parser.parse_args() # ───────────────────────────────────────────────────────────── # Guided Masking — returns per-sample masking rates # ───────────────────────────────────────────────────────────── def find_bot_start(sequence, bot_token_ids): """ Find the position where 'bot' token appears. We only mask tokens AFTER this so the model learns: given user prompt → predict bot reply. """ seq = sequence.tolist() for bid in bot_token_ids: for i in range(len(seq)): if seq[i] == bid: return i + 1 # Start masking after 'bot' token return len(seq) // 2 # Fallback: mask second half if 'bot' not found def apply_guided_masking(x_0, mask_id, pad_id, special_ids, bot_token_ids): """ Per-sample masking with returned t values for timestep conditioning. Returns: x_t: masked sequences is_mask: boolean mask of which positions were masked t_values: [batch_size] tensor of masking rates (for timestep conditioning) """ batch_size, seq_len = x_0.shape device = x_0.device x_t = x_0.clone() is_mask = torch.zeros_like(x_0, dtype=torch.bool) t_values = torch.zeros(batch_size, device=device) for b in range(batch_size): # Per-sample masking rate t = torch.rand(1).item() t = max(t, 0.05) # Minimum 5% masking (matches inference range) t_values[b] = t bot_start = find_bot_start(x_0[b], bot_token_ids) for pos in range(bot_start, seq_len): tok = x_0[b, pos].item() if tok in special_ids or tok == pad_id: continue if torch.rand(1).item() < t: x_t[b, pos] = mask_id is_mask[b, pos] = True return x_t, is_mask, t_values def get_bot_token_ids(vocab): """ Find the token IDs for 'bot' in the BPE vocabulary. BPE may represent it as 'bot', 'bot:', or with a space prefix. """ candidates = ['bot', 'bot:', 'Ġbot', ' bot'] ids = [vocab[t] for t in candidates if t in vocab] return ids def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) # ───────────────────────────────────────────────────────────── # Validation Loop — uses EMA weights # ───────────────────────────────────────────────────────────── def validate(model, val_loader, vocab_size, mask_id, pad_id, special_ids, bot_token_ids, device, label_smoothing=0.1): """Run validation and return average masked cross-entropy loss.""" model.eval() total_loss = 0 total_masked = 0 valid_batches = 0 criterion = nn.CrossEntropyLoss( reduction='none', ignore_index=pad_id, label_smoothing=label_smoothing ) with torch.no_grad(): for x_0 in val_loader: x_0 = x_0.to(device) pad_mask = (x_0 == pad_id) x_t, is_mask, t_values = apply_guided_masking( x_0, mask_id, pad_id, special_ids, bot_token_ids ) if is_mask.sum() == 0: continue # Pass timestep to model logits = model(x_t, t_values, src_key_padding_mask=pad_mask) loss_per_token = criterion( logits.view(-1, vocab_size), x_0.view(-1), ).view_as(x_0) masked_loss = (loss_per_token * is_mask.float()).sum() / (is_mask.sum() + 1e-8) total_loss += masked_loss.item() total_masked += is_mask.sum().item() valid_batches += 1 model.train() return total_loss / max(valid_batches, 1) # ───────────────────────────────────────────────────────────── # Training Loop # ───────────────────────────────────────────────────────────── def train_model( tokenized_file='tokenized_data.json', batch_size=32, epochs=150, base_lr=1e-4, warmup_epochs=5, label_smoothing=0.1, dropout=0.2, weight_decay=0.01, max_patience=20, val_split=0.1, ): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Training on: {device}") with open("subword_tokenizer.json", "r", encoding="utf-8") as f: vocab_data = json.load(f) vocab = vocab_data["model"]["vocab"] vocab_size = len(vocab) mask_id = vocab["[MASK]"] pad_id = vocab["[PAD]"] special_ids = {vocab["[PAD]"], vocab["[BOS]"], vocab["[EOS]"], vocab["[UNK]"], vocab["[MASK]"]} bot_token_ids = get_bot_token_ids(vocab) print(f"'bot' token IDs: {bot_token_ids}") if not bot_token_ids: print("WARNING: 'bot' not found in vocab — falling back to half-masking.") # ── Model (right-sized for quality on RTX 5070: ~8M params) ──── model = MaskedDiffusionModel( vocab_size=vocab_size, d_model=256, nhead=8, num_layers=6, max_seq_len=MAX_SEQ_LENGTH, dropout=dropout, ).to(device) total_params = count_parameters(model) print(f"Parameters: {total_params:,}") # ── EMA ──────────────────────────────────────────────── ema = EMA(model, decay=0.999) # ── Data ─────────────────────────────────────────────── train_loader, val_loader, total = create_dataloader( tokenized_file, batch_size=batch_size, val_split=val_split ) # ── Optimizer + Schedule ───────────────────────────────── warmup_steps = len(train_loader) * warmup_epochs global_step = 0 optimizer = AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=max(1, epochs - warmup_epochs), eta_min=1e-5 ) # Label smoothing cross-entropy criterion = nn.CrossEntropyLoss( reduction='none', ignore_index=pad_id, label_smoothing=label_smoothing ) best_val_loss = float('inf') best_train_loss = float('inf') patience = 0 print(f"\nTraining for {epochs} epochs (early stop after {max_patience} epochs no improvement)") print(f"Label smoothing: {label_smoothing} | Weight decay: {weight_decay} | Dropout: {dropout}") print(f"EMA decay: 0.999") print(f"{'='*70}\n") for epoch in range(epochs): model.train() total_loss = 0 total_masked = 0 valid_batches = 0 for x_0 in train_loader: x_0 = x_0.to(device) # Linear warmup for first warmup_epochs if global_step < warmup_steps: lr = base_lr * (global_step + 1) / warmup_steps for g in optimizer.param_groups: g['lr'] = lr pad_mask = (x_0 == pad_id) optimizer.zero_grad() # Get masked input + per-sample masking rates x_t, is_mask, t_values = apply_guided_masking( x_0, mask_id, pad_id, special_ids, bot_token_ids ) if is_mask.sum() == 0: continue # Skip batches where nothing was masked # Forward pass WITH timestep conditioning logits = model(x_t, t_values, src_key_padding_mask=pad_mask) loss_per_token = criterion( logits.view(-1, vocab_size), x_0.view(-1), ).view_as(x_0) # Loss only on masked positions masked_loss = (loss_per_token * is_mask.float()).sum() / (is_mask.sum() + 1e-8) masked_loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() # Update EMA after each optimizer step ema.update(model) total_loss += masked_loss.item() total_masked += is_mask.sum().item() valid_batches += 1 global_step += 1 if epoch >= warmup_epochs: scheduler.step() avg_train_loss = total_loss / max(valid_batches, 1) # ── Validation every 5 epochs (using EMA weights) ── if (epoch + 1) % 5 == 0 or epoch == 0: # Save current weights, apply EMA, validate, restore original_state = copy.deepcopy(model.state_dict()) ema.apply(model) val_loss = validate( model, val_loader, vocab_size, mask_id, pad_id, special_ids, bot_token_ids, device, label_smoothing ) # Restore training weights model.load_state_dict(original_state) lr_now = optimizer.param_groups[0]['lr'] print( f"Epoch {epoch+1:>4}/{epochs} | " f"Train Loss: {avg_train_loss:.4f} | " f"Val Loss: {val_loss:.4f} | " f"Masked: {int(total_masked):>8} | " f"LR: {lr_now:.6f}" ) # Save best model (EMA weights) based on validation loss if val_loss < best_val_loss: best_val_loss = val_loss # Save EMA weights for inference torch.save(ema.state_dict(), "diffusion_model_best.pth") patience = 0 print(f" → New best val loss! Saved EMA checkpoint.") else: patience += 1 # Early stopping if patience >= max_patience: print(f"\nEarly stopping at epoch {epoch+1} (no improvement for {max_patience} epochs)") break # Also track best training loss if 0 < avg_train_loss < best_train_loss: best_train_loss = avg_train_loss # Save final EMA checkpoint torch.save(ema.state_dict(), "diffusion_model.pth") print(f"\n{'='*70}") print(f"Done! Best train loss: {best_train_loss:.4f} | Best val loss: {best_val_loss:.4f}") print("Use diffusion_model_best.pth for inference (EMA weights).") if __name__ == "__main__": args = parse_args() train_model( tokenized_file=args.tokenized_file, batch_size=args.batch_size, epochs=args.epochs, base_lr=args.base_lr, warmup_epochs=args.warmup_epochs, label_smoothing=args.label_smoothing, dropout=args.dropout, weight_decay=args.weight_decay, max_patience=args.patience, val_split=args.val_split, )