| import argparse |
| import json |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim import AdamW |
| try: |
| from dataset import create_dataloader |
| from tokenizer import MAX_SEQ_LENGTH |
| except ImportError: |
| MAX_SEQ_LENGTH = 64 |
| create_dataloader = None |
| import math |
| import copy |
|
|
| |
| |
| |
|
|
| class TimestepEmbedding(nn.Module): |
| """ |
| Maps scalar timestep t β [0, 1] to a d_model-dimensional vector. |
| Uses sinusoidal encoding β 2-layer MLP (like DiT / DDPM). |
| |
| This is THE critical missing piece from the old architecture. |
| Without this, the model treats step 1 (90% masked) the same as |
| step 40 (fully revealed), making iterative refinement impossible. |
| """ |
| def __init__(self, d_model, max_period=10000): |
| super().__init__() |
| self.d_model = d_model |
| self.max_period = max_period |
|
|
| |
| |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(d_model, d_model * 4), |
| nn.GELU(), |
| nn.Linear(d_model * 4, d_model * 4), |
| nn.GELU(), |
| ) |
|
|
| def forward(self, t): |
| """ |
| Args: |
| t: [batch_size] tensor of timestep values in [0, 1] |
| Returns: |
| [batch_size, d_model * 4] conditioning vector |
| """ |
| |
| half = self.d_model // 2 |
| freqs = torch.exp( |
| -math.log(self.max_period) * torch.arange(half, device=t.device, dtype=torch.float32) / half |
| ) |
| |
| args = t.unsqueeze(-1).float() * freqs.unsqueeze(0) |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
| return self.mlp(embedding) |
|
|
|
|
| |
| |
| |
|
|
| class AdaLNTransformerLayer(nn.Module): |
| """ |
| Transformer encoder layer with Adaptive Layer Normalization (AdaLN). |
| |
| Instead of standard LayerNorm, we use: |
| AdaLN(x, t) = Ξ³_t * LayerNorm(x) + Ξ²_t |
| |
| where Ξ³_t and Ξ²_t are predicted from the timestep embedding. |
| This is how DiT (Diffusion Transformer) conditions on timestep, |
| and it's the standard approach in MDLM papers. |
| """ |
| def __init__(self, d_model, nhead, dim_feedforward, dropout=0.2): |
| super().__init__() |
|
|
| |
| self.self_attn = nn.MultiheadAttention( |
| embed_dim=d_model, num_heads=nhead, |
| dropout=dropout, batch_first=True |
| ) |
|
|
| |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, dim_feedforward), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(dim_feedforward, d_model), |
| nn.Dropout(dropout), |
| ) |
|
|
| |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
|
|
| |
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| |
| self.adaLN_modulation = nn.Sequential( |
| nn.GELU(), |
| nn.Linear(d_model * 4, d_model * 4), |
| ) |
|
|
| |
| |
| |
| nn.init.zeros_(self.adaLN_modulation[-1].weight) |
| nn.init.zeros_(self.adaLN_modulation[-1].bias) |
|
|
| def forward(self, x, t_emb, src_key_padding_mask=None): |
| """ |
| Args: |
| x: [B, seq_len, d_model] |
| t_emb: [B, d_model * 4] from TimestepEmbedding |
| src_key_padding_mask: [B, seq_len] True for padding positions |
| """ |
| |
| mod = self.adaLN_modulation(t_emb) |
| scale1, shift1, scale2, shift2 = mod.chunk(4, dim=-1) |
|
|
| |
| scale1 = scale1.unsqueeze(1) |
| shift1 = shift1.unsqueeze(1) |
| scale2 = scale2.unsqueeze(1) |
| shift2 = shift2.unsqueeze(1) |
|
|
| |
| normed = self.norm1(x) |
| normed = (1 + scale1) * normed + shift1 |
|
|
| attn_out, _ = self.self_attn( |
| normed, normed, normed, |
| key_padding_mask=src_key_padding_mask |
| ) |
| x = x + self.dropout(attn_out) |
|
|
| |
| normed = self.norm2(x) |
| normed = (1 + scale2) * normed + shift2 |
|
|
| x = x + self.ff(normed) |
|
|
| return x |
|
|
|
|
| |
| |
| |
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, max_len=5000): |
| super().__init__() |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| self.register_buffer('pe', pe.unsqueeze(0)) |
|
|
| def forward(self, x): |
| return x + self.pe[:, :x.size(1), :] |
|
|
|
|
| |
| |
| |
|
|
| class MaskedDiffusionModel(nn.Module): |
| """ |
| Key differences from old model: |
| 1. Accepts timestep t as input β AdaLN conditioning |
| 2. Smaller: d_model=192, nhead=6, num_layers=4 (~4M params) |
| 3. Higher dropout (0.2) throughout |
| 4. Final LayerNorm before output projection |
| """ |
| def __init__(self, vocab_size, d_model=192, nhead=6, num_layers=4, |
| max_seq_len=MAX_SEQ_LENGTH, dropout=0.2): |
| super().__init__() |
| self.d_model = d_model |
|
|
| |
| self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0) |
| self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len) |
| self.embed_dropout = nn.Dropout(dropout) |
|
|
| |
| self.time_embed = TimestepEmbedding(d_model) |
|
|
| |
| self.layers = nn.ModuleList([ |
| AdaLNTransformerLayer( |
| d_model=d_model, |
| nhead=nhead, |
| dim_feedforward=d_model * 4, |
| dropout=dropout, |
| ) |
| for _ in range(num_layers) |
| ]) |
|
|
| |
| self.final_norm = nn.LayerNorm(d_model) |
| self.fc_out = nn.Linear(d_model, vocab_size) |
|
|
| |
| self._init_weights() |
|
|
| def _init_weights(self): |
| """Xavier uniform init for better gradient flow.""" |
| for p in self.parameters(): |
| if p.dim() > 1: |
| nn.init.xavier_uniform_(p) |
|
|
| def forward(self, x, t, src_key_padding_mask=None): |
| """ |
| Args: |
| x: [B, seq_len] token IDs |
| t: [B] masking rates in [0, 1] β THE NEW CRITICAL INPUT |
| src_key_padding_mask: [B, seq_len] True for padding |
| """ |
| |
| h = self.embedding(x) |
| h = self.pos_encoder(h) |
| h = self.embed_dropout(h) |
|
|
| |
| t_emb = self.time_embed(t) |
|
|
| |
| for layer in self.layers: |
| h = layer(h, t_emb, src_key_padding_mask=src_key_padding_mask) |
|
|
| |
| h = self.final_norm(h) |
| return self.fc_out(h) |
|
|
|
|
| |
| |
| |
|
|
| class EMA: |
| """ |
| Maintains an exponential moving average of model parameters. |
| |
| During training, weights fluctuate. EMA smooths them out: |
| ΞΈ_ema = decay * ΞΈ_ema + (1 - decay) * ΞΈ_model |
| |
| Use EMA weights for validation and inference β they generalize better. |
| This is standard practice in all diffusion model papers. |
| """ |
| def __init__(self, model, decay=0.999): |
| self.decay = decay |
| self.shadow = copy.deepcopy(model.state_dict()) |
|
|
| def update(self, model): |
| with torch.no_grad(): |
| model_state = model.state_dict() |
| for key in self.shadow: |
| self.shadow[key] = ( |
| self.decay * self.shadow[key] + |
| (1 - self.decay) * model_state[key] |
| ) |
|
|
| def apply(self, model): |
| """Load EMA weights into model (for validation/inference).""" |
| model.load_state_dict(self.shadow) |
|
|
| def state_dict(self): |
| return self.shadow |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Train the masked diffusion dialogue model.') |
| parser.add_argument('--tokenized-file', type=str, default='tokenized_data.json', help='Tokenized JSON dataset file') |
| parser.add_argument('--batch-size', type=int, default=32, help='Training batch size') |
| parser.add_argument('--epochs', type=int, default=150, help='Number of training epochs') |
| parser.add_argument('--base-lr', type=float, default=1e-4, help='Base learning rate') |
| parser.add_argument('--warmup-epochs', type=int, default=5, help='Number of warmup epochs') |
| parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing value') |
| parser.add_argument('--dropout', type=float, default=0.2, help='Dropout probability') |
| parser.add_argument('--weight-decay', type=float, default=0.01, help='AdamW weight decay') |
| parser.add_argument('--patience', type=int, default=20, help='Early stopping patience') |
| parser.add_argument('--val-split', type=float, default=0.1, help='Validation split fraction') |
| return parser.parse_args() |
|
|
|
|
| |
| |
| |
|
|
| def find_bot_start(sequence, bot_token_ids): |
| """ |
| Find the position where 'bot' token appears. |
| We only mask tokens AFTER this so the model learns: |
| given user prompt β predict bot reply. |
| """ |
| seq = sequence.tolist() |
| for bid in bot_token_ids: |
| for i in range(len(seq)): |
| if seq[i] == bid: |
| return i + 1 |
| return len(seq) // 2 |
|
|
|
|
| def apply_guided_masking(x_0, mask_id, pad_id, special_ids, bot_token_ids): |
| """ |
| Per-sample masking with returned t values for timestep conditioning. |
| |
| Returns: |
| x_t: masked sequences |
| is_mask: boolean mask of which positions were masked |
| t_values: [batch_size] tensor of masking rates (for timestep conditioning) |
| """ |
| batch_size, seq_len = x_0.shape |
| device = x_0.device |
|
|
| x_t = x_0.clone() |
| is_mask = torch.zeros_like(x_0, dtype=torch.bool) |
| t_values = torch.zeros(batch_size, device=device) |
|
|
| for b in range(batch_size): |
| |
| t = torch.rand(1).item() |
| t = max(t, 0.05) |
| t_values[b] = t |
|
|
| bot_start = find_bot_start(x_0[b], bot_token_ids) |
| for pos in range(bot_start, seq_len): |
| tok = x_0[b, pos].item() |
| if tok in special_ids or tok == pad_id: |
| continue |
| if torch.rand(1).item() < t: |
| x_t[b, pos] = mask_id |
| is_mask[b, pos] = True |
|
|
| return x_t, is_mask, t_values |
|
|
|
|
| def get_bot_token_ids(vocab): |
| """ |
| Find the token IDs for 'bot' in the BPE vocabulary. |
| BPE may represent it as 'bot', 'bot:', or with a space prefix. |
| """ |
| candidates = ['bot', 'bot:', 'Δ bot', ' bot'] |
| ids = [vocab[t] for t in candidates if t in vocab] |
| return ids |
|
|
|
|
| def count_parameters(model): |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
| |
| |
| |
|
|
| def validate(model, val_loader, vocab_size, mask_id, pad_id, special_ids, |
| bot_token_ids, device, label_smoothing=0.1): |
| """Run validation and return average masked cross-entropy loss.""" |
| model.eval() |
| total_loss = 0 |
| total_masked = 0 |
| valid_batches = 0 |
|
|
| criterion = nn.CrossEntropyLoss( |
| reduction='none', ignore_index=pad_id, |
| label_smoothing=label_smoothing |
| ) |
|
|
| with torch.no_grad(): |
| for x_0 in val_loader: |
| x_0 = x_0.to(device) |
| pad_mask = (x_0 == pad_id) |
|
|
| x_t, is_mask, t_values = apply_guided_masking( |
| x_0, mask_id, pad_id, special_ids, bot_token_ids |
| ) |
|
|
| if is_mask.sum() == 0: |
| continue |
|
|
| |
| logits = model(x_t, t_values, src_key_padding_mask=pad_mask) |
|
|
| loss_per_token = criterion( |
| logits.view(-1, vocab_size), |
| x_0.view(-1), |
| ).view_as(x_0) |
|
|
| masked_loss = (loss_per_token * is_mask.float()).sum() / (is_mask.sum() + 1e-8) |
|
|
| total_loss += masked_loss.item() |
| total_masked += is_mask.sum().item() |
| valid_batches += 1 |
|
|
| model.train() |
| return total_loss / max(valid_batches, 1) |
|
|
|
|
| |
| |
| |
|
|
| def train_model( |
| tokenized_file='tokenized_data.json', |
| batch_size=32, |
| epochs=150, |
| base_lr=1e-4, |
| warmup_epochs=5, |
| label_smoothing=0.1, |
| dropout=0.2, |
| weight_decay=0.01, |
| max_patience=20, |
| val_split=0.1, |
| ): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Training on: {device}") |
|
|
| with open("subword_tokenizer.json", "r", encoding="utf-8") as f: |
| vocab_data = json.load(f) |
| vocab = vocab_data["model"]["vocab"] |
|
|
| vocab_size = len(vocab) |
| mask_id = vocab["[MASK]"] |
| pad_id = vocab["[PAD]"] |
| special_ids = {vocab["[PAD]"], vocab["[BOS]"], vocab["[EOS]"], vocab["[UNK]"], vocab["[MASK]"]} |
|
|
| bot_token_ids = get_bot_token_ids(vocab) |
| print(f"'bot' token IDs: {bot_token_ids}") |
| if not bot_token_ids: |
| print("WARNING: 'bot' not found in vocab β falling back to half-masking.") |
|
|
| |
| model = MaskedDiffusionModel( |
| vocab_size=vocab_size, |
| d_model=256, nhead=8, num_layers=6, |
| max_seq_len=MAX_SEQ_LENGTH, |
| dropout=dropout, |
| ).to(device) |
|
|
| total_params = count_parameters(model) |
| print(f"Parameters: {total_params:,}") |
|
|
| |
| ema = EMA(model, decay=0.999) |
|
|
| |
| train_loader, val_loader, total = create_dataloader( |
| tokenized_file, batch_size=batch_size, val_split=val_split |
| ) |
|
|
| |
| warmup_steps = len(train_loader) * warmup_epochs |
| global_step = 0 |
|
|
| optimizer = AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=max(1, epochs - warmup_epochs), eta_min=1e-5 |
| ) |
|
|
| |
| criterion = nn.CrossEntropyLoss( |
| reduction='none', ignore_index=pad_id, |
| label_smoothing=label_smoothing |
| ) |
|
|
| best_val_loss = float('inf') |
| best_train_loss = float('inf') |
| patience = 0 |
|
|
| print(f"\nTraining for {epochs} epochs (early stop after {max_patience} epochs no improvement)") |
| print(f"Label smoothing: {label_smoothing} | Weight decay: {weight_decay} | Dropout: {dropout}") |
| print(f"EMA decay: 0.999") |
| print(f"{'='*70}\n") |
|
|
| for epoch in range(epochs): |
| model.train() |
| total_loss = 0 |
| total_masked = 0 |
| valid_batches = 0 |
|
|
| for x_0 in train_loader: |
| x_0 = x_0.to(device) |
|
|
| |
| if global_step < warmup_steps: |
| lr = base_lr * (global_step + 1) / warmup_steps |
| for g in optimizer.param_groups: |
| g['lr'] = lr |
|
|
| pad_mask = (x_0 == pad_id) |
| optimizer.zero_grad() |
|
|
| |
| x_t, is_mask, t_values = apply_guided_masking( |
| x_0, mask_id, pad_id, special_ids, bot_token_ids |
| ) |
|
|
| if is_mask.sum() == 0: |
| continue |
|
|
| |
| logits = model(x_t, t_values, src_key_padding_mask=pad_mask) |
|
|
| loss_per_token = criterion( |
| logits.view(-1, vocab_size), |
| x_0.view(-1), |
| ).view_as(x_0) |
|
|
| |
| masked_loss = (loss_per_token * is_mask.float()).sum() / (is_mask.sum() + 1e-8) |
|
|
| masked_loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| |
| ema.update(model) |
|
|
| total_loss += masked_loss.item() |
| total_masked += is_mask.sum().item() |
| valid_batches += 1 |
| global_step += 1 |
|
|
| if epoch >= warmup_epochs: |
| scheduler.step() |
|
|
| avg_train_loss = total_loss / max(valid_batches, 1) |
|
|
| |
| if (epoch + 1) % 5 == 0 or epoch == 0: |
| |
| original_state = copy.deepcopy(model.state_dict()) |
| ema.apply(model) |
|
|
| val_loss = validate( |
| model, val_loader, vocab_size, |
| mask_id, pad_id, special_ids, bot_token_ids, |
| device, label_smoothing |
| ) |
|
|
| |
| model.load_state_dict(original_state) |
|
|
| lr_now = optimizer.param_groups[0]['lr'] |
| print( |
| f"Epoch {epoch+1:>4}/{epochs} | " |
| f"Train Loss: {avg_train_loss:.4f} | " |
| f"Val Loss: {val_loss:.4f} | " |
| f"Masked: {int(total_masked):>8} | " |
| f"LR: {lr_now:.6f}" |
| ) |
|
|
| |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| |
| torch.save(ema.state_dict(), "diffusion_model_best.pth") |
| patience = 0 |
| print(f" β New best val loss! Saved EMA checkpoint.") |
| else: |
| patience += 1 |
|
|
| |
| if patience >= max_patience: |
| print(f"\nEarly stopping at epoch {epoch+1} (no improvement for {max_patience} epochs)") |
| break |
|
|
| |
| if 0 < avg_train_loss < best_train_loss: |
| best_train_loss = avg_train_loss |
|
|
| |
| torch.save(ema.state_dict(), "diffusion_model.pth") |
| print(f"\n{'='*70}") |
| print(f"Done! Best train loss: {best_train_loss:.4f} | Best val loss: {best_val_loss:.4f}") |
| print("Use diffusion_model_best.pth for inference (EMA weights).") |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| train_model( |
| tokenized_file=args.tokenized_file, |
| batch_size=args.batch_size, |
| epochs=args.epochs, |
| base_lr=args.base_lr, |
| warmup_epochs=args.warmup_epochs, |
| label_smoothing=args.label_smoothing, |
| dropout=args.dropout, |
| weight_decay=args.weight_decay, |
| max_patience=args.patience, |
| val_split=args.val_split, |
| ) |