import gradio as gr import yfinance as yf import numpy as np import pandas as pd import tensorflow as tf import matplotlib.pyplot as plt from sklearn.preprocessing import MinMaxScaler # Load your pre-trained Keras model model = tf.keras.models.load_model("./best.keras") # scale the data def create_scaler(df): scaler = MinMaxScaler(feature_range=(0,1)) scaled_df = scaler.fit_transform(df['Close'].values.reshape(-1, 1)) return scaler, scaled_df # create input output sequence def create_sequence(scaled_df): X, y = [], [] window = 60 n_future = 1 for i in range(len(scaled_df) - window - n_future - 1): X.append(scaled_df[i:i+window]) y.append(scaled_df[i+window+n_future]) X = np.array(X) y = np.array(y) return X, y def fetch_and_predict(ticker, period): # Fetch historical stock data using yfinance try: df = yf.download(ticker, period=period) if isinstance(df.columns, pd.MultiIndex): df.columns = df.columns.get_level_values(0) except Exception as e: print("check 2") return f"Error downloading data: {e}" # Check if we have enough data for predictions if df.shape[0] < 60: return "Not enough data for predictions. Please select a longer period." # prepare data scaler, df = create_scaler(df) X, y = create_sequence(df) # Predicting stock prices try: print("fine") yhat = model.predict(X) except Exception as e: return f"Error during prediction: {e}" # Plot the predicted prices plt.figure(figsize=(14, 7)) plt.plot(y, label='Actual Prices') plt.plot(yhat, label='Predicted Prices') plt.title(f'Stock Price Prediction (LSTM) - [{str(ticker)}]') plt.xlabel('Time') plt.ylabel('Stock Price') plt.legend() plt.xticks(rotation=45) return plt.gcf() interface = gr.Interface( fn=fetch_and_predict, inputs=[ gr.Textbox(label="Stock Ticker", placeholder="Enter stock ticker (e.g., DAL, AAPL)"), gr.Textbox(label="Period", placeholder="Enter period (e.g., '1y')") ], outputs=gr.Plot(), live=False, allow_flagging="never", title="Stock Price Prediction", description="Enter the stock ticker and period, then click the button to fetch data and predict prices.", theme="huggingface", ) interface.launch()