| """Tiny Decoder module for CAT V3.""" |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class TinyDecoder(nn.Module): |
| """A lightweight causal transformer decoder for converting concept embeddings to text.""" |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| concept_dim: int, |
| hidden_size: int = 128, |
| nhead: int = 4, |
| num_layers: int = 2, |
| max_seq_len: int = 128, |
| ) -> None: |
| super().__init__() |
| self.embedding = nn.Embedding(vocab_size, hidden_size) |
| self.concept_proj = nn.Linear(concept_dim, hidden_size) |
| self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_len, hidden_size)) |
|
|
| decoder_layer = nn.TransformerDecoderLayer( |
| d_model=hidden_size, |
| nhead=nhead, |
| dim_feedforward=hidden_size * 2, |
| batch_first=True, |
| norm_first=True |
| ) |
| self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers) |
| self.output_head = nn.Linear(hidden_size, vocab_size) |
| self.vocab_size = vocab_size |
|
|
| def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor: |
| """Construct standard upper-triangular causal mask for transformer self-attention.""" |
| mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1) |
| return mask |
|
|
| def forward( |
| self, |
| organized_embeddings: torch.Tensor, |
| target_ids: torch.Tensor, |
| target_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Autoregressive training forward pass (teacher forcing). |
| |
| Args: |
| organized_embeddings: [batch_size, top_m, concept_dim] |
| target_ids: [batch_size, seq_len] |
| target_mask: [batch_size, seq_len] (1 for active, 0 for pad) |
| |
| Returns: |
| [batch_size, seq_len, vocab_size] token logits |
| """ |
| batch_size, seq_len = target_ids.size() |
| device = target_ids.device |
|
|
| |
| memory = self.concept_proj(organized_embeddings) |
|
|
| |
| x = self.embedding(target_ids) + self.pos_embedding[:, :seq_len] |
|
|
| |
| tgt_mask = self._generate_causal_mask(seq_len, device) |
| tgt_key_padding_mask = (target_mask == 0) if target_mask is not None else None |
|
|
| |
| out = self.decoder( |
| tgt=x, |
| memory=memory, |
| tgt_mask=tgt_mask, |
| tgt_key_padding_mask=tgt_key_padding_mask |
| ) |
|
|
| return self.output_head(out) |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| organized_embeddings: torch.Tensor, |
| max_length: int = 32, |
| start_id: int = 0, |
| eos_id: int = 1, |
| ) -> torch.Tensor: |
| """Autoregressively generate English tokens from concept embeddings. |
| |
| Args: |
| organized_embeddings: [batch_size, top_m, concept_dim] |
| max_length: Maximum sequence length to generate |
| start_id: ID of the PAD/start token |
| eos_id: ID of the EOS token |
| |
| Returns: |
| [batch_size, gen_seq_len] generated token IDs |
| """ |
| batch_size = organized_embeddings.size(0) |
| device = organized_embeddings.device |
|
|
| |
| generated = torch.full((batch_size, 1), start_id, dtype=torch.long, device=device) |
| memory = self.concept_proj(organized_embeddings) |
| finished = torch.zeros(batch_size, dtype=torch.bool, device=device) |
|
|
| for _ in range(max_length - 1): |
| seq_len = generated.size(1) |
| x = self.embedding(generated) + self.pos_embedding[:, :seq_len] |
| tgt_mask = self._generate_causal_mask(seq_len, device) |
|
|
| out = self.decoder(tgt=x, memory=memory, tgt_mask=tgt_mask) |
| logits = self.output_head(out[:, -1, :]) |
| next_tokens = logits.argmax(dim=-1) |
|
|
| |
| next_tokens = torch.where(finished, torch.tensor(eos_id, device=device), next_tokens) |
| |
| generated = torch.cat([generated, next_tokens.unsqueeze(1)], dim=1) |
| finished = finished | (next_tokens == eos_id) |
| if finished.all(): |
| break |
|
|
| return generated |
|
|