# app.py import gradio as gr import json import traceback # Import necessary functions and the model class from your original script # Make sure predict_stock_prices.py is in the same directory from predict_stock_prices import ( BiLSTMModel, # Need to import the class for joblib/torch to load model correctly predict_stock_prices, batch_predict_to_json # Assuming this function takes the list and paths ) # --- Model Configuration --- # These paths should correspond to the files uploaded to your Hugging Face Space MODEL_PATH = "bilstm_stock_model.pth" SCALER_PATH = "scaler_diff.pkl" METADATA_PATH = "model_metadata.pkl" # --- Gradio Interface Function --- def run_prediction(ticker_string): """ Takes a comma-separated string of tickers, runs prediction, and returns the result as a JSON object or error string. """ if not ticker_string: return {"error": "Please enter at least one ticker symbol."} # Split string into a list of tickers, removing whitespace tickers = [ticker.strip().upper() for ticker in ticker_string.split(',') if ticker.strip()] if not tickers: return {"error": "No valid ticker symbols entered."} print(f"Received request for tickers: {tickers}") # Log received tickers try: # Call your existing batch prediction function # It already returns a dictionary suitable for JSON output predictions = batch_predict_to_json( ticker_symbols=tickers, model_path=MODEL_PATH, scaler_path=SCALER_PATH, metadata_path=METADATA_PATH ) print(f"Prediction successful for: {list(predictions.keys())}") # Log success # Check for errors within the prediction results errors = {k:v for k,v in predictions.items() if isinstance(v, dict) and 'error' in v} if errors: print(f"Errors occurred during prediction: {errors}") # Log errors return predictions # Return the entire dictionary except FileNotFoundError as e: print(f"Error: Model file not found - {e}") return {"error": f"Required file not found: {e}. Ensure model, scaler, and metadata files are uploaded correctly."} except Exception as e: print(f"An unexpected error occurred: {e}") traceback.print_exc() # Print detailed traceback to logs return {"error": f"An unexpected error occurred: {str(e)}"} # --- Build Gradio Interface --- # Use Markdown for a richer description description = """ ## BiLSTM Stock Price Predictor (-15y / +15y) Enter one or more stock ticker symbols (e.g., `AAPL`, `MSFT`, `GOOGL`), separated by commas. The model will fetch historical data, predict future prices for the next 15 years using a BiLSTM model combined with Geometric Brownian Motion (GBM), and return the historical data for the past 15 years (or less if unavailable) combined with the predictions. **Note:** * Predictions are based on historical 'Close' prices and involve inherent uncertainty. **This is not financial advice.** * Fetching data and running predictions might take a moment, especially for multiple tickers. * Ensure ticker symbols are valid on Yahoo Finance. """ iface = gr.Interface( fn=run_prediction, inputs=gr.Textbox( lines=1, placeholder="Enter Ticker Symbols (e.g., AAPL, MSFT, GOOGL)", label="Ticker Symbols (comma-separated)" ), outputs=gr.JSON(label="Prediction Results (Historical + Future Prices)"), title="Stock Price Prediction", description=description, examples=[["AAPL"], ["MSFT,GOOGL,NVDA"]], cache_examples=False, allow_flagging='never' # Optional: Disable flagging ) # --- Launch the App --- if __name__ == "__main__": iface.launch() # Share=True is not needed when deploying on Spaces