File size: 2,891 Bytes
8b187bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 |
#!/usr/bin/env python3
"""
MiniMind Max2 Quick Start Example
Demonstrates basic usage of the Max2 model.
"""
import sys
from pathlib import Path
# Add parent directory
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
def main():
print("=" * 60)
print("MiniMind Max2 Quick Start")
print("=" * 60)
# Import model components
from configs.model_config import get_config, estimate_params
from model import Max2ForCausalLM
# Select model variant
model_name = "max2-nano" # Options: max2-nano, max2-lite, max2-pro
print(f"\n1. Creating {model_name} model...")
config = get_config(model_name)
model = Max2ForCausalLM(config)
# Show model info
params = estimate_params(config)
print(f" Total parameters: {params['total_params_b']:.3f}B")
print(f" Active parameters: {params['active_params_b']:.3f}B")
print(f" Activation ratio: {params['activation_ratio']:.1%}")
print(f" Estimated size (INT4): {params['estimated_size_int4_gb']:.2f}GB")
# Move to device
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
model = model.to(device=device, dtype=dtype)
print(f"\n2. Model loaded on {device} with {dtype}")
# Test forward pass
print("\n3. Testing forward pass...")
batch_size, seq_len = 2, 64
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device)
model.eval()
with torch.no_grad():
loss, logits, _, aux_loss = model(input_ids, labels=input_ids)
print(f" Input shape: {input_ids.shape}")
print(f" Output logits shape: {logits.shape}")
print(f" Loss: {loss:.4f}")
print(f" MoE auxiliary loss: {aux_loss:.6f}")
# Test generation
print("\n4. Testing generation...")
prompt = torch.randint(0, config.vocab_size, (1, 10), device=device)
with torch.no_grad():
generated = model.generate(
prompt,
max_new_tokens=20,
temperature=0.8,
top_k=50,
top_p=0.9,
do_sample=True,
)
print(f" Prompt length: {prompt.shape[1]}")
print(f" Generated length: {generated.shape[1]}")
print(f" New tokens: {generated.shape[1] - prompt.shape[1]}")
# Memory usage
if device == "cuda":
memory_used = torch.cuda.max_memory_allocated() / 1024**3
print(f"\n5. Peak GPU memory: {memory_used:.2f}GB")
print("\n" + "=" * 60)
print("Quick start complete!")
print("=" * 60)
# Usage hints
print("\nNext steps:")
print(" - Train: python scripts/train.py --model max2-lite --train-data your_data.jsonl")
print(" - Export: python scripts/export.py --model max2-nano --format onnx gguf")
print(" - See README.md for full documentation")
if __name__ == "__main__":
main()
|