|
|
|
|
|
""" |
|
|
MiniMind Max2 Quick Start Example |
|
|
Demonstrates basic usage of the Max2 model. |
|
|
""" |
|
|
|
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
import torch |
|
|
|
|
|
|
|
|
def main(): |
|
|
print("=" * 60) |
|
|
print("MiniMind Max2 Quick Start") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
from configs.model_config import get_config, estimate_params |
|
|
from model import Max2ForCausalLM |
|
|
|
|
|
|
|
|
model_name = "max2-nano" |
|
|
print(f"\n1. Creating {model_name} model...") |
|
|
|
|
|
config = get_config(model_name) |
|
|
model = Max2ForCausalLM(config) |
|
|
|
|
|
|
|
|
params = estimate_params(config) |
|
|
print(f" Total parameters: {params['total_params_b']:.3f}B") |
|
|
print(f" Active parameters: {params['active_params_b']:.3f}B") |
|
|
print(f" Activation ratio: {params['activation_ratio']:.1%}") |
|
|
print(f" Estimated size (INT4): {params['estimated_size_int4_gb']:.2f}GB") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
dtype = torch.float16 if device == "cuda" else torch.float32 |
|
|
model = model.to(device=device, dtype=dtype) |
|
|
print(f"\n2. Model loaded on {device} with {dtype}") |
|
|
|
|
|
|
|
|
print("\n3. Testing forward pass...") |
|
|
batch_size, seq_len = 2, 64 |
|
|
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device) |
|
|
|
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
loss, logits, _, aux_loss = model(input_ids, labels=input_ids) |
|
|
|
|
|
print(f" Input shape: {input_ids.shape}") |
|
|
print(f" Output logits shape: {logits.shape}") |
|
|
print(f" Loss: {loss:.4f}") |
|
|
print(f" MoE auxiliary loss: {aux_loss:.6f}") |
|
|
|
|
|
|
|
|
print("\n4. Testing generation...") |
|
|
prompt = torch.randint(0, config.vocab_size, (1, 10), device=device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
generated = model.generate( |
|
|
prompt, |
|
|
max_new_tokens=20, |
|
|
temperature=0.8, |
|
|
top_k=50, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
) |
|
|
|
|
|
print(f" Prompt length: {prompt.shape[1]}") |
|
|
print(f" Generated length: {generated.shape[1]}") |
|
|
print(f" New tokens: {generated.shape[1] - prompt.shape[1]}") |
|
|
|
|
|
|
|
|
if device == "cuda": |
|
|
memory_used = torch.cuda.max_memory_allocated() / 1024**3 |
|
|
print(f"\n5. Peak GPU memory: {memory_used:.2f}GB") |
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Quick start complete!") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
print("\nNext steps:") |
|
|
print(" - Train: python scripts/train.py --model max2-lite --train-data your_data.jsonl") |
|
|
print(" - Export: python scripts/export.py --model max2-nano --format onnx gguf") |
|
|
print(" - See README.md for full documentation") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|