| """ |
| Export the M3 U-Net predictor to ONNX, with a PyTorch-vs-ONNX parity check. |
| """ |
|
|
| import torch |
| import numpy as np |
| import onnxruntime as ort |
| from src.models.predictor import UNetPredictor |
|
|
|
|
| if __name__ == "__main__": |
| device = "cpu" |
|
|
| |
| model = UNetPredictor().to(device) |
| ckpt = torch.load("checkpoints/pred_best.pt", map_location=device) |
| model.load_state_dict(ckpt["model_state"]) |
| model.eval() |
|
|
| |
| dummy = torch.randn(1, 15, 1, 128, 128) |
|
|
| |
| torch.onnx.export( |
| model, |
| dummy, |
| "checkpoints/model.onnx", |
| input_names=["input"], |
| output_names=["output"], |
| dynamic_axes={ |
| "input": {0: "batch"}, |
| "output": {0: "batch"}, |
| }, |
| opset_version=18, |
| ) |
|
|
| |
| with torch.no_grad(): |
| torch_out = model(dummy).cpu().numpy() |
|
|
| sess = ort.InferenceSession("checkpoints/model.onnx") |
| input_name = sess.get_inputs()[0].name |
| onnx_out = sess.run(None, {input_name: dummy.numpy()})[0] |
|
|
| |
| max_diff = np.abs(torch_out - onnx_out).max() |
| print(max_diff) |
| assert max_diff < 1e-4 |