Synthetic_Stock_Data / src /ae_model.py
Raheel Abdul Rehman
Prod Publish
bbf5d55
import os
import sys
import json
import optuna
import warnings
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset, DataLoader
warnings.simplefilter(action='ignore', category=FutureWarning)
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.logger import get_logger # pylint: disable=import-error
logger = get_logger(__name__)
class QuarterlyStockDataset(Dataset):
def __init__(self, df, sequence_length=90):
try:
self.sequence_length = sequence_length
self.samples = []
df = df.sort_values(by=["Ticker", "Date"]).reset_index(drop=True)
tickers = df['Ticker'].unique()
feature_cols = ['Open', 'High', 'Low', 'Close', 'Volume']
for ticker in tickers:
ticker_df = df[df['Ticker'] == ticker]
data = ticker_df[feature_cols].values
ticker_id = ticker_df['Ticker_Encoded'].iloc[0]
for i in range(0, len(data) - sequence_length + 1, sequence_length):
window = data[i:i+sequence_length]
self.samples.append((torch.tensor(window, dtype=torch.float32),
torch.tensor(ticker_id, dtype=torch.long)))
print(f"Created {len(self.samples)} quarterly sequences across {len(tickers)} tickers.")
except Exception as e:
logger.error("Error batching dataset: %s", e)
raise
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
class LSTMAutoEncoder(nn.Module):
def __init__(self, input_dim, num_tickers, embed_dim=8, hidden_size=64, latent_dim=16, num_layers=1):
super(LSTMAutoEncoder, self).__init__()
self.ticker_embed = nn.Embedding(num_tickers, embed_dim)
# Encoder
self.encoder = nn.LSTM(input_dim + embed_dim, hidden_size, num_layers=num_layers, batch_first=True)
self.fc_enc = nn.Linear(hidden_size, latent_dim)
# Decoder
self.fc_dec = nn.Linear(latent_dim + embed_dim, hidden_size)
self.decoder = nn.LSTM(hidden_size, input_dim, num_layers=num_layers, batch_first=True)
def forward(self, x, ticker_id):
ticker_emb = self.ticker_embed(ticker_id).unsqueeze(1).repeat(1, x.size(1), 1)
x_in = torch.cat([x, ticker_emb], dim=2)
# Encoder
enc_out, (h, c) = self.encoder(x_in)
latent = self.fc_enc(enc_out[:, -1, :])
latent_cat = torch.cat([latent, self.ticker_embed(ticker_id)], dim=1)
latent_cat = latent_cat.unsqueeze(1).repeat(1, x.size(1), 1)
# Decoder
dec_input = self.fc_dec(latent_cat)
out_dec, _ = self.decoder(dec_input)
return out_dec
def objective(trial, df, sequence_length=90, device='cpu'):
try:
num_layers = trial.suggest_int("num_layers", 1, 3)
hidden_size = trial.suggest_categorical("hidden_size", [32, 64, 128])
latent_dim = trial.suggest_categorical("latent_dim", [8, 16, 32])
lr = trial.suggest_loguniform("lr", 1e-4, 1e-2)
embed_dim = trial.suggest_categorical("embed_dim", [4, 8, 16])
train_df = df[df['Date'] < '2023-01-01']
val_df = df[(df['Date'] >= '2023-01-01') & (df['Date'] < '2024-01-01')]
train_dataset = QuarterlyStockDataset(train_df, sequence_length)
val_dataset = QuarterlyStockDataset(val_df, sequence_length)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
num_tickers = df['Ticker_Encoded'].nunique()
model = LSTMAutoEncoder(
input_dim=5, num_tickers=num_tickers, embed_dim=embed_dim,
hidden_size=hidden_size, latent_dim=latent_dim, num_layers=num_layers
).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
epochs = 20
for epoch in range(epochs):
model.train()
total_train_loss = 0
for batch_x, batch_ticker in train_loader:
batch_x, batch_ticker = batch_x.to(device), batch_ticker.to(device)
optimizer.zero_grad()
recon = model(batch_x, batch_ticker)
loss = criterion(recon, batch_x)
loss.backward()
optimizer.step()
total_train_loss += loss.item()
model.eval()
total_val_loss = 0
with torch.no_grad():
for batch_x, batch_ticker in val_loader:
batch_x, batch_ticker = batch_x.to(device), batch_ticker.to(device)
recon = model(batch_x, batch_ticker)
loss = criterion(recon, batch_x)
total_val_loss += loss.item()
avg_val_loss = total_val_loss / len(val_loader)
return avg_val_loss
except Exception as e:
logger.error("Error training Model : %s", e)
raise
if __name__ == "__main__":
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
processed_data_path = os.path.join(base_dir, 'data', 'processed', 'stock_data.parquet')
model_path = os.path.join(base_dir, 'models', 'lstm_autoencoder.pth')
loss_path = os.path.join(base_dir, 'resources', 'loss_values.json')
hyperparams_path = os.path.join(base_dir, 'models', 'hyperparameters.json')
df = pd.read_parquet(processed_data_path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
study = optuna.create_study(direction="minimize")
study.optimize(lambda trial: objective(trial, df, device=device), n_trials=10)
best_trial = study.best_trial
best_params = best_trial.params
train_df = df[df['Date'] < '2024-01-01']
full_dataset = QuarterlyStockDataset(train_df, sequence_length=90)
full_loader = DataLoader(full_dataset, batch_size=64, shuffle=False)
num_tickers = df['Ticker_Encoded'].nunique()
best_model = LSTMAutoEncoder(
input_dim=5,
num_tickers=num_tickers,
embed_dim=best_params.get('embed_dim', 8),
hidden_size=best_params['hidden_size'],
latent_dim=best_params['latent_dim'],
num_layers=best_params['num_layers']
).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(best_model.parameters(), lr=best_params['lr'])
epochs = 50
train_losses = []
for epoch in range(epochs):
best_model.train()
total_loss = 0
for batch_x, batch_ticker in full_loader:
batch_x, batch_ticker = batch_x.to(device), batch_ticker.to(device)
optimizer.zero_grad()
recon = best_model(batch_x, batch_ticker)
loss = criterion(recon, batch_x)
loss.backward()
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(full_loader)
train_losses.append(avg_loss)
print(f"Epoch [{epoch+1}/{epochs}] Loss: {avg_loss:.6f}")
torch.save(best_model.state_dict(), model_path)
with open(loss_path, 'w') as f:
json.dump(train_losses, f)
with open(hyperparams_path, 'w') as f:
json.dump(best_params, f)
print(f"Model, losses, and hyperparameters saved successfully.")