DevaFlow / model /d3pm_model_encoder_decoder.py
bhsinghgrid's picture
Add files using upload-large-folder tool
7d6a683 verified
import torch
import torch.nn as nn
from diffusion.scheduler import OptimizedCosineScheduler
from diffusion.forward_process import AbsorbingForwardProcess
# Import shared classes to guarantee identical architectures
from model.d3pm_model_cross_attention import SanskritEmbeddings, EncoderBlock, MultiHeadAttention
class DecoderBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) # ← restored
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(d_ff, d_model),
nn.Dropout(dropout),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model) # ← restored (for cross-attn residual)
def forward(self, x, memory, tgt_pad_mask=None):
# 1. Masked self-attention on target
x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask))
# 2. Cross-attention: queries from decoder, keys/values from encoder memory
x = self.norm2(x + self.cross_attn(x, memory, memory))
# 3. Feed-forward
return self.norm3(x + self.ff(x))
class DecoderBlockNoCrossAttn(nn.Module):
"""Kept for reference — NOT used by D3PMEncoderDecoder."""
def __init__(self, d_model, n_heads, d_ff, dropout=0.15):
super().__init__()
self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
self.ff = nn.Sequential(
nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout),
nn.Linear(d_ff, d_model), nn.Dropout(dropout),
)
self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model)
def forward(self, x, tgt_pad_mask=None, causal_mask=None):
combined_mask = None
if tgt_pad_mask is not None and causal_mask is not None:
combined_mask = tgt_pad_mask | causal_mask
elif causal_mask is not None:
combined_mask = causal_mask
elif tgt_pad_mask is not None:
combined_mask = tgt_pad_mask
x = self.norm1(x + self.self_attn(x, x, x, mask=combined_mask))
return self.norm2(x + self.ff(x))
# ============================================================
# 1. D3PM Encoder-Decoder Model
# ============================================================
class D3PMEncoderDecoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.mask_token_id = cfg['diffusion']['mask_token_id']
src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size'])
tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size'])
d_model = cfg['model']['d_model']
n_heads = cfg['model']['n_heads']
d_ff = cfg['model']['d_ff']
dropout = cfg['model']['dropout']
n_layers = cfg['model']['n_layers']
max_len = cfg['model']['max_seq_len']
self.src_embed = SanskritEmbeddings(src_vocab, d_model, max_len)
self.tgt_embed = SanskritEmbeddings(tgt_vocab, d_model, max_len)
self.scheduler = OptimizedCosineScheduler(cfg)
self.forward_process = AbsorbingForwardProcess(self.scheduler)
self.encoder_blocks = nn.ModuleList([
EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
])
# DecoderBlock now has cross-attention — matches saved checkpoint
self.decoder_blocks = nn.ModuleList([
DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
])
self.time_mlp = nn.Sequential(
nn.Linear(1, d_model // 4), nn.SiLU(),
nn.Linear(d_model // 4, d_model),
)
self.head = nn.Linear(d_model, tgt_vocab)
self.head.weight = self.tgt_embed.token_embedding.weight
def forward(self, src, tgt, t, x0_hint=None):
src_pad_mask = (src == 1)
tgt_pad_mask = (tgt == 1)
# Encode source (Roman IAST)
memory = self.src_embed(src)
for block in self.encoder_blocks:
memory = block(memory, pad_mask=src_pad_mask)
# Corrupt target with forward diffusion
_, x_t_ids = self.forward_process.q_sample(tgt, t)
# Optionally blend in x0_hint (self-conditioning)
if x0_hint is not None:
hint_prob = 0.5
blend_mask = (torch.rand(x_t_ids.shape, device=x_t_ids.device) < hint_prob)
still_mask = (x_t_ids == self.mask_token_id)
x_t_ids = torch.where(blend_mask & still_mask, x0_hint, x_t_ids)
x = self.tgt_embed(x_t_ids)
t_emb = self.time_mlp(t.float().unsqueeze(-1)).unsqueeze(1)
x = x + t_emb.expand(-1, tgt.shape[1], -1)
# Decode with cross-attention over encoder memory
for block in self.decoder_blocks:
x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
return self.head(x), None
@torch.no_grad()
def generate(
self,
src,
num_steps = None,
temperature = 0.75,
top_k = 50,
repetition_penalty = 1.15,
diversity_penalty = 0.0,
):
"""
Iterative D3PM reverse diffusion — same signature as
D3PMCrossAttention.generate() so SanskritModel.generate() works
identically for both model types.
"""
device = src.device
B, L = src.shape[0], self.cfg['model']['max_seq_len']
T = num_steps or self.scheduler.num_timesteps
mask_id = self.mask_token_id
pad_id = 1
x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device)
for step in range(T - 1, -1, -1):
t_tensor = torch.full((B,), step, dtype=torch.long, device=device)
hint = x0_est.clone()
logits, _ = self.forward(src, x0_est, t_tensor, x0_hint=hint)
# Repetition penalty
if repetition_penalty != 1.0:
for b in range(B):
for tok in set(x0_est[b].tolist()):
if tok > pad_id:
logits[b, :, tok] /= repetition_penalty
# Diversity penalty (suppress common tokens)
if diversity_penalty > 0.0:
logits = logits - diversity_penalty * logits.mean(dim=1, keepdim=True)
# Temperature + top-k sampling
logits = logits / max(temperature, 1e-8)
if top_k > 0:
vals, _ = torch.topk(logits, top_k, dim=-1)
logits = logits.masked_fill(logits < vals[..., -1:], float('-inf'))
probs = torch.softmax(logits, dim=-1)
# Only update positions that are still masked
still = (x0_est == mask_id)
sample = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(B, L)
x0_est = torch.where(still, sample, x0_est)
return x0_est
# ============================================================
# 2. Baseline Encoder-Decoder Model (unchanged)
# ============================================================
class BaselineEncoderDecoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.src_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
cfg['model']['max_seq_len'])
self.tgt_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'],
cfg['model']['max_seq_len'])
self.encoder_blocks = nn.ModuleList([
EncoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
cfg['model']['d_ff'], cfg['model']['dropout'])
for _ in range(cfg['model']['n_layers'])
])
self.decoder_blocks = nn.ModuleList([
DecoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'],
cfg['model']['d_ff'], cfg['model']['dropout'])
for _ in range(cfg['model']['n_layers'])
])
self.head = nn.Linear(cfg['model']['d_model'], cfg['model']['vocab_size'])
self.head.weight = self.tgt_embed.token_embedding.weight
def forward(self, src, tgt):
src_pad_mask, tgt_pad_mask = (src == 1), (tgt == 1)
memory = self.src_embed(src)
for block in self.encoder_blocks:
memory = block(memory, pad_mask=src_pad_mask)
x = self.tgt_embed(tgt)
for block in self.decoder_blocks:
x = block(x, memory, tgt_pad_mask=tgt_pad_mask)
return self.head(x)
@torch.no_grad()
def generate(self, src, max_len=80, start_token_id=2):
batch_size, device = src.size(0), src.device
src_pad_mask = (src == 1)
memory = self.src_embed(src)
for block in self.encoder_blocks:
memory = block(memory, pad_mask=src_pad_mask)
ys = torch.ones(batch_size, 1, dtype=torch.long, device=device) * start_token_id
for _ in range(max_len):
x = self.tgt_embed(ys)
for block in self.decoder_blocks:
x = block(x, memory, tgt_pad_mask=None)
logits = self.head(x)
next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
ys = torch.cat([ys, next_token], dim=1)
return ys[:, 1:]