import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling from datasets import load_dataset from huggingface_hub import HfApi, create_repo import math import os class ZephyrCoderConfig(PretrainedConfig): model_type = "zephyr_coder" def __init__( self, vocab_size=128000, hidden_size=2560, intermediate_size=10240, num_hidden_layers=36, num_attention_heads=32, num_key_value_heads=8, max_position_embeddings=8192, rope_theta=1000000.0, attention_dropout=0.0, hidden_dropout=0.0, use_flash_attention=True, use_moe=True, num_experts=24, num_experts_per_tok=6, sliding_window_size=4096, pad_token_id=0, bos_token_id=1, eos_token_id=2, **kwargs ): super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.num_key_value_heads = num_key_value_heads self.max_position_embeddings = max_position_embeddings self.rope_theta = rope_theta self.attention_dropout = attention_dropout self.hidden_dropout = hidden_dropout self.use_flash_attention = use_flash_attention self.use_moe = use_moe self.num_experts = num_experts self.num_experts_per_tok = num_experts_per_tok self.sliding_window_size = sliding_window_size class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.eps = eps def forward(self, x): variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return self.weight * x class RotaryEmbedding(nn.Module): def __init__(self, dim, max_position_embeddings=8192, base=1000000.0): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self._build_cache(max_position_embeddings) def _build_cache(self, seq_len): t = torch.arange(seq_len, device=self.inv_freq.device) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()) self.register_buffer("sin_cached", emb.sin()) def forward(self, x, seq_len=None): if seq_len > self.max_position_embeddings: self._build_cache(seq_len) return self.cos_cached[:seq_len], self.sin_cached[:seq_len] def rotate_half(x): x1, x2 = x.chunk(2, dim=-1) return torch.cat((-x2, x1), dim=-1) def apply_rotary_pos_emb(q, k, cos, sin): cos = cos.unsqueeze(0).unsqueeze(0) sin = sin.unsqueeze(0).unsqueeze(0) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class GroupedQueryAttention(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.num_kv_heads = config.num_key_value_heads self.head_dim = config.hidden_size // config.num_attention_heads self.num_groups = self.num_heads // self.num_kv_heads self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) self.dropout = nn.Dropout(config.attention_dropout) self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta) def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, output_attentions=False): batch_size, seq_len, _ = hidden_states.shape q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) cos, sin = self.rotary_emb(q, seq_len=seq_len) q, k = apply_rotary_pos_emb(q, k, cos, sin) k = k.repeat_interleave(self.num_groups, dim=1) v = v.repeat_interleave(self.num_groups, dim=1) attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim) if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, v) attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, seq_len, self.hidden_size) attn_output = self.o_proj(attn_output) return attn_output, attn_weights class MoE(nn.Module): def __init__(self, config): super().__init__() self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) self.experts = nn.ModuleList([nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size, bias=False), nn.GELU(), nn.Linear(config.intermediate_size, config.hidden_size, bias=False) ) for _ in range(config.num_experts)]) def forward(self, x): batch_size, seq_len, hidden_size = x.shape x_flat = x.view(-1, hidden_size) gate_logits = self.gate(x_flat) gate_weights = F.softmax(gate_logits, dim=-1) top_weights, top_indices = torch.topk(gate_weights, self.num_experts_per_tok, dim=-1) top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True) final_output = torch.zeros_like(x_flat) for i in range(self.num_experts): mask = (top_indices == i).any(dim=-1) if mask.any(): expert_output = self.experts[i](x_flat[mask]) weight_mask = (top_indices == i).float() weights = (top_weights * weight_mask).sum(dim=-1) final_output[mask] += expert_output * weights[mask].unsqueeze(-1) return final_output.view(batch_size, seq_len, hidden_size) class ZephyrCoderBlock(nn.Module): def __init__(self, config): super().__init__() self.self_attn = GroupedQueryAttention(config) self.input_layernorm = RMSNorm(config.hidden_size) self.mlp = MoE(config) if config.use_moe else nn.Sequential( nn.Linear(config.hidden_size, config.intermediate_size, bias=False), nn.GELU(), nn.Linear(config.intermediate_size, config.hidden_size, bias=False) ) self.post_attention_layernorm = RMSNorm(config.hidden_size) def forward(self, hidden_states, attention_mask=None, position_ids=None): residual = hidden_states hidden_states = self.input_layernorm(hidden_states) attn_output, _ = self.self_attn(hidden_states, attention_mask, position_ids) hidden_states = residual + attn_output residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states return hidden_states class ZephyrCoderModel(PreTrainedModel): config_class = ZephyrCoderConfig def __init__(self, config): super().__init__(config) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ZephyrCoderBlock(config) for _ in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size) def forward(self, input_ids=None, attention_mask=None, position_ids=None): hidden_states = self.embed_tokens(input_ids) if attention_mask is not None: attention_mask = attention_mask[:, None, None, :] attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min for layer in self.layers: hidden_states = layer(hidden_states, attention_mask, position_ids) hidden_states = self.norm(hidden_states) return hidden_states class ZephyrCoderForCausalLM(PreTrainedModel): config_class = ZephyrCoderConfig def __init__(self, config): super().__init__(config) self.model = ZephyrCoderModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) def forward(self, input_ids=None, attention_mask=None, labels=None): hidden_states = self.model(input_ids, attention_mask) logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) return loss, logits def generate(self, input_ids, max_length=2048, temperature=0.7, top_p=0.9): self.eval() with torch.no_grad(): for _ in range(max_length - input_ids.shape[1]): _, logits = self.forward(input_ids=input_ids) logits = logits[:, -1, :] / temperature sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) logits[indices_to_remove] = float('-inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) input_ids = torch.cat([input_ids, next_token], dim=-1) if next_token.item() == self.config.eos_token_id: break return input_ids def train_zephyr_coder(): config = ZephyrCoderConfig() model = ZephyrCoderForCausalLM(config) tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b") tokenizer.add_special_tokens({'pad_token': '[PAD]'}) dataset = load_dataset("bigcode/the-stack-dedup", data_dir="data/python", split="train", streaming=True) def tokenize_function(examples): return tokenizer(examples['content'], truncation=True, max_length=2048, padding=False) tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names) training_args = TrainingArguments( output_dir="./zephyr-coder-final", num_train_epochs=3, per_device_train_batch_size=2, gradient_accumulation_steps=16, learning_rate=3e-4, warmup_steps=2000, logging_steps=10, save_steps=1000, fp16=True, gradient_checkpointing=True, optim="adamw_8bit", ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_dataset, data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), ) trainer.train() trainer.save_model("./zephyr-coder-final") tokenizer.save_pretrained("./zephyr-coder-final") return model, tokenizer def upload_to_huggingface(model_dir="./zephyr-coder-final", repo_name="zephyr-coder-15b"): create_repo(repo_name, exist_ok=True) api = HfApi() api.upload_folder(folder_path=model_dir, repo_id=repo_name) print(f"Uploaded to https://huggingface.co/{repo_name}") def demo(): tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b") config = ZephyrCoderConfig() model = ZephyrCoderForCausalLM(config) prompts = [ "def quicksort(arr):", "class TransformerBlock:", "def train_neural_network():", "async def process_api_request():", "def optimize_python_code():", ] for prompt in prompts: inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate(inputs.input_ids, max_length=500, temperature=0.7, top_p=0.95) print(f"\nPrompt: {prompt}\nGenerated:\n{tokenizer.decode(outputs[0])}\n{'-'*80}") if __name__ == "__main__": model, tokenizer = train_zephyr_coder() upload_to_huggingface() demo()