iliasslasri commited on
Commit
c4df49c
·
verified ·
1 Parent(s): 535d8f0

Chess Challenge submission by iliasslasri

Browse files
Files changed (4) hide show
  1. README.md +3 -3
  2. config.json +8 -6
  3. model.py +86 -4
  4. model.safetensors +2 -2
README.md CHANGED
@@ -14,13 +14,13 @@ Chess model submitted to the LLM Course Chess Challenge.
14
  ## Submission Info
15
 
16
  - **Submitted by**: [iliasslasri](https://huggingface.co/iliasslasri)
17
- - **Parameters**: 980,720
18
  - **Organization**: LLM-course
19
 
20
  ## Model Details
21
 
22
  - **Architecture**: Chess Transformer (GPT-style)
23
  - **Vocab size**: 75
24
- - **Embedding dim**: 92
25
  - **Layers**: 11
26
- - **Heads**: 4
 
14
  ## Submission Info
15
 
16
  - **Submitted by**: [iliasslasri](https://huggingface.co/iliasslasri)
17
+ - **Parameters**: 998,036
18
  - **Organization**: LLM-course
19
 
20
  ## Model Details
21
 
22
  - **Architecture**: Chess Transformer (GPT-style)
23
  - **Vocab size**: 75
24
+ - **Embedding dim**: 96
25
  - **Layers**: 11
26
+ - **Heads**: 8
config.json CHANGED
@@ -1,9 +1,9 @@
1
  {
2
- "_name_or_path": "./11_4_92_ft_ft_ft/checkpoint-475008/",
3
  "architectures": [
4
  "ChessForCausalLM"
5
  ],
6
- "attn": "MHA",
7
  "auto_map": {
8
  "AutoConfig": "model.ChessConfig",
9
  "AutoModelForCausalLM": "model.ChessForCausalLM"
@@ -14,12 +14,14 @@
14
  "layer_norm_epsilon": 1e-05,
15
  "model_type": "chess_transformer",
16
  "n_ctx": 256,
17
- "n_embd": 92,
18
- "n_head": 4,
19
- "n_inner": 276,
20
  "n_layer": 11,
21
- "num_groups": 2,
22
  "pad_token_id": 0,
 
 
23
  "tie_weights": false,
24
  "tie_word_embeddings": false,
25
  "torch_dtype": "float32",
 
1
  {
2
+ "_name_or_path": "./gqa_rpe/checkpoint-311724/",
3
  "architectures": [
4
  "ChessForCausalLM"
5
  ],
6
+ "attn": "GQA",
7
  "auto_map": {
8
  "AutoConfig": "model.ChessConfig",
9
  "AutoModelForCausalLM": "model.ChessForCausalLM"
 
14
  "layer_norm_epsilon": 1e-05,
15
  "model_type": "chess_transformer",
16
  "n_ctx": 256,
17
+ "n_embd": 96,
18
+ "n_head": 8,
19
+ "n_inner": 316,
20
  "n_layer": 11,
21
+ "num_groups": 4,
22
  "pad_token_id": 0,
23
+ "rot_pos_emb": true,
24
+ "rotary_base": 10000,
25
  "tie_weights": false,
26
  "tie_word_embeddings": false,
27
  "torch_dtype": "float32",
model.py CHANGED
@@ -66,6 +66,8 @@ class ChessConfig(PretrainedConfig):
66
  eos_token_id: int = 2,
67
  attn: str = "MHA",
68
  num_groups: int = 2,
 
 
69
  **kwargs,
70
  ):
71
  super().__init__(
@@ -91,6 +93,11 @@ class ChessConfig(PretrainedConfig):
91
  self.attn = attn
92
  self.num_groups = num_groups
93
 
 
 
 
 
 
94
 
95
  class MultiHeadAttention(nn.Module):
96
  """
@@ -110,6 +117,14 @@ class MultiHeadAttention(nn.Module):
110
  self.n_embd = config.n_embd
111
  self.head_dim = config.n_embd // config.n_head
112
 
 
 
 
 
 
 
 
 
113
  # Combined QKV projection for efficiency
114
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
115
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
@@ -141,6 +156,10 @@ class MultiHeadAttention(nn.Module):
141
  k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
142
  v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
143
 
 
 
 
 
144
  # Scaled dot-product attention
145
  attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
146
 
@@ -328,8 +347,10 @@ class ChessForCausalLM(PreTrainedModel):
328
 
329
  # Token and position embeddings
330
  self.wte = nn.Embedding(config.vocab_size, config.n_embd)
331
- self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
332
-
 
 
333
  self.drop = nn.Dropout(config.dropout)
334
 
335
  # Transformer blocks
@@ -418,8 +439,11 @@ class ChessForCausalLM(PreTrainedModel):
418
 
419
  # Get embeddings
420
  token_embeds = self.wte(input_ids)
421
- position_embeds = self.wpe(position_ids)
422
- hidden_states = self.drop(token_embeds + position_embeds)
 
 
 
423
 
424
  # Pass through transformer blocks
425
  for block in self.h:
@@ -510,6 +534,64 @@ class ChessForCausalLM(PreTrainedModel):
510
 
511
  return next_token.item()
512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
 
514
  # Register the model with Auto classes for easy loading
515
  from transformers import AutoConfig, AutoModelForCausalLM
 
66
  eos_token_id: int = 2,
67
  attn: str = "MHA",
68
  num_groups: int = 2,
69
+ rot_pos_emb=False,
70
+ rotary_base=10000,
71
  **kwargs,
72
  ):
73
  super().__init__(
 
93
  self.attn = attn
94
  self.num_groups = num_groups
95
 
96
+ # rot_pos_emb
97
+ self.rot_pos_emb = rot_pos_emb
98
+ self.rotary_base = rotary_base
99
+
100
+
101
 
102
  class MultiHeadAttention(nn.Module):
103
  """
 
117
  self.n_embd = config.n_embd
118
  self.head_dim = config.n_embd // config.n_head
119
 
120
+ self.rot_pos_emb = config.rot_pos_emb
121
+ if self.rot_pos_emb:
122
+ self.rotary_emb = RotaryEmbedding(
123
+ self.head_dim,
124
+ max_position_embeddings=config.n_ctx,
125
+ base=getattr(config, 'rotary_base', 10000)
126
+ )
127
+
128
  # Combined QKV projection for efficiency
129
  self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
130
  self.c_proj = nn.Linear(config.n_embd, config.n_embd)
 
156
  k = k.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
157
  v = v.view(batch_size, seq_len, self.n_head, self.head_dim).transpose(1, 2)
158
 
159
+ if self.rot_pos_emb:
160
+ cos, sin = self.rotary_emb(v, seq_len=seq_len)
161
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
162
+
163
  # Scaled dot-product attention
164
  attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
165
 
 
347
 
348
  # Token and position embeddings
349
  self.wte = nn.Embedding(config.vocab_size, config.n_embd)
350
+ if not config.rot_pos_emb:
351
+ self.wpe = nn.Embedding(config.n_ctx, config.n_embd)
352
+ self.rot_pos_emb = config.rot_pos_emb
353
+
354
  self.drop = nn.Dropout(config.dropout)
355
 
356
  # Transformer blocks
 
439
 
440
  # Get embeddings
441
  token_embeds = self.wte(input_ids)
442
+ if not self.rot_pos_emb:
443
+ position_embeds = self.wpe(position_ids)
444
+ hidden_states = self.drop(token_embeds + position_embeds)
445
+ else:
446
+ hidden_states = self.drop(token_embeds)
447
 
448
  # Pass through transformer blocks
449
  for block in self.h:
 
534
 
535
  return next_token.item()
536
 
537
+ class RotaryEmbedding(nn.Module):
538
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
539
+ super().__init__()
540
+ self.dim = dim
541
+ self.max_position_embeddings = max_position_embeddings
542
+ self.base = base
543
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float().to(device) / dim))
544
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
545
+
546
+ # Build here to make `forward` cleaner
547
+ self._set_cos_sin_cache(
548
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
549
+ )
550
+
551
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
552
+ self.max_seq_len_cached = seq_len
553
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
554
+
555
+ freqs = torch.outer(t, self.inv_freq)
556
+ # Different implementations use polar form; here we use the LLaMA style expansion
557
+ emb = torch.cat((freqs, freqs), dim=-1)
558
+ self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
559
+ self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
560
+
561
+ def forward(self, x, seq_len=None):
562
+ # x: [batch, seq_len, head_dim]
563
+ if seq_len > self.max_seq_len_cached:
564
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
565
+
566
+ return (
567
+ self.cos_cached[:seq_len].to(dtype=x.dtype),
568
+ self.sin_cached[:seq_len].to(dtype=x.dtype),
569
+ )
570
+
571
+ def rotate_half(x):
572
+ """Rotates half the hidden dims of the input."""
573
+ x1 = x[..., : x.shape[-1] // 2]
574
+ x2 = x[..., x.shape[-1] // 2 :]
575
+ return torch.cat((-x2, x1), dim=-1)
576
+
577
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None):
578
+ # q, k: [batch, seq_len, heads, head_dim] -> transpose to [batch, heads, seq_len, head_dim] for math
579
+ # But your model uses [batch, seq_len, heads, head_dim] internally until transpose.
580
+ # Let's align with the shape inside your attention:
581
+ # Your code computes: q = q.view(batch, seq, heads, dim).transpose(1, 2) -> [batch, heads, seq, dim]
582
+
583
+ # We assume inputs q, k are [batch, heads, seq_len, head_dim]
584
+ # cos, sin are [seq_len, head_dim] -> unsqueeze to [1, 1, seq_len, head_dim]
585
+
586
+ cos = cos.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, head_dim]
587
+ sin = sin.unsqueeze(0).unsqueeze(0)
588
+
589
+ # If we have custom position_ids (not strictly necessary for causal LM unless doing cache), handle here.
590
+ # For simple causal LM, we assume standard 0..T indexing.
591
+
592
+ q_embed = (q * cos) + (rotate_half(q) * sin)
593
+ k_embed = (k * cos) + (rotate_half(k) * sin)
594
+ return q_embed, k_embed
595
 
596
  # Register the model with Auto classes for easy loading
597
  from transformers import AutoConfig, AutoModelForCausalLM
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0f7244a5c854e9c9684f98b1b63970ad82899c0545f1fb1b105ce1ae2e8f76a8
3
- size 3934384
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:105d6544cc86f6cad342e7ed9602f9628e3399e230a2048d1149d8dc09d35aa6
3
+ size 4007408