File size: 7,631 Bytes
4336553 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
"""MiniMind Max2 Model for Transformers"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List, Union
from transformers import PreTrainedModel
from transformers.modeling_outputs import CausalLMOutputWithPast
from .configuration_minimind import MiniMindConfig
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
class RotaryEmbedding(nn.Module):
def __init__(self, dim, max_pos=32768, base=10000.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, pos_ids):
freqs = torch.outer(pos_ids.float(), self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def rotate_half(x):
x1, x2 = x.chunk(2, dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rope(q, k, cos, sin):
cos, sin = cos.unsqueeze(0).unsqueeze(0), sin.unsqueeze(0).unsqueeze(0)
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
class Attention(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // self.num_heads
self.kv_groups = self.num_heads // self.num_kv_heads
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
self.rotary = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
def forward(self, x, mask=None, pos_ids=None, past_kv=None, use_cache=False):
B, L, _ = x.shape
q = self.q_proj(x).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
if pos_ids is None: pos_ids = torch.arange(L, device=x.device)
cos, sin = self.rotary(v, pos_ids)
q, k = apply_rope(q, k, cos, sin)
if past_kv: k, v = torch.cat([past_kv[0], k], 2), torch.cat([past_kv[1], v], 2)
new_kv = (k, v) if use_cache else None
k = k.repeat_interleave(self.kv_groups, 1)
v = v.repeat_interleave(self.kv_groups, 1)
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None: attn = attn + mask
attn = F.softmax(attn, dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, L, -1)
return self.o_proj(out), new_kv
class Expert(nn.Module):
def __init__(self, config):
super().__init__()
self.gate = nn.Linear(config.hidden_size, config.intermediate_size // config.num_experts, bias=False)
self.up = nn.Linear(config.hidden_size, config.intermediate_size // config.num_experts, bias=False)
self.down = nn.Linear(config.intermediate_size // config.num_experts, config.hidden_size, bias=False)
def forward(self, x):
return self.down(F.silu(self.gate(x)) * self.up(x))
class MoE(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_token
self.router = nn.Linear(config.hidden_size, self.num_experts, bias=False)
self.experts = nn.ModuleList([Expert(config) for _ in range(self.num_experts)])
def forward(self, x):
B, L, D = x.shape
x_flat = x.view(-1, D)
logits = self.router(x_flat)
weights = F.softmax(logits, dim=-1)
top_w, top_i = torch.topk(weights, self.top_k, dim=-1)
top_w = top_w / top_w.sum(-1, keepdim=True)
out = torch.zeros_like(x_flat)
for i, exp in enumerate(self.experts):
mask = (top_i == i).any(-1)
if mask.any():
w = (top_w * (top_i == i).float()).sum(-1, keepdim=True)[mask]
out[mask] += w * exp(x_flat[mask])
return out.view(B, L, D), torch.tensor(0.0, device=x.device)
class DecoderLayer(nn.Module):
def __init__(self, config, idx):
super().__init__()
self.attn = Attention(config, idx)
self.moe = MoE(config)
self.norm1 = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.norm2 = RMSNorm(config.hidden_size, config.rms_norm_eps)
def forward(self, x, mask=None, pos_ids=None, past_kv=None, use_cache=False):
h, kv = self.attn(self.norm1(x), mask, pos_ids, past_kv, use_cache)
x = x + h
m, aux = self.moe(self.norm2(x))
return x + m, kv, aux
class MiniMindPreTrainedModel(PreTrainedModel):
config_class = MiniMindConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
class MiniMindModel(MiniMindPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
self.post_init()
def forward(self, input_ids, attention_mask=None, position_ids=None, past_key_values=None, use_cache=False, **kwargs):
B, L = input_ids.shape
h = self.embed(input_ids)
mask = torch.triu(torch.full((L, L), float("-inf"), device=h.device), 1).unsqueeze(0).unsqueeze(0)
cache = [] if use_cache else None
aux = 0.0
for i, layer in enumerate(self.layers):
pkv = past_key_values[i] if past_key_values else None
h, kv, a = layer(h, mask, position_ids, pkv, use_cache)
if use_cache: cache.append(kv)
aux += a
return self.norm(h), cache, aux
class MiniMindForCausalLM(MiniMindPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = MiniMindModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self): return self.model.embed
def get_output_embeddings(self): return self.lm_head
def forward(self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None,
labels=None, use_cache=None, return_dict=True, **kwargs):
h, cache, aux = self.model(input_ids, attention_mask, position_ids, past_key_values, use_cache or False)
logits = self.lm_head(h)
loss = None
if labels is not None:
loss = F.cross_entropy(logits[..., :-1, :].reshape(-1, logits.size(-1)), labels[..., 1:].reshape(-1))
return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=cache)
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
if past_key_values: input_ids = input_ids[:, -1:]
return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": True}
|