File size: 2,383 Bytes
0c0d46a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import yfinance as yf
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
# Load your pre-trained Keras model
model = tf.keras.models.load_model("./best.keras")

# scale the data
def create_scaler(df):
    scaler = MinMaxScaler(feature_range=(0,1))
    scaled_df = scaler.fit_transform(df['Close'].values.reshape(-1, 1))
    return scaler, scaled_df
# create input output sequence
def create_sequence(scaled_df):
    X, y = [], []
    window = 60
    n_future = 1
    for i in range(len(scaled_df) - window - n_future - 1):
        X.append(scaled_df[i:i+window])
        y.append(scaled_df[i+window+n_future])
    X = np.array(X)
    y = np.array(y)
    return X, y

def fetch_and_predict(ticker, period):
    # Fetch historical stock data using yfinance
    try:
        df = yf.download(ticker, period=period)
        if isinstance(df.columns, pd.MultiIndex):
            df.columns = df.columns.get_level_values(0)
    except Exception as e:
        print("check 2")
        return f"Error downloading data: {e}"
    
    # Check if we have enough data for predictions

    if df.shape[0] < 60:
        return "Not enough data for predictions. Please select a longer period."
    
    # prepare data
    scaler, df = create_scaler(df)
    X, y = create_sequence(df)
    # Predicting stock prices
    try:
        print("fine")
        yhat = model.predict(X)
    except Exception as e:
        return f"Error during prediction: {e}"
    # Plot the predicted prices
    plt.figure(figsize=(14, 7))
    plt.plot(y, label='Actual Prices')
    plt.plot(yhat, label='Predicted Prices')
    plt.title(f'Stock Price Prediction (LSTM) - [{str(ticker)}]')
    plt.xlabel('Time')
    plt.ylabel('Stock Price')
    plt.legend()
    plt.xticks(rotation=45)
    return plt.gcf()

interface = gr.Interface(
    fn=fetch_and_predict,
    inputs=[
        gr.Textbox(label="Stock Ticker", placeholder="Enter stock ticker (e.g., DAL, AAPL)"),
        gr.Textbox(label="Period", placeholder="Enter period (e.g., '1y')")
    ],
    outputs=gr.Plot(),
    live=False,
    allow_flagging="never",
    title="Stock Price Prediction",
    description="Enter the stock ticker and period, then click the button to fetch data and predict prices.",
    theme="huggingface",
)

interface.launch()