import numpy as np import torch import sys, os sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from tqdm import tqdm from lstm.model import STARM np.set_printoptions(suppress=True) device = torch.device('cpu') torch.cuda.empty_cache() data = torch.load('../../data/processed/striding/data.pt') train_split = int(.85*54100) X = data['x'] y = data['y'] X_train = X[:train_split] #, X_test = X[:train_split], X[train_split:] y_train = y[:train_split] #, y_test = y[:train_split], y[train_split:] X_train = X_train.to(device) # X_test = X_test.to(device) y_train = y_train.to(device) # y_test = y_test.to(device) n_epoch = 1000 lr = 0.001 input_size = 12 hidden_size = 2 num_lstm = 1 num_classes = 12 lstm = STARM( num_classes, input_size, hidden_size, num_lstm ).to(device) loss_fn = torch.nn.MSELoss() optimiser = torch.optim.Adam(lstm.parameters(), lr=lr) def training_loop(n_epochs, lstm, optimiser, loss_fn, X_train, y_train, batch_size=541000): dataset_size = X_train.size(0) num_batches = (dataset_size + batch_size - 1) // batch_size # Number of batches for epoch in range(n_epochs): lstm.train() epoch_loss = 0.0 # To accumulate loss for the epoch # Progress bar for each epoch with tqdm(total=num_batches, desc=f"Epoch {epoch + 1}/{n_epochs}", unit="batch") as pbar: for i in range(0, dataset_size, batch_size): # Get batch X_batch = X_train[i:i + batch_size] y_batch = y_train[i:i + batch_size] # Forward pass outputs = lstm(X_batch) optimiser.zero_grad() # Compute loss loss = loss_fn(outputs, y_batch) loss.backward() optimiser.step() # Update progress bar and accumulate epoch loss epoch_loss += loss.item() pbar.set_postfix({"loss": loss.item()}) pbar.update(1) # Increment progress bar by 1 batch print(f"Epoch {epoch + 1} completed. Average loss: {epoch_loss / num_batches:.5f}") training_loop(n_epochs=n_epoch,lstm=lstm,optimiser=optimiser,loss_fn=loss_fn,X_train=X_train,y_train=y_train)