File size: 2,109 Bytes
514c4c0
2e0442e
514c4c0
 
 
 
457b70c
514c4c0
 
 
 
 
 
 
 
 
 
 
 
 
7fd066b
514c4c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import gradio as gr
import spaces
from train import train as train_fn
from testing import evaluate as eval_fn
from inference import predict_next as predict_fn

@spaces.GPU
def train_api(symbol, seq_len=60, epochs=5, batch_size=32, start="", end=""):
    return train_fn(
        symbol,
        seq_len=int(seq_len),
        epochs=int(epochs),
        batch_size=int(batch_size),
        start=start or None,
        end=end or None,
    )

def test_api(symbol):
    return eval_fn(symbol)

@spaces.GPU
def predict_api(symbol, days=1):
    return predict_fn(symbol, n_days=int(days))

def hello_api(name="world"):
    return {"message": f"hello {name}"}

with gr.Blocks() as demo:
    gr.Markdown("## LSTM Stock Predictor (PyTorch • Train / Test / Predict)")
    with gr.Tab("Train"):
        sym_t = gr.Textbox(label="Symbol", value="AAPL")
        seq = gr.Number(label="Seq length", value=60, precision=0)
        ep = gr.Number(label="Epochs", value=5, precision=0)
        bs = gr.Number(label="Batch size", value=32, precision=0)
        start = gr.Textbox(label="Start (YYYY-MM-DD)", placeholder="optional")
        end = gr.Textbox(label="End (YYYY-MM-DD)", placeholder="optional")
        btn_t = gr.Button("Train")
        out_t = gr.JSON()
        btn_t.click(train_api, [sym_t, seq, ep, bs, start, end], out_t, api_name="train")

    with gr.Tab("Test"):
        sym_e = gr.Textbox(label="Symbol", value="AAPL")
        btn_e = gr.Button("Run Test")
        out_e = gr.JSON()
        btn_e.click(test_api, [sym_e], out_e, api_name="test")

    with gr.Tab("Predict"):
        sym_p = gr.Textbox(label="Symbol", value="AAPL")
        days = gr.Number(label="Days to predict", value=1, precision=0)
        btn_p = gr.Button("Predict")
        out_p = gr.JSON()
        btn_p.click(predict_api, [sym_p, days], out_p, api_name="predict")
    
    with gr.Tab("Hello"):
        who = gr.Textbox(label="Name", value="world")
        btn_h = gr.Button("Say Hello")
        out_h = gr.JSON()
        btn_h.click(hello_api, [who], out_h, api_name="hello")

if __name__ == "__main__":
    demo.launch()