File size: 2,891 Bytes
8b187bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#!/usr/bin/env python3
"""
MiniMind Max2 Quick Start Example
Demonstrates basic usage of the Max2 model.
"""

import sys
from pathlib import Path

# Add parent directory
sys.path.insert(0, str(Path(__file__).parent.parent))

import torch


def main():
    print("=" * 60)
    print("MiniMind Max2 Quick Start")
    print("=" * 60)

    # Import model components
    from configs.model_config import get_config, estimate_params
    from model import Max2ForCausalLM

    # Select model variant
    model_name = "max2-nano"  # Options: max2-nano, max2-lite, max2-pro
    print(f"\n1. Creating {model_name} model...")

    config = get_config(model_name)
    model = Max2ForCausalLM(config)

    # Show model info
    params = estimate_params(config)
    print(f"   Total parameters: {params['total_params_b']:.3f}B")
    print(f"   Active parameters: {params['active_params_b']:.3f}B")
    print(f"   Activation ratio: {params['activation_ratio']:.1%}")
    print(f"   Estimated size (INT4): {params['estimated_size_int4_gb']:.2f}GB")

    # Move to device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device == "cuda" else torch.float32
    model = model.to(device=device, dtype=dtype)
    print(f"\n2. Model loaded on {device} with {dtype}")

    # Test forward pass
    print("\n3. Testing forward pass...")
    batch_size, seq_len = 2, 64
    input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_len), device=device)

    model.eval()
    with torch.no_grad():
        loss, logits, _, aux_loss = model(input_ids, labels=input_ids)

    print(f"   Input shape: {input_ids.shape}")
    print(f"   Output logits shape: {logits.shape}")
    print(f"   Loss: {loss:.4f}")
    print(f"   MoE auxiliary loss: {aux_loss:.6f}")

    # Test generation
    print("\n4. Testing generation...")
    prompt = torch.randint(0, config.vocab_size, (1, 10), device=device)

    with torch.no_grad():
        generated = model.generate(
            prompt,
            max_new_tokens=20,
            temperature=0.8,
            top_k=50,
            top_p=0.9,
            do_sample=True,
        )

    print(f"   Prompt length: {prompt.shape[1]}")
    print(f"   Generated length: {generated.shape[1]}")
    print(f"   New tokens: {generated.shape[1] - prompt.shape[1]}")

    # Memory usage
    if device == "cuda":
        memory_used = torch.cuda.max_memory_allocated() / 1024**3
        print(f"\n5. Peak GPU memory: {memory_used:.2f}GB")

    print("\n" + "=" * 60)
    print("Quick start complete!")
    print("=" * 60)

    # Usage hints
    print("\nNext steps:")
    print("  - Train: python scripts/train.py --model max2-lite --train-data your_data.jsonl")
    print("  - Export: python scripts/export.py --model max2-nano --format onnx gguf")
    print("  - See README.md for full documentation")


if __name__ == "__main__":
    main()