| import torch |
| import torch.nn as nn |
| from diffusion.scheduler import OptimizedCosineScheduler |
| from diffusion.forward_process import AbsorbingForwardProcess |
| |
| from model.d3pm_model_cross_attention import SanskritEmbeddings, EncoderBlock, MultiHeadAttention |
| class DecoderBlock(nn.Module): |
| def __init__(self, d_model, n_heads, d_ff, dropout=0.15): |
| super().__init__() |
| self.self_attn = MultiHeadAttention(d_model, n_heads, dropout) |
| self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_ff), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model), |
| nn.Dropout(dropout), |
| ) |
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.norm3 = nn.LayerNorm(d_model) |
|
|
| def forward(self, x, memory, tgt_pad_mask=None): |
| |
| x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask)) |
| |
| x = self.norm2(x + self.cross_attn(x, memory, memory)) |
| |
| return self.norm3(x + self.ff(x)) |
|
|
|
|
| class DecoderBlockNoCrossAttn(nn.Module): |
| """Kept for reference — NOT used by D3PMEncoderDecoder.""" |
| def __init__(self, d_model, n_heads, d_ff, dropout=0.15): |
| super().__init__() |
| self.self_attn = MultiHeadAttention(d_model, n_heads, dropout) |
| self.ff = nn.Sequential( |
| nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout), |
| nn.Linear(d_ff, d_model), nn.Dropout(dropout), |
| ) |
| self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model) |
|
|
| def forward(self, x, tgt_pad_mask=None, causal_mask=None): |
| combined_mask = None |
| if tgt_pad_mask is not None and causal_mask is not None: |
| combined_mask = tgt_pad_mask | causal_mask |
| elif causal_mask is not None: |
| combined_mask = causal_mask |
| elif tgt_pad_mask is not None: |
| combined_mask = tgt_pad_mask |
| x = self.norm1(x + self.self_attn(x, x, x, mask=combined_mask)) |
| return self.norm2(x + self.ff(x)) |
|
|
|
|
| |
| |
| |
| class D3PMEncoderDecoder(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.mask_token_id = cfg['diffusion']['mask_token_id'] |
|
|
| src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size']) |
| tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size']) |
| d_model = cfg['model']['d_model'] |
| n_heads = cfg['model']['n_heads'] |
| d_ff = cfg['model']['d_ff'] |
| dropout = cfg['model']['dropout'] |
| n_layers = cfg['model']['n_layers'] |
| max_len = cfg['model']['max_seq_len'] |
|
|
| self.src_embed = SanskritEmbeddings(src_vocab, d_model, max_len) |
| self.tgt_embed = SanskritEmbeddings(tgt_vocab, d_model, max_len) |
|
|
| self.scheduler = OptimizedCosineScheduler(cfg) |
| self.forward_process = AbsorbingForwardProcess(self.scheduler) |
|
|
| self.encoder_blocks = nn.ModuleList([ |
| EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) |
| ]) |
| |
| self.decoder_blocks = nn.ModuleList([ |
| DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) |
| ]) |
|
|
| self.time_mlp = nn.Sequential( |
| nn.Linear(1, d_model // 4), nn.SiLU(), |
| nn.Linear(d_model // 4, d_model), |
| ) |
| self.head = nn.Linear(d_model, tgt_vocab) |
| self.head.weight = self.tgt_embed.token_embedding.weight |
|
|
| def forward(self, src, tgt, t, x0_hint=None): |
| src_pad_mask = (src == 1) |
| tgt_pad_mask = (tgt == 1) |
|
|
| |
| memory = self.src_embed(src) |
| for block in self.encoder_blocks: |
| memory = block(memory, pad_mask=src_pad_mask) |
|
|
| |
| _, x_t_ids = self.forward_process.q_sample(tgt, t) |
|
|
| |
| if x0_hint is not None: |
| hint_prob = 0.5 |
| blend_mask = (torch.rand(x_t_ids.shape, device=x_t_ids.device) < hint_prob) |
| still_mask = (x_t_ids == self.mask_token_id) |
| x_t_ids = torch.where(blend_mask & still_mask, x0_hint, x_t_ids) |
|
|
| x = self.tgt_embed(x_t_ids) |
| t_emb = self.time_mlp(t.float().unsqueeze(-1)).unsqueeze(1) |
| x = x + t_emb.expand(-1, tgt.shape[1], -1) |
|
|
| |
| for block in self.decoder_blocks: |
| x = block(x, memory, tgt_pad_mask=tgt_pad_mask) |
|
|
| return self.head(x), None |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| src, |
| num_steps = None, |
| temperature = 0.75, |
| top_k = 50, |
| repetition_penalty = 1.15, |
| diversity_penalty = 0.0, |
| ): |
| """ |
| Iterative D3PM reverse diffusion — same signature as |
| D3PMCrossAttention.generate() so SanskritModel.generate() works |
| identically for both model types. |
| """ |
| device = src.device |
| B, L = src.shape[0], self.cfg['model']['max_seq_len'] |
| T = num_steps or self.scheduler.num_timesteps |
| mask_id = self.mask_token_id |
| pad_id = 1 |
|
|
| x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device) |
|
|
| for step in range(T - 1, -1, -1): |
| t_tensor = torch.full((B,), step, dtype=torch.long, device=device) |
| hint = x0_est.clone() |
|
|
| logits, _ = self.forward(src, x0_est, t_tensor, x0_hint=hint) |
|
|
| |
| if repetition_penalty != 1.0: |
| for b in range(B): |
| for tok in set(x0_est[b].tolist()): |
| if tok > pad_id: |
| logits[b, :, tok] /= repetition_penalty |
|
|
| |
| if diversity_penalty > 0.0: |
| logits = logits - diversity_penalty * logits.mean(dim=1, keepdim=True) |
|
|
| |
| logits = logits / max(temperature, 1e-8) |
| if top_k > 0: |
| vals, _ = torch.topk(logits, top_k, dim=-1) |
| logits = logits.masked_fill(logits < vals[..., -1:], float('-inf')) |
|
|
| probs = torch.softmax(logits, dim=-1) |
| |
| still = (x0_est == mask_id) |
| sample = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(B, L) |
| x0_est = torch.where(still, sample, x0_est) |
|
|
| return x0_est |
|
|
|
|
| |
| |
| |
| class BaselineEncoderDecoder(nn.Module): |
| def __init__(self, cfg): |
| super().__init__() |
| self.cfg = cfg |
| self.src_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'], |
| cfg['model']['max_seq_len']) |
| self.tgt_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'], |
| cfg['model']['max_seq_len']) |
| self.encoder_blocks = nn.ModuleList([ |
| EncoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'], |
| cfg['model']['d_ff'], cfg['model']['dropout']) |
| for _ in range(cfg['model']['n_layers']) |
| ]) |
| self.decoder_blocks = nn.ModuleList([ |
| DecoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'], |
| cfg['model']['d_ff'], cfg['model']['dropout']) |
| for _ in range(cfg['model']['n_layers']) |
| ]) |
| self.head = nn.Linear(cfg['model']['d_model'], cfg['model']['vocab_size']) |
| self.head.weight = self.tgt_embed.token_embedding.weight |
|
|
| def forward(self, src, tgt): |
| src_pad_mask, tgt_pad_mask = (src == 1), (tgt == 1) |
| memory = self.src_embed(src) |
| for block in self.encoder_blocks: |
| memory = block(memory, pad_mask=src_pad_mask) |
| x = self.tgt_embed(tgt) |
| for block in self.decoder_blocks: |
| x = block(x, memory, tgt_pad_mask=tgt_pad_mask) |
| return self.head(x) |
|
|
| @torch.no_grad() |
| def generate(self, src, max_len=80, start_token_id=2): |
| batch_size, device = src.size(0), src.device |
| src_pad_mask = (src == 1) |
| memory = self.src_embed(src) |
| for block in self.encoder_blocks: |
| memory = block(memory, pad_mask=src_pad_mask) |
| ys = torch.ones(batch_size, 1, dtype=torch.long, device=device) * start_token_id |
| for _ in range(max_len): |
| x = self.tgt_embed(ys) |
| for block in self.decoder_blocks: |
| x = block(x, memory, tgt_pad_mask=None) |
| logits = self.head(x) |
| next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) |
| ys = torch.cat([ys, next_token], dim=1) |
| return ys[:, 1:] |