import torch import torch.nn as nn from src.model.config import ModelConfig from src.model.attention import MultiHeadAttention, AttentionResidual class FeedForward(nn.Module): def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): super().__init__() self.net = nn.Sequential( nn.Linear(d_model, d_ff), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_ff, d_model), nn.Dropout(dropout), ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.net(x) class TransformerBlock(nn.Module): def __init__(self, cfg: ModelConfig, layer_idx: int): super().__init__() self.layer_idx = layer_idx self.use_attn_res = cfg.use_attn_res self.ln1 = nn.LayerNorm(cfg.d_model) self.attn = MultiHeadAttention(cfg.d_model, cfg.n_heads, cfg.dropout) self.ln2 = nn.LayerNorm(cfg.d_model) self.ff = FeedForward(cfg.d_model, cfg.d_ff, cfg.dropout) if self.use_attn_res: self.attn_res = AttentionResidual(cfg.d_model, layer_idx) else: self.ln_res = nn.LayerNorm(cfg.d_model) def forward( self, x: torch.Tensor, layer_outputs: list[torch.Tensor] ) -> torch.Tensor: normed = self.ln1(x) attn_out = self.attn(normed, normed, normed, causal=True) if self.use_attn_res: x = self.attn_res(attn_out, layer_outputs) else: x = self.ln_res(x + attn_out) x = x + self.ff(self.ln2(x)) return x class Backbone(nn.Module): def __init__(self, cfg: ModelConfig): super().__init__() self.cfg = cfg self.tok_emb = nn.Embedding(cfg.vocab_size, cfg.d_model) self.pos_emb = nn.Embedding(cfg.max_seq_len, cfg.d_model) self.drop = nn.Dropout(cfg.dropout) self.blocks = nn.ModuleList([ TransformerBlock(cfg, i) for i in range(cfg.n_layers) ]) self.ln_final = nn.LayerNorm(cfg.d_model) def forward(self, input_ids: torch.Tensor) -> dict: B, T = input_ids.shape pos = torch.arange(T, device=input_ids.device).unsqueeze(0) x = self.drop(self.tok_emb(input_ids) + self.pos_emb(pos)) layer_outputs = [x] for block in self.blocks: x = block(x, layer_outputs) layer_outputs.append(x) hidden = self.ln_final(x) return {"hidden": hidden, "layer_outputs": layer_outputs}