MiniMind / examples /quickstart.py
fariasultana's picture
MiniMind Max2 - Efficient MoE Language Model
8b187bb verified
#!/usr/bin/env python3
"""
MiniMind Max2 Quick Start Example
Demonstrates basic usage of the Max2 model.
"""
import sys
from pathlib import Path
# Add parent directory
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
def main():
print("=" * 60)
print("MiniMind Max2 Quick Start")
print("=" * 60)
# Import model components
from configs.model_config import get_config, estimate_params
from model import Max2ForCausalLM
# Select model variant
model_name = "max2-nano" # Options: max2-nano, max2-lite, max2-pro
print(f"\n1. Creating {model_name} model...")
config = get_config(model_name)
model = Max2ForCausalLM(config)
# Show model info
params = estimate_params(config)
print(f" Total parameters: {params['total_params_b']:.3f}B")
print(f" Active parameters: {params['active_params_b']:.3f}B")
print(f" Activation ratio: {params['activation_ratio']:.1%}")
print(f" Estimated size (INT4): {params['estimated_size_int4_gb']:.2f}GB")
# Move to device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
model = model.to(device=device, dtype=dtype)
print(f"\n2. Model loaded on {device} with {dtype}")
# Test forward pass
print("\n3. Testing forward pass...")
batch_size, seq_len = 2, 64
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device)
model.eval()
with torch.no_grad():
loss, logits, _, aux_loss = model(input_ids, labels=input_ids)
print(f" Input shape: {input_ids.shape}")
print(f" Output logits shape: {logits.shape}")
print(f" Loss: {loss:.4f}")
print(f" MoE auxiliary loss: {aux_loss:.6f}")
# Test generation
print("\n4. Testing generation...")
prompt = torch.randint(0, config.vocab_size, (1, 10), device=device)
with torch.no_grad():
generated = model.generate(
prompt,
max_new_tokens=20,
temperature=0.8,
top_k=50,
top_p=0.9,
do_sample=True,
)
print(f" Prompt length: {prompt.shape[1]}")
print(f" Generated length: {generated.shape[1]}")
print(f" New tokens: {generated.shape[1] - prompt.shape[1]}")
# Memory usage
if device == "cuda":
memory_used = torch.cuda.max_memory_allocated() / 1024**3
print(f"\n5. Peak GPU memory: {memory_used:.2f}GB")
print("\n" + "=" * 60)
print("Quick start complete!")
print("=" * 60)
# Usage hints
print("\nNext steps:")
print(" - Train: python scripts/train.py --model max2-lite --train-data your_data.jsonl")
print(" - Export: python scripts/export.py --model max2-nano --format onnx gguf")
print(" - See README.md for full documentation")
if __name__ == "__main__":
main()