ritikraj2425's picture
pushed 17m param model
fb3221e
Raw
History Blame Contribute Delete
23.2 kB
import argparse
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
try:
from dataset import create_dataloader
from tokenizer import MAX_SEQ_LENGTH
except ImportError:
MAX_SEQ_LENGTH = 64
create_dataloader = None
import math
import copy
# ─────────────────────────────────────────────────────────────
# Timestep Embedding β€” tells the model what diffusion stage it's at
# ─────────────────────────────────────────────────────────────
class TimestepEmbedding(nn.Module):
"""
Maps scalar timestep t ∈ [0, 1] to a d_model-dimensional vector.
Uses sinusoidal encoding β†’ 2-layer MLP (like DiT / DDPM).
This is THE critical missing piece from the old architecture.
Without this, the model treats step 1 (90% masked) the same as
step 40 (fully revealed), making iterative refinement impossible.
"""
def __init__(self, d_model, max_period=10000):
super().__init__()
self.d_model = d_model
self.max_period = max_period
# MLP: sinusoidal features β†’ hidden β†’ output
# The MLP learns to transform raw sinusoidal features into
# useful scale/shift parameters for AdaLN
self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4),
nn.GELU(),
nn.Linear(d_model * 4, d_model * 4),
nn.GELU(),
)
def forward(self, t):
"""
Args:
t: [batch_size] tensor of timestep values in [0, 1]
Returns:
[batch_size, d_model * 4] conditioning vector
"""
# Sinusoidal encoding of scalar timestep
half = self.d_model // 2
freqs = torch.exp(
-math.log(self.max_period) * torch.arange(half, device=t.device, dtype=torch.float32) / half
)
# t: [B] β†’ [B, 1] * [1, half] β†’ [B, half]
args = t.unsqueeze(-1).float() * freqs.unsqueeze(0)
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) # [B, d_model]
return self.mlp(embedding) # [B, d_model * 4]
# ─────────────────────────────────────────────────────────────
# AdaLN Transformer Layer β€” timestep-conditioned normalization
# ─────────────────────────────────────────────────────────────
class AdaLNTransformerLayer(nn.Module):
"""
Transformer encoder layer with Adaptive Layer Normalization (AdaLN).
Instead of standard LayerNorm, we use:
AdaLN(x, t) = Ξ³_t * LayerNorm(x) + Ξ²_t
where Ξ³_t and Ξ²_t are predicted from the timestep embedding.
This is how DiT (Diffusion Transformer) conditions on timestep,
and it's the standard approach in MDLM papers.
"""
def __init__(self, d_model, nhead, dim_feedforward, dropout=0.2):
super().__init__()
# Self-attention
self.self_attn = nn.MultiheadAttention(
embed_dim=d_model, num_heads=nhead,
dropout=dropout, batch_first=True
)
# Feedforward
self.ff = nn.Sequential(
nn.Linear(d_model, dim_feedforward),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dim_feedforward, d_model),
nn.Dropout(dropout),
)
# Layer norms (will be modulated by timestep)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
# Dropout for residual connections
self.dropout = nn.Dropout(dropout)
# AdaLN modulation: timestep_emb β†’ (scale1, shift1, scale2, shift2)
# 4 vectors of d_model each = 4 * d_model from the conditioning
self.adaLN_modulation = nn.Sequential(
nn.GELU(),
nn.Linear(d_model * 4, d_model * 4),
)
# Zero-init the modulation output so model starts as standard transformer
# (1 + 0) * LayerNorm(x) + 0 = LayerNorm(x) at initialization
# This is critical for stable early training (from DiT paper)
nn.init.zeros_(self.adaLN_modulation[-1].weight)
nn.init.zeros_(self.adaLN_modulation[-1].bias)
def forward(self, x, t_emb, src_key_padding_mask=None):
"""
Args:
x: [B, seq_len, d_model]
t_emb: [B, d_model * 4] from TimestepEmbedding
src_key_padding_mask: [B, seq_len] True for padding positions
"""
# Predict modulation parameters from timestep
mod = self.adaLN_modulation(t_emb) # [B, d_model * 4]
scale1, shift1, scale2, shift2 = mod.chunk(4, dim=-1) # each [B, d_model]
# Unsqueeze for broadcasting: [B, d_model] β†’ [B, 1, d_model]
scale1 = scale1.unsqueeze(1)
shift1 = shift1.unsqueeze(1)
scale2 = scale2.unsqueeze(1)
shift2 = shift2.unsqueeze(1)
# Pre-norm self-attention with AdaLN
normed = self.norm1(x)
normed = (1 + scale1) * normed + shift1 # AdaLN modulation
attn_out, _ = self.self_attn(
normed, normed, normed,
key_padding_mask=src_key_padding_mask
)
x = x + self.dropout(attn_out)
# Pre-norm feedforward with AdaLN
normed = self.norm2(x)
normed = (1 + scale2) * normed + shift2 # AdaLN modulation
x = x + self.ff(normed)
return x
# ─────────────────────────────────────────────────────────────
# Positional Encoding (unchanged)
# ─────────────────────────────────────────────────────────────
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_len=5000):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe.unsqueeze(0))
def forward(self, x):
return x + self.pe[:, :x.size(1), :]
# ─────────────────────────────────────────────────────────────
# Main Model β€” Timestep-Conditioned Masked Diffusion
# ─────────────────────────────────────────────────────────────
class MaskedDiffusionModel(nn.Module):
"""
Key differences from old model:
1. Accepts timestep t as input β†’ AdaLN conditioning
2. Smaller: d_model=192, nhead=6, num_layers=4 (~4M params)
3. Higher dropout (0.2) throughout
4. Final LayerNorm before output projection
"""
def __init__(self, vocab_size, d_model=192, nhead=6, num_layers=4,
max_seq_len=MAX_SEQ_LENGTH, dropout=0.2):
super().__init__()
self.d_model = d_model
# Token embedding + positional encoding
self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
self.pos_encoder = PositionalEncoding(d_model, max_len=max_seq_len)
self.embed_dropout = nn.Dropout(dropout)
# Timestep conditioning
self.time_embed = TimestepEmbedding(d_model)
# AdaLN transformer layers
self.layers = nn.ModuleList([
AdaLNTransformerLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_model * 4,
dropout=dropout,
)
for _ in range(num_layers)
])
# Final norm and output projection
self.final_norm = nn.LayerNorm(d_model)
self.fc_out = nn.Linear(d_model, vocab_size)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Xavier uniform init for better gradient flow."""
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, x, t, src_key_padding_mask=None):
"""
Args:
x: [B, seq_len] token IDs
t: [B] masking rates in [0, 1] β€” THE NEW CRITICAL INPUT
src_key_padding_mask: [B, seq_len] True for padding
"""
# Embed tokens + positions
h = self.embedding(x)
h = self.pos_encoder(h)
h = self.embed_dropout(h)
# Get timestep conditioning vector
t_emb = self.time_embed(t) # [B, d_model * 4]
# Pass through AdaLN transformer layers
for layer in self.layers:
h = layer(h, t_emb, src_key_padding_mask=src_key_padding_mask)
# Final norm + project to vocab
h = self.final_norm(h)
return self.fc_out(h)
# ─────────────────────────────────────────────────────────────
# EMA (Exponential Moving Average) for stable inference
# ─────────────────────────────────────────────────────────────
class EMA:
"""
Maintains an exponential moving average of model parameters.
During training, weights fluctuate. EMA smooths them out:
ΞΈ_ema = decay * ΞΈ_ema + (1 - decay) * ΞΈ_model
Use EMA weights for validation and inference β€” they generalize better.
This is standard practice in all diffusion model papers.
"""
def __init__(self, model, decay=0.999):
self.decay = decay
self.shadow = copy.deepcopy(model.state_dict())
def update(self, model):
with torch.no_grad():
model_state = model.state_dict()
for key in self.shadow:
self.shadow[key] = (
self.decay * self.shadow[key] +
(1 - self.decay) * model_state[key]
)
def apply(self, model):
"""Load EMA weights into model (for validation/inference)."""
model.load_state_dict(self.shadow)
def state_dict(self):
return self.shadow
def parse_args():
parser = argparse.ArgumentParser(description='Train the masked diffusion dialogue model.')
parser.add_argument('--tokenized-file', type=str, default='tokenized_data.json', help='Tokenized JSON dataset file')
parser.add_argument('--batch-size', type=int, default=32, help='Training batch size')
parser.add_argument('--epochs', type=int, default=150, help='Number of training epochs')
parser.add_argument('--base-lr', type=float, default=1e-4, help='Base learning rate')
parser.add_argument('--warmup-epochs', type=int, default=5, help='Number of warmup epochs')
parser.add_argument('--label-smoothing', type=float, default=0.1, help='Label smoothing value')
parser.add_argument('--dropout', type=float, default=0.2, help='Dropout probability')
parser.add_argument('--weight-decay', type=float, default=0.01, help='AdamW weight decay')
parser.add_argument('--patience', type=int, default=20, help='Early stopping patience')
parser.add_argument('--val-split', type=float, default=0.1, help='Validation split fraction')
return parser.parse_args()
# ─────────────────────────────────────────────────────────────
# Guided Masking β€” returns per-sample masking rates
# ─────────────────────────────────────────────────────────────
def find_bot_start(sequence, bot_token_ids):
"""
Find the position where 'bot' token appears.
We only mask tokens AFTER this so the model learns:
given user prompt β†’ predict bot reply.
"""
seq = sequence.tolist()
for bid in bot_token_ids:
for i in range(len(seq)):
if seq[i] == bid:
return i + 1 # Start masking after 'bot' token
return len(seq) // 2 # Fallback: mask second half if 'bot' not found
def apply_guided_masking(x_0, mask_id, pad_id, special_ids, bot_token_ids):
"""
Per-sample masking with returned t values for timestep conditioning.
Returns:
x_t: masked sequences
is_mask: boolean mask of which positions were masked
t_values: [batch_size] tensor of masking rates (for timestep conditioning)
"""
batch_size, seq_len = x_0.shape
device = x_0.device
x_t = x_0.clone()
is_mask = torch.zeros_like(x_0, dtype=torch.bool)
t_values = torch.zeros(batch_size, device=device)
for b in range(batch_size):
# Per-sample masking rate
t = torch.rand(1).item()
t = max(t, 0.05) # Minimum 5% masking (matches inference range)
t_values[b] = t
bot_start = find_bot_start(x_0[b], bot_token_ids)
for pos in range(bot_start, seq_len):
tok = x_0[b, pos].item()
if tok in special_ids or tok == pad_id:
continue
if torch.rand(1).item() < t:
x_t[b, pos] = mask_id
is_mask[b, pos] = True
return x_t, is_mask, t_values
def get_bot_token_ids(vocab):
"""
Find the token IDs for 'bot' in the BPE vocabulary.
BPE may represent it as 'bot', 'bot:', or with a space prefix.
"""
candidates = ['bot', 'bot:', 'Δ bot', ' bot']
ids = [vocab[t] for t in candidates if t in vocab]
return ids
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# ─────────────────────────────────────────────────────────────
# Validation Loop β€” uses EMA weights
# ─────────────────────────────────────────────────────────────
def validate(model, val_loader, vocab_size, mask_id, pad_id, special_ids,
bot_token_ids, device, label_smoothing=0.1):
"""Run validation and return average masked cross-entropy loss."""
model.eval()
total_loss = 0
total_masked = 0
valid_batches = 0
criterion = nn.CrossEntropyLoss(
reduction='none', ignore_index=pad_id,
label_smoothing=label_smoothing
)
with torch.no_grad():
for x_0 in val_loader:
x_0 = x_0.to(device)
pad_mask = (x_0 == pad_id)
x_t, is_mask, t_values = apply_guided_masking(
x_0, mask_id, pad_id, special_ids, bot_token_ids
)
if is_mask.sum() == 0:
continue
# Pass timestep to model
logits = model(x_t, t_values, src_key_padding_mask=pad_mask)
loss_per_token = criterion(
logits.view(-1, vocab_size),
x_0.view(-1),
).view_as(x_0)
masked_loss = (loss_per_token * is_mask.float()).sum() / (is_mask.sum() + 1e-8)
total_loss += masked_loss.item()
total_masked += is_mask.sum().item()
valid_batches += 1
model.train()
return total_loss / max(valid_batches, 1)
# ─────────────────────────────────────────────────────────────
# Training Loop
# ─────────────────────────────────────────────────────────────
def train_model(
tokenized_file='tokenized_data.json',
batch_size=32,
epochs=150,
base_lr=1e-4,
warmup_epochs=5,
label_smoothing=0.1,
dropout=0.2,
weight_decay=0.01,
max_patience=20,
val_split=0.1,
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Training on: {device}")
with open("subword_tokenizer.json", "r", encoding="utf-8") as f:
vocab_data = json.load(f)
vocab = vocab_data["model"]["vocab"]
vocab_size = len(vocab)
mask_id = vocab["[MASK]"]
pad_id = vocab["[PAD]"]
special_ids = {vocab["[PAD]"], vocab["[BOS]"], vocab["[EOS]"], vocab["[UNK]"], vocab["[MASK]"]}
bot_token_ids = get_bot_token_ids(vocab)
print(f"'bot' token IDs: {bot_token_ids}")
if not bot_token_ids:
print("WARNING: 'bot' not found in vocab β€” falling back to half-masking.")
# ── Model (right-sized for quality on RTX 5070: ~8M params) ────
model = MaskedDiffusionModel(
vocab_size=vocab_size,
d_model=256, nhead=8, num_layers=6,
max_seq_len=MAX_SEQ_LENGTH,
dropout=dropout,
).to(device)
total_params = count_parameters(model)
print(f"Parameters: {total_params:,}")
# ── EMA ────────────────────────────────────────────────
ema = EMA(model, decay=0.999)
# ── Data ───────────────────────────────────────────────
train_loader, val_loader, total = create_dataloader(
tokenized_file, batch_size=batch_size, val_split=val_split
)
# ── Optimizer + Schedule ─────────────────────────────────
warmup_steps = len(train_loader) * warmup_epochs
global_step = 0
optimizer = AdamW(model.parameters(), lr=base_lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=max(1, epochs - warmup_epochs), eta_min=1e-5
)
# Label smoothing cross-entropy
criterion = nn.CrossEntropyLoss(
reduction='none', ignore_index=pad_id,
label_smoothing=label_smoothing
)
best_val_loss = float('inf')
best_train_loss = float('inf')
patience = 0
print(f"\nTraining for {epochs} epochs (early stop after {max_patience} epochs no improvement)")
print(f"Label smoothing: {label_smoothing} | Weight decay: {weight_decay} | Dropout: {dropout}")
print(f"EMA decay: 0.999")
print(f"{'='*70}\n")
for epoch in range(epochs):
model.train()
total_loss = 0
total_masked = 0
valid_batches = 0
for x_0 in train_loader:
x_0 = x_0.to(device)
# Linear warmup for first warmup_epochs
if global_step < warmup_steps:
lr = base_lr * (global_step + 1) / warmup_steps
for g in optimizer.param_groups:
g['lr'] = lr
pad_mask = (x_0 == pad_id)
optimizer.zero_grad()
# Get masked input + per-sample masking rates
x_t, is_mask, t_values = apply_guided_masking(
x_0, mask_id, pad_id, special_ids, bot_token_ids
)
if is_mask.sum() == 0:
continue # Skip batches where nothing was masked
# Forward pass WITH timestep conditioning
logits = model(x_t, t_values, src_key_padding_mask=pad_mask)
loss_per_token = criterion(
logits.view(-1, vocab_size),
x_0.view(-1),
).view_as(x_0)
# Loss only on masked positions
masked_loss = (loss_per_token * is_mask.float()).sum() / (is_mask.sum() + 1e-8)
masked_loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
# Update EMA after each optimizer step
ema.update(model)
total_loss += masked_loss.item()
total_masked += is_mask.sum().item()
valid_batches += 1
global_step += 1
if epoch >= warmup_epochs:
scheduler.step()
avg_train_loss = total_loss / max(valid_batches, 1)
# ── Validation every 5 epochs (using EMA weights) ──
if (epoch + 1) % 5 == 0 or epoch == 0:
# Save current weights, apply EMA, validate, restore
original_state = copy.deepcopy(model.state_dict())
ema.apply(model)
val_loss = validate(
model, val_loader, vocab_size,
mask_id, pad_id, special_ids, bot_token_ids,
device, label_smoothing
)
# Restore training weights
model.load_state_dict(original_state)
lr_now = optimizer.param_groups[0]['lr']
print(
f"Epoch {epoch+1:>4}/{epochs} | "
f"Train Loss: {avg_train_loss:.4f} | "
f"Val Loss: {val_loss:.4f} | "
f"Masked: {int(total_masked):>8} | "
f"LR: {lr_now:.6f}"
)
# Save best model (EMA weights) based on validation loss
if val_loss < best_val_loss:
best_val_loss = val_loss
# Save EMA weights for inference
torch.save(ema.state_dict(), "diffusion_model_best.pth")
patience = 0
print(f" β†’ New best val loss! Saved EMA checkpoint.")
else:
patience += 1
# Early stopping
if patience >= max_patience:
print(f"\nEarly stopping at epoch {epoch+1} (no improvement for {max_patience} epochs)")
break
# Also track best training loss
if 0 < avg_train_loss < best_train_loss:
best_train_loss = avg_train_loss
# Save final EMA checkpoint
torch.save(ema.state_dict(), "diffusion_model.pth")
print(f"\n{'='*70}")
print(f"Done! Best train loss: {best_train_loss:.4f} | Best val loss: {best_val_loss:.4f}")
print("Use diffusion_model_best.pth for inference (EMA weights).")
if __name__ == "__main__":
args = parse_args()
train_model(
tokenized_file=args.tokenized_file,
batch_size=args.batch_size,
epochs=args.epochs,
base_lr=args.base_lr,
warmup_epochs=args.warmup_epochs,
label_smoothing=args.label_smoothing,
dropout=args.dropout,
weight_decay=args.weight_decay,
max_patience=args.patience,
val_split=args.val_split,
)