Binary-Addition-LLM-POC / modeling_binaryllm.py
PhysiQuanty's picture
export inference-ready
a6c916c verified
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from .configuration_binaryllm import BinaryLLMConfig
class PositionalEncoding(nn.Module):
"""
Sinusoidal positional encoding, stocké en fp32,
puis casté au dtype de x à chaque forward.
"""
def __init__(self, d_model: int, max_len: int) -> None:
super().__init__()
pe = torch.zeros(max_len, d_model, dtype=torch.float32)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float32) * (-torch.log(torch.tensor(10000.0)) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer("pe", pe, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = x.size(1)
pe = self.pe[:, :t, :]
pe = pe.to(device=x.device, dtype=x.dtype)
return x + pe
@dataclass
class _InnerCfg:
block_size: int
embed_dim: int
vocab_size: int
num_heads: int
num_layers: int
ff_hidden_dim: int
dropout: float
layernorm_dim: Optional[int] = None
head_dim: Optional[int] = None
class TinyTransformerLM(nn.Module):
def __init__(self, cfg: _InnerCfg) -> None:
super().__init__()
self.cfg = cfg
vocab_size = cfg.vocab_size
self.tok_embed = nn.Embedding(vocab_size, cfg.embed_dim)
self.pos_encoding = PositionalEncoding(cfg.embed_dim, cfg.block_size)
encoder_layer = nn.TransformerEncoderLayer(
d_model=cfg.embed_dim,
nhead=cfg.num_heads,
dim_feedforward=cfg.ff_hidden_dim,
dropout=cfg.dropout,
activation="gelu",
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.num_layers)
ln_dim = cfg.layernorm_dim or cfg.embed_dim
head_dim = cfg.head_dim or ln_dim
self.pre_ln_proj: Optional[nn.Linear] = None
if ln_dim != cfg.embed_dim:
self.pre_ln_proj = nn.Linear(cfg.embed_dim, ln_dim)
self.ln = nn.LayerNorm(ln_dim)
self.head_pre: Optional[nn.Linear] = None
if head_dim != ln_dim:
self.head_pre = nn.Linear(ln_dim, head_dim)
self.head = nn.Linear(head_dim, vocab_size, bias=False)
# weight tying seulement si parfait alignement
if self.pre_ln_proj is None and self.head_pre is None and head_dim == cfg.embed_dim:
self.head.weight = self.tok_embed.weight
causal = torch.triu(torch.ones(cfg.block_size, cfg.block_size, dtype=torch.bool), diagonal=1)
self.register_buffer("causal_mask", causal, persistent=False)
def forward(self, tokens: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.tok_embed(tokens)
x = self.pos_encoding(x)
seq_len = tokens.size(1)
attn_mask = self.causal_mask[:seq_len, :seq_len].to(device=tokens.device)
if padding_mask is not None:
padding_mask = padding_mask[:, :seq_len].to(device=tokens.device, dtype=torch.bool)
x = self.encoder(x, mask=attn_mask, src_key_padding_mask=padding_mask)
if self.pre_ln_proj is not None:
x = self.pre_ln_proj(x)
x = self.ln(x)
if self.head_pre is not None:
x = self.head_pre(x)
return self.head(x)
class BinaryLLMForCausalLM(PreTrainedModel):
config_class = BinaryLLMConfig
main_input_name = "input_ids"
def __init__(self, config: BinaryLLMConfig):
super().__init__(config)
inner = _InnerCfg(
block_size=int(config.max_position_embeddings),
embed_dim=int(config.hidden_size),
vocab_size=int(config.vocab_size),
num_heads=int(config.num_attention_heads),
num_layers=int(config.num_hidden_layers),
ff_hidden_dim=int(config.intermediate_size),
dropout=float(getattr(config, "dropout", 0.0)),
layernorm_dim=None,
head_dim=None,
)
self.model = TinyTransformerLM(inner)
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutput:
padding_mask = None
if attention_mask is not None:
padding_mask = ~attention_mask.to(torch.bool) # True = ignore
logits = self.model(input_ids, padding_mask=padding_mask)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutput(loss=loss, logits=logits)