|
|
|
|
|
""" |
|
|
MiniMind Export Script |
|
|
Export models to ONNX and GGUF formats for deployment. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
import torch |
|
|
|
|
|
from configs.model_config import get_config |
|
|
from model import Mind2ForCausalLM |
|
|
from optimization.export import export_to_onnx, export_to_gguf, export_for_android, ExportConfig |
|
|
from optimization.quantization import quantize_model, QuantizationConfig, QuantizationType |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Export MiniMind models") |
|
|
|
|
|
parser.add_argument("--model", type=str, default="mind2-lite", |
|
|
choices=["mind2-nano", "mind2-lite", "mind2-pro"]) |
|
|
parser.add_argument("--checkpoint", type=str, default=None, |
|
|
help="Path to model checkpoint") |
|
|
parser.add_argument("--output-dir", type=str, default="./exports") |
|
|
|
|
|
parser.add_argument("--format", type=str, nargs="+", |
|
|
default=["onnx", "gguf"], |
|
|
choices=["onnx", "gguf", "android"]) |
|
|
|
|
|
parser.add_argument("--quantize", type=str, default=None, |
|
|
choices=["int4_awq", "int4_gptq", "int8_dynamic"]) |
|
|
parser.add_argument("--max-seq-len", type=int, default=2048) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
|
|
|
print(f"=" * 60) |
|
|
print(f"MiniMind Export") |
|
|
print(f"=" * 60) |
|
|
print(f"Model: {args.model}") |
|
|
print(f"Formats: {args.format}") |
|
|
print(f"Quantization: {args.quantize or 'None'}") |
|
|
|
|
|
|
|
|
config = get_config(args.model) |
|
|
model = Mind2ForCausalLM(config) |
|
|
|
|
|
if args.checkpoint: |
|
|
print(f"Loading checkpoint from {args.checkpoint}") |
|
|
state_dict = torch.load(args.checkpoint, map_location="cpu") |
|
|
model.load_state_dict(state_dict) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if args.quantize: |
|
|
print(f"\nQuantizing to {args.quantize}...") |
|
|
model = quantize_model(model, args.quantize) |
|
|
print("Quantization complete!") |
|
|
|
|
|
|
|
|
output_dir = Path(args.output_dir) |
|
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
export_config = ExportConfig( |
|
|
max_seq_len=args.max_seq_len, |
|
|
optimize_for_mobile=True, |
|
|
) |
|
|
|
|
|
outputs = {} |
|
|
|
|
|
if "android" in args.format: |
|
|
print(f"\nExporting for Android...") |
|
|
outputs = export_for_android(model, str(output_dir / "android"), config) |
|
|
else: |
|
|
if "onnx" in args.format: |
|
|
print(f"\nExporting to ONNX...") |
|
|
onnx_path = output_dir / f"{args.model}.onnx" |
|
|
outputs["onnx"] = export_to_onnx(model, str(onnx_path), export_config) |
|
|
|
|
|
if "gguf" in args.format: |
|
|
print(f"\nExporting to GGUF...") |
|
|
gguf_path = output_dir / f"{args.model}.gguf" |
|
|
outputs["gguf"] = export_to_gguf(model, str(gguf_path), config, export_config) |
|
|
|
|
|
print(f"\n" + "=" * 60) |
|
|
print("Export complete!") |
|
|
print("=" * 60) |
|
|
for fmt, path in outputs.items(): |
|
|
print(f" {fmt}: {path}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|