MiniMind / model /components.py
fariasultana's picture
MiniMind Max2 - Efficient MoE Language Model
8b187bb verified
"""
MiniMind Max2 Model Components
Core building blocks: RMSNorm, RoPE, GQA Attention, MoE
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from configs.model_config import Max2Config
class Max2RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (faster than LayerNorm)."""
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
input_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(-1, keepdim=True)
x = x * torch.rsqrt(variance + self.eps)
return self.weight * x.to(input_dtype)
class Max2RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE) for efficient position encoding."""
def __init__(self, dim: int, max_position_embeddings: int = 8192, base: float = 10000.0):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self._set_cos_sin_cache(max_position_embeddings)
def _set_cos_sin_cache(self, seq_len: int):
self.max_seq_len_cached = seq_len
t = torch.arange(seq_len, dtype=torch.float32)
freqs = torch.outer(t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos(), persistent=False)
self.register_buffer("sin_cached", emb.sin(), persistent=False)
def forward(self, x: torch.Tensor, seq_len: int) -> Tuple[torch.Tensor, torch.Tensor]:
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len)
return self.cos_cached[:seq_len].to(x.dtype), self.sin_cached[:seq_len].to(x.dtype)
def rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate half the hidden dims of the input."""
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary position embeddings to query and key tensors."""
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
class Max2Attention(nn.Module):
"""Grouped Query Attention (GQA) - fewer KV heads than Q heads for memory efficiency."""
def __init__(self, config: Max2Config, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_groups = self.num_heads // self.num_kv_heads
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.rotary_emb = Max2RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
self.attention_dropout = config.attention_dropout
def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
if n_rep == 1:
return hidden_states
bs, num_kv_heads, seq_len, head_dim = hidden_states.shape
hidden_states = hidden_states[:, :, None, :, :].expand(bs, num_kv_heads, n_rep, seq_len, head_dim)
return hidden_states.reshape(bs, num_kv_heads * n_rep, seq_len, head_dim)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
batch_size, seq_len, _ = hidden_states.shape
query_states = self.q_proj(hidden_states).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = self._repeat_kv(key_states, self.num_key_value_groups)
value_states = self._repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, past_key_value
class Max2MLP(nn.Module):
"""SwiGLU Feed-Forward Network."""
def __init__(self, hidden_size: int, intermediate_size: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
class Max2Expert(nn.Module):
"""Single expert in the Mixture of Experts layer."""
def __init__(self, hidden_size: int, expert_hidden_size: int):
super().__init__()
self.mlp = Max2MLP(hidden_size, expert_hidden_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.mlp(x)
class Max2MoE(nn.Module):
"""
Mixture of Experts (MoE) layer.
Efficient parameter activation - only top-k experts are used per token.
Inspired by MiniMax M2's efficient activated parameters design.
"""
def __init__(self, config: Max2Config):
super().__init__()
self.hidden_size = config.hidden_size
self.num_experts = config.num_experts
self.num_experts_per_tok = config.num_experts_per_tok
self.expert_hidden_size = config.expert_hidden_size
self.gate = nn.Linear(self.hidden_size, self.num_experts, bias=False)
self.experts = nn.ModuleList([
Max2Expert(self.hidden_size, self.expert_hidden_size)
for _ in range(self.num_experts)
])
self.router_aux_loss_coef = config.router_aux_loss_coef
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, seq_len, hidden_dim = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, hidden_dim)
router_logits = self.gate(hidden_states_flat)
router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32)
router_weights, selected_experts = torch.topk(router_probs, self.num_experts_per_tok, dim=-1)
router_weights = router_weights.to(hidden_states.dtype)
router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)
final_hidden_states = torch.zeros_like(hidden_states_flat)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
for expert_idx in range(self.num_experts):
expert = self.experts[expert_idx]
for top_k_idx in range(self.num_experts_per_tok):
token_indices = expert_mask[expert_idx, top_k_idx].nonzero(as_tuple=True)[0]
if token_indices.numel() > 0:
expert_input = hidden_states_flat[token_indices]
expert_output = expert(expert_input)
weights = router_weights[token_indices, top_k_idx].unsqueeze(-1)
final_hidden_states[token_indices] += weights * expert_output
final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
num_tokens = router_probs.shape[0]
expert_mask_float = F.one_hot(selected_experts, num_classes=self.num_experts).float()
tokens_per_expert = expert_mask_float.sum(dim=(0, 1)) / num_tokens
router_prob_per_expert = router_probs.mean(dim=0)
aux_loss = self.num_experts * (tokens_per_expert * router_prob_per_expert).sum() * self.router_aux_loss_coef
return final_hidden_states, aux_loss
class Max2DecoderLayer(nn.Module):
"""Single transformer decoder layer with GQA attention and MoE FFN."""
def __init__(self, config: Max2Config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = Max2Attention(config, layer_idx)
if config.use_moe:
self.mlp = Max2MoE(config)
self.use_moe = True
else:
self.mlp = Max2MLP(config.hidden_size, config.intermediate_size)
self.use_moe = False
self.input_layernorm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], torch.Tensor]:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states, present_key_value = self.self_attn(hidden_states, attention_mask, past_key_value, use_cache)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if self.use_moe:
hidden_states, aux_loss = self.mlp(hidden_states)
else:
hidden_states = self.mlp(hidden_states)
aux_loss = torch.tensor(0.0, device=hidden_states.device)
hidden_states = residual + hidden_states
return hidden_states, present_key_value, aux_loss
# Backward compatibility aliases
Mind2RMSNorm = Max2RMSNorm
Mind2RotaryEmbedding = Max2RotaryEmbedding
Mind2Attention = Max2Attention
Mind2MLP = Max2MLP
Mind2Expert = Max2Expert
Mind2MoE = Max2MoE
Mind2DecoderLayer = Max2DecoderLayer