eaglelandsonce commited on
Commit
eef69b7
·
verified ·
1 Parent(s): 915c6cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -4
app.py CHANGED
@@ -1,10 +1,26 @@
1
  import os
 
 
 
 
2
  import onnx
3
  from onnx import external_data_helper
4
 
5
- def export_onnx_model(trained_model: LitClassifier, mu: np.ndarray, sd: np.ndarray, n_features: int) -> str:
6
- wrapper = OnnxWrapper(trained_model.net.cpu().eval(), mu=mu, sd=sd).eval()
 
 
 
 
 
 
 
 
 
7
 
 
 
 
8
  dummy = torch.zeros(1, n_features, dtype=torch.float32)
9
 
10
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".onnx", prefix="model_")
@@ -21,13 +37,12 @@ def export_onnx_model(trained_model: LitClassifier, mu: np.ndarray, sd: np.ndarr
21
  do_constant_folding=True,
22
  )
23
 
24
- # If exporter wrote external data, merge it back into a single .onnx
25
  data_path = onnx_path + ".data"
26
  if os.path.exists(data_path):
27
  m = onnx.load_model(onnx_path, load_external_data=True)
28
  external_data_helper.convert_model_from_external_data(m)
29
  onnx.save_model(m, onnx_path)
30
- # remove external blob now that it's embedded
31
  try:
32
  os.remove(data_path)
33
  except OSError:
 
1
  import os
2
+ import tempfile
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
  import onnx
7
  from onnx import external_data_helper
8
 
9
+ class OnnxWrapper(nn.Module):
10
+ def __init__(self, net: nn.Module, mu: np.ndarray, sd: np.ndarray):
11
+ super().__init__()
12
+ self.net = net
13
+ self.register_buffer("mu", torch.tensor(mu, dtype=torch.float32))
14
+ self.register_buffer("sd", torch.tensor(sd, dtype=torch.float32))
15
+
16
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
17
+ x = (x - self.mu) / self.sd
18
+ logits = self.net(x).squeeze(-1)
19
+ return torch.sigmoid(logits)
20
 
21
+ def export_onnx_model(trained_model, mu: np.ndarray, sd: np.ndarray, n_features: int) -> str:
22
+ # trained_model is your LightningModule; we export its .net
23
+ wrapper = OnnxWrapper(trained_model.net.cpu().eval(), mu=mu, sd=sd).eval()
24
  dummy = torch.zeros(1, n_features, dtype=torch.float32)
25
 
26
  tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".onnx", prefix="model_")
 
37
  do_constant_folding=True,
38
  )
39
 
40
+ # Merge external data back into one .onnx if needed
41
  data_path = onnx_path + ".data"
42
  if os.path.exists(data_path):
43
  m = onnx.load_model(onnx_path, load_external_data=True)
44
  external_data_helper.convert_model_from_external_data(m)
45
  onnx.save_model(m, onnx_path)
 
46
  try:
47
  os.remove(data_path)
48
  except OSError: