Shivik-M4-Pretrained / modeling_shivik_m4.py
ziadrone's picture
Upload folder using huggingface_hub
73b6dff verified
"""
SHIVIK-M4 Model Architecture (SmolLM2-Compatible)
==================================================
Matched to SmolLM2-1.7B for weight loading:
- 24 layers, 2048 hidden, 32 heads (MHA - all heads are KV heads)
- Full RoPE, SwiGLU MLP, RMSNorm
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
class ShivikM4Config(PretrainedConfig):
model_type = "shivik_m4"
def __init__(
self,
vocab_size=49152,
hidden_size=2048,
intermediate_size=8192,
num_hidden_layers=24,
num_attention_heads=32,
num_key_value_heads=32, # MHA for SmolLM2 compatibility
head_dim=64,
rms_norm_eps=1e-5,
max_position_embeddings=4096,
rope_theta=100000.0,
tie_word_embeddings=True,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.head_dim = head_dim
self.rms_norm_eps = rms_norm_eps
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
class ShivikM4RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
dtype = x.dtype
x = x.float()
norm = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(norm + self.eps)
return (self.weight * x).to(dtype)
class ShivikM4RotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings, base=10000.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.max_seq_len_cached = max_position_embeddings
self._set_cos_sin_cache(max_position_embeddings)
def _set_cos_sin_cache(self, seq_len):
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat([freqs, freqs], dim=-1)
self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0), persistent=False)
self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0), persistent=False)
def forward(self, x, seq_len):
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
return (
self.cos_cached[:, :, :seq_len, :].to(x.dtype),
self.sin_cached[:, :, :seq_len, :].to(x.dtype),
)
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos.squeeze(0).squeeze(0)
sin = sin.squeeze(0).squeeze(0)
cos = cos[position_ids].unsqueeze(1)
sin = sin[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class ShivikM4Attention(nn.Module):
def __init__(self, config: ShivikM4Config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_kv_heads = config.num_key_value_heads
self.num_kv_groups = self.num_heads // self.num_kv_heads
self.scale = 1.0 / math.sqrt(self.head_dim)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = ShivikM4RotaryEmbedding(
self.head_dim, config.max_position_embeddings, config.rope_theta
)
def forward(
self,
hidden_states,
attention_mask=None,
position_ids=None,
past_key_value=None,
use_cache=False,
):
bsz, q_len, _ = hidden_states.size()
q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
past_kv_len = 0
if past_key_value is not None and past_key_value[0] is not None:
past_kv_len = past_key_value[0].shape[2]
cos, sin = self.rotary_emb(v, seq_len=past_kv_len + q_len)
q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids)
if past_key_value is not None and past_key_value[0] is not None:
k = torch.cat([past_key_value[0], k], dim=2)
v = torch.cat([past_key_value[1], v], dim=2)
present_kv = (k, v) if use_cache else None
# GQA expansion (for MHA, num_kv_groups=1, so this is a no-op)
if self.num_kv_groups > 1:
k_expanded = k.repeat_interleave(self.num_kv_groups, dim=1)
v_expanded = v.repeat_interleave(self.num_kv_groups, dim=1)
else:
k_expanded = k
v_expanded = v
attn_weights = torch.matmul(q, k_expanded.transpose(2, 3)) * self.scale
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype)
attn_output = torch.matmul(attn_weights, v_expanded)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
return self.o_proj(attn_output), present_kv
class ShivikM4MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, x):
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class ShivikM4DecoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
self.input_layernorm = ShivikM4RMSNorm(config.hidden_size, config.rms_norm_eps)
self.self_attn = ShivikM4Attention(config)
self.post_attention_layernorm = ShivikM4RMSNorm(config.hidden_size, config.rms_norm_eps)
self.mlp = ShivikM4MLP(config)
def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, present_kv = self.self_attn(
hidden_states, attention_mask, position_ids, past_key_value, use_cache
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states, present_kv
class ShivikM4Model(PreTrainedModel):
config_class = ShivikM4Config
def __init__(self, config):
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([ShivikM4DecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = ShivikM4RMSNorm(config.hidden_size, config.rms_norm_eps)
def _make_causal_mask(self, q_len, kv_len, dtype, device):
if q_len == kv_len:
mask = torch.full((q_len, kv_len), torch.finfo(dtype).min, dtype=dtype, device=device)
mask = torch.triu(mask, diagonal=1)
else:
mask = torch.zeros((q_len, kv_len), dtype=dtype, device=device)
return mask[None, None, :, :]
def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=None):
bsz, seq_len = input_ids.shape
past_len = 0
if past_key_values is not None and past_key_values[0] is not None and past_key_values[0][0] is not None:
past_len = past_key_values[0][0].shape[2]
if position_ids is None:
position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device).unsqueeze(0)
hidden_states = self.embed_tokens(input_ids)
kv_len = past_len + seq_len
causal_mask = self._make_causal_mask(seq_len, kv_len, hidden_states.dtype, hidden_states.device)
if attention_mask is not None:
padding_mask = (1.0 - attention_mask[:, None, None, :].to(hidden_states.dtype)) * torch.finfo(hidden_states.dtype).min
causal_mask = causal_mask + padding_mask
next_cache = () if use_cache else None
for i, layer in enumerate(self.layers):
past_kv = past_key_values[i] if past_key_values is not None else None
hidden_states, present_kv = layer(hidden_states, causal_mask, position_ids, past_kv, use_cache)
if use_cache:
next_cache += (present_kv,)
hidden_states = self.norm(hidden_states)
return hidden_states, next_cache
class ShivikM4ForCausalLM(PreTrainedModel, GenerationMixin):
config_class = ShivikM4Config
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = ShivikM4Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids,
attention_mask=None,
position_ids=None,
past_key_values=None,
use_cache=None,
labels=None,
**kwargs,
):
outputs = self.model(input_ids, attention_mask, position_ids, past_key_values, use_cache)
hidden_states, past_key_values = outputs
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
)
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=past_key_values,
)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
past_len = 0
if past_key_values is not None and past_key_values[0] is not None and past_key_values[0][0] is not None:
past_len = past_key_values[0][0].shape[2]
input_ids = input_ids[:, -1:]
position_ids = torch.arange(
past_len, past_len + input_ids.shape[1],
dtype=torch.long, device=input_ids.device
).unsqueeze(0)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache", True),
"position_ids": position_ids,
"attention_mask": attention_mask,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered = ()
for layer_past in past_key_values:
reordered += (
tuple(state.index_select(0, beam_idx) for state in layer_past),
)
return reordered