Ti-sha's picture
Create app.py
7ebf996 verified
# app.py
import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from inference import infer # make sure inference.py is in the same folder
def parse_prices(text, csv_file):
# Priority: uploaded CSV > pasted text
if csv_file is not None:
try:
df = pd.read_csv(csv_file.name)
if "Close" in df.columns:
prices = df["Close"].dropna().astype(float).tolist()
return prices
# if no Close column fall back to first numeric column
for c in df.columns:
if pd.api.types.is_numeric_dtype(df[c]):
return df[c].dropna().astype(float).tolist()
return []
except Exception as e:
return []
if text:
# accept comma or newline separated floats
tokens = [t.strip() for t in text.replace("\n", ",").split(",") if t.strip() != ""]
try:
return [float(t) for t in tokens]
except:
return []
return []
def run_forecast(model_type, prices_text, csv_file, steps, epochs, plot_history_len):
prices = parse_prices(prices_text, csv_file)
if not prices:
return "ERROR: No valid input prices found. Upload a CSV with a Close column or paste comma-separated prices.", None
# ensure list length is reasonable
if len(prices) < 2 and model_type.lower() == "arima":
return "ERROR: Need at least 2 prices for ARIMA.", None
# Call infer (inference.infer should accept epochs param)
try:
preds = infer(model_type, prices, steps=steps, epochs=epochs)
except Exception as e:
return f"ERROR during inference: {e}", None
# Build a simple plot: last N history points + forecast points
hist_len = min(plot_history_len, len(prices))
hist_x = list(range(-hist_len, 0))
hist_y = prices[-hist_len:]
forecast_x = list(range(0, len(preds)))
forecast_y = preds
fig, ax = plt.subplots(figsize=(8,4))
ax.plot(hist_x, hist_y, marker="o", label="History (last {})".format(hist_len))
# plot forecast continuing after history
ax.plot([hist_x[-1]] + [hist_x[-1] + i + 1 for i in forecast_x],
[hist_y[-1]] + forecast_y, marker="o", linestyle="--", label="Forecast")
ax.set_xlabel("Time (relative)")
ax.set_ylabel("Price")
ax.legend()
ax.grid(True)
plt.tight_layout()
# return predictions (list) and the matplotlib figure
preds_text = {"model": model_type, "predictions": preds}
return str(preds_text), fig
with gr.Blocks() as demo:
gr.Markdown("# Stock Forecast — ARIMA vs LSTM\nUpload a CSV with a `Close` column or paste comma-separated closing prices.")
with gr.Row():
with gr.Column(scale=1):
model_type = gr.Radio(choices=["arima", "lstm"], value="arima", label="Model")
steps = gr.Slider(minimum=1, maximum=30, step=1, value=5, label="Forecast steps")
epochs = gr.Slider(minimum=1, maximum=100, step=1, value=5, label="LSTM training epochs (only used for LSTM)")
plot_history_len = gr.Slider(minimum=10, maximum=500, step=10, value=100, label="History length to plot")
csv_file = gr.File(label="Upload CSV (optional, must include Close column)")
prices_text = gr.Textbox(lines=4, placeholder="Or paste comma-separated prices (e.g. 100,101.5,102)", label="Paste prices (optional)")
run_btn = gr.Button("Run forecast")
with gr.Column(scale=1):
output_text = gr.Textbox(label="Predictions (JSON string)")
output_plot = gr.Plot(label="History + Forecast Plot")
run_btn.click(fn=run_forecast,
inputs=[model_type, prices_text, csv_file, steps, epochs, plot_history_len],
outputs=[output_text, output_plot])
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)