Zephyr_Coder / main.py
tgregrg's picture
Create main.py
71289bc verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig, AutoTokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from huggingface_hub import HfApi, create_repo
import math
import os
class ZephyrCoderConfig(PretrainedConfig):
model_type = "zephyr_coder"
def __init__(
self,
vocab_size=128000,
hidden_size=2560,
intermediate_size=10240,
num_hidden_layers=36,
num_attention_heads=32,
num_key_value_heads=8,
max_position_embeddings=8192,
rope_theta=1000000.0,
attention_dropout=0.0,
hidden_dropout=0.0,
use_flash_attention=True,
use_moe=True,
num_experts=24,
num_experts_per_tok=6,
sliding_window_size=4096,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
**kwargs
):
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.attention_dropout = attention_dropout
self.hidden_dropout = hidden_dropout
self.use_flash_attention = use_flash_attention
self.use_moe = use_moe
self.num_experts = num_experts
self.num_experts_per_tok = num_experts_per_tok
self.sliding_window_size = sliding_window_size
class RMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x):
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=8192, base=1000000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self._build_cache(max_position_embeddings)
def _build_cache(self, seq_len):
t = torch.arange(seq_len, device=self.inv_freq.device)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos())
self.register_buffer("sin_cached", emb.sin())
def forward(self, x, seq_len=None):
if seq_len > self.max_position_embeddings:
self._build_cache(seq_len)
return self.cos_cached[:seq_len], self.sin_cached[:seq_len]
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class GroupedQueryAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
self.num_groups = self.num_heads // self.num_kv_heads
self.q_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
self.dropout = nn.Dropout(config.attention_dropout)
self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False, output_attentions=False):
batch_size, seq_len, _ = hidden_states.shape
q = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(q, seq_len=seq_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin)
k = k.repeat_interleave(self.num_groups, dim=1)
v = v.repeat_interleave(self.num_groups, dim=1)
attn_weights = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_weights = self.dropout(attn_weights)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous().reshape(batch_size, seq_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
class MoE(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.experts = nn.ModuleList([nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
nn.GELU(),
nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
) for _ in range(config.num_experts)])
def forward(self, x):
batch_size, seq_len, hidden_size = x.shape
x_flat = x.view(-1, hidden_size)
gate_logits = self.gate(x_flat)
gate_weights = F.softmax(gate_logits, dim=-1)
top_weights, top_indices = torch.topk(gate_weights, self.num_experts_per_tok, dim=-1)
top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)
final_output = torch.zeros_like(x_flat)
for i in range(self.num_experts):
mask = (top_indices == i).any(dim=-1)
if mask.any():
expert_output = self.experts[i](x_flat[mask])
weight_mask = (top_indices == i).float()
weights = (top_weights * weight_mask).sum(dim=-1)
final_output[mask] += expert_output * weights[mask].unsqueeze(-1)
return final_output.view(batch_size, seq_len, hidden_size)
class ZephyrCoderBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.self_attn = GroupedQueryAttention(config)
self.input_layernorm = RMSNorm(config.hidden_size)
self.mlp = MoE(config) if config.use_moe else nn.Sequential(
nn.Linear(config.hidden_size, config.intermediate_size, bias=False),
nn.GELU(),
nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
)
self.post_attention_layernorm = RMSNorm(config.hidden_size)
def forward(self, hidden_states, attention_mask=None, position_ids=None):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn_output, _ = self.self_attn(hidden_states, attention_mask, position_ids)
hidden_states = residual + attn_output
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class ZephyrCoderModel(PreTrainedModel):
config_class = ZephyrCoderConfig
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([ZephyrCoderBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size)
def forward(self, input_ids=None, attention_mask=None, position_ids=None):
hidden_states = self.embed_tokens(input_ids)
if attention_mask is not None:
attention_mask = attention_mask[:, None, None, :]
attention_mask = (1.0 - attention_mask) * torch.finfo(hidden_states.dtype).min
for layer in self.layers:
hidden_states = layer(hidden_states, attention_mask, position_ids)
hidden_states = self.norm(hidden_states)
return hidden_states
class ZephyrCoderForCausalLM(PreTrainedModel):
config_class = ZephyrCoderConfig
def __init__(self, config):
super().__init__(config)
self.model = ZephyrCoderModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def forward(self, input_ids=None, attention_mask=None, labels=None):
hidden_states = self.model(input_ids, attention_mask)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
return loss, logits
def generate(self, input_ids, max_length=2048, temperature=0.7, top_p=0.9):
self.eval()
with torch.no_grad():
for _ in range(max_length - input_ids.shape[1]):
_, logits = self.forward(input_ids=input_ids)
logits = logits[:, -1, :] / temperature
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = float('-inf')
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=-1)
if next_token.item() == self.config.eos_token_id:
break
return input_ids
def train_zephyr_coder():
config = ZephyrCoderConfig()
model = ZephyrCoderForCausalLM(config)
tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
dataset = load_dataset("bigcode/the-stack-dedup", data_dir="data/python", split="train", streaming=True)
def tokenize_function(examples):
return tokenizer(examples['content'], truncation=True, max_length=2048, padding=False)
tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset.column_names)
training_args = TrainingArguments(
output_dir="./zephyr-coder-final",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=16,
learning_rate=3e-4,
warmup_steps=2000,
logging_steps=10,
save_steps=1000,
fp16=True,
gradient_checkpointing=True,
optim="adamw_8bit",
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
trainer.save_model("./zephyr-coder-final")
tokenizer.save_pretrained("./zephyr-coder-final")
return model, tokenizer
def upload_to_huggingface(model_dir="./zephyr-coder-final", repo_name="zephyr-coder-15b"):
create_repo(repo_name, exist_ok=True)
api = HfApi()
api.upload_folder(folder_path=model_dir, repo_id=repo_name)
print(f"Uploaded to https://huggingface.co/{repo_name}")
def demo():
tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b")
config = ZephyrCoderConfig()
model = ZephyrCoderForCausalLM(config)
prompts = [
"def quicksort(arr):",
"class TransformerBlock:",
"def train_neural_network():",
"async def process_api_request():",
"def optimize_python_code():",
]
for prompt in prompts:
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_length=500, temperature=0.7, top_p=0.95)
print(f"\nPrompt: {prompt}\nGenerated:\n{tokenizer.decode(outputs[0])}\n{'-'*80}")
if __name__ == "__main__":
model, tokenizer = train_zephyr_coder()
upload_to_huggingface()
demo()