File size: 14,999 Bytes
fccfe9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583c794
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
"""

Modèle Transformer pour le Chess Challenge (1M paramètres).



Ce module fournit une architecture transformer de style GPT conçue

pour respecter la contrainte stricte de moins de 1 million de paramètres.



Composants clés :

- ChessConfig : Configuration des hyperparamètres.

- ChessForCausalLM : Le modèle principal pour la prédiction du prochain coup.

"""
from __future__ import annotations
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel, AutoConfig, AutoModelForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

# -----------------------------------------------------------------------------
# Configuration
# -----------------------------------------------------------------------------

class ChessConfig(PretrainedConfig):
    """

    Classe de configuration pour le modèle Chess Transformer.

    

    Conçue pour un budget de paramètres très serré (< 1M).

    

    Répartition du budget (avec les valeurs par défaut de ton ami) :

    - Vocabulaire (Embeddings) : 72 * 92 = ~6,6k

    - Embeddings de position : 256 * 92 = ~23,5k

    - Couches Transformer : 11 couches * (~85k par couche) = ~935k

    - Tête LM (liée aux embeddings) : 0 paramètre supplémentaire

    - Total : ~970k paramètres (juste en dessous de 1M).

    """
    
    model_type = "chess_transformer"
    
    def __init__(

        self,

        vocab_size: int = 1200,

        n_embd: int = 128,

        n_layer: int = 6,

        n_head: int = 4,

        n_ctx: int = 256,

        n_inner: Optional[int] = None,

        dropout: float = 0.1,

        layer_norm_epsilon: float = 1e-5,

        tie_weights: bool = True,

        pad_token_id: int = 0,

        bos_token_id: int = 1,

        eos_token_id: int = 2,

        **kwargs,

    ):
        super().__init__(
            pad_token_id=pad_token_id,
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )
        
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_ctx = n_ctx
        self.n_inner = n_inner if n_inner is not None else 3 * n_embd
        self.dropout = dropout
        self.layer_norm_epsilon = layer_norm_epsilon
        self.tie_weights = tie_weights
        self.tie_word_embeddings = bool(tie_weights)


# -----------------------------------------------------------------------------
# Modules du Transformer
# -----------------------------------------------------------------------------

