Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,26 @@
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
import onnx
|
| 3 |
from onnx import external_data_helper
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
|
|
|
|
|
|
|
|
|
| 8 |
dummy = torch.zeros(1, n_features, dtype=torch.float32)
|
| 9 |
|
| 10 |
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".onnx", prefix="model_")
|
|
@@ -21,13 +37,12 @@ def export_onnx_model(trained_model: LitClassifier, mu: np.ndarray, sd: np.ndarr
|
|
| 21 |
do_constant_folding=True,
|
| 22 |
)
|
| 23 |
|
| 24 |
-
#
|
| 25 |
data_path = onnx_path + ".data"
|
| 26 |
if os.path.exists(data_path):
|
| 27 |
m = onnx.load_model(onnx_path, load_external_data=True)
|
| 28 |
external_data_helper.convert_model_from_external_data(m)
|
| 29 |
onnx.save_model(m, onnx_path)
|
| 30 |
-
# remove external blob now that it's embedded
|
| 31 |
try:
|
| 32 |
os.remove(data_path)
|
| 33 |
except OSError:
|
|
|
|
| 1 |
import os
|
| 2 |
+
import tempfile
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
import onnx
|
| 7 |
from onnx import external_data_helper
|
| 8 |
|
| 9 |
+
class OnnxWrapper(nn.Module):
|
| 10 |
+
def __init__(self, net: nn.Module, mu: np.ndarray, sd: np.ndarray):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.net = net
|
| 13 |
+
self.register_buffer("mu", torch.tensor(mu, dtype=torch.float32))
|
| 14 |
+
self.register_buffer("sd", torch.tensor(sd, dtype=torch.float32))
|
| 15 |
+
|
| 16 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 17 |
+
x = (x - self.mu) / self.sd
|
| 18 |
+
logits = self.net(x).squeeze(-1)
|
| 19 |
+
return torch.sigmoid(logits)
|
| 20 |
|
| 21 |
+
def export_onnx_model(trained_model, mu: np.ndarray, sd: np.ndarray, n_features: int) -> str:
|
| 22 |
+
# trained_model is your LightningModule; we export its .net
|
| 23 |
+
wrapper = OnnxWrapper(trained_model.net.cpu().eval(), mu=mu, sd=sd).eval()
|
| 24 |
dummy = torch.zeros(1, n_features, dtype=torch.float32)
|
| 25 |
|
| 26 |
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".onnx", prefix="model_")
|
|
|
|
| 37 |
do_constant_folding=True,
|
| 38 |
)
|
| 39 |
|
| 40 |
+
# Merge external data back into one .onnx if needed
|
| 41 |
data_path = onnx_path + ".data"
|
| 42 |
if os.path.exists(data_path):
|
| 43 |
m = onnx.load_model(onnx_path, load_external_data=True)
|
| 44 |
external_data_helper.convert_model_from_external_data(m)
|
| 45 |
onnx.save_model(m, onnx_path)
|
|
|
|
| 46 |
try:
|
| 47 |
os.remove(data_path)
|
| 48 |
except OSError:
|