StockSenseSpace / app.py
amitke
.
457b70c
import gradio as gr
import spaces
from train import train as train_fn
from testing import evaluate as eval_fn
from inference import predict_next as predict_fn
@spaces.GPU
def train_api(symbol, seq_len=60, epochs=5, batch_size=32, start="", end=""):
return train_fn(
symbol,
seq_len=int(seq_len),
epochs=int(epochs),
batch_size=int(batch_size),
start=start or None,
end=end or None,
)
def test_api(symbol):
return eval_fn(symbol)
@spaces.GPU
def predict_api(symbol, days=1):
return predict_fn(symbol, n_days=int(days))
def hello_api(name="world"):
return {"message": f"hello {name}"}
with gr.Blocks() as demo:
gr.Markdown("## LSTM Stock Predictor (PyTorch • Train / Test / Predict)")
with gr.Tab("Train"):
sym_t = gr.Textbox(label="Symbol", value="AAPL")
seq = gr.Number(label="Seq length", value=60, precision=0)
ep = gr.Number(label="Epochs", value=5, precision=0)
bs = gr.Number(label="Batch size", value=32, precision=0)
start = gr.Textbox(label="Start (YYYY-MM-DD)", placeholder="optional")
end = gr.Textbox(label="End (YYYY-MM-DD)", placeholder="optional")
btn_t = gr.Button("Train")
out_t = gr.JSON()
btn_t.click(train_api, [sym_t, seq, ep, bs, start, end], out_t, api_name="train")
with gr.Tab("Test"):
sym_e = gr.Textbox(label="Symbol", value="AAPL")
btn_e = gr.Button("Run Test")
out_e = gr.JSON()
btn_e.click(test_api, [sym_e], out_e, api_name="test")
with gr.Tab("Predict"):
sym_p = gr.Textbox(label="Symbol", value="AAPL")
days = gr.Number(label="Days to predict", value=1, precision=0)
btn_p = gr.Button("Predict")
out_p = gr.JSON()
btn_p.click(predict_api, [sym_p, days], out_p, api_name="predict")
with gr.Tab("Hello"):
who = gr.Textbox(label="Name", value="world")
btn_h = gr.Button("Say Hello")
out_h = gr.JSON()
btn_h.click(hello_api, [who], out_h, api_name="hello")
if __name__ == "__main__":
demo.launch()