File size: 2,767 Bytes
aed1d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
#!/usr/bin/env python3
"""
Export MODEL-W to ONNX format for deployment.

Useful for:
- Running inference without PyTorch
- Integration with DAW plugins (via ONNX runtime)
- Edge deployment
"""

import argparse
from pathlib import Path

import torch
import torch.onnx

from modelw.api import ModelW


def export_to_onnx(
    checkpoint_path: str,
    output_path: str,
    max_seq_len: int = 512,
    opset_version: int = 14,
):
    """Export model to ONNX format."""
    
    print(f"Loading model from {checkpoint_path}...")
    model_api = ModelW.load(checkpoint_path, device="cpu")
    model = model_api.model
    model.eval()
    
    print(f"Model loaded: {model_api}")
    
    # Create dummy input
    batch_size = 1
    seq_len = max_seq_len
    dummy_input = torch.randint(0, model_api.tokenizer.vocab_size, (batch_size, seq_len))
    
    # Export
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    print(f"Exporting to {output_path}...")
    
    torch.onnx.export(
        model,
        (dummy_input,),
        str(output_path),
        export_params=True,
        opset_version=opset_version,
        do_constant_folding=True,
        input_names=["input_ids"],
        output_names=["logits"],
        dynamic_axes={
            "input_ids": {0: "batch_size", 1: "sequence"},
            "logits": {0: "batch_size", 1: "sequence"},
        },
    )
    
    print(f"✓ Exported to {output_path}")
    
    # Verify
    try:
        import onnx
        import onnxruntime
        
        print("\nVerifying ONNX model...")
        onnx_model = onnx.load(str(output_path))
        onnx.checker.check_model(onnx_model)
        print("✓ ONNX model is valid")
        
        # Test inference
        session = onnxruntime.InferenceSession(str(output_path))
        test_input = dummy_input.numpy()
        outputs = session.run(None, {"input_ids": test_input})
        print(f"✓ ONNX inference works, output shape: {outputs[0].shape}")
        
    except ImportError:
        print("\n(Install onnx and onnxruntime to verify the export)")


def main():
    parser = argparse.ArgumentParser(description="Export MODEL-W to ONNX")
    parser.add_argument(
        "checkpoint",
        type=str,
        help="Path to checkpoint directory or .pt file",
    )
    parser.add_argument(
        "--output",
        type=str,
        default="./exports/model_w.onnx",
        help="Output ONNX file path",
    )
    parser.add_argument(
        "--max-seq-len",
        type=int,
        default=512,
        help="Maximum sequence length for export",
    )
    args = parser.parse_args()
    
    export_to_onnx(args.checkpoint, args.output, args.max_seq_len)


if __name__ == "__main__":
    main()