File size: 1,939 Bytes
2cb327c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/env python3
"""
Export mT5_multilingual_XLSum to ONNX Format
=============================================
Converts the csebuetnlp/mT5_multilingual_XLSum model from PyTorch to ONNX format
for efficient CPU inference in the Hindi summarization pipeline.

This model is mT5-base fine-tuned on XL-Sum (45 languages including Hindi news).
It produces significantly better Hindi summaries than vanilla mT5-small.

The exported model is saved to: models/mt5_onnx/
This needs to be run ONCE before using the Hindi pipeline.

Disk space required: ~2.3 GB
Time: ~5 minutes (first run only)

Usage:
    python backend/models/export_mt5.py
"""

import sys
from pathlib import Path

from optimum.onnxruntime import ORTModelForSeq2SeqLM
from transformers import AutoTokenizer

MODEL_ID = "csebuetnlp/mT5_multilingual_XLSum"
OUTPUT_DIR = Path(__file__).parent.parent.parent / "models" / "mt5_onnx"


def main():
    print(f"Exporting {MODEL_ID} to ONNX → {OUTPUT_DIR}")

    if OUTPUT_DIR.exists():
        print(f"Output directory already exists. Skipping export.")
        return

    OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    # Single API call — Optimum handles tied weights, ONNX export, configs
    model = ORTModelForSeq2SeqLM.from_pretrained(MODEL_ID, export=True)
    model.save_pretrained(OUTPUT_DIR)

    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    tokenizer.save_pretrained(OUTPUT_DIR)

    # Verify output
    files = sorted(OUTPUT_DIR.iterdir())
    print(f"\nSaved {len(files)} files to {OUTPUT_DIR}:")
    for f in files:
        size_mb = f.stat().st_size / (1024 * 1024)
        print(f"  {f.name:40s} {size_mb:>8.2f} MB")

    onnx_files = list(OUTPUT_DIR.glob("*.onnx"))
    if not onnx_files:
        print("ERROR: No .onnx files were produced!", file=sys.stderr)
        sys.exit(1)

    print(f"\nExport complete. {len(onnx_files)} ONNX model(s) ready.")


if __name__ == "__main__":
    main()