Ti-sha commited on
Commit
7ebf996
·
verified ·
1 Parent(s): e2680d8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import pandas as pd
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from inference import infer # make sure inference.py is in the same folder
7
+
8
+ def parse_prices(text, csv_file):
9
+ # Priority: uploaded CSV > pasted text
10
+ if csv_file is not None:
11
+ try:
12
+ df = pd.read_csv(csv_file.name)
13
+ if "Close" in df.columns:
14
+ prices = df["Close"].dropna().astype(float).tolist()
15
+ return prices
16
+ # if no Close column fall back to first numeric column
17
+ for c in df.columns:
18
+ if pd.api.types.is_numeric_dtype(df[c]):
19
+ return df[c].dropna().astype(float).tolist()
20
+ return []
21
+ except Exception as e:
22
+ return []
23
+ if text:
24
+ # accept comma or newline separated floats
25
+ tokens = [t.strip() for t in text.replace("\n", ",").split(",") if t.strip() != ""]
26
+ try:
27
+ return [float(t) for t in tokens]
28
+ except:
29
+ return []
30
+ return []
31
+
32
+ def run_forecast(model_type, prices_text, csv_file, steps, epochs, plot_history_len):
33
+ prices = parse_prices(prices_text, csv_file)
34
+ if not prices:
35
+ return "ERROR: No valid input prices found. Upload a CSV with a Close column or paste comma-separated prices.", None
36
+
37
+ # ensure list length is reasonable
38
+ if len(prices) < 2 and model_type.lower() == "arima":
39
+ return "ERROR: Need at least 2 prices for ARIMA.", None
40
+
41
+ # Call infer (inference.infer should accept epochs param)
42
+ try:
43
+ preds = infer(model_type, prices, steps=steps, epochs=epochs)
44
+ except Exception as e:
45
+ return f"ERROR during inference: {e}", None
46
+
47
+ # Build a simple plot: last N history points + forecast points
48
+ hist_len = min(plot_history_len, len(prices))
49
+ hist_x = list(range(-hist_len, 0))
50
+ hist_y = prices[-hist_len:]
51
+
52
+ forecast_x = list(range(0, len(preds)))
53
+ forecast_y = preds
54
+
55
+ fig, ax = plt.subplots(figsize=(8,4))
56
+ ax.plot(hist_x, hist_y, marker="o", label="History (last {})".format(hist_len))
57
+ # plot forecast continuing after history
58
+ ax.plot([hist_x[-1]] + [hist_x[-1] + i + 1 for i in forecast_x],
59
+ [hist_y[-1]] + forecast_y, marker="o", linestyle="--", label="Forecast")
60
+ ax.set_xlabel("Time (relative)")
61
+ ax.set_ylabel("Price")
62
+ ax.legend()
63
+ ax.grid(True)
64
+ plt.tight_layout()
65
+
66
+ # return predictions (list) and the matplotlib figure
67
+ preds_text = {"model": model_type, "predictions": preds}
68
+ return str(preds_text), fig
69
+
70
+ with gr.Blocks() as demo:
71
+ gr.Markdown("# Stock Forecast — ARIMA vs LSTM\nUpload a CSV with a `Close` column or paste comma-separated closing prices.")
72
+
73
+ with gr.Row():
74
+ with gr.Column(scale=1):
75
+ model_type = gr.Radio(choices=["arima", "lstm"], value="arima", label="Model")
76
+ steps = gr.Slider(minimum=1, maximum=30, step=1, value=5, label="Forecast steps")
77
+ epochs = gr.Slider(minimum=1, maximum=100, step=1, value=5, label="LSTM training epochs (only used for LSTM)")
78
+ plot_history_len = gr.Slider(minimum=10, maximum=500, step=10, value=100, label="History length to plot")
79
+
80
+ csv_file = gr.File(label="Upload CSV (optional, must include Close column)")
81
+ prices_text = gr.Textbox(lines=4, placeholder="Or paste comma-separated prices (e.g. 100,101.5,102)", label="Paste prices (optional)")
82
+
83
+ run_btn = gr.Button("Run forecast")
84
+ with gr.Column(scale=1):
85
+ output_text = gr.Textbox(label="Predictions (JSON string)")
86
+ output_plot = gr.Plot(label="History + Forecast Plot")
87
+
88
+ run_btn.click(fn=run_forecast,
89
+ inputs=[model_type, prices_text, csv_file, steps, epochs, plot_history_len],
90
+ outputs=[output_text, output_plot])
91
+
92
+ if __name__ == "__main__":
93
+ demo.launch(server_name="0.0.0.0", server_port=7860)