#!/usr/bin/env python3 """ MiniMind Max2 Quick Start Example Demonstrates basic usage of the Max2 model. """ import sys from pathlib import Path # Add parent directory sys.path.insert(0, str(Path(__file__).parent.parent)) import torch def main(): print("=" * 60) print("MiniMind Max2 Quick Start") print("=" * 60) # Import model components from configs.model_config import get_config, estimate_params from model import Max2ForCausalLM # Select model variant model_name = "max2-nano" # Options: max2-nano, max2-lite, max2-pro print(f"\n1. Creating {model_name} model...") config = get_config(model_name) model = Max2ForCausalLM(config) # Show model info params = estimate_params(config) print(f" Total parameters: {params['total_params_b']:.3f}B") print(f" Active parameters: {params['active_params_b']:.3f}B") print(f" Activation ratio: {params['activation_ratio']:.1%}") print(f" Estimated size (INT4): {params['estimated_size_int4_gb']:.2f}GB") # Move to device device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 model = model.to(device=device, dtype=dtype) print(f"\n2. Model loaded on {device} with {dtype}") # Test forward pass print("\n3. Testing forward pass...") batch_size, seq_len = 2, 64 input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device) model.eval() with torch.no_grad(): loss, logits, _, aux_loss = model(input_ids, labels=input_ids) print(f" Input shape: {input_ids.shape}") print(f" Output logits shape: {logits.shape}") print(f" Loss: {loss:.4f}") print(f" MoE auxiliary loss: {aux_loss:.6f}") # Test generation print("\n4. Testing generation...") prompt = torch.randint(0, config.vocab_size, (1, 10), device=device) with torch.no_grad(): generated = model.generate( prompt, max_new_tokens=20, temperature=0.8, top_k=50, top_p=0.9, do_sample=True, ) print(f" Prompt length: {prompt.shape[1]}") print(f" Generated length: {generated.shape[1]}") print(f" New tokens: {generated.shape[1] - prompt.shape[1]}") # Memory usage if device == "cuda": memory_used = torch.cuda.max_memory_allocated() / 1024**3 print(f"\n5. Peak GPU memory: {memory_used:.2f}GB") print("\n" + "=" * 60) print("Quick start complete!") print("=" * 60) # Usage hints print("\nNext steps:") print(" - Train: python scripts/train.py --model max2-lite --train-data your_data.jsonl") print(" - Export: python scripts/export.py --model max2-nano --format onnx gguf") print(" - See README.md for full documentation") if __name__ == "__main__": main()