devoppro commited on
Commit
eec13c0
·
verified ·
1 Parent(s): 051a4f4

Create model/architecture.py

Browse files
Files changed (1) hide show
  1. model/architecture.py +237 -0
model/architecture.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CodeLLM - Custom Decoder-only Transformer Architecture
3
+ Built from scratch for code generation.
4
+ Architecture: GPT-style, 125M parameters
5
+ """
6
+
7
+ import math
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from dataclasses import dataclass
12
+ from typing import Optional, Tuple
13
+
14
+
15
+ @dataclass
16
+ class CodeLLMConfig:
17
+ vocab_size: int = 50257
18
+ n_positions: int = 2048
19
+ n_embd: int = 768
20
+ n_layer: int = 12
21
+ n_head: int = 12
22
+ n_inner: int = 3072
23
+ dropout: float = 0.1
24
+ layer_norm_epsilon: float = 1e-5
25
+ initializer_range: float = 0.02
26
+ use_cache: bool = True
27
+ pad_token_id: int = 50256
28
+ bos_token_id: int = 50256
29
+ eos_token_id: int = 50256
30
+ tie_word_embeddings: bool = True
31
+
32
+ @property
33
+ def num_parameters(self):
34
+ embed = self.vocab_size * self.n_embd
35
+ attn = self.n_layer * (4 * self.n_embd * self.n_embd)
36
+ ffn = self.n_layer * (2 * self.n_embd * self.n_inner)
37
+ return embed + attn + ffn
38
+
39
+
40
+ class RotaryEmbedding(nn.Module):
41
+ def __init__(self, dim: int, max_seq_len: int = 2048, base: int = 10000):
42
+ super().__init__()
43
+ self.dim = dim
44
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
45
+ self.register_buffer("inv_freq", inv_freq)
46
+ self._build_cache(max_seq_len)
47
+
48
+ def _build_cache(self, seq_len: int):
49
+ t = torch.arange(seq_len, device=self.inv_freq.device).float()
50
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
51
+ emb = torch.cat([freqs, freqs], dim=-1)
52
+ self.register_buffer("cos_cache", emb.cos()[None, None, :, :])
53
+ self.register_buffer("sin_cache", emb.sin()[None, None, :, :])
54
+
55
+ def forward(self, q: torch.Tensor, k: torch.Tensor, seq_len: int):
56
+ if seq_len > self.cos_cache.shape[2]:
57
+ self._build_cache(seq_len)
58
+ cos = self.cos_cache[:, :, :seq_len, :]
59
+ sin = self.sin_cache[:, :, :seq_len, :]
60
+ return apply_rotary(q, cos, sin), apply_rotary(k, cos, sin)
61
+
62
+
63
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
64
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
65
+ return torch.cat([-x2, x1], dim=-1)
66
+
67
+
68
+ def apply_rotary(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
69
+ return (x * cos) + (rotate_half(x) * sin)
70
+
71
+
72
+ class CausalSelfAttention(nn.Module):
73
+ def __init__(self, config: CodeLLMConfig):
74
+ super().__init__()
75
+ assert config.n_embd % config.n_head == 0
76
+ self.n_head = config.n_head
77
+ self.n_embd = config.n_embd
78
+ self.head_dim = config.n_embd // config.n_head
79
+ self.dropout = config.dropout
80
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
81
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
82
+ self.attn_drop = nn.Dropout(config.dropout)
83
+ self.resid_drop = nn.Dropout(config.dropout)
84
+ self.rotary = RotaryEmbedding(self.head_dim, max_seq_len=config.n_positions)
85
+ self.register_buffer(
86
+ "bias",
87
+ torch.tril(torch.ones(config.n_positions, config.n_positions))
88
+ .view(1, 1, config.n_positions, config.n_positions),
89
+ )
90
+
91
+ def forward(self, x, attention_mask=None, past_key_value=None, use_cache=False):
92
+ B, T, C = x.size()
93
+ qkv = self.c_attn(x)
94
+ q, k, v = qkv.split(self.n_embd, dim=2)
95
+ q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
96
+ k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
97
+ v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
98
+ q, k = self.rotary(q, k, seq_len=T)
99
+ if past_key_value is not None:
100
+ k = torch.cat([past_key_value[0], k], dim=2)
101
+ v = torch.cat([past_key_value[1], v], dim=2)
102
+ present = (k, v) if use_cache else None
103
+ if hasattr(F, "scaled_dot_product_attention"):
104
+ y = F.scaled_dot_product_attention(
105
+ q, k, v,
106
+ attn_mask=attention_mask,
107
+ dropout_p=self.dropout if self.training else 0.0,
108
+ is_causal=(past_key_value is None),
109
+ )
110
+ else:
111
+ scale = 1.0 / math.sqrt(self.head_dim)
112
+ attn = (q @ k.transpose(-2, -1)) * scale
113
+ kT = k.size(2)
114
+ causal_mask = self.bias[:, :, kT - T : kT, :kT]
115
+ attn = attn.masked_fill(causal_mask == 0, float("-inf"))
116
+ if attention_mask is not None:
117
+ attn = attn + attention_mask
118
+ attn = F.softmax(attn, dim=-1)
119
+ attn = self.attn_drop(attn)
120
+ y = attn @ v
121
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
122
+ y = self.resid_drop(self.c_proj(y))
123
+ return y, present
124
+
125
+
126
+ class SwiGLUFFN(nn.Module):
127
+ def __init__(self, config: CodeLLMConfig):
128
+ super().__init__()
129
+ hidden = config.n_inner
130
+ self.w1 = nn.Linear(config.n_embd, hidden, bias=False)
131
+ self.w2 = nn.Linear(config.n_embd, hidden, bias=False)
132
+ self.w3 = nn.Linear(hidden, config.n_embd, bias=False)
133
+ self.drop = nn.Dropout(config.dropout)
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ return self.drop(self.w3(F.silu(self.w1(x)) * self.w2(x)))
137
+
138
+
139
+ class TransformerBlock(nn.Module):
140
+ def __init__(self, config: CodeLLMConfig):
141
+ super().__init__()
142
+ self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
143
+ self.attn = CausalSelfAttention(config)
144
+ self.ln_2 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
145
+ self.ffn = SwiGLUFFN(config)
146
+
147
+ def forward(self, x, attention_mask=None, past_key_value=None, use_cache=False):
148
+ attn_out, present = self.attn(
149
+ self.ln_1(x),
150
+ attention_mask=attention_mask,
151
+ past_key_value=past_key_value,
152
+ use_cache=use_cache,
153
+ )
154
+ x = x + attn_out
155
+ x = x + self.ffn(self.ln_2(x))
156
+ return x, present
157
+
158
+
159
+ class CodeLLM(nn.Module):
160
+ def __init__(self, config: CodeLLMConfig):
161
+ super().__init__()
162
+ self.config = config
163
+ self.transformer = nn.ModuleDict(dict(
164
+ wte = nn.Embedding(config.vocab_size, config.n_embd),
165
+ drop = nn.Dropout(config.dropout),
166
+ h = nn.ModuleList([TransformerBlock(config) for _ in range(config.n_layer)]),
167
+ ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
168
+ ))
169
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
170
+ if config.tie_word_embeddings:
171
+ self.lm_head.weight = self.transformer.wte.weight
172
+ self.apply(self._init_weights)
173
+ for name, p in self.named_parameters():
174
+ if name.endswith("c_proj.weight"):
175
+ nn.init.normal_(p, mean=0.0, std=config.initializer_range / math.sqrt(2 * config.n_layer))
176
+ print(f"CodeLLM initialized | params: {self.num_parameters:,}")
177
+
178
+ def _init_weights(self, module):
179
+ if isinstance(module, nn.Linear):
180
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
181
+ if module.bias is not None:
182
+ nn.init.zeros_(module.bias)
183
+ elif isinstance(module, nn.Embedding):
184
+ nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
185
+
186
+ @property
187
+ def num_parameters(self):
188
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
189
+
190
+ def forward(self, input_ids, attention_mask=None, labels=None, past_key_values=None, use_cache=False):
191
+ B, T = input_ids.size()
192
+ x = self.transformer.wte(input_ids)
193
+ x = self.transformer.drop(x)
194
+ presents = []
195
+ for i, block in enumerate(self.transformer.h):
196
+ past_kv = past_key_values[i] if past_key_values else None
197
+ x, present = block(x, attention_mask=attention_mask, past_key_value=past_kv, use_cache=use_cache)
198
+ if use_cache:
199
+ presents.append(present)
200
+ x = self.transformer.ln_f(x)
201
+ logits = self.lm_head(x)
202
+ loss = None
203
+ if labels is not None:
204
+ shift_logits = logits[..., :-1, :].contiguous()
205
+ shift_labels = labels[..., 1:].contiguous()
206
+ loss = F.cross_entropy(
207
+ shift_logits.view(-1, shift_logits.size(-1)),
208
+ shift_labels.view(-1),
209
+ ignore_index=-100,
210
+ )
211
+ return {"loss": loss, "logits": logits, "past_key_values": presents if use_cache else None}
212
+
213
+ @torch.no_grad()
214
+ def generate(self, input_ids, max_new_tokens=256, temperature=0.8, top_k=50, top_p=0.95, eos_token_id=None):
215
+ self.eval()
216
+ past_key_values = None
217
+ eos = eos_token_id or self.config.eos_token_id
218
+ for _ in range(max_new_tokens):
219
+ input_slice = input_ids if past_key_values is None else input_ids[:, -1:]
220
+ out = self.forward(input_slice, past_key_values=past_key_values, use_cache=True)
221
+ past_key_values = out["past_key_values"]
222
+ logits = out["logits"][:, -1, :] / temperature
223
+ if top_k > 0:
224
+ v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
225
+ logits[logits < v[:, [-1]]] = float("-inf")
226
+ if top_p < 1.0:
227
+ sorted_logits, sorted_idx = torch.sort(logits, descending=True)
228
+ cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
229
+ remove = cumprobs - F.softmax(sorted_logits, dim=-1) > top_p
230
+ sorted_logits[remove] = float("-inf")
231
+ logits.scatter_(1, sorted_idx, sorted_logits)
232
+ probs = F.softmax(logits, dim=-1)
233
+ next_tok = torch.multinomial(probs, num_samples=1)
234
+ input_ids = torch.cat([input_ids, next_tok], dim=1)
235
+ if (next_tok == eos).all():
236
+ break
237
+ return input_ids