Spaces:
Runtime error
Runtime error
File size: 7,603 Bytes
bd21ba5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
"""
MiniMind Max2 Main Model
Complete implementation of the Max2 language model.
"""
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from configs.model_config import Max2Config, get_config
from .components import Max2DecoderLayer, Max2RMSNorm
class Max2Model(nn.Module):
"""Max2 Transformer Model - outputs raw hidden states."""
def __init__(self, config: Max2Config):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=self.padding_idx)
self.layers = nn.ModuleList([Max2DecoderLayer(config, i) for i in range(config.num_hidden_layers)])
self.norm = Max2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self._init_weights()
def _init_weights(self):
for module in self.modules():
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
def _make_causal_mask(self, seq_len: int, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
mask = torch.triu(mask, diagonal=1)
return mask.unsqueeze(0).unsqueeze(0)
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]], torch.Tensor]:
batch_size, seq_len = input_ids.shape
hidden_states = self.embed_tokens(input_ids)
causal_mask = self._make_causal_mask(seq_len, hidden_states.dtype, hidden_states.device)
if attention_mask is not None:
padding_mask = (1.0 - attention_mask[:, None, None, :].to(hidden_states.dtype)) * float("-inf")
causal_mask = causal_mask + padding_mask
next_cache = [] if use_cache else None
total_aux_loss = torch.tensor(0.0, device=hidden_states.device)
for idx, layer in enumerate(self.layers):
past_kv = past_key_values[idx] if past_key_values else None
hidden_states, present_kv, aux_loss = layer(hidden_states, causal_mask, past_kv, use_cache)
if use_cache:
next_cache.append(present_kv)
total_aux_loss = total_aux_loss + aux_loss
hidden_states = self.norm(hidden_states)
return hidden_states, next_cache, total_aux_loss
class Max2ForCausalLM(nn.Module):
"""Max2 Model with Language Modeling head for text generation."""
def __init__(self, config: Max2Config):
super().__init__()
self.config = config
self.model = Max2Model(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.lm_head.weight = self.model.embed_tokens.weight
def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
use_cache: bool = False,
) -> Tuple[Optional[torch.Tensor], torch.Tensor, Optional[List], torch.Tensor]:
hidden_states, next_cache, aux_loss = self.model(input_ids, attention_mask, past_key_values, use_cache)
logits = self.lm_head(hidden_states).float()
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = CrossEntropyLoss()(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
loss = loss + aux_loss
return loss, logits, next_cache, aux_loss
@torch.no_grad()
def generate(
self,
input_ids: torch.LongTensor,
max_new_tokens: int = 100,
temperature: float = 1.0,
top_k: int = 50,
top_p: float = 0.95,
do_sample: bool = True,
) -> torch.LongTensor:
"""Simple generation with top-k/top-p sampling."""
generated = input_ids
past_key_values = None
for _ in range(max_new_tokens):
if past_key_values is None:
_, logits, past_key_values, _ = self(generated, use_cache=True)
else:
_, logits, past_key_values, _ = self(generated[:, -1:], past_key_values=past_key_values, use_cache=True)
next_token_logits = logits[:, -1, :] / temperature
if do_sample:
if top_k > 0:
indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
next_token_logits[indices_to_remove] = float('-inf')
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
else:
next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
generated = torch.cat([generated, next_token], dim=1)
if (next_token == self.config.eos_token_id).all():
break
return generated
# Backward compatibility aliases
Mind2Model = Max2Model
Mind2ForCausalLM = Max2ForCausalLM
def create_model(model_name: str = "max2-lite", device: str = "cuda", dtype: torch.dtype = torch.float16) -> Max2ForCausalLM:
"""Factory function to create a Max2 model."""
config = get_config(model_name)
model = Max2ForCausalLM(config)
return model.to(device=device, dtype=dtype) if torch.cuda.is_available() else model
if __name__ == "__main__":
for model_name in ["max2-nano", "max2-lite", "max2-pro"]:
print(f"\n{'='*50}\nTesting {model_name}\n{'='*50}")
config = get_config(model_name)
model = Max2ForCausalLM(config)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params / 1e9:.3f}B")
input_ids = torch.randint(0, config.vocab_size, (2, 128))
model.eval()
with torch.no_grad():
loss, logits, _, aux_loss = model(input_ids, labels=input_ids)
print(f"Logits shape: {logits.shape}")
print(f"Loss: {loss:.4f}, Aux loss: {aux_loss:.6f}")
print("Forward pass successful!")
|