| import math, torch, torch.nn as nn, torch.nn.functional as F |
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| class ChessConfig(PretrainedConfig): |
| model_type = "chess_gpt" |
| def __init__(self, vocab_size=1000, n_positions=256, n_embd=128, n_layer=4, n_head=4, dropout=0.1, **kwargs): |
| super().__init__(**kwargs) |
| self.vocab_size = vocab_size |
| self.n_positions = n_positions |
| self.n_embd = n_embd |
| self.n_layer = n_layer |
| self.n_head = n_head |
| self.dropout = dropout |
|
|
| class CausalSelfAttention(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) |
| self.c_proj = nn.Linear(config.n_embd, config.n_embd) |
| self.attn_dropout = nn.Dropout(config.dropout) |
| self.resid_dropout = nn.Dropout(config.dropout) |
| self.n_head = config.n_head |
| self.n_embd = config.n_embd |
| def forward(self, x): |
| B, T, C = x.size() |
| q, k, v = self.c_attn(x).split(self.n_embd, dim=2) |
| k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) |
| att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) |
| mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool() |
| att = att.masked_fill(mask, float('-inf')) |
| att = F.softmax(att, dim=-1) |
| y = self.resid_dropout(self.c_proj((self.attn_dropout(att) @ v).transpose(1, 2).contiguous().view(B, T, C))) |
| return y |
|
|
| class GPTBlock(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.ln_1 = nn.LayerNorm(config.n_embd) |
| self.attn = CausalSelfAttention(config) |
| self.ln_2 = nn.LayerNorm(config.n_embd) |
| self.mlp = nn.Sequential(nn.Linear(config.n_embd, 4 * config.n_embd), nn.GELU(), nn.Linear(4 * config.n_embd, config.n_embd), nn.Dropout(config.dropout)) |
| def forward(self, x): return x + self.mlp(self.ln_2(x + self.attn(self.ln_1(x)))) |
|
|
| class ChessGPT(PreTrainedModel): |
| config_class = ChessConfig |
| def __init__(self, config): |
| super().__init__(config) |
| self.token_emb = nn.Embedding(config.vocab_size, config.n_embd) |
| self.pos_emb = nn.Embedding(config.n_positions, config.n_embd) |
| self.dropout = nn.Dropout(config.dropout) |
| self.blocks = nn.ModuleList([GPTBlock(config) for _ in range(config.n_layer)]) |
| self.ln_f = nn.LayerNorm(config.n_embd) |
| self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
| self.apply(self._init_weights) |
| |
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.Embedding): |
| torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) |
| elif isinstance(module, nn.LayerNorm): |
| torch.nn.init.zeros_(module.bias) |
| torch.nn.init.ones_(module.weight) |
|
|
| def forward(self, input_ids, labels=None, **kwargs): |
| B, T = input_ids.shape |
| x = self.dropout(self.token_emb(input_ids) + self.pos_emb(torch.arange(0, T, device=input_ids.device))) |
| for block in self.blocks: x = block(x) |
| logits = self.lm_head(self.ln_f(x)) |
| loss = None |
| if labels is not None: |
| |
| labels = labels.clamp(0, self.config.vocab_size - 1) |
| loss = nn.CrossEntropyLoss(ignore_index=0)(logits[..., :-1, :].contiguous().view(-1, logits.size(-1)), labels[..., 1:].contiguous().view(-1)) |
| return CausalLMOutputWithPast(loss=loss, logits=logits) |
|
|