Spaces:
Runtime error
Runtime error
| import dijkprofile_annotator.preprocessing as preprocessing | |
| import dijkprofile_annotator.utils as utils | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from dijkprofile_annotator.models import Dijknet | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| def get_loss_train(model, data_train, criterion): | |
| """generate loss over train set. | |
| Args: | |
| model (): model to use for prediction | |
| data_train (torch.utils.data.DataLoader)): Dataloader containing the profiles | |
| and labels | |
| criterion (pytorch loss function, probably nn.CrossEntropyLoss): loss function to be used. | |
| Returns: | |
| float: total accuracy | |
| float: total loss | |
| """ | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model.eval() | |
| total_acc = 0 | |
| total_loss = 0 | |
| for batch, (profile, masks) in enumerate(data_train): | |
| with torch.no_grad(): | |
| profile = torch.Tensor(profile).to(device) | |
| masks = torch.Tensor(masks).to(device) | |
| outputs = model(profile) | |
| loss = criterion(outputs, masks) | |
| preds = torch.argmax(outputs, dim=1).float() | |
| acc = accuracy_check_for_batch(masks.cpu(), preds.cpu(), profile.size()[0]) | |
| total_acc = total_acc + acc | |
| total_loss = total_loss + loss.cpu().item() | |
| return total_acc/(batch+1), total_loss/(batch + 1) | |
| def accuracy_check(mask, prediction): | |
| """check accuracy of prediciton. | |
| Args: | |
| mask (torch.Tensor, PIL Image or str): labels | |
| prediction (torch.Tensor, PIL Image or str): predictions | |
| Returns: | |
| float: accuracy of prediction given mask. | |
| """ | |
| ims = [mask, prediction] | |
| np_ims = [] | |
| for item in ims: | |
| if 'str' in str(type(item)): | |
| item = np.array(Image.open(item)) | |
| elif 'PIL' in str(type(item)): | |
| item = np.array(item) | |
| elif 'torch' in str(type(item)): | |
| item = item.numpy() | |
| np_ims.append(item) | |
| compare = np.equal(np_ims[0], np_ims[1]) | |
| accuracy = np.sum(compare) | |
| return accuracy/len(np_ims[0].flatten()) | |
| def accuracy_check_for_batch(masks, predictions, batch_size): | |
| """check accuracy of prediciton given mask. | |
| Args: | |
| masks (torch.Tensor): labels | |
| predictions (torch.Tensor): predictions | |
| batch_size (int): batch size of prediciton/mask. | |
| Returns: | |
| float: accuracy of prediction given mask. | |
| """ | |
| total_acc = 0 | |
| for index in range(batch_size): | |
| total_acc += accuracy_check(masks[index], predictions[index]) | |
| return total_acc/batch_size | |
| def train(annotation_tuples, | |
| epochs=100, | |
| batch_size_train=32, | |
| batch_size_val=512, | |
| num_workers=6, | |
| custom_scaler_path=None, | |
| class_list='simple', | |
| test_size=0.2, | |
| max_profile_size=512, | |
| shuffle=True): | |
| """[summary] | |
| Args: | |
| annotation_tuples ([type]): [description] | |
| epochs (int, optional): [description]. Defaults to 100. | |
| batch_size_train (int, optional): [description]. Defaults to 32. | |
| batch_size_val (int, optional): [description]. Defaults to 512. | |
| num_workers (int, optional): [description]. Defaults to 6. | |
| custom_scaler_path ([type], optional): [description]. Defaults to None. | |
| class_list (str, optional): [description]. Defaults to 'simple'. | |
| test_size (float, optional): [description]. Defaults to 0.2. | |
| max_profile_size (int, optional): [description]. Defaults to 512. | |
| shuffle (bool, optional): [description]. Defaults to True. | |
| Raises: | |
| NotImplementedError: when given class_list is not implemented | |
| Returns: | |
| [type]: trained Dijknet model. | |
| """ | |
| print(f"loading datasets") | |
| train_dataset, test_dataset = preprocessing.load_datasets(annotation_tuples, | |
| custom_scaler_path=custom_scaler_path, | |
| test_size=test_size, | |
| max_profile_size=max_profile_size) | |
| print(f"loaded datasets:") | |
| print(f" train: {len(train_dataset)} samples") | |
| print(f" test: {len(test_dataset)} samples") | |
| class_dict, _, class_weights = utils.get_class_dict(class_list) | |
| print(f"constructing model with {len(class_dict)} output classes") | |
| model = Dijknet(1, len(class_dict)) | |
| # parameters | |
| train_params = {'batch_size': batch_size_train, | |
| 'shuffle': shuffle, | |
| 'num_workers': num_workers} | |
| params_val = {'batch_size': batch_size_val, | |
| 'shuffle': False, | |
| 'num_workers': num_workers} | |
| training_generator = DataLoader(train_dataset, **train_params) | |
| validation_generator = DataLoader(test_dataset, **params_val) | |
| # CUDA for PyTorch | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| # loss | |
| criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(class_weights).to(device)) | |
| # Optimizer | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
| print("starting training.") | |
| # Loop over epochs | |
| for epoch in range(epochs): | |
| print("epoch: {}".format(epoch)) | |
| # Training | |
| loss_list = [] | |
| model.train() | |
| for local_batch, local_labels in tqdm(training_generator): | |
| # bug with dataloader, it doesn't return the right size batch when it runs out of samples | |
| if not local_labels.shape[0] == train_params['batch_size']: | |
| continue | |
| # Transfer to GPU | |
| local_batch, local_labels = local_batch.to(device), local_labels.to(device).long() | |
| # Model computations | |
| outputs = model(local_batch) | |
| local_labels = local_labels.reshape(train_params['batch_size'], -1) | |
| loss = criterion(outputs, local_labels) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| # Update weights | |
| optimizer.step() | |
| loss_list.append(loss.detach().cpu().numpy()) | |
| # report average loss over epoch | |
| print("training loss: ", np.mean(loss_list)) | |
| # Validation | |
| model.eval() | |
| batch_accuracies = [] | |
| batch_accuracies_iso = [] | |
| batch_loss_val = [] | |
| for local_batch, local_labels in validation_generator: | |
| # get new batches | |
| local_batch, local_labels = local_batch.to(device), local_labels.to(device).long() | |
| # Model computations | |
| outputs = model(local_batch) | |
| # calc loss | |
| loss = criterion(outputs, local_labels.reshape(local_labels.shape[0], -1)) | |
| batch_loss_val.append(loss.detach().cpu().numpy()) | |
| outputs_iso = utils.force_sequential_predictions(outputs, method='isotonic') | |
| outputs_first = utils.force_sequential_predictions(outputs, method='first') | |
| # compute accuracy for whole validation set | |
| flat_output = torch.argmax(outputs, dim=1).cpu().reshape(local_batch.shape[0], 1, -1) | |
| compare = flat_output == local_labels.cpu() | |
| acc = np.sum(compare.numpy(), axis=2) / \ | |
| int(local_batch.shape[-1]) # * params_val['batch_size'] | |
| batch_accuracies.append(np.mean(acc, axis=0)[0]) | |
| flat_output = torch.argmax(outputs_iso, dim=1).cpu().reshape(local_batch.shape[0], 1, -1) | |
| compare = flat_output == local_labels.cpu() | |
| acc = np.sum(compare.numpy(), axis=2) / \ | |
| int(local_batch.shape[-1]) # * params_val['batch_size'] | |
| batch_accuracies_iso.append(np.mean(acc, axis=0)[0]) | |
| print("validation accuracy: {}".format(np.mean(batch_accuracies))) | |
| print("validation accuracy isotonic regression: {}".format(np.mean(batch_accuracies_iso))) | |
| print("validation loss: {}".format(np.mean(batch_loss_val))) | |
| print("="*50) | |
| return model | |