File size: 6,331 Bytes
f5b9ac5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""Arkadiko V4 decoder.

Dense transformer decoder with optional LASER2 cross-attention.
Three variants (set via config.cross_attention_mode):
  - "per_layer":  cross-attention at every decoder layer (Variant A)
  - "none":       pure decoder, no LASER2 (Variant B)
  - "input_only": LASER2 added to token embeddings at input only (Variant C)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from arkadiko.embedding.rope import precompute_rotary_embeddings
from arkadiko.embedding.mlp import SwiGLU
from arkadiko.llm.attention import CausalSelfAttention, CrossAttention
from arkadiko.llm.config import V4Config


def norm(x):
    return F.rms_norm(x, (x.size(-1),))


class V4Block(nn.Module):
    """Decoder block: self-attn → (optional cross-attn) → FFN."""

    def __init__(self, config: V4Config):
        super().__init__()
        self.config = config
        self.use_cross_attn = (config.cross_attention_mode == "per_layer")

        self.self_attn = CausalSelfAttention(config)
        if self.use_cross_attn:
            self.cross_attn = CrossAttention(config)
        self.mlp = SwiGLU(config.n_embd, config.ffn_mult, hidden=config.ffn_hidden)

    def forward(self, x, cos, sin, encoder_hidden=None, encoder_pad_mask=None):
        # Self-attention (pre-norm)
        x = x + self.self_attn(norm(x), cos, sin)

        # Cross-attention (if enabled and encoder output provided)
        if self.use_cross_attn and encoder_hidden is not None:
            x = x + self.cross_attn(norm(x), encoder_hidden, encoder_pad_mask)

        # FFN
        x = x + self.mlp(norm(x))
        return x


class V4Decoder(nn.Module):
    """Arkadiko V4 decoder."""

    def __init__(self, config: V4Config):
        super().__init__()
        self.config = config

        # Token embedding
        self.wte = nn.Embedding(config.vocab_size, config.n_embd, padding_idx=config.pad_token_id)

        # Input projection for LASER2 (for input_only mode)
        if config.cross_attention_mode == "input_only":
            self.laser_input_proj = nn.Linear(config.laser_dim, config.n_embd, bias=False)

        # Decoder blocks
        self.blocks = nn.ModuleList([V4Block(config) for _ in range(config.n_layer)])
        self.final_norm_gamma = nn.Parameter(torch.ones(config.n_embd))

        # LM head (tied to wte if config says so)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # RoPE buffers
        cos, sin = precompute_rotary_embeddings(
            config.max_seq_len, config.head_dim, config.rope_theta
        )
        self.register_buffer("cos", cos)
        self.register_buffer("sin", sin)

        self.init_weights()

        if config.tied_embeddings:
            self.lm_head.weight = self.wte.weight

    def init_weights(self):
        std = self.config.init_std
        n_embd = self.config.n_embd
        s = 3**0.5 * n_embd**-0.5

        nn.init.normal_(self.wte.weight, mean=0.0, std=std)

        for block in self.blocks:
            nn.init.uniform_(block.self_attn.c_q.weight, -s, s)
            nn.init.uniform_(block.self_attn.c_k.weight, -s, s)
            nn.init.uniform_(block.self_attn.c_v.weight, -s, s)
            nn.init.zeros_(block.self_attn.c_proj.weight)

            if block.use_cross_attn:
                # Cross-attention inputs scaled for decoder dim
                nn.init.uniform_(block.cross_attn.c_q.weight, -s, s)
                # K/V project from laser_dim (1024) to decoder dim
                s_laser = 3**0.5 * self.config.laser_dim**-0.5
                nn.init.uniform_(block.cross_attn.c_k.weight, -s_laser, s_laser)
                nn.init.uniform_(block.cross_attn.c_v.weight, -s_laser, s_laser)
                nn.init.zeros_(block.cross_attn.c_proj.weight)  # start as no-op

            nn.init.uniform_(block.mlp.c_gate.weight, -s * 0.5, s * 0.5)
            nn.init.uniform_(block.mlp.c_up.weight, -s * 0.5, s * 0.5)
            nn.init.zeros_(block.mlp.c_proj.weight)

        if hasattr(self, "laser_input_proj"):
            nn.init.zeros_(self.laser_input_proj.weight)

    def forward(
        self,
        input_ids: torch.Tensor,
        encoder_hidden: torch.Tensor | None = None,
        encoder_pad_mask: torch.Tensor | None = None,
        labels: torch.Tensor | None = None,
    ):
        """
        Args:
            input_ids: [B, T] decoder tokens (causal LM targets)
            encoder_hidden: [B, T_enc, laser_dim] LASER2 output (per_layer or input_only)
            encoder_pad_mask: [B, T_enc] bool, True = pad
            labels: [B, T] targets for cross-entropy loss (shifted by caller)

        Returns:
            dict with 'logits' [B, T, V] and optionally 'loss'
        """
        B, T = input_ids.shape

        # Embeddings
        x = self.wte(input_ids)

        # Input-only LASER2 injection
        if self.config.cross_attention_mode == "input_only" and encoder_hidden is not None:
            # Mean-pool encoder output across time, broadcast to all positions
            if encoder_pad_mask is not None:
                mask = (~encoder_pad_mask).to(encoder_hidden.dtype).unsqueeze(-1)
                laser_pool = (encoder_hidden * mask).sum(1) / mask.sum(1).clamp(min=1)
            else:
                laser_pool = encoder_hidden.mean(1)
            laser_proj = self.laser_input_proj(laser_pool.to(x.dtype))  # [B, C]
            x = x + laser_proj.unsqueeze(1)  # broadcast to all decoder positions

        # Decoder blocks
        for block in self.blocks:
            x = block(x, self.cos, self.sin,
                      encoder_hidden=encoder_hidden,
                      encoder_pad_mask=encoder_pad_mask)

        # Final norm + LM head
        x = norm(x) * self.final_norm_gamma
        logits = self.lm_head(x)

        out = {"logits": logits}

        if labels is not None:
            loss = F.cross_entropy(
                logits.view(-1, self.config.vocab_size),
                labels.view(-1),
                ignore_index=self.config.pad_token_id,
            )
            out["loss"] = loss

        return out

    def num_parameters(self, exclude_embedding: bool = False) -> int:
        n = sum(p.numel() for p in self.parameters())
        if exclude_embedding:
            n -= self.wte.weight.numel()
        return n