# app.py import gradio as gr import pandas as pd import numpy as np import matplotlib.pyplot as plt from inference import infer # make sure inference.py is in the same folder def parse_prices(text, csv_file): # Priority: uploaded CSV > pasted text if csv_file is not None: try: df = pd.read_csv(csv_file.name) if "Close" in df.columns: prices = df["Close"].dropna().astype(float).tolist() return prices # if no Close column fall back to first numeric column for c in df.columns: if pd.api.types.is_numeric_dtype(df[c]): return df[c].dropna().astype(float).tolist() return [] except Exception as e: return [] if text: # accept comma or newline separated floats tokens = [t.strip() for t in text.replace("\n", ",").split(",") if t.strip() != ""] try: return [float(t) for t in tokens] except: return [] return [] def run_forecast(model_type, prices_text, csv_file, steps, epochs, plot_history_len): prices = parse_prices(prices_text, csv_file) if not prices: return "ERROR: No valid input prices found. Upload a CSV with a Close column or paste comma-separated prices.", None # ensure list length is reasonable if len(prices) < 2 and model_type.lower() == "arima": return "ERROR: Need at least 2 prices for ARIMA.", None # Call infer (inference.infer should accept epochs param) try: preds = infer(model_type, prices, steps=steps, epochs=epochs) except Exception as e: return f"ERROR during inference: {e}", None # Build a simple plot: last N history points + forecast points hist_len = min(plot_history_len, len(prices)) hist_x = list(range(-hist_len, 0)) hist_y = prices[-hist_len:] forecast_x = list(range(0, len(preds))) forecast_y = preds fig, ax = plt.subplots(figsize=(8,4)) ax.plot(hist_x, hist_y, marker="o", label="History (last {})".format(hist_len)) # plot forecast continuing after history ax.plot([hist_x[-1]] + [hist_x[-1] + i + 1 for i in forecast_x], [hist_y[-1]] + forecast_y, marker="o", linestyle="--", label="Forecast") ax.set_xlabel("Time (relative)") ax.set_ylabel("Price") ax.legend() ax.grid(True) plt.tight_layout() # return predictions (list) and the matplotlib figure preds_text = {"model": model_type, "predictions": preds} return str(preds_text), fig with gr.Blocks() as demo: gr.Markdown("# Stock Forecast — ARIMA vs LSTM\nUpload a CSV with a `Close` column or paste comma-separated closing prices.") with gr.Row(): with gr.Column(scale=1): model_type = gr.Radio(choices=["arima", "lstm"], value="arima", label="Model") steps = gr.Slider(minimum=1, maximum=30, step=1, value=5, label="Forecast steps") epochs = gr.Slider(minimum=1, maximum=100, step=1, value=5, label="LSTM training epochs (only used for LSTM)") plot_history_len = gr.Slider(minimum=10, maximum=500, step=10, value=100, label="History length to plot") csv_file = gr.File(label="Upload CSV (optional, must include Close column)") prices_text = gr.Textbox(lines=4, placeholder="Or paste comma-separated prices (e.g. 100,101.5,102)", label="Paste prices (optional)") run_btn = gr.Button("Run forecast") with gr.Column(scale=1): output_text = gr.Textbox(label="Predictions (JSON string)") output_plot = gr.Plot(label="History + Forecast Plot") run_btn.click(fn=run_forecast, inputs=[model_type, prices_text, csv_file, steps, epochs, plot_history_len], outputs=[output_text, output_plot]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)