File size: 1,908 Bytes
9f3aa4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import argparse
from pathlib import Path
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer


def export_to_onnx(
    model_dir: str,
    output_path: str | None = None,
    max_len: int = 256,
    opset_version: int = 17,
) -> Path:
    model_dir = Path(model_dir)
    if output_path is None:
        output_path = model_dir / "model.onnx"
    else:
        output_path = Path(output_path)

    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir)
    model.eval()

    dummy = tokenizer(
        ["dummy input", "another dummy"],
        max_length=max_len,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )

    batch = torch.export.Dim("batch", min=1, max=4096)
    torch.onnx.export(
        model,
        (dummy["input_ids"], dummy["attention_mask"]),
        str(output_path),
        input_names=["input_ids", "attention_mask"],
        output_names=["logits"],
        dynamic_shapes={
            "input_ids": {0: batch},
            "attention_mask": {0: batch},
        },
        opset_version=opset_version,
        dynamo=True,
        external_data=False,
    )

    print(f"Exported ONNX model to '{output_path}' ({output_path.stat().st_size / 1024 / 1024:.1f} MB)")
    return output_path


def main():
    parser = argparse.ArgumentParser(description="Export model to ONNX")
    parser.add_argument("--model-dir", required=True, help="Path to saved model directory")
    parser.add_argument("--output", default=None, help="Output ONNX path (default: model_dir/model.onnx)")
    parser.add_argument("--max-len", type=int, default=256)
    parser.add_argument("--opset", type=int, default=17)
    args = parser.parse_args()

    export_to_onnx(args.model_dir, args.output, args.max_len, args.opset)


if __name__ == "__main__":
    main()