File size: 5,966 Bytes
bd0f882 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """
Optimized Chess Transformer
Strategy: Deep & Narrow.
Max parameters dedicated to reasoning (Layers), minimal for Vocab.
"""
from __future__ import annotations
import math
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
class ChessConfig(PretrainedConfig):
model_type = "chess_transformer"
def __init__(
self,
vocab_size=80,
n_embd=128,
n_layer=10,
n_head=8,
n_ctx=256,
n_inner=None,
dropout=0.0,
bias=False,
tie_weights=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
unk_token_id=3,
**kwargs,
):
self.vocab_size = vocab_size
self.n_embd = n_embd
self.n_layer = n_layer
self.n_head = n_head
self.n_ctx = n_ctx
self.n_inner = n_inner if n_inner is not None else 4 * n_embd
self.dropout = dropout
self.bias = bias
self.tie_weights = tie_weights
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, unk_token_id=unk_token_id, **kwargs)
class MultiHeadAttention(nn.Module):
def __init__(self, config: ChessConfig):
super().__init__()
self.n_head = config.n_head
self.n_embd = config.n_embd
self.head_dim = config.n_embd // config.n_head
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
self.register_buffer("bias_mask", torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(1, 1, config.n_ctx, config.n_ctx), persistent=False)
def forward(self, x, attention_mask=None):
B, T, C = x.size()
qkv = self.c_attn(x)
q, k, v = qkv.split(self.n_embd, dim=2)
q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
att = att.masked_fill(self.bias_mask[:,:,:T,:T] == 0, float('-inf'))
if attention_mask is not None: att = att.masked_fill(attention_mask.view(B, 1, 1, T) == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.dropout(att)
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.c_proj(y)
class FeedForward(nn.Module):
def __init__(self, config: ChessConfig):
super().__init__()
self.c_fc = nn.Linear(config.n_embd, config.n_inner, bias=config.bias)
self.c_proj = nn.Linear(config.n_inner, config.n_embd, bias=config.bias)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
return self.dropout(self.c_proj(F.gelu(self.c_fc(x))))
class TransformerBlock(nn.Module):
def __init__(self, config: ChessConfig):
super().__init__()
self.ln_1 = nn.LayerNorm(config.n_embd)
self.attn = MultiHeadAttention(config)
self.ln_2 = nn.LayerNorm(config.n_embd)
self.mlp = FeedForward(config)
def forward(self, x, attention_mask=None):
x = x + self.attn(self.ln_1(x), attention_mask)
x = x + self.mlp(self.ln_2(x))
return x
class ChessForCausalLM(PreTrainedModel):
config_class = ChessConfig
base_model_prefix = "transformer"
def __init__(self, config: ChessConfig):
super().__init__(config)
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
self.drop = nn.Dropout(config.dropout)
self.h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(config.n_embd)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
if config.tie_weights: self.lm_head.weight = self.wte.weight
self.post_init()
def get_input_embeddings(self): return self.wte
def set_input_embeddings(self, new): self.wte = new
def get_output_embeddings(self): return self.lm_head
def set_output_embeddings(self, new): self.lm_head = new
def forward(self, input_ids, attention_mask=None, position_ids=None, labels=None, return_dict=None, **kwargs):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if return_dict is None: return_dict = True
device = input_ids.device
b, t = input_ids.size()
if position_ids is None: position_ids = torch.arange(t, device=device).unsqueeze(0)
x = self.wte(input_ids) + self.wpe(position_ids)
x = self.drop(x)
for block in self.h: x = block(x, attention_mask)
x = self.ln_f(x)
logits = self.lm_head(x)
# Protection génération
if labels is None:
logits[:, :, [self.config.pad_token_id, self.config.bos_token_id, self.config.unk_token_id]] = float("-inf")
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss(ignore_index=self.config.pad_token_id)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
if not return_dict: return (loss, logits) if loss is not None else (logits,)
return CausalLMOutputWithPast(loss=loss, logits=logits)
from transformers import AutoConfig, AutoModelForCausalLM
AutoConfig.register("chess_transformer", ChessConfig)
AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM) |