File size: 3,925 Bytes
7ebf996
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# app.py
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from inference import infer  # make sure inference.py is in the same folder

def parse_prices(text, csv_file):
    # Priority: uploaded CSV > pasted text
    if csv_file is not None:
        try:
            df = pd.read_csv(csv_file.name)
            if "Close" in df.columns:
                prices = df["Close"].dropna().astype(float).tolist()
                return prices
            # if no Close column fall back to first numeric column
            for c in df.columns:
                if pd.api.types.is_numeric_dtype(df[c]):
                    return df[c].dropna().astype(float).tolist()
            return []
        except Exception as e:
            return []
    if text:
        # accept comma or newline separated floats
        tokens = [t.strip() for t in text.replace("\n", ",").split(",") if t.strip() != ""]
        try:
            return [float(t) for t in tokens]
        except:
            return []
    return []

def run_forecast(model_type, prices_text, csv_file, steps, epochs, plot_history_len):
    prices = parse_prices(prices_text, csv_file)
    if not prices:
        return "ERROR: No valid input prices found. Upload a CSV with a Close column or paste comma-separated prices.", None

    # ensure list length is reasonable
    if len(prices) < 2 and model_type.lower() == "arima":
        return "ERROR: Need at least 2 prices for ARIMA.", None

    # Call infer (inference.infer should accept epochs param)
    try:
        preds = infer(model_type, prices, steps=steps, epochs=epochs)
    except Exception as e:
        return f"ERROR during inference: {e}", None

    # Build a simple plot: last N history points + forecast points
    hist_len = min(plot_history_len, len(prices))
    hist_x = list(range(-hist_len, 0))
    hist_y = prices[-hist_len:]

    forecast_x = list(range(0, len(preds)))
    forecast_y = preds

    fig, ax = plt.subplots(figsize=(8,4))
    ax.plot(hist_x, hist_y, marker="o", label="History (last {})".format(hist_len))
    # plot forecast continuing after history
    ax.plot([hist_x[-1]] + [hist_x[-1] + i + 1 for i in forecast_x],
            [hist_y[-1]] + forecast_y, marker="o", linestyle="--", label="Forecast")
    ax.set_xlabel("Time (relative)")
    ax.set_ylabel("Price")
    ax.legend()
    ax.grid(True)
    plt.tight_layout()

    # return predictions (list) and the matplotlib figure
    preds_text = {"model": model_type, "predictions": preds}
    return str(preds_text), fig

with gr.Blocks() as demo:
    gr.Markdown("# Stock Forecast — ARIMA vs LSTM\nUpload a CSV with a `Close` column or paste comma-separated closing prices.")

    with gr.Row():
        with gr.Column(scale=1):
            model_type = gr.Radio(choices=["arima", "lstm"], value="arima", label="Model")
            steps = gr.Slider(minimum=1, maximum=30, step=1, value=5, label="Forecast steps")
            epochs = gr.Slider(minimum=1, maximum=100, step=1, value=5, label="LSTM training epochs (only used for LSTM)")
            plot_history_len = gr.Slider(minimum=10, maximum=500, step=10, value=100, label="History length to plot")

            csv_file = gr.File(label="Upload CSV (optional, must include Close column)")
            prices_text = gr.Textbox(lines=4, placeholder="Or paste comma-separated prices (e.g. 100,101.5,102)", label="Paste prices (optional)")

            run_btn = gr.Button("Run forecast")
        with gr.Column(scale=1):
            output_text = gr.Textbox(label="Predictions (JSON string)")
            output_plot = gr.Plot(label="History + Forecast Plot")

    run_btn.click(fn=run_forecast,
                  inputs=[model_type, prices_text, csv_file, steps, epochs, plot_history_len],
                  outputs=[output_text, output_plot])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)