hackfest_biLSTM / app.py
Div0013's picture
cache examples set false
0058845
# app.py
import gradio as gr
import json
import traceback
# Import necessary functions and the model class from your original script
# Make sure predict_stock_prices.py is in the same directory
from predict_stock_prices import (
BiLSTMModel, # Need to import the class for joblib/torch to load model correctly
predict_stock_prices,
batch_predict_to_json # Assuming this function takes the list and paths
)
# --- Model Configuration ---
# These paths should correspond to the files uploaded to your Hugging Face Space
MODEL_PATH = "bilstm_stock_model.pth"
SCALER_PATH = "scaler_diff.pkl"
METADATA_PATH = "model_metadata.pkl"
# --- Gradio Interface Function ---
def run_prediction(ticker_string):
"""
Takes a comma-separated string of tickers, runs prediction,
and returns the result as a JSON object or error string.
"""
if not ticker_string:
return {"error": "Please enter at least one ticker symbol."}
# Split string into a list of tickers, removing whitespace
tickers = [ticker.strip().upper() for ticker in ticker_string.split(',') if ticker.strip()]
if not tickers:
return {"error": "No valid ticker symbols entered."}
print(f"Received request for tickers: {tickers}") # Log received tickers
try:
# Call your existing batch prediction function
# It already returns a dictionary suitable for JSON output
predictions = batch_predict_to_json(
ticker_symbols=tickers,
model_path=MODEL_PATH,
scaler_path=SCALER_PATH,
metadata_path=METADATA_PATH
)
print(f"Prediction successful for: {list(predictions.keys())}") # Log success
# Check for errors within the prediction results
errors = {k:v for k,v in predictions.items() if isinstance(v, dict) and 'error' in v}
if errors:
print(f"Errors occurred during prediction: {errors}") # Log errors
return predictions # Return the entire dictionary
except FileNotFoundError as e:
print(f"Error: Model file not found - {e}")
return {"error": f"Required file not found: {e}. Ensure model, scaler, and metadata files are uploaded correctly."}
except Exception as e:
print(f"An unexpected error occurred: {e}")
traceback.print_exc() # Print detailed traceback to logs
return {"error": f"An unexpected error occurred: {str(e)}"}
# --- Build Gradio Interface ---
# Use Markdown for a richer description
description = """
## BiLSTM Stock Price Predictor (-15y / +15y)
Enter one or more stock ticker symbols (e.g., `AAPL`, `MSFT`, `GOOGL`), separated by commas.
The model will fetch historical data, predict future prices for the next 15 years using a BiLSTM model combined with Geometric Brownian Motion (GBM),
and return the historical data for the past 15 years (or less if unavailable) combined with the predictions.
**Note:**
* Predictions are based on historical 'Close' prices and involve inherent uncertainty. **This is not financial advice.**
* Fetching data and running predictions might take a moment, especially for multiple tickers.
* Ensure ticker symbols are valid on Yahoo Finance.
"""
iface = gr.Interface(
fn=run_prediction,
inputs=gr.Textbox(
lines=1,
placeholder="Enter Ticker Symbols (e.g., AAPL, MSFT, GOOGL)",
label="Ticker Symbols (comma-separated)"
),
outputs=gr.JSON(label="Prediction Results (Historical + Future Prices)"),
title="Stock Price Prediction",
description=description,
examples=[["AAPL"], ["MSFT,GOOGL,NVDA"]],
cache_examples=False,
allow_flagging='never' # Optional: Disable flagging
)
# --- Launch the App ---
if __name__ == "__main__":
iface.launch() # Share=True is not needed when deploying on Spaces