MDaytek's picture
V12 Final
f71d4c6 verified
import math, torch, torch.nn as nn, torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
class ChessConfig(PretrainedConfig):
model_type = "chess_gpt" # Nom interne
def __init__(self, vocab_size=1000, n_positions=256, n_embd=128, n_layer=4, n_head=4, dropout=0.1, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.dropout = dropout
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
self.attn_dropout = nn.Dropout(config.dropout)
self.resid_dropout = nn.Dropout(config.dropout)
self.n_head = config.n_head
self.n_embd = config.n_embd
def forward(self, x):
B, T, C = x.size()
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
att = att.masked_fill(mask, float('-inf'))
att = F.softmax(att, dim=-1)
y = self.resid_dropout(self.c_proj((self.attn_dropout(att) @ v).transpose(1, 2).contiguous().view(B, T, C)))
return y
class GPTBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = CausalSelfAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = nn.Sequential(nn.Linear(config.n_embd, 4 * config.n_embd), nn.GELU(), nn.Linear(4 * config.n_embd, config.n_embd), nn.Dropout(config.dropout))
def forward(self, x): return x + self.mlp(self.ln_2(x + self.attn(self.ln_1(x))))
class ChessGPT(PreTrainedModel):
config_class = ChessConfig
def __init__(self, config):
super().__init__(config)
self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
self.pos_emb = nn.Embedding(config.n_positions, config.n_embd)
self.dropout = nn.Dropout(config.dropout)
self.blocks = nn.ModuleList([GPTBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
elif isinstance(module, nn.LayerNorm):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def forward(self, input_ids, labels=None, **kwargs):
B, T = input_ids.shape
x = self.dropout(self.token_emb(input_ids) + self.pos_emb(torch.arange(0, T, device=input_ids.device)))
for block in self.blocks: x = block(x)
logits = self.lm_head(self.ln_f(x))
loss = None
if labels is not None:
# Sécurité anti-crash
labels = labels.clamp(0, self.config.vocab_size - 1)
loss = nn.CrossEntropyLoss(ignore_index=0)(logits[..., :-1, :].contiguous().view(-1, logits.size(-1)), labels[..., 1:].contiguous().view(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits)