Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Export mT5_multilingual_XLSum to ONNX Format | |
| ============================================= | |
| Converts the csebuetnlp/mT5_multilingual_XLSum model from PyTorch to ONNX format | |
| for efficient CPU inference in the Hindi summarization pipeline. | |
| This model is mT5-base fine-tuned on XL-Sum (45 languages including Hindi news). | |
| It produces significantly better Hindi summaries than vanilla mT5-small. | |
| The exported model is saved to: models/mt5_onnx/ | |
| This needs to be run ONCE before using the Hindi pipeline. | |
| Disk space required: ~2.3 GB | |
| Time: ~5 minutes (first run only) | |
| Usage: | |
| python backend/models/export_mt5.py | |
| """ | |
| import sys | |
| from pathlib import Path | |
| from optimum.onnxruntime import ORTModelForSeq2SeqLM | |
| from transformers import AutoTokenizer | |
| MODEL_ID = "csebuetnlp/mT5_multilingual_XLSum" | |
| OUTPUT_DIR = Path(__file__).parent.parent.parent / "models" / "mt5_onnx" | |
| def main(): | |
| print(f"Exporting {MODEL_ID} to ONNX → {OUTPUT_DIR}") | |
| if OUTPUT_DIR.exists(): | |
| print(f"Output directory already exists. Skipping export.") | |
| return | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| # Single API call — Optimum handles tied weights, ONNX export, configs | |
| model = ORTModelForSeq2SeqLM.from_pretrained(MODEL_ID, export=True) | |
| model.save_pretrained(OUTPUT_DIR) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| # Verify output | |
| files = sorted(OUTPUT_DIR.iterdir()) | |
| print(f"\nSaved {len(files)} files to {OUTPUT_DIR}:") | |
| for f in files: | |
| size_mb = f.stat().st_size / (1024 * 1024) | |
| print(f" {f.name:40s} {size_mb:>8.2f} MB") | |
| onnx_files = list(OUTPUT_DIR.glob("*.onnx")) | |
| if not onnx_files: | |
| print("ERROR: No .onnx files were produced!", file=sys.stderr) | |
| sys.exit(1) | |
| print(f"\nExport complete. {len(onnx_files)} ONNX model(s) ready.") | |
| if __name__ == "__main__": | |
| main() | |