| """ |
| SHIVIK-M4 Model Architecture (SmolLM2-Compatible) |
| ================================================== |
| Matched to SmolLM2-1.7B for weight loading: |
| - 24 layers, 2048 hidden, 32 heads (MHA - all heads are KV heads) |
| - Full RoPE, SwiGLU MLP, RMSNorm |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.generation import GenerationMixin |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
| class ShivikM4Config(PretrainedConfig): |
| model_type = "shivik_m4" |
|
|
| def __init__( |
| self, |
| vocab_size=49152, |
| hidden_size=2048, |
| intermediate_size=8192, |
| num_hidden_layers=24, |
| num_attention_heads=32, |
| num_key_value_heads=32, |
| head_dim=64, |
| rms_norm_eps=1e-5, |
| max_position_embeddings=4096, |
| rope_theta=100000.0, |
| tie_word_embeddings=True, |
| **kwargs, |
| ): |
| self.vocab_size = vocab_size |
| self.hidden_size = hidden_size |
| self.intermediate_size = intermediate_size |
| self.num_hidden_layers = num_hidden_layers |
| self.num_attention_heads = num_attention_heads |
| self.num_key_value_heads = num_key_value_heads |
| self.head_dim = head_dim |
| self.rms_norm_eps = rms_norm_eps |
| self.max_position_embeddings = max_position_embeddings |
| self.rope_theta = rope_theta |
| super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) |
|
|
|
|
| class ShivikM4RMSNorm(nn.Module): |
| def __init__(self, dim, eps=1e-5): |
| super().__init__() |
| self.eps = eps |
| self.weight = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x): |
| dtype = x.dtype |
| x = x.float() |
| norm = x.pow(2).mean(-1, keepdim=True) |
| x = x * torch.rsqrt(norm + self.eps) |
| return (self.weight * x).to(dtype) |
|
|
|
|
| class ShivikM4RotaryEmbedding(nn.Module): |
| def __init__(self, dim, max_position_embeddings, base=10000.0): |
| super().__init__() |
| inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.max_seq_len_cached = max_position_embeddings |
| self._set_cos_sin_cache(max_position_embeddings) |
|
|
| def _set_cos_sin_cache(self, seq_len): |
| self.max_seq_len_cached = seq_len |
| t = torch.arange(seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) |
| freqs = torch.outer(t, self.inv_freq) |
| emb = torch.cat([freqs, freqs], dim=-1) |
| self.register_buffer("cos_cached", emb.cos().unsqueeze(0).unsqueeze(0), persistent=False) |
| self.register_buffer("sin_cached", emb.sin().unsqueeze(0).unsqueeze(0), persistent=False) |
|
|
| def forward(self, x, seq_len): |
| if seq_len > self.max_seq_len_cached: |
| self._set_cos_sin_cache(seq_len) |
| return ( |
| self.cos_cached[:, :, :seq_len, :].to(x.dtype), |
| self.sin_cached[:, :, :seq_len, :].to(x.dtype), |
| ) |
|
|
|
|
| def rotate_half(x): |
| x1, x2 = x.chunk(2, dim=-1) |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids): |
| cos = cos.squeeze(0).squeeze(0) |
| sin = sin.squeeze(0).squeeze(0) |
| cos = cos[position_ids].unsqueeze(1) |
| sin = sin[position_ids].unsqueeze(1) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| class ShivikM4Attention(nn.Module): |
| def __init__(self, config: ShivikM4Config): |
| super().__init__() |
| self.hidden_size = config.hidden_size |
| self.num_heads = config.num_attention_heads |
| self.head_dim = config.head_dim |
| self.num_kv_heads = config.num_key_value_heads |
| self.num_kv_groups = self.num_heads // self.num_kv_heads |
| self.scale = 1.0 / math.sqrt(self.head_dim) |
|
|
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False) |
| self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
| self.rotary_emb = ShivikM4RotaryEmbedding( |
| self.head_dim, config.max_position_embeddings, config.rope_theta |
| ) |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_ids=None, |
| past_key_value=None, |
| use_cache=False, |
| ): |
| bsz, q_len, _ = hidden_states.size() |
|
|
| q = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| k = self.k_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
| v = self.v_proj(hidden_states).view(bsz, q_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
| past_kv_len = 0 |
| if past_key_value is not None and past_key_value[0] is not None: |
| past_kv_len = past_key_value[0].shape[2] |
|
|
| cos, sin = self.rotary_emb(v, seq_len=past_kv_len + q_len) |
| q, k = apply_rotary_pos_emb(q, k, cos, sin, position_ids) |
|
|
| if past_key_value is not None and past_key_value[0] is not None: |
| k = torch.cat([past_key_value[0], k], dim=2) |
| v = torch.cat([past_key_value[1], v], dim=2) |
|
|
| present_kv = (k, v) if use_cache else None |
|
|
| |
| if self.num_kv_groups > 1: |
| k_expanded = k.repeat_interleave(self.num_kv_groups, dim=1) |
| v_expanded = v.repeat_interleave(self.num_kv_groups, dim=1) |
| else: |
| k_expanded = k |
| v_expanded = v |
|
|
| attn_weights = torch.matmul(q, k_expanded.transpose(2, 3)) * self.scale |
|
|
| if attention_mask is not None: |
| attn_weights = attn_weights + attention_mask |
|
|
| attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) |
| attn_output = torch.matmul(attn_weights, v_expanded) |
|
|
| attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size) |
| return self.o_proj(attn_output), present_kv |
|
|
|
|
| class ShivikM4MLP(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False) |
| self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) |
|
|
| def forward(self, x): |
| return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
| class ShivikM4DecoderLayer(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.input_layernorm = ShivikM4RMSNorm(config.hidden_size, config.rms_norm_eps) |
| self.self_attn = ShivikM4Attention(config) |
| self.post_attention_layernorm = ShivikM4RMSNorm(config.hidden_size, config.rms_norm_eps) |
| self.mlp = ShivikM4MLP(config) |
|
|
| def forward(self, hidden_states, attention_mask=None, position_ids=None, past_key_value=None, use_cache=False): |
| residual = hidden_states |
| hidden_states = self.input_layernorm(hidden_states) |
| hidden_states, present_kv = self.self_attn( |
| hidden_states, attention_mask, position_ids, past_key_value, use_cache |
| ) |
| hidden_states = residual + hidden_states |
|
|
| residual = hidden_states |
| hidden_states = self.post_attention_layernorm(hidden_states) |
| hidden_states = self.mlp(hidden_states) |
| hidden_states = residual + hidden_states |
|
|
| return hidden_states, present_kv |
|
|
|
|
| class ShivikM4Model(PreTrainedModel): |
| config_class = ShivikM4Config |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.layers = nn.ModuleList([ShivikM4DecoderLayer(config) for _ in range(config.num_hidden_layers)]) |
| self.norm = ShivikM4RMSNorm(config.hidden_size, config.rms_norm_eps) |
|
|
| def _make_causal_mask(self, q_len, kv_len, dtype, device): |
| if q_len == kv_len: |
| mask = torch.full((q_len, kv_len), torch.finfo(dtype).min, dtype=dtype, device=device) |
| mask = torch.triu(mask, diagonal=1) |
| else: |
| mask = torch.zeros((q_len, kv_len), dtype=dtype, device=device) |
| return mask[None, None, :, :] |
|
|
| def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=None): |
| bsz, seq_len = input_ids.shape |
|
|
| past_len = 0 |
| if past_key_values is not None and past_key_values[0] is not None and past_key_values[0][0] is not None: |
| past_len = past_key_values[0][0].shape[2] |
|
|
| if position_ids is None: |
| position_ids = torch.arange(past_len, past_len + seq_len, device=input_ids.device).unsqueeze(0) |
|
|
| hidden_states = self.embed_tokens(input_ids) |
|
|
| kv_len = past_len + seq_len |
| causal_mask = self._make_causal_mask(seq_len, kv_len, hidden_states.dtype, hidden_states.device) |
|
|
| if attention_mask is not None: |
| padding_mask = (1.0 - attention_mask[:, None, None, :].to(hidden_states.dtype)) * torch.finfo(hidden_states.dtype).min |
| causal_mask = causal_mask + padding_mask |
|
|
| next_cache = () if use_cache else None |
| for i, layer in enumerate(self.layers): |
| past_kv = past_key_values[i] if past_key_values is not None else None |
| hidden_states, present_kv = layer(hidden_states, causal_mask, position_ids, past_kv, use_cache) |
| if use_cache: |
| next_cache += (present_kv,) |
|
|
| hidden_states = self.norm(hidden_states) |
| return hidden_states, next_cache |
|
|
|
|
| class ShivikM4ForCausalLM(PreTrainedModel, GenerationMixin): |
| config_class = ShivikM4Config |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = ShivikM4Model(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| if config.tie_word_embeddings: |
| self.lm_head.weight = self.model.embed_tokens.weight |
|
|
| def get_input_embeddings(self): |
| return self.model.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.model.embed_tokens = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids, |
| attention_mask=None, |
| position_ids=None, |
| past_key_values=None, |
| use_cache=None, |
| labels=None, |
| **kwargs, |
| ): |
| outputs = self.model(input_ids, attention_mask, position_ids, past_key_values, use_cache) |
| hidden_states, past_key_values = outputs |
|
|
| logits = self.lm_head(hidden_states) |
|
|
| loss = None |
| if labels is not None: |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| loss = F.cross_entropy( |
| shift_logits.view(-1, self.config.vocab_size), |
| shift_labels.view(-1), |
| ) |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=past_key_values, |
| ) |
|
|
| def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs): |
| past_len = 0 |
| if past_key_values is not None and past_key_values[0] is not None and past_key_values[0][0] is not None: |
| past_len = past_key_values[0][0].shape[2] |
| input_ids = input_ids[:, -1:] |
|
|
| position_ids = torch.arange( |
| past_len, past_len + input_ids.shape[1], |
| dtype=torch.long, device=input_ids.device |
| ).unsqueeze(0) |
|
|
| return { |
| "input_ids": input_ids, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache", True), |
| "position_ids": position_ids, |
| "attention_mask": attention_mask, |
| } |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| reordered = () |
| for layer_past in past_key_values: |
| reordered += ( |
| tuple(state.index_select(0, beam_idx) for state in layer_past), |
| ) |
| return reordered |
|
|