File size: 4,529 Bytes
3437bd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Tiny Decoder module for CAT V3."""

from __future__ import annotations

from typing import Optional
import torch
import torch.nn as nn


class TinyDecoder(nn.Module):
    """A lightweight causal transformer decoder for converting concept embeddings to text."""

    def __init__(
        self,
        vocab_size: int,
        concept_dim: int,
        hidden_size: int = 128,
        nhead: int = 4,
        num_layers: int = 2,
        max_seq_len: int = 128,
    ) -> None:
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.concept_proj = nn.Linear(concept_dim, hidden_size)
        self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_len, hidden_size))

        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_size,
            nhead=nhead,
            dim_feedforward=hidden_size * 2,
            batch_first=True,
            norm_first=True
        )
        self.decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_head = nn.Linear(hidden_size, vocab_size)
        self.vocab_size = vocab_size

    def _generate_causal_mask(self, seq_len: int, device: torch.device) -> torch.Tensor:
        """Construct standard upper-triangular causal mask for transformer self-attention."""
        mask = torch.triu(torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1)
        return mask

    def forward(
        self,
        organized_embeddings: torch.Tensor,
        target_ids: torch.Tensor,
        target_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Autoregressive training forward pass (teacher forcing).
        
        Args:
            organized_embeddings: [batch_size, top_m, concept_dim]
            target_ids: [batch_size, seq_len]
            target_mask: [batch_size, seq_len] (1 for active, 0 for pad)
            
        Returns:
            [batch_size, seq_len, vocab_size] token logits
        """
        batch_size, seq_len = target_ids.size()
        device = target_ids.device

        # Map concept states to decoder hidden size (cross-attention memory)
        memory = self.concept_proj(organized_embeddings)

        # Target embeddings + positional encodings
        x = self.embedding(target_ids) + self.pos_embedding[:, :seq_len]

        # Masks
        tgt_mask = self._generate_causal_mask(seq_len, device)
        tgt_key_padding_mask = (target_mask == 0) if target_mask is not None else None

        # Transformer decoding
        out = self.decoder(
            tgt=x,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

        return self.output_head(out)

    @torch.no_grad()
    def generate(
        self,
        organized_embeddings: torch.Tensor,
        max_length: int = 32,
        start_id: int = 0,
        eos_id: int = 1,
    ) -> torch.Tensor:
        """Autoregressively generate English tokens from concept embeddings.
        
        Args:
            organized_embeddings: [batch_size, top_m, concept_dim]
            max_length: Maximum sequence length to generate
            start_id: ID of the PAD/start token
            eos_id: ID of the EOS token
            
        Returns:
            [batch_size, gen_seq_len] generated token IDs
        """
        batch_size = organized_embeddings.size(0)
        device = organized_embeddings.device

        # Initialize with start token
        generated = torch.full((batch_size, 1), start_id, dtype=torch.long, device=device)
        memory = self.concept_proj(organized_embeddings)
        finished = torch.zeros(batch_size, dtype=torch.bool, device=device)

        for _ in range(max_length - 1):
            seq_len = generated.size(1)
            x = self.embedding(generated) + self.pos_embedding[:, :seq_len]
            tgt_mask = self._generate_causal_mask(seq_len, device)

            out = self.decoder(tgt=x, memory=memory, tgt_mask=tgt_mask)
            logits = self.output_head(out[:, -1, :])  # Take logits of last step
            next_tokens = logits.argmax(dim=-1)       # [batch_size]

            # Enforce EOS if finished
            next_tokens = torch.where(finished, torch.tensor(eos_id, device=device), next_tokens)
            
            generated = torch.cat([generated, next_tokens.unsqueeze(1)], dim=1)
            finished = finished | (next_tokens == eos_id)
            if finished.all():
                break

        return generated