Spaces:
Build error
Build error
| #!/usr/bin/env python3 | |
| """ | |
| Export MODEL-W to ONNX format for deployment. | |
| Useful for: | |
| - Running inference without PyTorch | |
| - Integration with DAW plugins (via ONNX runtime) | |
| - Edge deployment | |
| """ | |
| import argparse | |
| from pathlib import Path | |
| import torch | |
| import torch.onnx | |
| from modelw.api import ModelW | |
| def export_to_onnx( | |
| checkpoint_path: str, | |
| output_path: str, | |
| max_seq_len: int = 512, | |
| opset_version: int = 14, | |
| ): | |
| """Export model to ONNX format.""" | |
| print(f"Loading model from {checkpoint_path}...") | |
| model_api = ModelW.load(checkpoint_path, device="cpu") | |
| model = model_api.model | |
| model.eval() | |
| print(f"Model loaded: {model_api}") | |
| # Create dummy input | |
| batch_size = 1 | |
| seq_len = max_seq_len | |
| dummy_input = torch.randint(0, model_api.tokenizer.vocab_size, (batch_size, seq_len)) | |
| # Export | |
| output_path = Path(output_path) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| print(f"Exporting to {output_path}...") | |
| torch.onnx.export( | |
| model, | |
| (dummy_input,), | |
| str(output_path), | |
| export_params=True, | |
| opset_version=opset_version, | |
| do_constant_folding=True, | |
| input_names=["input_ids"], | |
| output_names=["logits"], | |
| dynamic_axes={ | |
| "input_ids": {0: "batch_size", 1: "sequence"}, | |
| "logits": {0: "batch_size", 1: "sequence"}, | |
| }, | |
| ) | |
| print(f"✓ Exported to {output_path}") | |
| # Verify | |
| try: | |
| import onnx | |
| import onnxruntime | |
| print("\nVerifying ONNX model...") | |
| onnx_model = onnx.load(str(output_path)) | |
| onnx.checker.check_model(onnx_model) | |
| print("✓ ONNX model is valid") | |
| # Test inference | |
| session = onnxruntime.InferenceSession(str(output_path)) | |
| test_input = dummy_input.numpy() | |
| outputs = session.run(None, {"input_ids": test_input}) | |
| print(f"✓ ONNX inference works, output shape: {outputs[0].shape}") | |
| except ImportError: | |
| print("\n(Install onnx and onnxruntime to verify the export)") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Export MODEL-W to ONNX") | |
| parser.add_argument( | |
| "checkpoint", | |
| type=str, | |
| help="Path to checkpoint directory or .pt file", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| type=str, | |
| default="./exports/model_w.onnx", | |
| help="Output ONNX file path", | |
| ) | |
| parser.add_argument( | |
| "--max-seq-len", | |
| type=int, | |
| default=512, | |
| help="Maximum sequence length for export", | |
| ) | |
| args = parser.parse_args() | |
| export_to_onnx(args.checkpoint, args.output, args.max_seq_len) | |
| if __name__ == "__main__": | |
| main() | |