""" TaxiLM — Vanilla transformer for Hassaniya. """ import math import torch import torch.nn as nn import torch.nn.functional as F from config import TaxiConfig class Attention(nn.Module): def __init__(self, config): super().__init__() self.n_heads = config.n_heads self.head_dim = config.d_model // config.n_heads self.qkv = nn.Linear(config.d_model, 3 * config.d_model) self.out = nn.Linear(config.d_model, config.d_model) self.dropout = nn.Dropout(config.dropout) def forward(self, x, mask=None): B, T, C = x.shape qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim) if mask is not None: attn = attn.masked_fill(mask == 0, float("-inf")) attn = self.dropout(F.softmax(attn, dim=-1)) return self.out((attn @ v).transpose(1, 2).contiguous().view(B, T, C)) class FFN(nn.Module): def __init__(self, config): super().__init__() self.up = nn.Linear(config.d_model, config.ffn_hidden) self.down = nn.Linear(config.ffn_hidden, config.d_model) self.dropout = nn.Dropout(config.dropout) def forward(self, x): return self.dropout(self.down(F.relu(self.up(x)))) class Block(nn.Module): def __init__(self, config): super().__init__() self.norm1 = nn.LayerNorm(config.d_model) self.attn = Attention(config) self.norm2 = nn.LayerNorm(config.d_model) self.ffn = FFN(config) def forward(self, x, mask=None): x = x + self.attn(self.norm1(x), mask) x = x + self.ffn(self.norm2(x)) return x class TaxiLM(nn.Module): # ← Changé def __init__(self, config: TaxiConfig): # ← Changé super().__init__() self.config = config self.tok_emb = nn.Embedding(config.vocab_size, config.d_model) self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model) self.drop = nn.Dropout(config.dropout) self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)]) self.norm = nn.LayerNorm(config.d_model) self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): nn.init.normal_(m.weight, mean=0.0, std=0.02) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def forward(self, idx, targets=None): B, T = idx.shape pos = torch.arange(T, device=idx.device) x = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) mask = torch.tril(torch.ones(T, T, device=idx.device)).unsqueeze(0).unsqueeze(0) for block in self.blocks: x = block(x, mask) logits = self.lm_head(self.norm(x)) loss = None if targets is not None: loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1), ignore_index=0) return logits, loss @torch.no_grad() def generate(self, idx, max_new_tokens=64, temperature=0.7, top_k=50, **kwargs): self.eval() for _ in range(max_new_tokens): idx_cond = idx[:, -self.config.max_seq_len:] logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature if top_k > 0: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = float("-inf") probs = F.softmax(logits, dim=-1) next_id = torch.multinomial(probs, num_samples=1) idx = torch.cat([idx, next_id], dim=1) if next_id.item() == self.config.eos_id: break return idx, [] def param_count(self): total = sum(p.numel() for p in self.parameters()) return total, 0 def param_summary(self): total, _ = self.param_count() return f"TaxiLM: {total:,} params ({total/1e6:.1f}M)" # ← Changé