Spaces:
Sleeping
Sleeping
File size: 2,383 Bytes
0c0d46a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import gradio as gr
import yfinance as yf
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
# Load your pre-trained Keras model
model = tf.keras.models.load_model("./best.keras")
# scale the data
def create_scaler(df):
scaler = MinMaxScaler(feature_range=(0,1))
scaled_df = scaler.fit_transform(df['Close'].values.reshape(-1, 1))
return scaler, scaled_df
# create input output sequence
def create_sequence(scaled_df):
X, y = [], []
window = 60
n_future = 1
for i in range(len(scaled_df) - window - n_future - 1):
X.append(scaled_df[i:i+window])
y.append(scaled_df[i+window+n_future])
X = np.array(X)
y = np.array(y)
return X, y
def fetch_and_predict(ticker, period):
# Fetch historical stock data using yfinance
try:
df = yf.download(ticker, period=period)
if isinstance(df.columns, pd.MultiIndex):
df.columns = df.columns.get_level_values(0)
except Exception as e:
print("check 2")
return f"Error downloading data: {e}"
# Check if we have enough data for predictions
if df.shape[0] < 60:
return "Not enough data for predictions. Please select a longer period."
# prepare data
scaler, df = create_scaler(df)
X, y = create_sequence(df)
# Predicting stock prices
try:
print("fine")
yhat = model.predict(X)
except Exception as e:
return f"Error during prediction: {e}"
# Plot the predicted prices
plt.figure(figsize=(14, 7))
plt.plot(y, label='Actual Prices')
plt.plot(yhat, label='Predicted Prices')
plt.title(f'Stock Price Prediction (LSTM) - [{str(ticker)}]')
plt.xlabel('Time')
plt.ylabel('Stock Price')
plt.legend()
plt.xticks(rotation=45)
return plt.gcf()
interface = gr.Interface(
fn=fetch_and_predict,
inputs=[
gr.Textbox(label="Stock Ticker", placeholder="Enter stock ticker (e.g., DAL, AAPL)"),
gr.Textbox(label="Period", placeholder="Enter period (e.g., '1y')")
],
outputs=gr.Plot(),
live=False,
allow_flagging="never",
title="Stock Price Prediction",
description="Enter the stock ticker and period, then click the button to fetch data and predict prices.",
theme="huggingface",
)
interface.launch()
|