eaglelandsonce commited on
Commit
6dad776
·
verified ·
1 Parent(s): d777abf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -47
app.py CHANGED
@@ -1,51 +1,128 @@
1
  import os
2
- import tempfile
 
 
 
3
  import numpy as np
4
- import torch
5
- import torch.nn as nn
6
- import onnx
7
- from onnx import external_data_helper
8
-
9
- class OnnxWrapper(nn.Module):
10
- def __init__(self, net: nn.Module, mu: np.ndarray, sd: np.ndarray):
11
- super().__init__()
12
- self.net = net
13
- self.register_buffer("mu", torch.tensor(mu, dtype=torch.float32))
14
- self.register_buffer("sd", torch.tensor(sd, dtype=torch.float32))
15
-
16
- def forward(self, x: torch.Tensor) -> torch.Tensor:
17
- x = (x - self.mu) / self.sd
18
- logits = self.net(x).squeeze(-1)
19
- return torch.sigmoid(logits)
20
-
21
- def export_onnx_model(trained_model, mu: np.ndarray, sd: np.ndarray, n_features: int) -> str:
22
- # trained_model is your LightningModule; we export its .net
23
- wrapper = OnnxWrapper(trained_model.net.cpu().eval(), mu=mu, sd=sd).eval()
24
- dummy = torch.zeros(1, n_features, dtype=torch.float32)
25
-
26
- tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".onnx", prefix="model_")
27
- onnx_path = tmp.name
28
-
29
- torch.onnx.export(
30
- wrapper,
31
- dummy,
32
- onnx_path,
33
- input_names=["features"],
34
- output_names=["p_up"],
35
- dynamic_axes={"features": {0: "batch"}, "p_up": {0: "batch"}},
36
- opset_version=17,
37
- do_constant_folding=True,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  )
39
 
40
- # Merge external data back into one .onnx if needed
41
- data_path = onnx_path + ".data"
42
- if os.path.exists(data_path):
43
- m = onnx.load_model(onnx_path, load_external_data=True)
44
- external_data_helper.convert_model_from_external_data(m)
45
- onnx.save_model(m, onnx_path)
46
- try:
47
- os.remove(data_path)
48
- except OSError:
49
- pass
50
-
51
- return onnx_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import traceback
3
+ from pathlib import Path
4
+
5
+ import gradio as gr
6
  import numpy as np
7
+ import onnxruntime as ort
8
+
9
+ MODEL_PATH = Path("model.onnx")
10
+
11
+ # Lazy-loaded ORT session cache
12
+ _SESSION = None
13
+ _INPUT_NAME = None
14
+
15
+
16
+ def _load_session():
17
+ """Load ONNX Runtime session only when needed (prevents startup crash)."""
18
+ global _SESSION, _INPUT_NAME
19
+
20
+ if _SESSION is not None:
21
+ return _SESSION, _INPUT_NAME
22
+
23
+ if not MODEL_PATH.exists():
24
+ raise FileNotFoundError(
25
+ "model.onnx not found in the Space root. "
26
+ "Upload your ONNX file and name it exactly: model.onnx"
27
+ )
28
+
29
+ # CPU provider is the most compatible on Spaces
30
+ sess = ort.InferenceSession(str(MODEL_PATH), providers=["CPUExecutionProvider"])
31
+ inp_name = sess.get_inputs()[0].name
32
+
33
+ _SESSION = sess
34
+ _INPUT_NAME = inp_name
35
+ return _SESSION, _INPUT_NAME
36
+
37
+
38
+ def _parse_vector(text: str) -> np.ndarray:
39
+ """
40
+ Parse a comma/space separated vector like:
41
+ "0.1, 0.2, 0.3"
42
+ "0.1 0.2 0.3"
43
+ Returns shape (1, n_features)
44
+ """
45
+ if not text or not text.strip():
46
+ raise ValueError("Vector input is empty.")
47
+ parts = [p for p in text.replace(",", " ").split() if p.strip()]
48
+ vals = [float(p) for p in parts]
49
+ x = np.array([vals], dtype=np.float32)
50
+ return x
51
+
52
+
53
+ def predict_5(ret_1, ret_5, sma_ratio, rsi, vol):
54
+ """
55
+ For your 5-feature model wrapper:
56
+ [ret_1, ret_5, sma_ratio, rsi, vol]
57
+ """
58
+ try:
59
+ sess, inp_name = _load_session()
60
+ x = np.array([[ret_1, ret_5, sma_ratio, rsi, vol]], dtype=np.float32)
61
+ y = sess.run(None, {inp_name: x})[0]
62
+ y = np.array(y).reshape(-1)[0]
63
+ return float(y), "OK"
64
+ except Exception:
65
+ return None, traceback.format_exc()
66
+
67
+
68
+ def predict_vector(vec_text: str):
69
+ """Generic vector inference for any ONNX model expecting [batch, features]."""
70
+ try:
71
+ sess, inp_name = _load_session()
72
+ x = _parse_vector(vec_text)
73
+ y = sess.run(None, {inp_name: x})[0]
74
+ y = np.array(y).reshape(-1)
75
+ # show first value for convenience, but also return full output
76
+ first = float(y[0]) if y.size else None
77
+ return first, y.tolist(), "OK"
78
+ except Exception:
79
+ return None, None, traceback.format_exc()
80
+
81
+
82
+ with gr.Blocks(title="ONNX Inference Only") as demo:
83
+ gr.Markdown(
84
+ """
85
+ # ONNX Inference Only (No training / no data)
86
+
87
+ Place your model in the Space root as **`model.onnx`**.
88
+
89
+ ⚠️ If your ONNX was exported with external weights, you must also upload the referenced
90
+ `*.onnx.data` file into the same folder — OR re-export as a single-file ONNX.
91
+ """
92
  )
93
 
94
+ with gr.Tab("5-Feature Input (recommended for your stock model)"):
95
+ with gr.Row():
96
+ ret_1 = gr.Number(label="ret_1", value=0.001)
97
+ ret_5 = gr.Number(label="ret_5", value=0.01)
98
+ sma_ratio = gr.Number(label="sma_ratio", value=0.02)
99
+ rsi = gr.Number(label="rsi", value=55.0)
100
+ vol = gr.Number(label="vol", value=0.012)
101
+ btn1 = gr.Button("Run ONNX", variant="primary")
102
+ out1 = gr.Number(label="Model output (e.g., p_up)")
103
+ status1 = gr.Textbox(label="Status / Error", lines=10)
104
+
105
+ btn1.click(
106
+ fn=predict_5,
107
+ inputs=[ret_1, ret_5, sma_ratio, rsi, vol],
108
+ outputs=[out1, status1],
109
+ )
110
+
111
+ with gr.Tab("Vector Input (any feature size)"):
112
+ vec = gr.Textbox(
113
+ label="Input vector (comma or space separated)",
114
+ value="0.001, 0.01, 0.02, 55.0, 0.012",
115
+ )
116
+ btn2 = gr.Button("Run ONNX (vector)", variant="primary")
117
+ out2_first = gr.Number(label="First output value")
118
+ out2_full = gr.JSON(label="Full output")
119
+ status2 = gr.Textbox(label="Status / Error", lines=10)
120
+
121
+ btn2.click(
122
+ fn=predict_vector,
123
+ inputs=[vec],
124
+ outputs=[out2_first, out2_full, status2],
125
+ )
126
+
127
+ if __name__ == "__main__":
128
+ demo.launch()