Spaces:
Runtime error
Runtime error
amitke commited on
Commit ·
7fd066b
1
Parent(s): b53cc5c
fix @spaces.GPU
Browse files- app.py +1 -0
- apptest.py +0 -58
app.py
CHANGED
|
@@ -16,6 +16,7 @@ def train_api(symbol, seq_len=60, epochs=5, batch_size=32, start="", end=""):
|
|
| 16 |
def test_api(symbol):
|
| 17 |
return eval_fn(symbol)
|
| 18 |
|
|
|
|
| 19 |
def predict_api(symbol, days=1):
|
| 20 |
return predict_fn(symbol, n_days=int(days))
|
| 21 |
|
|
|
|
| 16 |
def test_api(symbol):
|
| 17 |
return eval_fn(symbol)
|
| 18 |
|
| 19 |
+
@spaces.GPU
|
| 20 |
def predict_api(symbol, days=1):
|
| 21 |
return predict_fn(symbol, n_days=int(days))
|
| 22 |
|
apptest.py
DELETED
|
@@ -1,58 +0,0 @@
|
|
| 1 |
-
import gradio as gr
|
| 2 |
-
from train import train as train_fn
|
| 3 |
-
from testing import evaluate as eval_fn
|
| 4 |
-
from inference import predict_next as predict_fn
|
| 5 |
-
|
| 6 |
-
def train_api(symbol, seq_len=60, epochs=5, batch_size=32, start="", end=""):
|
| 7 |
-
return train_fn(
|
| 8 |
-
symbol,
|
| 9 |
-
seq_len=int(seq_len),
|
| 10 |
-
epochs=int(epochs),
|
| 11 |
-
batch_size=int(batch_size),
|
| 12 |
-
start=start or None,
|
| 13 |
-
end=end or None,
|
| 14 |
-
)
|
| 15 |
-
|
| 16 |
-
def test_api(symbol):
|
| 17 |
-
return eval_fn(symbol)
|
| 18 |
-
|
| 19 |
-
def predict_api(symbol, days=1):
|
| 20 |
-
return predict_fn(symbol, n_days=int(days))
|
| 21 |
-
|
| 22 |
-
def hello_api(name="world"):
|
| 23 |
-
return {"message": f"hello {name}"}
|
| 24 |
-
|
| 25 |
-
with gr.Blocks() as demo:
|
| 26 |
-
gr.Markdown("## LSTM Stock Predictor (PyTorch • Train / Test / Predict)")
|
| 27 |
-
with gr.Tab("Train"):
|
| 28 |
-
sym_t = gr.Textbox(label="Symbol", value="AAPL")
|
| 29 |
-
seq = gr.Number(label="Seq length", value=60, precision=0)
|
| 30 |
-
ep = gr.Number(label="Epochs", value=5, precision=0)
|
| 31 |
-
bs = gr.Number(label="Batch size", value=32, precision=0)
|
| 32 |
-
start = gr.Textbox(label="Start (YYYY-MM-DD)", placeholder="optional")
|
| 33 |
-
end = gr.Textbox(label="End (YYYY-MM-DD)", placeholder="optional")
|
| 34 |
-
btn_t = gr.Button("Train")
|
| 35 |
-
out_t = gr.JSON()
|
| 36 |
-
btn_t.click(train_api, [sym_t, seq, ep, bs, start, end], out_t, api_name="train")
|
| 37 |
-
|
| 38 |
-
with gr.Tab("Test"):
|
| 39 |
-
sym_e = gr.Textbox(label="Symbol", value="AAPL")
|
| 40 |
-
btn_e = gr.Button("Run Test")
|
| 41 |
-
out_e = gr.JSON()
|
| 42 |
-
btn_e.click(test_api, [sym_e], out_e, api_name="test")
|
| 43 |
-
|
| 44 |
-
with gr.Tab("Predict"):
|
| 45 |
-
sym_p = gr.Textbox(label="Symbol", value="AAPL")
|
| 46 |
-
days = gr.Number(label="Days to predict", value=1, precision=0)
|
| 47 |
-
btn_p = gr.Button("Predict")
|
| 48 |
-
out_p = gr.JSON()
|
| 49 |
-
btn_p.click(predict_api, [sym_p, days], out_p, api_name="predict")
|
| 50 |
-
|
| 51 |
-
with gr.Tab("Hello"):
|
| 52 |
-
who = gr.Textbox(label="Name", value="world")
|
| 53 |
-
btn_h = gr.Button("Say Hello")
|
| 54 |
-
out_h = gr.JSON()
|
| 55 |
-
btn_h.click(hello_api, [who], out_h, api_name="hello")
|
| 56 |
-
|
| 57 |
-
if __name__ == "__main__":
|
| 58 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|