Sheikh-2.5-Coder / model.py
likhonsheikh's picture
Add model.py
eb82cd0 verified
"""
Sheikh-2.5-Coder Model Implementation
====================================
This module implements the Sheikh-2.5-Coder model architecture, a 3B parameter
transformer model optimized for code generation and on-device deployment.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
from dataclasses import dataclass
from transformers import (
PreTrainedModel,
PreTrainedTokenizer,
AutoConfig,
AutoTokenizer,
AutoModelForCausalLM,
BitsAndBytesConfig,
TrainingArguments
)
import json
@dataclass
class SheikhConfig:
"""Configuration class for Sheikh-2.5-Coder model."""
# Model architecture
num_attention_heads: int = 16
num_key_value_heads: int = 2
hidden_size: int = 3072
intermediate_size: int = 8192
num_hidden_layers: int = 36
vocab_size: int = 50257
# Position embeddings
max_position_embeddings: int = 32768
rope_theta: float = 10000.0
# Attention
attention_dropout: float = 0.1
hidden_dropout: float = 0.1
# Normalization
layer_norm_epsilon: float = 1e-6
rms_norm_eps: float = 1e-6
# Activation
activation_function: str = "swiglu"
# Precision
torch_dtype: str = "bfloat16"
# Cache
use_cache: bool = True
# Tie word embeddings
tie_word_embeddings: bool = True
class SheikhRMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(hidden_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return (self.weight * x).to(input_dtype)
class SheikhRotaryEmbedding(nn.Module):
"""Rotary Positional Embedding."""
def __init__(self, dim: int, max_position_embeddings: int = 32768, base: int = 10000):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32
)
def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
def forward(self, x: torch.Tensor, seq_len: Optional[int] = None):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)
class SheikhAttention(nn.Module):
"""Multi-head attention with Grouped Query Attention."""
def __init__(self, config: SheikhConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.rotary_emb = SheikhRotaryEmbedding(
self.head_dim, max_position_embeddings=config.max_position_embeddings
)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
):
bsz, q_len, _ = hidden_states.size()
# Query, Key, Value projections
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
# Reshape for grouped query attention
q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
cos, sin = self.rotary_emb(v, seq_len=q_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
# Group key and value for grouped query attention
k = repeat_kv(k, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
# Scaled dot-product attention
attn_output = F.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=True
)
# Reshape and project output
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
outputs = (attn_output,)
if output_attentions:
outputs += (attn_weights,)
return outputs
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Repeat key/value states for grouped query attention."""
batch, slen, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :, None, :].repeat(1, 1, 1, n_rep, 1)
return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor):
"""Apply rotary positional embeddings."""
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
cos = cos.squeeze(1).squeeze(0)
sin = sin.squeeze(1).squeeze(0)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class SheikhMLP(nn.Module):
"""SwiGLU MLP."""
def __init__(self, config: SheikhConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class SheikhTransformerBlock(nn.Module):
"""Transformer block for Sheikh-2.5-Coder."""
def __init__(self, config: SheikhConfig):
super().__init__()
self.self_attn = SheikhAttention(config)
self.mlp = SheikhMLP(config)
self.input_layernorm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
):
# Self-attention
attn_output, _ = self.self_attn(
self.input_layernorm(hidden_states),
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
)
hidden_states = hidden_states + attn_output
# MLP
mlp_output = self.mlp(self.post_attention_layernorm(hidden_states))
hidden_states = hidden_states + mlp_output
return hidden_states
class SheikhModel(PreTrainedModel):
"""Sheikh-2.5-Coder base model."""
def __init__(self, config: SheikhConfig):
super().__init__(config)
self.config = config
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([SheikhTransformerBlock(config) for _ in range(config.num_hidden_layers)])
self.norm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
"""Initialize model weights."""
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.Tensor]] = None,
inputs_embeds: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
# Implementation continues...
pass
# Model loading utilities
def load_sheikh_model(
model_name_or_path: str,
device_map: Optional[str] = "auto",
torch_dtype: torch.dtype = torch.bfloat16,
load_in_8bit: bool = False,
load_in_4bit: bool = False,
) -> AutoModelForCausalLM:
"""Load Sheikh-2.5-Coder model with optional quantization."""
# Setup quantization config
quantization_config = None
if load_in_8bit:
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
elif load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map=device_map,
torch_dtype=torch_dtype,
quantization_config=quantization_config,
)
return model, tokenizer
# Model training utilities
def setup_training_args(output_dir: str, learning_rate: float = 1e-4) -> TrainingArguments:
"""Setup training arguments for Sheikh-2.5-Coder."""
return TrainingArguments(
output_dir=output_dir,
learning_rate=learning_rate,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=3,
max_steps=100000,
logging_steps=100,
save_steps=2000,
eval_steps=1000,
warmup_steps=2000,
fp16=True,
bf16=True,
gradient_accumulation_steps=4,
gradient_checkpointing=True,
remove_unused_columns=False,
dataloader_pin_memory=True,
report_to="wandb",
run_name="sheikh-2.5-coder",
)
if __name__ == "__main__":
# Example usage
config = SheikhConfig()
model = SheikhModel(config)
# Save configuration
with open("config.json", "w") as f:
json.dump(config.__dict__, f, indent=2)
print("Sheikh-2.5-Coder model configuration created successfully!")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")