Spaces:
Runtime error
Runtime error
| import torch | |
| import numpy as np | |
| from torch import nn | |
| from data_loader import DataLoader | |
| from helper import ValTest | |
| from modality_lstm import ModalityLSTM | |
| batch_size = 32 | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| train_on_gpu = True | |
| output_size = 5 | |
| hidden_dim = 128 | |
| trip_dim = 7 | |
| n_layers = 2 | |
| drop_prob = 0.2 | |
| net = ModalityLSTM(trip_dim, output_size, batch_size, hidden_dim, n_layers, train_on_gpu, drop_prob, lstm_drop_prob=0.2) | |
| lr=0.001 | |
| loss_function = nn.CrossEntropyLoss(ignore_index=-1) | |
| optimizer = torch.optim.Adam(net.parameters(), lr=lr) | |
| epochs = 6 | |
| print_every = 5 | |
| log_every = 1 | |
| evaluate_every = 100 | |
| clip = 0.2 # gradient clipping | |
| if train_on_gpu: | |
| net.cuda() | |
| net.train() | |
| dl = DataLoader(batchsize=batch_size, read_from_pickle=True) | |
| dl.prepare_data() | |
| def pad_trajs(trajs, lengths): | |
| for w, elem in enumerate(trajs): | |
| while len(elem) < lengths[0]: | |
| elem.append([-1] * trip_dim) | |
| return trajs | |
| losses, avg_losses = [], [] | |
| validator = ValTest(dl.val_batches, net, trip_dim, batch_size, device, loss_function, output_size, dl.get_val_size()) | |
| test = ValTest(dl.test_batches, net, trip_dim, batch_size, device, loss_function, output_size, dl.get_test_size()) | |
| for e in range(1,epochs+1): | |
| print("epoch ",e) | |
| hidden = net.init_hidden() | |
| counter = 0 | |
| torch.cuda.empty_cache() | |
| for train_sorted, labels_sorted in dl.batches(): | |
| counter += 1 | |
| lengths = [len(x) for x in train_sorted] | |
| print("Lengths are ", lengths) | |
| print("SUm of lengths",sum(lengths)) | |
| train_sorted = pad_trajs(train_sorted, lengths) | |
| X = np.asarray(train_sorted, dtype=np.float) | |
| input_tensor = torch.from_numpy(X) | |
| print("Input tensor is ",input_tensor.shape) | |
| input_tensor = input_tensor.to(device) | |
| net.zero_grad() | |
| output, max_padding_for_this_batch = net(input_tensor, lengths) | |
| print("Output is",output.shape) | |
| for labelz in labels_sorted: | |
| while len(labelz) < max_padding_for_this_batch: | |
| labelz.append(-1) | |
| labels_for_loss = torch.tensor(labels_sorted).view(max_padding_for_this_batch * batch_size, -1).squeeze( | |
| 1).long().to(device) | |
| print("Labels for loss is",len(labels_for_loss)) | |
| loss = loss_function(output.view( | |
| max_padding_for_this_batch*batch_size, -1), | |
| labels_for_loss) | |
| loss.backward() | |
| nn.utils.clip_grad_norm_(net.parameters(), clip) | |
| optimizer.step() | |
| if counter % log_every == 0: | |
| losses.append(loss.item()) | |
| if counter % print_every == 0: | |
| avg_losses.append(sum(losses[-50:]) / 50) | |
| print( | |
| f'Epoch: {e:2d}. {counter:d} of {int(dl.get_train_size() / batch_size):d} {avg_losses[len(avg_losses) - 1]:f} Loss: {loss.item():.4f}') | |
| if counter % evaluate_every == 0: | |
| validator.run() | |
| torch.save(net.state_dict(),"Model_Wieghts") | |
| print("Testing") | |
| test.run() | |
| torch.save(net.state_dict(),"Model_Wieghts") |