tgregrg commited on
Commit
71289bc
·
verified ·
1 Parent(s): ac5b52d

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +292 -0
main.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
5
+ from datasets import load_dataset
6
+ from huggingface_hub import HfApi, create_repo
7
+ import math
8
+ import os
9
+
10
+ class ZephyrCoderConfig(PretrainedConfig):
11
+ model_type = "zephyr_coder"
12
+ def __init__(
13
+ self,
14
+ vocab_size=128000,
15
+ hidden_size=2560,
16
+ intermediate_size=10240,
17
+ num_hidden_layers=36,
18
+ num_attention_heads=32,
19
+ num_key_value_heads=8,
20
+ max_position_embeddings=8192,
21
+ rope_theta=1000000.0,
22
+ attention_dropout=0.0,
23
+ hidden_dropout=0.0,
24
+ use_flash_attention=True,
25
+ use_moe=True,
26
+ num_experts=24,
27
+ num_experts_per_tok=6,
28
+ sliding_window_size=4096,
29
+ pad_token_id=0,
30
+ bos_token_id=1,
31
+ eos_token_id=2,
32
+ **kwargs
33
+ ):
34
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
35
+ self.vocab_size = vocab_size
36
+ self.hidden_size = hidden_size
37
+ self.intermediate_size = intermediate_size
38
+ self.num_hidden_layers = num_hidden_layers
39
+ self.num_attention_heads = num_attention_heads
40
+ self.num_key_value_heads = num_key_value_heads
41
+ self.max_position_embeddings = max_position_embeddings
42
+ self.rope_theta = rope_theta
43
+ self.attention_dropout = attention_dropout
44
+ self.hidden_dropout = hidden_dropout
45
+ self.use_flash_attention = use_flash_attention
46
+ self.use_moe = use_moe
47
+ self.num_experts = num_experts
48
+ self.num_experts_per_tok = num_experts_per_tok
49
+ self.sliding_window_size = sliding_window_size
50
+
51
+ class RMSNorm(nn.Module):
52
+ def __init__(self, hidden_size, eps=1e-6):
53
+ super().__init__()
54
+ self.weight = nn.Parameter(torch.ones(hidden_size))
55
+ self.eps = eps
56
+ def forward(self, x):
57
+ variance = x.pow(2).mean(-1, keepdim=True)
58
+ x = x * torch.rsqrt(variance + self.eps)
59
+ return self.weight * x
60
+
61
+ class RotaryEmbedding(nn.Module):
62
+ def __init__(self, dim, max_position_embeddings=8192, base=1000000.0):
63
+ super().__init__()
64
+ self.dim = dim
65
+ self.max_position_embeddings = max_position_embeddings
66
+ self.base = base
67
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
68
+ self.register_buffer("inv_freq", inv_freq)
69
+ self._build_cache(max_position_embeddings)
70
+ def _build_cache(self, seq_len):
71
+ t = torch.arange(seq_len, device=self.inv_freq.device)
72
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
73
+ emb = torch.cat((freqs, freqs), dim=-1)
74
+ self.register_buffer("cos_cached", emb.cos())
75
+ self.register_buffer("sin_cached", emb.sin())
76
+ def forward(self, x, seq_len=None):
77
+ if seq_len > self.max_position_embeddings:
78
+ self._build_cache(seq_len)
79
+ return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
80
+
81
+ def rotate_half(x):
82
+ x1, x2 = x.chunk(2, dim=-1)
83
+ return torch.cat((-x2, x1), dim=-1)
84
+
85
+ def apply_rotary_pos_emb(q, k, cos, sin):
86
+ cos = cos.unsqueeze(0).unsqueeze(0)
87
+ sin = sin.unsqueeze(0).unsqueeze(0)
88
+ q_embed = (q * cos) + (rotate_half(q) * sin)
89
+ k_embed = (k * cos) + (rotate_half(k) * sin)
90
+ return q_embed, k_embed
91
+
92
+ class GroupedQueryAttention(nn.Module):
93
+ def __init__(self, config):
94
+ super().__init__()
95
+ self.hidden_size = config.hidden_size
96
+ self.num_heads = config.num_attention_heads
97
+ self.num_kv_heads = config.num_key_value_heads
98
+ self.head_dim = config.hidden_size // config.num_attention_heads
99
+ self.num_groups = self.num_heads // self.num_kv_heads
100
+ self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
101
+ self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
102
+ self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
103
+ self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
104
+ self.dropout = nn.Dropout(config.attention_dropout)
105
+ self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
106
+ def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, output_attentions=False):
107
+ batch_size, seq_len, _ = hidden_states.shape
108
+ q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
109
+ k = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
110
+ v = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
111
+ cos, sin = self.rotary_emb(q, seq_len=seq_len)
112
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
113
+ k = k.repeat_interleave(self.num_groups, dim=1)
114
+ v = v.repeat_interleave(self.num_groups, dim=1)
115
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
116
+ if attention_mask is not None:
117
+ attn_weights = attn_weights + attention_mask
118
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
119
+ attn_weights = self.dropout(attn_weights)
120
+ attn_output = torch.matmul(attn_weights, v)
121
+ attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, seq_len, self.hidden_size)
122
+ attn_output = self.o_proj(attn_output)
123
+ return attn_output, attn_weights
124
+
125
+ class MoE(nn.Module):
126
+ def __init__(self, config):
127
+ super().__init__()
128
+ self.num_experts = config.num_experts
129
+ self.num_experts_per_tok = config.num_experts_per_tok
130
+ self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
131
+ self.experts = nn.ModuleList([nn.Sequential(
132
+ nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
133
+ nn.GELU(),
134
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
135
+ ) for _ in range(config.num_experts)])
136
+ def forward(self, x):
137
+ batch_size, seq_len, hidden_size = x.shape
138
+ x_flat = x.view(-1, hidden_size)
139
+ gate_logits = self.gate(x_flat)
140
+ gate_weights = F.softmax(gate_logits, dim=-1)
141
+ top_weights, top_indices = torch.topk(gate_weights, self.num_experts_per_tok, dim=-1)
142
+ top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
143
+ final_output = torch.zeros_like(x_flat)
144
+ for i in range(self.num_experts):
145
+ mask = (top_indices == i).any(dim=-1)
146
+ if mask.any():
147
+ expert_output = self.experts[i](x_flat[mask])
148
+ weight_mask = (top_indices == i).float()
149
+ weights = (top_weights * weight_mask).sum(dim=-1)
150
+ final_output[mask] += expert_output * weights[mask].unsqueeze(-1)
151
+ return final_output.view(batch_size, seq_len, hidden_size)
152
+
153
+ class ZephyrCoderBlock(nn.Module):
154
+ def __init__(self, config):
155
+ super().__init__()
156
+ self.self_attn = GroupedQueryAttention(config)
157
+ self.input_layernorm = RMSNorm(config.hidden_size)
158
+ self.mlp = MoE(config) if config.use_moe else nn.Sequential(
159
+ nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
160
+ nn.GELU(),
161
+ nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
162
+ )
163
+ self.post_attention_layernorm = RMSNorm(config.hidden_size)
164
+ def forward(self, hidden_states, attention_mask=None, position_ids=None):
165
+ residual = hidden_states
166
+ hidden_states = self.input_layernorm(hidden_states)
167
+ attn_output, _ = self.self_attn(hidden_states, attention_mask, position_ids)
168
+ hidden_states = residual + attn_output
169
+ residual = hidden_states
170
+ hidden_states = self.post_attention_layernorm(hidden_states)
171
+ hidden_states = self.mlp(hidden_states)
172
+ hidden_states = residual + hidden_states
173
+ return hidden_states
174
+
175
+ class ZephyrCoderModel(PreTrainedModel):
176
+ config_class = ZephyrCoderConfig
177
+ def __init__(self, config):
178
+ super().__init__(config)
179
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
180
+ self.layers = nn.ModuleList([ZephyrCoderBlock(config) for _ in range(config.num_hidden_layers)])
181
+ self.norm = RMSNorm(config.hidden_size)
182
+ def forward(self, input_ids=None, attention_mask=None, position_ids=None):
183
+ hidden_states = self.embed_tokens(input_ids)
184
+ if attention_mask is not None:
185
+ attention_mask = attention_mask[:, None, None, :]
186
+ attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min
187
+ for layer in self.layers:
188
+ hidden_states = layer(hidden_states, attention_mask, position_ids)
189
+ hidden_states = self.norm(hidden_states)
190
+ return hidden_states
191
+
192
+ class ZephyrCoderForCausalLM(PreTrainedModel):
193
+ config_class = ZephyrCoderConfig
194
+ def __init__(self, config):
195
+ super().__init__(config)
196
+ self.model = ZephyrCoderModel(config)
197
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
198
+ def forward(self, input_ids=None, attention_mask=None, labels=None):
199
+ hidden_states = self.model(input_ids, attention_mask)
200
+ logits = self.lm_head(hidden_states)
201
+ loss = None
202
+ if labels is not None:
203
+ shift_logits = logits[..., :-1, :].contiguous()
204
+ shift_labels = labels[..., 1:].contiguous()
205
+ loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
206
+ return loss, logits
207
+ def generate(self, input_ids, max_length=2048, temperature=0.7, top_p=0.9):
208
+ self.eval()
209
+ with torch.no_grad():
210
+ for _ in range(max_length - input_ids.shape[1]):
211
+ _, logits = self.forward(input_ids=input_ids)
212
+ logits = logits[:, -1, :] / temperature
213
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
214
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
215
+ sorted_indices_to_remove = cumulative_probs > top_p
216
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
217
+ sorted_indices_to_remove[..., 0] = 0
218
+ indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
219
+ logits[indices_to_remove] = float('-inf')
220
+ probs = F.softmax(logits, dim=-1)
221
+ next_token = torch.multinomial(probs, num_samples=1)
222
+ input_ids = torch.cat([input_ids, next_token], dim=-1)
223
+ if next_token.item() == self.config.eos_token_id:
224
+ break
225
+ return input_ids
226
+
227
+ def train_zephyr_coder():
228
+ config = ZephyrCoderConfig()
229
+ model = ZephyrCoderForCausalLM(config)
230
+ tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b")
231
+ tokenizer.add_special_tokens({'pad_token': '[PAD]'})
232
+
233
+ dataset = load_dataset("bigcode/the-stack-dedup", data_dir="data/python", split="train", streaming=True)
234
+ def tokenize_function(examples):
235
+ return tokenizer(examples['content'], truncation=True, max_length=2048, padding=False)
236
+
237
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
238
+
239
+ training_args = TrainingArguments(
240
+ output_dir="./zephyr-coder-final",
241
+ num_train_epochs=3,
242
+ per_device_train_batch_size=2,
243
+ gradient_accumulation_steps=16,
244
+ learning_rate=3e-4,
245
+ warmup_steps=2000,
246
+ logging_steps=10,
247
+ save_steps=1000,
248
+ fp16=True,
249
+ gradient_checkpointing=True,
250
+ optim="adamw_8bit",
251
+ )
252
+
253
+ trainer = Trainer(
254
+ model=model,
255
+ args=training_args,
256
+ train_dataset=tokenized_dataset,
257
+ data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
258
+ )
259
+
260
+ trainer.train()
261
+ trainer.save_model("./zephyr-coder-final")
262
+ tokenizer.save_pretrained("./zephyr-coder-final")
263
+ return model, tokenizer
264
+
265
+ def upload_to_huggingface(model_dir="./zephyr-coder-final", repo_name="zephyr-coder-15b"):
266
+ create_repo(repo_name, exist_ok=True)
267
+ api = HfApi()
268
+ api.upload_folder(folder_path=model_dir, repo_id=repo_name)
269
+ print(f"Uploaded to https://huggingface.co/{repo_name}")
270
+
271
+ def demo():
272
+ tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b")
273
+ config = ZephyrCoderConfig()
274
+ model = ZephyrCoderForCausalLM(config)
275
+
276
+ prompts = [
277
+ "def quicksort(arr):",
278
+ "class TransformerBlock:",
279
+ "def train_neural_network():",
280
+ "async def process_api_request():",
281
+ "def optimize_python_code():",
282
+ ]
283
+
284
+ for prompt in prompts:
285
+ inputs = tokenizer(prompt, return_tensors="pt")
286
+ outputs = model.generate(inputs.input_ids, max_length=500, temperature=0.7, top_p=0.95)
287
+ print(f"\nPrompt: {prompt}\nGenerated:\n{tokenizer.decode(outputs[0])}\n{'-'*80}")
288
+
289
+ if __name__ == "__main__":
290
+ model, tokenizer = train_zephyr_coder()
291
+ upload_to_huggingface()
292
+ demo()