Synthetic_Stock_Data / src /ae_latent_extract.py
Raheel Abdul Rehman
Prod Publish
bbf5d55
import os
import sys
import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from src.model import LSTMAutoEncoder, QuarterlyStockDataset # pylint: disable=import-error
from src.logger import get_logger
logger = get_logger(__name__)
if __name__ == "__main__":
try:
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
processed_data_path = os.path.join(base_dir, "data", "processed", "stock_data.parquet")
model_path = os.path.join(base_dir, "models", "lstm_autoencoder.pth")
latent_vectors_path = os.path.join(base_dir, "..", "GAN", "data", "processed", "latent_vectors.npy")
ticker_mapping_path = os.path.join(base_dir, "..", "GAN", "data", "processed", "ticker_mapping.npy")
device = "cuda" if torch.cuda.is_available() else "cpu"
df = pd.read_parquet(processed_data_path)
tickers = df["Ticker"].unique()
num_tickers = df["Ticker_Encoded"].nunique()
model = LSTMAutoEncoder(
input_dim=5,
num_tickers=num_tickers,
embed_dim=16,
hidden_size=64,
latent_dim=32,
num_layers=2
).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
def encode(model, x, ticker_id):
ticker_emb = model.ticker_embed(ticker_id).unsqueeze(1).repeat(1, x.size(1), 1)
x_in = torch.cat([x, ticker_emb], dim=2)
enc_out, _ = model.encoder(x_in)
latent = model.fc_enc(enc_out[:, -1, :])
return latent
all_latents = []
all_tickers = []
for ticker in tickers:
ticker_df = df[df["Ticker"] == ticker].copy()
if len(ticker_df) < 90:
continue
dataset = QuarterlyStockDataset(ticker_df, sequence_length=90)
loader = DataLoader(dataset, batch_size=64, shuffle=False)
ticker_latents = []
with torch.no_grad():
for batch_x, batch_ticker in loader:
batch_x, batch_ticker = batch_x.to(device), batch_ticker.to(device)
latent = encode(model, batch_x, batch_ticker)
ticker_latents.append(latent.cpu().numpy())
if ticker_latents:
ticker_latents = np.concatenate(ticker_latents, axis=0)
all_latents.append(ticker_latents)
all_tickers.extend([ticker] * len(ticker_latents))
logger.info(f"Extracted {len(ticker_latents)} latent vectors for {ticker}.")
all_latents = np.concatenate(all_latents, axis=0)
all_tickers = np.array(all_tickers)
np.save(latent_vectors_path, all_latents)
np.save(ticker_mapping_path, all_tickers)
logger.info(f"Saved {len(all_latents)} latent vectors to {latent_vectors_path}")
logger.info(f"Saved ticker mapping to {ticker_mapping_path}")
except Exception as e:
logger.error("Error extracting latent space vectors: %s", e)
raise