amitke commited on
Commit
7fd066b
·
1 Parent(s): b53cc5c

fix @spaces.GPU

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. apptest.py +0 -58
app.py CHANGED
@@ -16,6 +16,7 @@ def train_api(symbol, seq_len=60, epochs=5, batch_size=32, start="", end=""):
16
  def test_api(symbol):
17
  return eval_fn(symbol)
18
 
 
19
  def predict_api(symbol, days=1):
20
  return predict_fn(symbol, n_days=int(days))
21
 
 
16
  def test_api(symbol):
17
  return eval_fn(symbol)
18
 
19
+ @spaces.GPU
20
  def predict_api(symbol, days=1):
21
  return predict_fn(symbol, n_days=int(days))
22
 
apptest.py DELETED
@@ -1,58 +0,0 @@
1
- import gradio as gr
2
- from train import train as train_fn
3
- from testing import evaluate as eval_fn
4
- from inference import predict_next as predict_fn
5
-
6
- def train_api(symbol, seq_len=60, epochs=5, batch_size=32, start="", end=""):
7
- return train_fn(
8
- symbol,
9
- seq_len=int(seq_len),
10
- epochs=int(epochs),
11
- batch_size=int(batch_size),
12
- start=start or None,
13
- end=end or None,
14
- )
15
-
16
- def test_api(symbol):
17
- return eval_fn(symbol)
18
-
19
- def predict_api(symbol, days=1):
20
- return predict_fn(symbol, n_days=int(days))
21
-
22
- def hello_api(name="world"):
23
- return {"message": f"hello {name}"}
24
-
25
- with gr.Blocks() as demo:
26
- gr.Markdown("## LSTM Stock Predictor (PyTorch • Train / Test / Predict)")
27
- with gr.Tab("Train"):
28
- sym_t = gr.Textbox(label="Symbol", value="AAPL")
29
- seq = gr.Number(label="Seq length", value=60, precision=0)
30
- ep = gr.Number(label="Epochs", value=5, precision=0)
31
- bs = gr.Number(label="Batch size", value=32, precision=0)
32
- start = gr.Textbox(label="Start (YYYY-MM-DD)", placeholder="optional")
33
- end = gr.Textbox(label="End (YYYY-MM-DD)", placeholder="optional")
34
- btn_t = gr.Button("Train")
35
- out_t = gr.JSON()
36
- btn_t.click(train_api, [sym_t, seq, ep, bs, start, end], out_t, api_name="train")
37
-
38
- with gr.Tab("Test"):
39
- sym_e = gr.Textbox(label="Symbol", value="AAPL")
40
- btn_e = gr.Button("Run Test")
41
- out_e = gr.JSON()
42
- btn_e.click(test_api, [sym_e], out_e, api_name="test")
43
-
44
- with gr.Tab("Predict"):
45
- sym_p = gr.Textbox(label="Symbol", value="AAPL")
46
- days = gr.Number(label="Days to predict", value=1, precision=0)
47
- btn_p = gr.Button("Predict")
48
- out_p = gr.JSON()
49
- btn_p.click(predict_api, [sym_p, days], out_p, api_name="predict")
50
-
51
- with gr.Tab("Hello"):
52
- who = gr.Textbox(label="Name", value="world")
53
- btn_h = gr.Button("Say Hello")
54
- out_h = gr.JSON()
55
- btn_h.click(hello_api, [who], out_h, api_name="hello")
56
-
57
- if __name__ == "__main__":
58
- demo.launch()