File size: 3,115 Bytes
8b187bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
#!/usr/bin/env python3
"""
MiniMind Export Script
Export models to ONNX and GGUF formats for deployment.
"""
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import torch
from configs.model_config import get_config
from model import Mind2ForCausalLM
from optimization.export import export_to_onnx, export_to_gguf, export_for_android, ExportConfig
from optimization.quantization import quantize_model, QuantizationConfig, QuantizationType
def parse_args():
parser = argparse.ArgumentParser(description="Export MiniMind models")
parser.add_argument("--model", type=str, default="mind2-lite",
choices=["mind2-nano", "mind2-lite", "mind2-pro"])
parser.add_argument("--checkpoint", type=str, default=None,
help="Path to model checkpoint")
parser.add_argument("--output-dir", type=str, default="./exports")
parser.add_argument("--format", type=str, nargs="+",
default=["onnx", "gguf"],
choices=["onnx", "gguf", "android"])
parser.add_argument("--quantize", type=str, default=None,
choices=["int4_awq", "int4_gptq", "int8_dynamic"])
parser.add_argument("--max-seq-len", type=int, default=2048)
return parser.parse_args()
def main():
args = parse_args()
print(f"=" * 60)
print(f"MiniMind Export")
print(f"=" * 60)
print(f"Model: {args.model}")
print(f"Formats: {args.format}")
print(f"Quantization: {args.quantize or 'None'}")
# Load model
config = get_config(args.model)
model = Mind2ForCausalLM(config)
if args.checkpoint:
print(f"Loading checkpoint from {args.checkpoint}")
state_dict = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
# Quantize if requested
if args.quantize:
print(f"\nQuantizing to {args.quantize}...")
model = quantize_model(model, args.quantize)
print("Quantization complete!")
# Export
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
export_config = ExportConfig(
max_seq_len=args.max_seq_len,
optimize_for_mobile=True,
)
outputs = {}
if "android" in args.format:
print(f"\nExporting for Android...")
outputs = export_for_android(model, str(output_dir / "android"), config)
else:
if "onnx" in args.format:
print(f"\nExporting to ONNX...")
onnx_path = output_dir / f"{args.model}.onnx"
outputs["onnx"] = export_to_onnx(model, str(onnx_path), export_config)
if "gguf" in args.format:
print(f"\nExporting to GGUF...")
gguf_path = output_dir / f"{args.model}.gguf"
outputs["gguf"] = export_to_gguf(model, str(gguf_path), config, export_config)
print(f"\n" + "=" * 60)
print("Export complete!")
print("=" * 60)
for fmt, path in outputs.items():
print(f" {fmt}: {path}")
if __name__ == "__main__":
main()
|