mrshibly's picture
Create app.py
ac3e445 verified
raw
history blame
2.51 kB
import gradio as gr
import torch
import numpy as np
from sklearn.preprocessing import MinMaxScaler
import pandas as pd
# Load and preprocess data
df = pd.read_csv('HistoricalQuotes.csv')
df['Date'] = pd.to_datetime(df['Date'], format='%m/%d/%Y')
df.set_index('Date', inplace=True)
df = df[[' Close/Last']].rename(columns={' Close/Last': 'Close'})
df = df.sort_index()
df['Close'] = df['Close'].replace({r'\$': ''}, regex=True).astype(float)
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(df['Close'].values.reshape(-1, 1))
# Define LSTM model
class LSTMModel(nn.Module):
def __init__(self, input_size=1, hidden_size=50, num_layers=2, output_size=1, dropout=0.2):
super().__init__()
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(num_layers, x.size(0), 50)
c0 = torch.zeros(num_layers, x.size(0), 50)
out, _ = self.lstm(x, (h0, c0))
return self.fc(out[:, -1, :])
# Load trained model
model = LSTMModel()
model.load_state_dict(torch.load('lstm_model.pth'))
model.eval()
def forecast(past_prices, steps=30):
try:
# Parse input (comma-separated prices)
prices = [float(x.strip().replace('$', '')) for x in past_prices.split(',')]
if len(prices) < 60:
return "Error: Please provide at least 60 prices."
# Scale and prepare sequence
prices_scaled = scaler.transform(np.array(prices).reshape(-1, 1))
current_seq = torch.from_numpy(prices_scaled[-60:].reshape(1, 60, 1)).float()
# Forecast
predictions = []
for _ in range(steps):
with torch.no_grad():
pred_scaled = model(current_seq).item()
predictions.append(pred_scaled)
current_seq = torch.cat((current_seq[:, 1:, :], torch.tensor([[[pred_scaled]]]).float()), dim=1)
# Inverse transform
predictions = scaler.inverse_transform(np.array(predictions).reshape(-1, 1)).flatten()
return pd.DataFrame({'Forecast': predictions}).to_string()
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=forecast,
inputs=gr.Textbox(label="Past Prices (comma-separated, at least 60, e.g., 273.36,273.52,...)"),
outputs=gr.Textbox(label="Forecasted Prices")
)
iface.launch(share=True)