Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import json | |
| import traceback | |
| # Import necessary functions and the model class from your original script | |
| # Make sure predict_stock_prices.py is in the same directory | |
| from predict_stock_prices import ( | |
| BiLSTMModel, # Need to import the class for joblib/torch to load model correctly | |
| predict_stock_prices, | |
| batch_predict_to_json # Assuming this function takes the list and paths | |
| ) | |
| # --- Model Configuration --- | |
| # These paths should correspond to the files uploaded to your Hugging Face Space | |
| MODEL_PATH = "bilstm_stock_model.pth" | |
| SCALER_PATH = "scaler_diff.pkl" | |
| METADATA_PATH = "model_metadata.pkl" | |
| # --- Gradio Interface Function --- | |
| def run_prediction(ticker_string): | |
| """ | |
| Takes a comma-separated string of tickers, runs prediction, | |
| and returns the result as a JSON object or error string. | |
| """ | |
| if not ticker_string: | |
| return {"error": "Please enter at least one ticker symbol."} | |
| # Split string into a list of tickers, removing whitespace | |
| tickers = [ticker.strip().upper() for ticker in ticker_string.split(',') if ticker.strip()] | |
| if not tickers: | |
| return {"error": "No valid ticker symbols entered."} | |
| print(f"Received request for tickers: {tickers}") # Log received tickers | |
| try: | |
| # Call your existing batch prediction function | |
| # It already returns a dictionary suitable for JSON output | |
| predictions = batch_predict_to_json( | |
| ticker_symbols=tickers, | |
| model_path=MODEL_PATH, | |
| scaler_path=SCALER_PATH, | |
| metadata_path=METADATA_PATH | |
| ) | |
| print(f"Prediction successful for: {list(predictions.keys())}") # Log success | |
| # Check for errors within the prediction results | |
| errors = {k:v for k,v in predictions.items() if isinstance(v, dict) and 'error' in v} | |
| if errors: | |
| print(f"Errors occurred during prediction: {errors}") # Log errors | |
| return predictions # Return the entire dictionary | |
| except FileNotFoundError as e: | |
| print(f"Error: Model file not found - {e}") | |
| return {"error": f"Required file not found: {e}. Ensure model, scaler, and metadata files are uploaded correctly."} | |
| except Exception as e: | |
| print(f"An unexpected error occurred: {e}") | |
| traceback.print_exc() # Print detailed traceback to logs | |
| return {"error": f"An unexpected error occurred: {str(e)}"} | |
| # --- Build Gradio Interface --- | |
| # Use Markdown for a richer description | |
| description = """ | |
| ## BiLSTM Stock Price Predictor (-15y / +15y) | |
| Enter one or more stock ticker symbols (e.g., `AAPL`, `MSFT`, `GOOGL`), separated by commas. | |
| The model will fetch historical data, predict future prices for the next 15 years using a BiLSTM model combined with Geometric Brownian Motion (GBM), | |
| and return the historical data for the past 15 years (or less if unavailable) combined with the predictions. | |
| **Note:** | |
| * Predictions are based on historical 'Close' prices and involve inherent uncertainty. **This is not financial advice.** | |
| * Fetching data and running predictions might take a moment, especially for multiple tickers. | |
| * Ensure ticker symbols are valid on Yahoo Finance. | |
| """ | |
| iface = gr.Interface( | |
| fn=run_prediction, | |
| inputs=gr.Textbox( | |
| lines=1, | |
| placeholder="Enter Ticker Symbols (e.g., AAPL, MSFT, GOOGL)", | |
| label="Ticker Symbols (comma-separated)" | |
| ), | |
| outputs=gr.JSON(label="Prediction Results (Historical + Future Prices)"), | |
| title="Stock Price Prediction", | |
| description=description, | |
| examples=[["AAPL"], ["MSFT,GOOGL,NVDA"]], | |
| cache_examples=False, | |
| allow_flagging='never' # Optional: Disable flagging | |
| ) | |
| # --- Launch the App --- | |
| if __name__ == "__main__": | |
| iface.launch() # Share=True is not needed when deploying on Spaces |