mrshibly's picture
Update app.py
c9c1d04 verified
raw
history blame
2.42 kB
import gradio as gr
import torch
import torch.nn as nn
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
# Load data and scaler
df = pd.read_csv('HistoricalQuotes.csv')
df['Date'] = pd.to_datetime(df['Date'], format='%m/%d/%Y')
df.set_index('Date', inplace=True)
df = df[[' Close/Last']].rename(columns={' Close/Last': 'Close'})
df = df.sort_index()
df['Close'] = df['Close'].replace({r'\$': ''}, regex=True).astype(float)
scaler = MinMaxScaler(feature_range=(0, 1))
scaler.fit(df['Close'].values.reshape(-1, 1))
# Define LSTM model (matching trained model: hidden_size=50)
class LSTMModel(nn.Module):
def __init__(self, input_size=1, hidden_size=50, num_layers=2, output_size=1, dropout=0.2):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(num_layers, x.size(0), hidden_size)
c0 = torch.zeros(num_layers, x.size(0), hidden_size)
out, _ = self.lstm(x, (h0, c0))
return self.fc(out[:, -1, :])
# Load model
model = LSTMModel()
model.load_state_dict(torch.load('lstm_model.pth', map_location=torch.device('cpu')))
model.eval()
def forecast(past_prices, steps=30):
try:
prices = [float(x.strip().replace('$', '')) for x in past_prices.split(',')]
if len(prices) < 60:
return "Error: At least 60 prices required."
prices_scaled = scaler.transform(np.array(prices).reshape(-1, 1))
current_seq = torch.from_numpy(prices_scaled[-60:].reshape(1, 60, 1)).float()
predictions = []
for _ in range(steps):
with torch.no_grad():
pred_scaled = model(current_seq).item()
predictions.append(pred_scaled)
current_seq = torch.cat((current_seq[:, 1:, :], torch.tensor([[[pred_scaled]]]).float()), dim=1)
predictions = scaler.inverse_transform(np.array(predictions).reshape(-1, 1)).flatten()
return pd.DataFrame({'Forecast': predictions}).to_string()
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=forecast,
inputs=gr.Textbox(label="Past Prices (comma-separated, e.g., 273.36,273.52,...)"),
outputs=gr.Textbox(label="Forecasted Prices")
)
iface.launch(share=True)