Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| from diffusion.scheduler import OptimizedCosineScheduler | |
| from diffusion.forward_process import AbsorbingForwardProcess | |
| # Import shared classes to guarantee identical architectures | |
| from model.d3pm_model_cross_attention import SanskritEmbeddings, EncoderBlock, MultiHeadAttention | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, d_model, n_heads, d_ff, dropout=0.15): | |
| super().__init__() | |
| self.self_attn = MultiHeadAttention(d_model, n_heads, dropout) | |
| self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) # ← restored | |
| self.ff = nn.Sequential( | |
| nn.Linear(d_model, d_ff), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(d_ff, d_model), | |
| nn.Dropout(dropout), | |
| ) | |
| self.norm1 = nn.LayerNorm(d_model) | |
| self.norm2 = nn.LayerNorm(d_model) | |
| self.norm3 = nn.LayerNorm(d_model) # ← restored (for cross-attn residual) | |
| def forward(self, x, memory, tgt_pad_mask=None): | |
| # 1. Masked self-attention on target | |
| x = self.norm1(x + self.self_attn(x, x, x, mask=tgt_pad_mask)) | |
| # 2. Cross-attention: queries from decoder, keys/values from encoder memory | |
| x = self.norm2(x + self.cross_attn(x, memory, memory)) | |
| # 3. Feed-forward | |
| return self.norm3(x + self.ff(x)) | |
| class DecoderBlockNoCrossAttn(nn.Module): | |
| """Kept for reference — NOT used by D3PMEncoderDecoder.""" | |
| def __init__(self, d_model, n_heads, d_ff, dropout=0.15): | |
| super().__init__() | |
| self.self_attn = MultiHeadAttention(d_model, n_heads, dropout) | |
| self.ff = nn.Sequential( | |
| nn.Linear(d_model, d_ff), nn.ReLU(), nn.Dropout(dropout), | |
| nn.Linear(d_ff, d_model), nn.Dropout(dropout), | |
| ) | |
| self.norm1, self.norm2 = nn.LayerNorm(d_model), nn.LayerNorm(d_model) | |
| def forward(self, x, tgt_pad_mask=None, causal_mask=None): | |
| combined_mask = None | |
| if tgt_pad_mask is not None and causal_mask is not None: | |
| combined_mask = tgt_pad_mask | causal_mask | |
| elif causal_mask is not None: | |
| combined_mask = causal_mask | |
| elif tgt_pad_mask is not None: | |
| combined_mask = tgt_pad_mask | |
| x = self.norm1(x + self.self_attn(x, x, x, mask=combined_mask)) | |
| return self.norm2(x + self.ff(x)) | |
| # ============================================================ | |
| # 1. D3PM Encoder-Decoder Model | |
| # ============================================================ | |
| class D3PMEncoderDecoder(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.mask_token_id = cfg['diffusion']['mask_token_id'] | |
| src_vocab = cfg['model'].get('src_vocab_size', cfg['model']['vocab_size']) | |
| tgt_vocab = cfg['model'].get('tgt_vocab_size', cfg['model']['vocab_size']) | |
| d_model = cfg['model']['d_model'] | |
| n_heads = cfg['model']['n_heads'] | |
| d_ff = cfg['model']['d_ff'] | |
| dropout = cfg['model']['dropout'] | |
| n_layers = cfg['model']['n_layers'] | |
| max_len = cfg['model']['max_seq_len'] | |
| self.src_embed = SanskritEmbeddings(src_vocab, d_model, max_len) | |
| self.tgt_embed = SanskritEmbeddings(tgt_vocab, d_model, max_len) | |
| self.scheduler = OptimizedCosineScheduler(cfg) | |
| self.forward_process = AbsorbingForwardProcess(self.scheduler) | |
| self.encoder_blocks = nn.ModuleList([ | |
| EncoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) | |
| ]) | |
| # DecoderBlock now has cross-attention — matches saved checkpoint | |
| self.decoder_blocks = nn.ModuleList([ | |
| DecoderBlock(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) | |
| ]) | |
| self.time_mlp = nn.Sequential( | |
| nn.Linear(1, d_model // 4), nn.SiLU(), | |
| nn.Linear(d_model // 4, d_model), | |
| ) | |
| self.head = nn.Linear(d_model, tgt_vocab) | |
| self.head.weight = self.tgt_embed.token_embedding.weight | |
| def forward(self, src, tgt, t, x0_hint=None): | |
| src_pad_mask = (src == 1) | |
| tgt_pad_mask = (tgt == 1) | |
| # Encode source (Roman IAST) | |
| memory = self.src_embed(src) | |
| for block in self.encoder_blocks: | |
| memory = block(memory, pad_mask=src_pad_mask) | |
| # Corrupt target with forward diffusion | |
| _, x_t_ids = self.forward_process.q_sample(tgt, t) | |
| # Optionally blend in x0_hint (self-conditioning) | |
| if x0_hint is not None: | |
| hint_prob = 0.5 | |
| blend_mask = (torch.rand(x_t_ids.shape, device=x_t_ids.device) < hint_prob) | |
| still_mask = (x_t_ids == self.mask_token_id) | |
| x_t_ids = torch.where(blend_mask & still_mask, x0_hint, x_t_ids) | |
| x = self.tgt_embed(x_t_ids) | |
| t_emb = self.time_mlp(t.float().unsqueeze(-1)).unsqueeze(1) | |
| x = x + t_emb.expand(-1, tgt.shape[1], -1) | |
| # Decode with cross-attention over encoder memory | |
| for block in self.decoder_blocks: | |
| x = block(x, memory, tgt_pad_mask=tgt_pad_mask) | |
| return self.head(x), None | |
| def generate( | |
| self, | |
| src, | |
| num_steps = None, | |
| temperature = 0.75, | |
| top_k = 50, | |
| repetition_penalty = 1.15, | |
| diversity_penalty = 0.0, | |
| ): | |
| """ | |
| Iterative D3PM reverse diffusion — same signature as | |
| D3PMCrossAttention.generate() so SanskritModel.generate() works | |
| identically for both model types. | |
| """ | |
| device = src.device | |
| B, L = src.shape[0], self.cfg['model']['max_seq_len'] | |
| T = num_steps or self.scheduler.num_timesteps | |
| mask_id = self.mask_token_id | |
| pad_id = 1 | |
| x0_est = torch.full((B, L), mask_id, dtype=torch.long, device=device) | |
| for step in range(T - 1, -1, -1): | |
| t_tensor = torch.full((B,), step, dtype=torch.long, device=device) | |
| hint = x0_est.clone() | |
| logits, _ = self.forward(src, x0_est, t_tensor, x0_hint=hint) | |
| # Repetition penalty | |
| if repetition_penalty != 1.0: | |
| for b in range(B): | |
| for tok in set(x0_est[b].tolist()): | |
| if tok > pad_id: | |
| logits[b, :, tok] /= repetition_penalty | |
| # Diversity penalty (suppress common tokens) | |
| if diversity_penalty > 0.0: | |
| logits = logits - diversity_penalty * logits.mean(dim=1, keepdim=True) | |
| # Temperature + top-k sampling | |
| logits = logits / max(temperature, 1e-8) | |
| if top_k > 0: | |
| vals, _ = torch.topk(logits, top_k, dim=-1) | |
| logits = logits.masked_fill(logits < vals[..., -1:], float('-inf')) | |
| probs = torch.softmax(logits, dim=-1) | |
| # Only update positions that are still masked | |
| still = (x0_est == mask_id) | |
| sample = torch.multinomial(probs.view(-1, probs.size(-1)), 1).view(B, L) | |
| x0_est = torch.where(still, sample, x0_est) | |
| return x0_est | |
| # ============================================================ | |
| # 2. Baseline Encoder-Decoder Model (unchanged) | |
| # ============================================================ | |
| class BaselineEncoderDecoder(nn.Module): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.src_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'], | |
| cfg['model']['max_seq_len']) | |
| self.tgt_embed = SanskritEmbeddings(cfg['model']['vocab_size'], cfg['model']['d_model'], | |
| cfg['model']['max_seq_len']) | |
| self.encoder_blocks = nn.ModuleList([ | |
| EncoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'], | |
| cfg['model']['d_ff'], cfg['model']['dropout']) | |
| for _ in range(cfg['model']['n_layers']) | |
| ]) | |
| self.decoder_blocks = nn.ModuleList([ | |
| DecoderBlock(cfg['model']['d_model'], cfg['model']['n_heads'], | |
| cfg['model']['d_ff'], cfg['model']['dropout']) | |
| for _ in range(cfg['model']['n_layers']) | |
| ]) | |
| self.head = nn.Linear(cfg['model']['d_model'], cfg['model']['vocab_size']) | |
| self.head.weight = self.tgt_embed.token_embedding.weight | |
| def forward(self, src, tgt): | |
| src_pad_mask, tgt_pad_mask = (src == 1), (tgt == 1) | |
| memory = self.src_embed(src) | |
| for block in self.encoder_blocks: | |
| memory = block(memory, pad_mask=src_pad_mask) | |
| x = self.tgt_embed(tgt) | |
| for block in self.decoder_blocks: | |
| x = block(x, memory, tgt_pad_mask=tgt_pad_mask) | |
| return self.head(x) | |
| def generate(self, src, max_len=80, start_token_id=2): | |
| batch_size, device = src.size(0), src.device | |
| src_pad_mask = (src == 1) | |
| memory = self.src_embed(src) | |
| for block in self.encoder_blocks: | |
| memory = block(memory, pad_mask=src_pad_mask) | |
| ys = torch.ones(batch_size, 1, dtype=torch.long, device=device) * start_token_id | |
| for _ in range(max_len): | |
| x = self.tgt_embed(ys) | |
| for block in self.decoder_blocks: | |
| x = block(x, memory, tgt_pad_mask=None) | |
| logits = self.head(x) | |
| next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) | |
| ys = torch.cat([ys, next_token], dim=1) | |
| return ys[:, 1:] |