Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import spaces | |
| from train import train as train_fn | |
| from testing import evaluate as eval_fn | |
| from inference import predict_next as predict_fn | |
| def train_api(symbol, seq_len=60, epochs=5, batch_size=32, start="", end=""): | |
| return train_fn( | |
| symbol, | |
| seq_len=int(seq_len), | |
| epochs=int(epochs), | |
| batch_size=int(batch_size), | |
| start=start or None, | |
| end=end or None, | |
| ) | |
| def test_api(symbol): | |
| return eval_fn(symbol) | |
| def predict_api(symbol, days=1): | |
| return predict_fn(symbol, n_days=int(days)) | |
| def hello_api(name="world"): | |
| return {"message": f"hello {name}"} | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## LSTM Stock Predictor (PyTorch • Train / Test / Predict)") | |
| with gr.Tab("Train"): | |
| sym_t = gr.Textbox(label="Symbol", value="AAPL") | |
| seq = gr.Number(label="Seq length", value=60, precision=0) | |
| ep = gr.Number(label="Epochs", value=5, precision=0) | |
| bs = gr.Number(label="Batch size", value=32, precision=0) | |
| start = gr.Textbox(label="Start (YYYY-MM-DD)", placeholder="optional") | |
| end = gr.Textbox(label="End (YYYY-MM-DD)", placeholder="optional") | |
| btn_t = gr.Button("Train") | |
| out_t = gr.JSON() | |
| btn_t.click(train_api, [sym_t, seq, ep, bs, start, end], out_t, api_name="train") | |
| with gr.Tab("Test"): | |
| sym_e = gr.Textbox(label="Symbol", value="AAPL") | |
| btn_e = gr.Button("Run Test") | |
| out_e = gr.JSON() | |
| btn_e.click(test_api, [sym_e], out_e, api_name="test") | |
| with gr.Tab("Predict"): | |
| sym_p = gr.Textbox(label="Symbol", value="AAPL") | |
| days = gr.Number(label="Days to predict", value=1, precision=0) | |
| btn_p = gr.Button("Predict") | |
| out_p = gr.JSON() | |
| btn_p.click(predict_api, [sym_p, days], out_p, api_name="predict") | |
| with gr.Tab("Hello"): | |
| who = gr.Textbox(label="Name", value="world") | |
| btn_h = gr.Button("Say Hello") | |
| out_h = gr.JSON() | |
| btn_h.click(hello_api, [who], out_h, api_name="hello") | |
| if __name__ == "__main__": | |
| demo.launch() | |