File size: 3,808 Bytes
a24984d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0058845
a24984d
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# app.py
import gradio as gr
import json
import traceback

# Import necessary functions and the model class from your original script
# Make sure predict_stock_prices.py is in the same directory
from predict_stock_prices import (
    BiLSTMModel, # Need to import the class for joblib/torch to load model correctly
    predict_stock_prices,
    batch_predict_to_json # Assuming this function takes the list and paths
)

# --- Model Configuration ---
# These paths should correspond to the files uploaded to your Hugging Face Space
MODEL_PATH = "bilstm_stock_model.pth"
SCALER_PATH = "scaler_diff.pkl"
METADATA_PATH = "model_metadata.pkl"

# --- Gradio Interface Function ---
def run_prediction(ticker_string):
    """
    Takes a comma-separated string of tickers, runs prediction,
    and returns the result as a JSON object or error string.
    """
    if not ticker_string:
        return {"error": "Please enter at least one ticker symbol."}

    # Split string into a list of tickers, removing whitespace
    tickers = [ticker.strip().upper() for ticker in ticker_string.split(',') if ticker.strip()]

    if not tickers:
        return {"error": "No valid ticker symbols entered."}

    print(f"Received request for tickers: {tickers}") # Log received tickers

    try:
        # Call your existing batch prediction function
        # It already returns a dictionary suitable for JSON output
        predictions = batch_predict_to_json(
            ticker_symbols=tickers,
            model_path=MODEL_PATH,
            scaler_path=SCALER_PATH,
            metadata_path=METADATA_PATH
        )
        print(f"Prediction successful for: {list(predictions.keys())}") # Log success
        # Check for errors within the prediction results
        errors = {k:v for k,v in predictions.items() if isinstance(v, dict) and 'error' in v}
        if errors:
             print(f"Errors occurred during prediction: {errors}") # Log errors
        return predictions # Return the entire dictionary

    except FileNotFoundError as e:
         print(f"Error: Model file not found - {e}")
         return {"error": f"Required file not found: {e}. Ensure model, scaler, and metadata files are uploaded correctly."}
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        traceback.print_exc() # Print detailed traceback to logs
        return {"error": f"An unexpected error occurred: {str(e)}"}

# --- Build Gradio Interface ---
# Use Markdown for a richer description
description = """
## BiLSTM Stock Price Predictor (-15y / +15y)

Enter one or more stock ticker symbols (e.g., `AAPL`, `MSFT`, `GOOGL`), separated by commas.
The model will fetch historical data, predict future prices for the next 15 years using a BiLSTM model combined with Geometric Brownian Motion (GBM),
and return the historical data for the past 15 years (or less if unavailable) combined with the predictions.

**Note:**
*   Predictions are based on historical 'Close' prices and involve inherent uncertainty. **This is not financial advice.**
*   Fetching data and running predictions might take a moment, especially for multiple tickers.
*   Ensure ticker symbols are valid on Yahoo Finance.
"""

iface = gr.Interface(
    fn=run_prediction,
    inputs=gr.Textbox(
        lines=1,
        placeholder="Enter Ticker Symbols (e.g., AAPL, MSFT, GOOGL)",
        label="Ticker Symbols (comma-separated)"
    ),
    outputs=gr.JSON(label="Prediction Results (Historical + Future Prices)"),
    title="Stock Price Prediction",
    description=description,
    examples=[["AAPL"], ["MSFT,GOOGL,NVDA"]],
    cache_examples=False,
    allow_flagging='never' # Optional: Disable flagging
)

# --- Launch the App ---
if __name__ == "__main__":
    iface.launch() # Share=True is not needed when deploying on Spaces