Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from inference import infer # make sure inference.py is in the same folder | |
| def parse_prices(text, csv_file): | |
| # Priority: uploaded CSV > pasted text | |
| if csv_file is not None: | |
| try: | |
| df = pd.read_csv(csv_file.name) | |
| if "Close" in df.columns: | |
| prices = df["Close"].dropna().astype(float).tolist() | |
| return prices | |
| # if no Close column fall back to first numeric column | |
| for c in df.columns: | |
| if pd.api.types.is_numeric_dtype(df[c]): | |
| return df[c].dropna().astype(float).tolist() | |
| return [] | |
| except Exception as e: | |
| return [] | |
| if text: | |
| # accept comma or newline separated floats | |
| tokens = [t.strip() for t in text.replace("\n", ",").split(",") if t.strip() != ""] | |
| try: | |
| return [float(t) for t in tokens] | |
| except: | |
| return [] | |
| return [] | |
| def run_forecast(model_type, prices_text, csv_file, steps, epochs, plot_history_len): | |
| prices = parse_prices(prices_text, csv_file) | |
| if not prices: | |
| return "ERROR: No valid input prices found. Upload a CSV with a Close column or paste comma-separated prices.", None | |
| # ensure list length is reasonable | |
| if len(prices) < 2 and model_type.lower() == "arima": | |
| return "ERROR: Need at least 2 prices for ARIMA.", None | |
| # Call infer (inference.infer should accept epochs param) | |
| try: | |
| preds = infer(model_type, prices, steps=steps, epochs=epochs) | |
| except Exception as e: | |
| return f"ERROR during inference: {e}", None | |
| # Build a simple plot: last N history points + forecast points | |
| hist_len = min(plot_history_len, len(prices)) | |
| hist_x = list(range(-hist_len, 0)) | |
| hist_y = prices[-hist_len:] | |
| forecast_x = list(range(0, len(preds))) | |
| forecast_y = preds | |
| fig, ax = plt.subplots(figsize=(8,4)) | |
| ax.plot(hist_x, hist_y, marker="o", label="History (last {})".format(hist_len)) | |
| # plot forecast continuing after history | |
| ax.plot([hist_x[-1]] + [hist_x[-1] + i + 1 for i in forecast_x], | |
| [hist_y[-1]] + forecast_y, marker="o", linestyle="--", label="Forecast") | |
| ax.set_xlabel("Time (relative)") | |
| ax.set_ylabel("Price") | |
| ax.legend() | |
| ax.grid(True) | |
| plt.tight_layout() | |
| # return predictions (list) and the matplotlib figure | |
| preds_text = {"model": model_type, "predictions": preds} | |
| return str(preds_text), fig | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Stock Forecast — ARIMA vs LSTM\nUpload a CSV with a `Close` column or paste comma-separated closing prices.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_type = gr.Radio(choices=["arima", "lstm"], value="arima", label="Model") | |
| steps = gr.Slider(minimum=1, maximum=30, step=1, value=5, label="Forecast steps") | |
| epochs = gr.Slider(minimum=1, maximum=100, step=1, value=5, label="LSTM training epochs (only used for LSTM)") | |
| plot_history_len = gr.Slider(minimum=10, maximum=500, step=10, value=100, label="History length to plot") | |
| csv_file = gr.File(label="Upload CSV (optional, must include Close column)") | |
| prices_text = gr.Textbox(lines=4, placeholder="Or paste comma-separated prices (e.g. 100,101.5,102)", label="Paste prices (optional)") | |
| run_btn = gr.Button("Run forecast") | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox(label="Predictions (JSON string)") | |
| output_plot = gr.Plot(label="History + Forecast Plot") | |
| run_btn.click(fn=run_forecast, | |
| inputs=[model_type, prices_text, csv_file, steps, epochs, plot_history_len], | |
| outputs=[output_text, output_plot]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |