File size: 5,309 Bytes
a6c916c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutput
from .configuration_binaryllm import BinaryLLMConfig
class PositionalEncoding(nn.Module):
"""
Sinusoidal positional encoding, stocké en fp32,
puis casté au dtype de x à chaque forward.
"""
def __init__(self, d_model: int, max_len: int) -> None:
super().__init__()
pe = torch.zeros(max_len, d_model, dtype=torch.float32)
position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float32) * (-torch.log(torch.tensor(10000.0)) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, d_model)
self.register_buffer("pe", pe, persistent=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
t = x.size(1)
pe = self.pe[:, :t, :]
pe = pe.to(device=x.device, dtype=x.dtype)
return x + pe
@dataclass
class _InnerCfg:
block_size: int
embed_dim: int
vocab_size: int
num_heads: int
num_layers: int
ff_hidden_dim: int
dropout: float
layernorm_dim: Optional[int] = None
head_dim: Optional[int] = None
class TinyTransformerLM(nn.Module):
def __init__(self, cfg: _InnerCfg) -> None:
super().__init__()
self.cfg = cfg
vocab_size = cfg.vocab_size
self.tok_embed = nn.Embedding(vocab_size, cfg.embed_dim)
self.pos_encoding = PositionalEncoding(cfg.embed_dim, cfg.block_size)
encoder_layer = nn.TransformerEncoderLayer(
d_model=cfg.embed_dim,
nhead=cfg.num_heads,
dim_feedforward=cfg.ff_hidden_dim,
dropout=cfg.dropout,
activation="gelu",
batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=cfg.num_layers)
ln_dim = cfg.layernorm_dim or cfg.embed_dim
head_dim = cfg.head_dim or ln_dim
self.pre_ln_proj: Optional[nn.Linear] = None
if ln_dim != cfg.embed_dim:
self.pre_ln_proj = nn.Linear(cfg.embed_dim, ln_dim)
self.ln = nn.LayerNorm(ln_dim)
self.head_pre: Optional[nn.Linear] = None
if head_dim != ln_dim:
self.head_pre = nn.Linear(ln_dim, head_dim)
self.head = nn.Linear(head_dim, vocab_size, bias=False)
# weight tying seulement si parfait alignement
if self.pre_ln_proj is None and self.head_pre is None and head_dim == cfg.embed_dim:
self.head.weight = self.tok_embed.weight
causal = torch.triu(torch.ones(cfg.block_size, cfg.block_size, dtype=torch.bool), diagonal=1)
self.register_buffer("causal_mask", causal, persistent=False)
def forward(self, tokens: torch.Tensor, padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
x = self.tok_embed(tokens)
x = self.pos_encoding(x)
seq_len = tokens.size(1)
attn_mask = self.causal_mask[:seq_len, :seq_len].to(device=tokens.device)
if padding_mask is not None:
padding_mask = padding_mask[:, :seq_len].to(device=tokens.device, dtype=torch.bool)
x = self.encoder(x, mask=attn_mask, src_key_padding_mask=padding_mask)
if self.pre_ln_proj is not None:
x = self.pre_ln_proj(x)
x = self.ln(x)
if self.head_pre is not None:
x = self.head_pre(x)
return self.head(x)
class BinaryLLMForCausalLM(PreTrainedModel):
config_class = BinaryLLMConfig
main_input_name = "input_ids"
def __init__(self, config: BinaryLLMConfig):
super().__init__(config)
inner = _InnerCfg(
block_size=int(config.max_position_embeddings),
embed_dim=int(config.hidden_size),
vocab_size=int(config.vocab_size),
num_heads=int(config.num_attention_heads),
num_layers=int(config.num_hidden_layers),
ff_hidden_dim=int(config.intermediate_size),
dropout=float(getattr(config, "dropout", 0.0)),
layernorm_dim=None,
head_dim=None,
)
self.model = TinyTransformerLM(inner)
self.post_init()
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> CausalLMOutput:
padding_mask = None
if attention_mask is not None:
padding_mask = ~attention_mask.to(torch.bool) # True = ignore
logits = self.model(input_ids, padding_mask=padding_mask)
loss = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutput(loss=loss, logits=logits)
|