File size: 3,115 Bytes
8b187bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
"""
MiniMind Export Script
Export models to ONNX and GGUF formats for deployment.
"""

import argparse
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent))

import torch

from configs.model_config import get_config
from model import Mind2ForCausalLM
from optimization.export import export_to_onnx, export_to_gguf, export_for_android, ExportConfig
from optimization.quantization import quantize_model, QuantizationConfig, QuantizationType


def parse_args():
    parser = argparse.ArgumentParser(description="Export MiniMind models")

    parser.add_argument("--model", type=str, default="mind2-lite",
                        choices=["mind2-nano", "mind2-lite", "mind2-pro"])
    parser.add_argument("--checkpoint", type=str, default=None,
                        help="Path to model checkpoint")
    parser.add_argument("--output-dir", type=str, default="./exports")

    parser.add_argument("--format", type=str, nargs="+",
                        default=["onnx", "gguf"],
                        choices=["onnx", "gguf", "android"])

    parser.add_argument("--quantize", type=str, default=None,
                        choices=["int4_awq", "int4_gptq", "int8_dynamic"])
    parser.add_argument("--max-seq-len", type=int, default=2048)

    return parser.parse_args()


def main():
    args = parse_args()

    print(f"=" * 60)
    print(f"MiniMind Export")
    print(f"=" * 60)
    print(f"Model: {args.model}")
    print(f"Formats: {args.format}")
    print(f"Quantization: {args.quantize or 'None'}")

    # Load model
    config = get_config(args.model)
    model = Mind2ForCausalLM(config)

    if args.checkpoint:
        print(f"Loading checkpoint from {args.checkpoint}")
        state_dict = torch.load(args.checkpoint, map_location="cpu")
        model.load_state_dict(state_dict)

    model.eval()

    # Quantize if requested
    if args.quantize:
        print(f"\nQuantizing to {args.quantize}...")
        model = quantize_model(model, args.quantize)
        print("Quantization complete!")

    # Export
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    export_config = ExportConfig(
        max_seq_len=args.max_seq_len,
        optimize_for_mobile=True,
    )

    outputs = {}

    if "android" in args.format:
        print(f"\nExporting for Android...")
        outputs = export_for_android(model, str(output_dir / "android"), config)
    else:
        if "onnx" in args.format:
            print(f"\nExporting to ONNX...")
            onnx_path = output_dir / f"{args.model}.onnx"
            outputs["onnx"] = export_to_onnx(model, str(onnx_path), export_config)

        if "gguf" in args.format:
            print(f"\nExporting to GGUF...")
            gguf_path = output_dir / f"{args.model}.gguf"
            outputs["gguf"] = export_to_gguf(model, str(gguf_path), config, export_config)

    print(f"\n" + "=" * 60)
    print("Export complete!")
    print("=" * 60)
    for fmt, path in outputs.items():
        print(f"  {fmt}: {path}")


if __name__ == "__main__":
    main()