Spaces:
Running
Running
| from model import EventDetector | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| from dataloader import GolfDB, ToTensor, Normalize | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from util import correct_preds | |
| def eval(model, split, seq_length, n_cpu, disp): | |
| dataset = GolfDB(data_file='data/val_split_{}.pkl'.format(split), | |
| vid_dir='data/videos_160/', | |
| seq_length=seq_length, | |
| transform=transforms.Compose([ToTensor(), | |
| Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]), | |
| train=False) | |
| data_loader = DataLoader(dataset, | |
| batch_size=1, | |
| shuffle=False, | |
| num_workers=n_cpu, | |
| drop_last=False) | |
| correct = [] | |
| for i, sample in enumerate(data_loader): | |
| images, labels = sample['images'], sample['labels'] | |
| # full samples do not fit into GPU memory so evaluate sample in 'seq_length' batches | |
| batch = 0 | |
| while batch * seq_length < images.shape[1]: | |
| if (batch + 1) * seq_length > images.shape[1]: | |
| image_batch = images[:, batch * seq_length:, :, :, :] | |
| else: | |
| image_batch = images[:, batch * seq_length:(batch + 1) * seq_length, :, :, :] | |
| logits = model(image_batch.cuda()) | |
| if batch == 0: | |
| probs = F.softmax(logits.data, dim=1).cpu().numpy() | |
| else: | |
| probs = np.append(probs, F.softmax(logits.data, dim=1).cpu().numpy(), 0) | |
| batch += 1 | |
| _, _, _, _, c = correct_preds(probs, labels.squeeze()) | |
| if disp: | |
| print(i, c) | |
| correct.append(c) | |
| PCE = np.mean(correct) | |
| return PCE | |
| if __name__ == '__main__': | |
| split = 1 | |
| seq_length = 64 | |
| n_cpu = 6 | |
| model = EventDetector(pretrain=True, | |
| width_mult=1., | |
| lstm_layers=1, | |
| lstm_hidden=256, | |
| bidirectional=True, | |
| dropout=False) | |
| save_dict = torch.load('models_v1/swingnet_1800.pth.tar') | |
| model.load_state_dict(save_dict['model_state_dict']) | |
| model.cuda() | |
| model.eval() | |
| PCE = eval(model, split, seq_length, n_cpu, True) | |
| print('Average PCE: {}'.format(PCE)) | |