Chess Challenge submission by iliasslasri
Browse files- README.md +3 -3
- config.json +8 -6
- model.py +86 -4
- model.safetensors +2 -2
README.md
CHANGED
|
@@ -14,13 +14,13 @@ Chess model submitted to the LLM Course Chess Challenge.
|
|
| 14 |
## Submission Info
|
| 15 |
|
| 16 |
- **Submitted by**: [iliasslasri](https://huggingface.co/iliasslasri)
|
| 17 |
-
- **Parameters**:
|
| 18 |
- **Organization**: LLM-course
|
| 19 |
|
| 20 |
## Model Details
|
| 21 |
|
| 22 |
- **Architecture**: Chess Transformer (GPT-style)
|
| 23 |
- **Vocab size**: 75
|
| 24 |
-
- **Embedding dim**:
|
| 25 |
- **Layers**: 11
|
| 26 |
-
- **Heads**:
|
|
|
|
| 14 |
## Submission Info
|
| 15 |
|
| 16 |
- **Submitted by**: [iliasslasri](https://huggingface.co/iliasslasri)
|
| 17 |
+
- **Parameters**: 998,036
|
| 18 |
- **Organization**: LLM-course
|
| 19 |
|
| 20 |
## Model Details
|
| 21 |
|
| 22 |
- **Architecture**: Chess Transformer (GPT-style)
|
| 23 |
- **Vocab size**: 75
|
| 24 |
+
- **Embedding dim**: 96
|
| 25 |
- **Layers**: 11
|
| 26 |
+
- **Heads**: 8
|
config.json
CHANGED
|
@@ -1,9 +1,9 @@
|
|
| 1 |
{
|
| 2 |
-
"_name_or_path": "./
|
| 3 |
"architectures": [
|
| 4 |
"ChessForCausalLM"
|
| 5 |
],
|
| 6 |
-
"attn": "
|
| 7 |
"auto_map": {
|
| 8 |
"AutoConfig": "model.ChessConfig",
|
| 9 |
"AutoModelForCausalLM": "model.ChessForCausalLM"
|
|
@@ -14,12 +14,14 @@
|
|
| 14 |
"layer_norm_epsilon": 1e-05,
|
| 15 |
"model_type": "chess_transformer",
|
| 16 |
"n_ctx": 256,
|
| 17 |
-
"n_embd":
|
| 18 |
-
"n_head":
|
| 19 |
-
"n_inner":
|
| 20 |
"n_layer": 11,
|
| 21 |
-
"num_groups":
|
| 22 |
"pad_token_id": 0,
|
|
|
|
|
|
|
| 23 |
"tie_weights": false,
|
| 24 |
"tie_word_embeddings": false,
|
| 25 |
"torch_dtype": "float32",
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "./gqa_rpe/checkpoint-311724/",
|
| 3 |
"architectures": [
|
| 4 |
"ChessForCausalLM"
|
| 5 |
],
|
| 6 |
+
"attn": "GQA",
|
| 7 |
"auto_map": {
|
| 8 |
"AutoConfig": "model.ChessConfig",
|
| 9 |
"AutoModelForCausalLM": "model.ChessForCausalLM"
|
|
|
|
| 14 |
"layer_norm_epsilon": 1e-05,
|
| 15 |
"model_type": "chess_transformer",
|
| 16 |
"n_ctx": 256,
|
| 17 |
+
"n_embd": 96,
|
| 18 |
+
"n_head": 8,
|
| 19 |
+
"n_inner": 316,
|
| 20 |
"n_layer": 11,
|
| 21 |
+
"num_groups": 4,
|
| 22 |
"pad_token_id": 0,
|
| 23 |
+
"rot_pos_emb": true,
|
| 24 |
+
"rotary_base": 10000,
|
| 25 |
"tie_weights": false,
|
| 26 |
"tie_word_embeddings": false,
|
| 27 |
"torch_dtype": "float32",
|
model.py
CHANGED
|
@@ -66,6 +66,8 @@ class ChessConfig(PretrainedConfig):
|
|
| 66 |
eos_token_id: int = 2,
|
| 67 |
attn: str = "MHA",
|
| 68 |
num_groups: int = 2,
|
|
|
|
|
|
|
| 69 |
**kwargs,
|
| 70 |
):
|
| 71 |
super().__init__(
|
|
@@ -91,6 +93,11 @@ class ChessConfig(PretrainedConfig):
|
|
| 91 |
self.attn = attn
|
| 92 |
self.num_groups = num_groups
|
| 93 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
class MultiHeadAttention(nn.Module):
|
| 96 |
"""
|
|
@@ -110,6 +117,14 @@ class MultiHeadAttention(nn.Module):
|
|
| 110 |
self.n_embd = config.n_embd
|
| 111 |
self.head_dim = config.n_embd // config.n_head
|
| 112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
# Combined QKV projection for efficiency
|
| 114 |
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
| 115 |
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
|
@@ -141,6 +156,10 @@ class MultiHeadAttention(nn.Module):
|
|
| 141 |
k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
|
| 142 |
v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
# Scaled dot-product attention
|
| 145 |
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 146 |
|
|
@@ -328,8 +347,10 @@ class ChessForCausalLM(PreTrainedModel):
|
|
| 328 |
|
| 329 |
# Token and position embeddings
|
| 330 |
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 331 |
-
|
| 332 |
-
|
|
|
|
|
|
|
| 333 |
self.drop = nn.Dropout(config.dropout)
|
| 334 |
|
| 335 |
# Transformer blocks
|
|
@@ -418,8 +439,11 @@ class ChessForCausalLM(PreTrainedModel):
|
|
| 418 |
|
| 419 |
# Get embeddings
|
| 420 |
token_embeds = self.wte(input_ids)
|
| 421 |
-
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
# Pass through transformer blocks
|
| 425 |
for block in self.h:
|
|
@@ -510,6 +534,64 @@ class ChessForCausalLM(PreTrainedModel):
|
|
| 510 |
|
| 511 |
return next_token.item()
|
| 512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
|
| 514 |
# Register the model with Auto classes for easy loading
|
| 515 |
from transformers import AutoConfig, AutoModelForCausalLM
|
|
|
|
| 66 |
eos_token_id: int = 2,
|
| 67 |
attn: str = "MHA",
|
| 68 |
num_groups: int = 2,
|
| 69 |
+
rot_pos_emb=False,
|
| 70 |
+
rotary_base=10000,
|
| 71 |
**kwargs,
|
| 72 |
):
|
| 73 |
super().__init__(
|
|
|
|
| 93 |
self.attn = attn
|
| 94 |
self.num_groups = num_groups
|
| 95 |
|
| 96 |
+
# rot_pos_emb
|
| 97 |
+
self.rot_pos_emb = rot_pos_emb
|
| 98 |
+
self.rotary_base = rotary_base
|
| 99 |
+
|
| 100 |
+
|
| 101 |
|
| 102 |
class MultiHeadAttention(nn.Module):
|
| 103 |
"""
|
|
|
|
| 117 |
self.n_embd = config.n_embd
|
| 118 |
self.head_dim = config.n_embd // config.n_head
|
| 119 |
|
| 120 |
+
self.rot_pos_emb = config.rot_pos_emb
|
| 121 |
+
if self.rot_pos_emb:
|
| 122 |
+
self.rotary_emb = RotaryEmbedding(
|
| 123 |
+
self.head_dim,
|
| 124 |
+
max_position_embeddings=config.n_ctx,
|
| 125 |
+
base=getattr(config, 'rotary_base', 10000)
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
# Combined QKV projection for efficiency
|
| 129 |
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
|
| 130 |
self.c_proj = nn.Linear(config.n_embd, config.n_embd)
|
|
|
|
| 156 |
k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
|
| 157 |
v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
|
| 158 |
|
| 159 |
+
if self.rot_pos_emb:
|
| 160 |
+
cos, sin = self.rotary_emb(v, seq_len=seq_len)
|
| 161 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| 162 |
+
|
| 163 |
# Scaled dot-product attention
|
| 164 |
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 165 |
|
|
|
|
| 347 |
|
| 348 |
# Token and position embeddings
|
| 349 |
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 350 |
+
if not config.rot_pos_emb:
|
| 351 |
+
self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
|
| 352 |
+
self.rot_pos_emb = config.rot_pos_emb
|
| 353 |
+
|
| 354 |
self.drop = nn.Dropout(config.dropout)
|
| 355 |
|
| 356 |
# Transformer blocks
|
|
|
|
| 439 |
|
| 440 |
# Get embeddings
|
| 441 |
token_embeds = self.wte(input_ids)
|
| 442 |
+
if not self.rot_pos_emb:
|
| 443 |
+
position_embeds = self.wpe(position_ids)
|
| 444 |
+
hidden_states = self.drop(token_embeds + position_embeds)
|
| 445 |
+
else:
|
| 446 |
+
hidden_states = self.drop(token_embeds)
|
| 447 |
|
| 448 |
# Pass through transformer blocks
|
| 449 |
for block in self.h:
|
|
|
|
| 534 |
|
| 535 |
return next_token.item()
|
| 536 |
|
| 537 |
+
class RotaryEmbedding(nn.Module):
|
| 538 |
+
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
|
| 539 |
+
super().__init__()
|
| 540 |
+
self.dim = dim
|
| 541 |
+
self.max_position_embeddings = max_position_embeddings
|
| 542 |
+
self.base = base
|
| 543 |
+
inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float().to(device) / dim))
|
| 544 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 545 |
+
|
| 546 |
+
# Build here to make `forward` cleaner
|
| 547 |
+
self._set_cos_sin_cache(
|
| 548 |
+
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
|
| 549 |
+
)
|
| 550 |
+
|
| 551 |
+
def _set_cos_sin_cache(self, seq_len, device, dtype):
|
| 552 |
+
self.max_seq_len_cached = seq_len
|
| 553 |
+
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
|
| 554 |
+
|
| 555 |
+
freqs = torch.outer(t, self.inv_freq)
|
| 556 |
+
# Different implementations use polar form; here we use the LLaMA style expansion
|
| 557 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 558 |
+
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
|
| 559 |
+
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 560 |
+
|
| 561 |
+
def forward(self, x, seq_len=None):
|
| 562 |
+
# x: [batch, seq_len, head_dim]
|
| 563 |
+
if seq_len > self.max_seq_len_cached:
|
| 564 |
+
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
|
| 565 |
+
|
| 566 |
+
return (
|
| 567 |
+
self.cos_cached[:seq_len].to(dtype=x.dtype),
|
| 568 |
+
self.sin_cached[:seq_len].to(dtype=x.dtype),
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
def rotate_half(x):
|
| 572 |
+
"""Rotates half the hidden dims of the input."""
|
| 573 |
+
x1 = x[..., : x.shape[-1] // 2]
|
| 574 |
+
x2 = x[..., x.shape[-1] // 2 :]
|
| 575 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 576 |
+
|
| 577 |
+
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
|
| 578 |
+
# q, k: [batch, seq_len, heads, head_dim] -> transpose to [batch, heads, seq_len, head_dim] for math
|
| 579 |
+
# But your model uses [batch, seq_len, heads, head_dim] internally until transpose.
|
| 580 |
+
# Let's align with the shape inside your attention:
|
| 581 |
+
# Your code computes: q = q.view(batch, seq, heads, dim).transpose(1, 2) -> [batch, heads, seq, dim]
|
| 582 |
+
|
| 583 |
+
# We assume inputs q, k are [batch, heads, seq_len, head_dim]
|
| 584 |
+
# cos, sin are [seq_len, head_dim] -> unsqueeze to [1, 1, seq_len, head_dim]
|
| 585 |
+
|
| 586 |
+
cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, head_dim]
|
| 587 |
+
sin = sin.unsqueeze(0).unsqueeze(0)
|
| 588 |
+
|
| 589 |
+
# If we have custom position_ids (not strictly necessary for causal LM unless doing cache), handle here.
|
| 590 |
+
# For simple causal LM, we assume standard 0..T indexing.
|
| 591 |
+
|
| 592 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 593 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 594 |
+
return q_embed, k_embed
|
| 595 |
|
| 596 |
# Register the model with Auto classes for easy loading
|
| 597 |
from transformers import AutoConfig, AutoModelForCausalLM
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:105d6544cc86f6cad342e7ed9602f9628e3399e230a2048d1149d8dc09d35aa6
|
| 3 |
+
size 4007408
|