File size: 2,269 Bytes
2d7e335 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 | #!/usr/bin/env python3
"""
AAM Diffusion LLM — Export Script
Export a trained model for deployment.
Usage:
python scripts/export.py --checkpoint output/best.pt --output model_export/
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from diffusion_llm.config.model_config import AamDiffusionConfig
from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel
from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger(__name__)
def main() -> None:
parser = argparse.ArgumentParser(description="Export AAM Diffusion Model")
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--output", type=str, default="./model_export")
parser.add_argument("--format", type=str, default="pt", choices=["pt", "onnx"])
args = parser.parse_args()
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
# Load model
model = AamDiffusionModel.load(args.checkpoint)
model.eval()
# Save model
model_path = output_dir / "model.pt"
model.save(str(model_path))
logger.info("Model exported to %s", model_path)
# Save config
config_path = output_dir / "config.json"
model.config.to_json(config_path)
logger.info("Config saved to %s", config_path)
# Try to copy tokenizer
checkpoint_dir = Path(args.checkpoint).parent
tokenizer_path = checkpoint_dir / "data" / "tokenizer.json"
if tokenizer_path.exists():
import shutil
shutil.copy(tokenizer_path, output_dir / "tokenizer.json")
logger.info("Tokenizer copied to %s", output_dir / "tokenizer.json")
# Summary
print(f"\nExport complete!")
print(f" Model: {model_path}")
print(f" Config: {config_path}")
print(f" Parameters: {model._format_params(model.get_num_params())}")
print(f"\n This is AAM's own body — 1 mind + 1 body.")
print(f" Mind = RSVS Knowledge Graph")
print(f" Body = This Diffusion Model ({model.config.model_name})")
if __name__ == "__main__":
main()
|