""" Sheikh-2.5-Coder Model Implementation ==================================== This module implements the Sheikh-2.5-Coder model architecture, a 3B parameter transformer model optimized for code generation and on-device deployment. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List from dataclasses import dataclass from transformers import ( PreTrainedModel, PreTrainedTokenizer, AutoConfig, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments ) import json @dataclass class SheikhConfig: """Configuration class for Sheikh-2.5-Coder model.""" # Model architecture num_attention_heads: int = 16 num_key_value_heads: int = 2 hidden_size: int = 3072 intermediate_size: int = 8192 num_hidden_layers: int = 36 vocab_size: int = 50257 # Position embeddings max_position_embeddings: int = 32768 rope_theta: float = 10000.0 # Attention attention_dropout: float = 0.1 hidden_dropout: float = 0.1 # Normalization layer_norm_epsilon: float = 1e-6 rms_norm_eps: float = 1e-6 # Activation activation_function: str = "swiglu" # Precision torch_dtype: str = "bfloat16" # Cache use_cache: bool = True # Tie word embeddings tie_word_embeddings: bool = True class SheikhRMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, hidden_size: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(hidden_size)) def forward(self, x: torch.Tensor) -> torch.Tensor: input_dtype = x.dtype x = x.float() variance = x.pow(2).mean(-1, keepdim=True) x = x * torch.rsqrt(variance + self.eps) return (self.weight * x).to(input_dtype) class SheikhRotaryEmbedding(nn.Module): """Rotary Positional Embedding.""" def __init__(self, dim: int, max_position_embeddings: int = 32768, base: int = 10000): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) self._set_cos_sin_cache( seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.float32 ) def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype): self.max_seq_len_cached = seq_len t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype) freqs = torch.outer(t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) def forward(self, x: torch.Tensor, seq_len: Optional[int] = None): if seq_len > self.max_seq_len_cached: self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) return ( self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype), self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) class SheikhAttention(nn.Module): """Multi-head attention with Grouped Query Attention.""" def __init__(self, config: SheikhConfig): super().__init__() self.config = config self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.hidden_size // self.num_heads self.num_key_value_heads = config.num_key_value_heads self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=True) self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.rotary_emb = SheikhRotaryEmbedding( self.head_dim, max_position_embeddings=config.max_position_embeddings ) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ): bsz, q_len, _ = hidden_states.size() # Query, Key, Value projections q = self.q_proj(hidden_states) k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) # Reshape for grouped query attention q = q.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) k = k.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) v = v.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) # Apply rotary embeddings cos, sin = self.rotary_emb(v, seq_len=q_len) q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) # Group key and value for grouped query attention k = repeat_kv(k, self.num_key_value_groups) v = repeat_kv(v, self.num_key_value_groups) # Scaled dot-product attention attn_output = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=0.0, is_causal=True ) # Reshape and project output attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.o_proj(attn_output) if not output_attentions: attn_weights = None outputs = (attn_output,) if output_attentions: outputs += (attn_weights,) return outputs def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """Repeat key/value states for grouped query attention.""" batch, slen, num_key_value_heads, head_dim = hidden_states.shape if n_rep == 1: return hidden_states hidden_states = hidden_states[:, :, :, None, :].repeat(1, 1, 1, n_rep, 1) return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim) def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, position_ids: torch.Tensor): """Apply rotary positional embeddings.""" def rotate_half(x): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) cos = cos.squeeze(1).squeeze(0) sin = sin.squeeze(1).squeeze(0) cos = cos[position_ids].unsqueeze(1) sin = sin[position_ids].unsqueeze(1) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed class SheikhMLP(nn.Module): """SwiGLU MLP.""" def __init__(self, config: SheikhConfig): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) class SheikhTransformerBlock(nn.Module): """Transformer block for Sheikh-2.5-Coder.""" def __init__(self, config: SheikhConfig): super().__init__() self.self_attn = SheikhAttention(config) self.mlp = SheikhMLP(config) self.input_layernorm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ): # Self-attention attn_output, _ = self.self_attn( self.input_layernorm(hidden_states), attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, ) hidden_states = hidden_states + attn_output # MLP mlp_output = self.mlp(self.post_attention_layernorm(hidden_states)) hidden_states = hidden_states + mlp_output return hidden_states class SheikhModel(PreTrainedModel): """Sheikh-2.5-Coder base model.""" def __init__(self, config: SheikhConfig): super().__init__(config) self.config = config self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([SheikhTransformerBlock(config) for _ in range(config.num_hidden_layers)]) self.norm = SheikhRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Initialize weights self.apply(self._init_weights) def _init_weights(self, module): """Initialize model weights.""" if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def forward( self, input_ids: torch.Tensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[List[torch.Tensor]] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): # Implementation continues... pass # Model loading utilities def load_sheikh_model( model_name_or_path: str, device_map: Optional[str] = "auto", torch_dtype: torch.dtype = torch.bfloat16, load_in_8bit: bool = False, load_in_4bit: bool = False, ) -> AutoModelForCausalLM: """Load Sheikh-2.5-Coder model with optional quantization.""" # Setup quantization config quantization_config = None if load_in_8bit: quantization_config = BitsAndBytesConfig(load_in_8bit=True) elif load_in_4bit: quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", ) # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) model = AutoModelForCausalLM.from_pretrained( model_name_or_path, device_map=device_map, torch_dtype=torch_dtype, quantization_config=quantization_config, ) return model, tokenizer # Model training utilities def setup_training_args(output_dir: str, learning_rate: float = 1e-4) -> TrainingArguments: """Setup training arguments for Sheikh-2.5-Coder.""" return TrainingArguments( output_dir=output_dir, learning_rate=learning_rate, per_device_train_batch_size=8, per_device_eval_batch_size=8, num_train_epochs=3, max_steps=100000, logging_steps=100, save_steps=2000, eval_steps=1000, warmup_steps=2000, fp16=True, bf16=True, gradient_accumulation_steps=4, gradient_checkpointing=True, remove_unused_columns=False, dataloader_pin_memory=True, report_to="wandb", run_name="sheikh-2.5-coder", ) if __name__ == "__main__": # Example usage config = SheikhConfig() model = SheikhModel(config) # Save configuration with open("config.json", "w") as f: json.dump(config.__dict__, f, indent=2) print("Sheikh-2.5-Coder model configuration created successfully!") print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")