Div0013 commited on
Commit
9f578b5
·
1 Parent(s): 555067f

Add application file

Browse files
Files changed (1) hide show
  1. app.py +92 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import json
4
+ import traceback
5
+
6
+ # Import necessary functions and the model class from your original script
7
+ # Make sure predict_stock_prices.py is in the same directory
8
+ from predict_stock_prices import (
9
+ BiLSTMModel, # Need to import the class for joblib/torch to load model correctly
10
+ predict_stock_prices,
11
+ batch_predict_to_json # Assuming this function takes the list and paths
12
+ )
13
+
14
+ # --- Model Configuration ---
15
+ # These paths should correspond to the files uploaded to your Hugging Face Space
16
+ MODEL_PATH = "bilstm_stock_model.pth"
17
+ SCALER_PATH = "scaler_diff.pkl"
18
+ METADATA_PATH = "model_metadata.pkl"
19
+
20
+ # --- Gradio Interface Function ---
21
+ def run_prediction(ticker_string):
22
+ """
23
+ Takes a comma-separated string of tickers, runs prediction,
24
+ and returns the result as a JSON object or error string.
25
+ """
26
+ if not ticker_string:
27
+ return {"error": "Please enter at least one ticker symbol."}
28
+
29
+ # Split string into a list of tickers, removing whitespace
30
+ tickers = [ticker.strip().upper() for ticker in ticker_string.split(',') if ticker.strip()]
31
+
32
+ if not tickers:
33
+ return {"error": "No valid ticker symbols entered."}
34
+
35
+ print(f"Received request for tickers: {tickers}") # Log received tickers
36
+
37
+ try:
38
+ # Call your existing batch prediction function
39
+ # It already returns a dictionary suitable for JSON output
40
+ predictions = batch_predict_to_json(
41
+ ticker_symbols=tickers,
42
+ model_path=MODEL_PATH,
43
+ scaler_path=SCALER_PATH,
44
+ metadata_path=METADATA_PATH
45
+ )
46
+ print(f"Prediction successful for: {list(predictions.keys())}") # Log success
47
+ # Check for errors within the prediction results
48
+ errors = {k:v for k,v in predictions.items() if isinstance(v, dict) and 'error' in v}
49
+ if errors:
50
+ print(f"Errors occurred during prediction: {errors}") # Log errors
51
+ return predictions # Return the entire dictionary
52
+
53
+ except FileNotFoundError as e:
54
+ print(f"Error: Model file not found - {e}")
55
+ return {"error": f"Required file not found: {e}. Ensure model, scaler, and metadata files are uploaded correctly."}
56
+ except Exception as e:
57
+ print(f"An unexpected error occurred: {e}")
58
+ traceback.print_exc() # Print detailed traceback to logs
59
+ return {"error": f"An unexpected error occurred: {str(e)}"}
60
+
61
+ # --- Build Gradio Interface ---
62
+ # Use Markdown for a richer description
63
+ description = """
64
+ ## BiLSTM Stock Price Predictor (-15y / +15y)
65
+
66
+ Enter one or more stock ticker symbols (e.g., `AAPL`, `MSFT`, `GOOGL`), separated by commas.
67
+ The model will fetch historical data, predict future prices for the next 15 years using a BiLSTM model combined with Geometric Brownian Motion (GBM),
68
+ and return the historical data for the past 15 years (or less if unavailable) combined with the predictions.
69
+
70
+ **Note:**
71
+ * Predictions are based on historical 'Close' prices and involve inherent uncertainty. **This is not financial advice.**
72
+ * Fetching data and running predictions might take a moment, especially for multiple tickers.
73
+ * Ensure ticker symbols are valid on Yahoo Finance.
74
+ """
75
+
76
+ iface = gr.Interface(
77
+ fn=run_prediction,
78
+ inputs=gr.Textbox(
79
+ lines=1,
80
+ placeholder="Enter Ticker Symbols (e.g., AAPL, MSFT, GOOGL)",
81
+ label="Ticker Symbols (comma-separated)"
82
+ ),
83
+ outputs=gr.JSON(label="Prediction Results (Historical + Future Prices)"),
84
+ title="Stock Price Prediction",
85
+ description=description,
86
+ examples=[["AAPL"], ["MSFT,GOOGL,NVDA"]],
87
+ allow_flagging='never' # Optional: Disable flagging
88
+ )
89
+
90
+ # --- Launch the App ---
91
+ if __name__ == "__main__":
92
+ iface.launch() # Share=True is not needed when deploying on Spaces