class MultiHeadAttention(nn.Module):
    """

    Module d'attention multi-têtes standard.

    Inclut le masquage causal pour empêcher le modèle de "voir le futur".

    """
    
    def __init__(self, config: ChessConfig):
        super().__init__()
        
        assert config.n_embd % config.n_head == 0, \
            f"n_embd ({config.n_embd}) doit être divisible par n_head ({config.n_head})"
        
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.head_dim = config.n_embd // config.n_head
        
        # Projection combinée Q, K, V pour l'efficacité
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        
        self.dropout = nn.Dropout(config.dropout)
        
        # Masque causal (registre persistent=False pour ne pas le sauvegarder dans le checkpoint)
        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.n_ctx, config.n_ctx)).view(
                1, 1, config.n_ctx, config.n_ctx
            ),
            persistent=False,
        )

    def forward(

        self,

        x: torch.Tensor,

        attention_mask: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        batch_size, seq_len, _ = x.size()
        
        # Calcul de Q, K, V
        qkv = self.c_attn(x)
        q, k, v = qkv.split(self.n_embd, dim=2)
        
        # Remodelage pour l'attention multi-têtes
        # (batch, seq_len, n_head, head_dim) -> (batch, n_head, seq_len, head_dim)
        q = q.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
        
        # Attention produit scalaire (Scaled Dot-Product Attention)
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        
        # Application du masque causal (le futur est masqué avec -inf)
        causal_mask = self.bias[:, :, :seq_len, :seq_len]
        attn_weights = attn_weights.masked_fill(causal_mask == 0, float("-inf"))
        
        # Application du masque d'attention (pour le padding)
        if attention_mask is not None:
            # attention_mask shape: (batch_size, seq_len) -> (batch_size, 1, 1, seq_len)
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attn_weights = attn_weights.masked_fill(attention_mask == 0, float("-inf"))
        
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Application de l'attention aux valeurs
        attn_output = torch.matmul(attn_weights, v)
        
        # Remise en forme
        attn_output = attn_output.transpose(1, 2).contiguous().view(
            batch_size, seq_len, self.n_embd
        )
        
        # Projection de sortie
        attn_output = self.c_proj(attn_output)
        
        return attn_output


class FeedForward(nn.Module):
    """

    Réseau de neurones Feed-Forward (MLP).

    Deux couches linéaires avec une activation GELU entre les deux.

    """
    
    def __init__(self, config: ChessConfig):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, config.n_inner)
        self.c_proj = nn.Linear(config.n_inner, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.c_fc(x)
        x = F.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x


class TransformerBlock(nn.Module):
    """

    Un bloc transformer unique contenant Attention et Feed-Forward.

    Utilise la "Pre-normalization" (LayerNorm avant l'attention/FFN) pour la stabilité.

    """
    
    def __init__(self, config: ChessConfig):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.attn = MultiHeadAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.mlp = FeedForward(config)

    def forward(

        self,

        x: torch.Tensor,

        attention_mask: Optional[torch.Tensor] = None,

    ) -> torch.Tensor:
        # Connexion résiduelle + Pre-norm Attention
        x = x + self.attn(self.ln_1(x), attention_mask=attention_mask)
        # Connexion résiduelle + Pre-norm FFN
        x = x + self.mlp(self.ln_2(x))
        return x


# -----------------------------------------------------------------------------
# Modèle Principal
# -----------------------------------------------------------------------------

class ChessForCausalLM(PreTrainedModel):
    """

    Modèle final pour la prédiction de coups (Causal Language Modeling).

    

    Architecture :

    1. Embeddings (Tokens + Position)

    2. Empilement de blocs Transformer

    3. Tête linéaire finale (Projection vers le vocabulaire)

    """
    
    config_class = ChessConfig
    base_model_prefix = "transformer"
    supports_gradient_checkpointing = True
    
    # Ignore l'avertissement de clé manquante car lm_head partage les poids avec wte
    keys_to_ignore_on_load_missing = ["lm_head.weight"]
    
    def __init__(self, config: ChessConfig):
        super().__init__(config)
        
        # Embeddings
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
        
        self.drop = nn.Dropout(config.dropout)
        
        # Blocs Transformer
        self.h = nn.ModuleList([
            TransformerBlock(config) for _ in range(config.n_layer)
        ])
        
        # LayerNorm final
        self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        
        # Tête de sortie (sans biais pour économiser des paramètres)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
        # Gestion du partage de poids (Weight Tying)
        if config.tie_weights:
            self._tied_weights_keys = ["lm_head.weight"]
        
        # Initialisation des poids
        self.post_init()
        
        # Forcer le lien des poids si configuré
        if config.tie_weights:
            self.tie_weights()

    def get_input_embeddings(self) -> nn.Module:
        return self.wte

    def set_input_embeddings(self, new_embeddings: nn.Module):
        self.wte = new_embeddings
        if getattr(self.config, "tie_weights", False):
            self.tie_weights()

    def get_output_embeddings(self) -> nn.Module:
        return self.lm_head

    def set_output_embeddings(self, new_embeddings: nn.Module):
        self.lm_head = new_embeddings

    def tie_weights(self):
        """Lie les poids de l'embedding d'entrée et de la tête de sortie."""
        if getattr(self.config, "tie_weights", False) or getattr(self.config, "tie_word_embeddings", False):
            self._tie_or_clone_weights(self.lm_head, self.wte)

    def _init_weights(self, module: nn.Module):
        """Initialisation des poids style GPT-2."""
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.ones_(module.weight)
            torch.nn.init.zeros_(module.bias)

    def forward(

        self,

        input_ids: torch.LongTensor,

        attention_mask: Optional[torch.Tensor] = None,

        position_ids: Optional[torch.LongTensor] = None,

        labels: Optional[torch.LongTensor] = None,

        return_dict: Optional[bool] = None,

        **kwargs,

    ) -> Union[Tuple, CausalLMOutputWithPast]:
        """

        Passe avant (Forward pass).

        Calcule les logits et, si des étiquettes (labels) sont fournies, la perte (loss).

        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        batch_size, seq_len = input_ids.size()
        device = input_ids.device
        
        # Création des IDs de position s'ils ne sont pas fournis
        if position_ids is None:
            position_ids = torch.arange(seq_len, device=device).unsqueeze(0).expand(batch_size, -1)
        
        # Calcul des embeddings (Token + Position)
        token_embeds = self.wte(input_ids)
        position_embeds = self.wpe(position_ids)
        hidden_states = self.drop(token_embeds + position_embeds)
        
        # Passage dans les blocs Transformer
        for block in self.h:
            hidden_states = block(hidden_states, attention_mask=attention_mask)
        
        # Normalisation finale
        hidden_states = self.ln_f(hidden_states)
        
        # Calcul des logits (prédiction du prochain token)
        logits = self.lm_head(hidden_states)
        
        # Calcul de la perte (Training Loss)
        loss = None
        if labels is not None:
            # On décale les logits et les labels d'un cran
            # (le modèle doit prédire le token t+1 à partir du token t)
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            
            # Perte CrossEntropy standard
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
            )
        
        if not return_dict:
            output = (logits,)
            return ((loss,) + output) if loss is not None else output
        
        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None,
        )

    @torch.no_grad()
    def generate_move(

        self,

        input_ids: torch.LongTensor,

        temperature: float = 1.0,

        top_k: Optional[int] = None,

        top_p: Optional[float] = None,

    ) -> int:
        """

        Génère le prochain coup pour une séquence donnée.

        Utilisé pour l'inférence en jeu réel.

        """
        self.eval()
        
        # Récupère les logits pour la dernière position uniquement
        outputs = self(input_ids)
        logits = outputs.logits[:, -1, :] / temperature
        
        # Filtrage Top-K
        if top_k is not None:
            indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
            logits[indices_to_remove] = float("-inf")
        
        # Filtrage Top-P (Nucleus Sampling)
        if top_p is not None:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            
            # On retire les tokens qui sont au-dessus du seuil cumulatif
            sorted_indices_to_remove = cumulative_probs > top_p
            sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
            sorted_indices_to_remove[..., 0] = 0
            
            indices_to_remove = sorted_indices_to_remove.scatter(
                dim=-1, index=sorted_indices, src=sorted_indices_to_remove
            )
            logits[indices_to_remove] = float("-inf")
        
        # Échantillonnage final
        probs = F.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        
        return next_token.item()

# Enregistrement pour chargement automatique via AutoModel
AutoConfig.register("chess_transformer", ChessConfig)
AutoModelForCausalLM.register(ChessConfig, ChessForCausalLM)