starx / model /train /train.py
recorderlegend1's picture
Upload folder using huggingface_hub
e418c5a verified
import numpy as np
import torch
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tqdm import tqdm
from lstm.model import STARM
np.set_printoptions(suppress=True)
device = torch.device('cpu')
torch.cuda.empty_cache()
data = torch.load('../../data/processed/striding/data.pt')
train_split = int(.85*54100)
X = data['x']
y = data['y']
X_train = X[:train_split] #, X_test = X[:train_split], X[train_split:]
y_train = y[:train_split] #, y_test = y[:train_split], y[train_split:]
X_train = X_train.to(device)
# X_test = X_test.to(device)
y_train = y_train.to(device)
# y_test = y_test.to(device)
n_epoch = 1000
lr = 0.001
input_size = 12
hidden_size = 2
num_lstm = 1
num_classes = 12
lstm = STARM(
num_classes,
input_size,
hidden_size,
num_lstm
).to(device)
loss_fn = torch.nn.MSELoss()
optimiser = torch.optim.Adam(lstm.parameters(), lr=lr)
def training_loop(n_epochs, lstm, optimiser, loss_fn, X_train, y_train, batch_size=541000):
dataset_size = X_train.size(0)
num_batches = (dataset_size + batch_size - 1) // batch_size # Number of batches
for epoch in range(n_epochs):
lstm.train()
epoch_loss = 0.0 # To accumulate loss for the epoch
# Progress bar for each epoch
with tqdm(total=num_batches, desc=f"Epoch {epoch + 1}/{n_epochs}", unit="batch") as pbar:
for i in range(0, dataset_size, batch_size):
# Get batch
X_batch = X_train[i:i + batch_size]
y_batch = y_train[i:i + batch_size]
# Forward pass
outputs = lstm(X_batch)
optimiser.zero_grad()
# Compute loss
loss = loss_fn(outputs, y_batch)
loss.backward()
optimiser.step()
# Update progress bar and accumulate epoch loss
epoch_loss += loss.item()
pbar.set_postfix({"loss": loss.item()})
pbar.update(1) # Increment progress bar by 1 batch
print(f"Epoch {epoch + 1} completed. Average loss: {epoch_loss / num_batches:.5f}")
training_loop(n_epochs=n_epoch,lstm=lstm,optimiser=optimiser,loss_fn=loss_fn,X_train=X_train,y_train=y_train)