File size: 2,263 Bytes
e418c5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import numpy as np
import torch
import sys, os
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from tqdm import tqdm
from lstm.model import STARM
np.set_printoptions(suppress=True)
device = torch.device('cpu')
torch.cuda.empty_cache()


data = torch.load('../../data/processed/striding/data.pt')
train_split = int(.85*54100)
X = data['x']
y = data['y']

X_train  = X[:train_split] #, X_test = X[:train_split], X[train_split:]
y_train  = y[:train_split] #, y_test = y[:train_split], y[train_split:]

X_train = X_train.to(device)
# X_test = X_test.to(device)
y_train = y_train.to(device)
# y_test = y_test.to(device)


n_epoch = 1000
lr = 0.001
input_size = 12
hidden_size = 2
num_lstm = 1
num_classes = 12

lstm = STARM(
    num_classes,
    input_size,
    hidden_size,
    num_lstm
).to(device)

loss_fn = torch.nn.MSELoss() 
optimiser = torch.optim.Adam(lstm.parameters(), lr=lr)


def training_loop(n_epochs, lstm, optimiser, loss_fn, X_train, y_train, batch_size=541000):
    dataset_size = X_train.size(0)
    num_batches = (dataset_size + batch_size - 1) // batch_size  # Number of batches

    for epoch in range(n_epochs):
        lstm.train()
        epoch_loss = 0.0  # To accumulate loss for the epoch

        # Progress bar for each epoch
        with tqdm(total=num_batches, desc=f"Epoch {epoch + 1}/{n_epochs}", unit="batch") as pbar:
            for i in range(0, dataset_size, batch_size):
                # Get batch
                X_batch = X_train[i:i + batch_size]
                y_batch = y_train[i:i + batch_size]

                # Forward pass
                outputs = lstm(X_batch)
                optimiser.zero_grad()

                # Compute loss
                loss = loss_fn(outputs, y_batch)
                loss.backward()
                optimiser.step()

                # Update progress bar and accumulate epoch loss
                epoch_loss += loss.item()
                pbar.set_postfix({"loss": loss.item()})
                pbar.update(1)  # Increment progress bar by 1 batch

        print(f"Epoch {epoch + 1} completed. Average loss: {epoch_loss / num_batches:.5f}")


training_loop(n_epochs=n_epoch,lstm=lstm,optimiser=optimiser,loss_fn=loss_fn,X_train=X_train,y_train=y_train)