TaxiLM / model.py
AlihIng's picture
Upload model.py with huggingface_hub
7a7d768 verified
Raw
History Blame Contribute Delete
4.26 kB
"""
TaxiLM — Vanilla transformer for Hassaniya.
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from config import TaxiConfig
class Attention(nn.Module):
def __init__(self, config):
super().__init__()
self.n_heads = config.n_heads
self.head_dim = config.d_model // config.n_heads
self.qkv = nn.Linear(config.d_model, 3 * config.d_model)
self.out = nn.Linear(config.d_model, config.d_model)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x, mask=None):
B, T, C = x.shape
qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
attn = attn.masked_fill(mask == 0, float("-inf"))
attn = self.dropout(F.softmax(attn, dim=-1))
return self.out((attn @ v).transpose(1, 2).contiguous().view(B, T, C))
class FFN(nn.Module):
def __init__(self, config):
super().__init__()
self.up = nn.Linear(config.d_model, config.ffn_hidden)
self.down = nn.Linear(config.ffn_hidden, config.d_model)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.down(F.relu(self.up(x))))
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.norm1 = nn.LayerNorm(config.d_model)
self.attn = Attention(config)
self.norm2 = nn.LayerNorm(config.d_model)
self.ffn = FFN(config)
def forward(self, x, mask=None):
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x
class TaxiLM(nn.Module): # ← Changé
def __init__(self, config: TaxiConfig): # ← Changé
super().__init__()
self.config = config
self.tok_emb = nn.Embedding(config.vocab_size, config.d_model)
self.pos_emb = nn.Embedding(config.max_seq_len, config.d_model)
self.drop = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.n_layers)])
self.norm = nn.LayerNorm(config.d_model)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.tok_emb.weight
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Embedding):
nn.init.normal_(m.weight, mean=0.0, std=0.02)
def forward(self, idx, targets=None):
B, T = idx.shape
pos = torch.arange(T, device=idx.device)
x = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
mask = torch.tril(torch.ones(T, T, device=idx.device)).unsqueeze(0).unsqueeze(0)
for block in self.blocks:
x = block(x, mask)
logits = self.lm_head(self.norm(x))
loss = None
if targets is not None:
loss = F.cross_entropy(logits.view(-1, self.config.vocab_size), targets.view(-1), ignore_index=0)
return logits, loss
@torch.no_grad()
def generate(self, idx, max_new_tokens=64, temperature=0.7, top_k=50, **kwargs):
self.eval()
for _ in range(max_new_tokens):
idx_cond = idx[:, -self.config.max_seq_len:]
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
if top_k > 0:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = float("-inf")
probs = F.softmax(logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1)
idx = torch.cat([idx, next_id], dim=1)
if next_id.item() == self.config.eos_id:
break
return idx, []
def param_count(self):
total = sum(p.numel() for p in self.parameters())
return total, 0
def param_summary(self):
total, _ = self.param_count()
return f"TaxiLM: {total:,} params ({total/1e6:.1f}M)" # ← Changé