| | |
| | """ Finetuning functions for doing transfer learning to new datasets. |
| | """ |
| | from __future__ import print_function |
| |
|
| | import uuid |
| | from time import sleep |
| | from io import open |
| |
|
| | import math |
| | import pickle |
| | import numpy as np |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from sklearn.metrics import accuracy_score |
| | from torch.autograd import Variable |
| | from torch.utils.data import Dataset, DataLoader |
| | from torch.utils.data.sampler import BatchSampler, SequentialSampler |
| | from torch.nn.utils import clip_grad_norm |
| |
|
| | from sklearn.metrics import f1_score |
| |
|
| | from torchmoji.global_variables import (FINETUNING_METHODS, |
| | FINETUNING_METRICS, |
| | WEIGHTS_DIR) |
| | from torchmoji.tokenizer import tokenize |
| | from torchmoji.sentence_tokenizer import SentenceTokenizer |
| |
|
| | try: |
| | unicode |
| | IS_PYTHON2 = True |
| | except NameError: |
| | unicode = str |
| | IS_PYTHON2 = False |
| |
|
| |
|
| | def load_benchmark(path, vocab, extend_with=0): |
| | """ Loads the given benchmark dataset. |
| | |
| | Tokenizes the texts using the provided vocabulary, extending it with |
| | words from the training dataset if extend_with > 0. Splits them into |
| | three lists: training, validation and testing (in that order). |
| | |
| | Also calculates the maximum length of the texts and the |
| | suggested batch_size. |
| | |
| | # Arguments: |
| | path: Path to the dataset to be loaded. |
| | vocab: Vocabulary to be used for tokenizing texts. |
| | extend_with: If > 0, the vocabulary will be extended with up to |
| | extend_with tokens from the training set before tokenizing. |
| | |
| | # Returns: |
| | A dictionary with the following fields: |
| | texts: List of three lists, containing tokenized inputs for |
| | training, validation and testing (in that order). |
| | labels: List of three lists, containing labels for training, |
| | validation and testing (in that order). |
| | added: Number of tokens added to the vocabulary. |
| | batch_size: Batch size. |
| | maxlen: Maximum length of an input. |
| | """ |
| | |
| | with open(path, 'rb') as dataset: |
| | if IS_PYTHON2: |
| | data = pickle.load(dataset) |
| | else: |
| | data = pickle.load(dataset, fix_imports=True) |
| |
|
| | |
| | try: |
| | texts = [unicode(x) for x in data['texts']] |
| | except UnicodeDecodeError: |
| | texts = [x.decode('utf-8') for x in data['texts']] |
| |
|
| | |
| | labels = [x['label'] for x in data['info']] |
| |
|
| | batch_size, maxlen = calculate_batchsize_maxlen(texts) |
| |
|
| | st = SentenceTokenizer(vocab, maxlen) |
| |
|
| | |
| | |
| | texts, labels, added = st.split_train_val_test(texts, |
| | labels, |
| | [data['train_ind'], |
| | data['val_ind'], |
| | data['test_ind']], |
| | extend_with=extend_with) |
| | return {'texts': texts, |
| | 'labels': labels, |
| | 'added': added, |
| | 'batch_size': batch_size, |
| | 'maxlen': maxlen} |
| |
|
| |
|
| | def calculate_batchsize_maxlen(texts): |
| | """ Calculates the maximum length in the provided texts and a suitable |
| | batch size. Rounds up maxlen to the nearest multiple of ten. |
| | |
| | # Arguments: |
| | texts: List of inputs. |
| | |
| | # Returns: |
| | Batch size, |
| | max length |
| | """ |
| | def roundup(x): |
| | return int(math.ceil(x / 10.0)) * 10 |
| |
|
| | |
| | |
| | lengths = [len(tokenize(t)) for t in texts] |
| | maxlen = roundup(np.percentile(lengths, 80.0)) |
| | batch_size = 250 if maxlen <= 100 else 50 |
| | return batch_size, maxlen |
| |
|
| |
|
| |
|
| | def freeze_layers(model, unfrozen_types=[], unfrozen_keyword=None): |
| | """ Freezes all layers in the given model, except for ones that are |
| | explicitly specified to not be frozen. |
| | |
| | # Arguments: |
| | model: Model whose layers should be modified. |
| | unfrozen_types: List of layer types which shouldn't be frozen. |
| | unfrozen_keyword: Name keywords of layers that shouldn't be frozen. |
| | |
| | # Returns: |
| | Model with the selected layers frozen. |
| | """ |
| | |
| | trainable_modules = [(n, m) for n, m in model.named_children() if len([id(p) for p in m.parameters()]) != 0] |
| | for name, module in trainable_modules: |
| | trainable = (any(typ in str(module) for typ in unfrozen_types) or |
| | (unfrozen_keyword is not None and unfrozen_keyword.lower() in name.lower())) |
| | change_trainable(module, trainable, verbose=False) |
| | return model |
| |
|
| |
|
| | def change_trainable(module, trainable, verbose=False): |
| | """ Helper method that freezes or unfreezes a given layer. |
| | |
| | # Arguments: |
| | module: Module to be modified. |
| | trainable: Whether the layer should be frozen or unfrozen. |
| | verbose: Verbosity flag. |
| | """ |
| | |
| | if verbose: print('Changing MODULE', module, 'to trainable =', trainable) |
| | for name, param in module.named_parameters(): |
| | if verbose: print('Setting weight', name, 'to trainable =', trainable) |
| | param.requires_grad = trainable |
| |
|
| | if verbose: |
| | action = 'Unfroze' if trainable else 'Froze' |
| | if verbose: print("{} {}".format(action, module)) |
| |
|
| |
|
| | def find_f1_threshold(model, val_gen, test_gen, average='binary'): |
| | """ Choose a threshold for F1 based on the validation dataset |
| | (see https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4442797/ |
| | for details on why to find another threshold than simply 0.5) |
| | |
| | # Arguments: |
| | model: pyTorch model |
| | val_gen: Validation set dataloader. |
| | test_gen: Testing set dataloader. |
| | |
| | # Returns: |
| | F1 score for the given data and |
| | the corresponding F1 threshold |
| | """ |
| | thresholds = np.arange(0.01, 0.5, step=0.01) |
| | f1_scores = [] |
| |
|
| | model.eval() |
| | val_out = [(y, model(X)) for X, y in val_gen] |
| | y_val, y_pred_val = (list(t) for t in zip(*val_out)) |
| |
|
| | test_out = [(y, model(X)) for X, y in test_gen] |
| | y_test, y_pred_test = (list(t) for t in zip(*val_out)) |
| |
|
| | for t in thresholds: |
| | y_pred_val_ind = (y_pred_val > t) |
| | f1_val = f1_score(y_val, y_pred_val_ind, average=average) |
| | f1_scores.append(f1_val) |
| |
|
| | best_t = thresholds[np.argmax(f1_scores)] |
| | y_pred_ind = (y_pred_test > best_t) |
| | f1_test = f1_score(y_test, y_pred_ind, average=average) |
| | return f1_test, best_t |
| |
|
| |
|
| | def finetune(model, texts, labels, nb_classes, batch_size, method, |
| | metric='acc', epoch_size=5000, nb_epochs=1000, embed_l2=1E-6, |
| | verbose=1): |
| | """ Compiles and finetunes the given pytorch model. |
| | |
| | # Arguments: |
| | model: Model to be finetuned |
| | texts: List of three lists, containing tokenized inputs for training, |
| | validation and testing (in that order). |
| | labels: List of three lists, containing labels for training, |
| | validation and testing (in that order). |
| | nb_classes: Number of classes in the dataset. |
| | batch_size: Batch size. |
| | method: Finetuning method to be used. For available methods, see |
| | FINETUNING_METHODS in global_variables.py. |
| | metric: Evaluation metric to be used. For available metrics, see |
| | FINETUNING_METRICS in global_variables.py. |
| | epoch_size: Number of samples in an epoch. |
| | nb_epochs: Number of epochs. Doesn't matter much as early stopping is used. |
| | embed_l2: L2 regularization for the embedding layer. |
| | verbose: Verbosity flag. |
| | |
| | # Returns: |
| | Model after finetuning, |
| | score after finetuning using the provided metric. |
| | """ |
| |
|
| | if method not in FINETUNING_METHODS: |
| | raise ValueError('ERROR (finetune): Invalid method parameter. ' |
| | 'Available options: {}'.format(FINETUNING_METHODS)) |
| | if metric not in FINETUNING_METRICS: |
| | raise ValueError('ERROR (finetune): Invalid metric parameter. ' |
| | 'Available options: {}'.format(FINETUNING_METRICS)) |
| |
|
| | train_gen = get_data_loader(texts[0], labels[0], batch_size, |
| | extended_batch_sampler=True, epoch_size=epoch_size) |
| | val_gen = get_data_loader(texts[1], labels[1], batch_size, |
| | extended_batch_sampler=False) |
| | test_gen = get_data_loader(texts[2], labels[2], batch_size, |
| | extended_batch_sampler=False) |
| |
|
| | checkpoint_path = '{}/torchmoji-checkpoint-{}.bin' \ |
| | .format(WEIGHTS_DIR, str(uuid.uuid4())) |
| |
|
| | if method in ['last', 'new']: |
| | lr = 0.001 |
| | elif method in ['full', 'chain-thaw']: |
| | lr = 0.0001 |
| |
|
| | loss_op = nn.BCEWithLogitsLoss() if nb_classes <= 2 \ |
| | else nn.CrossEntropyLoss() |
| |
|
| | |
| | if method == 'last': |
| | model = freeze_layers(model, unfrozen_keyword='output_layer') |
| |
|
| | |
| | if method == 'last': |
| | adam = optim.Adam((p for p in model.parameters() if p.requires_grad), lr=lr) |
| | elif method in ['full', 'new']: |
| | |
| | embed_params_id = [id(p) for p in model.embed.parameters()] |
| | output_layer_params_id = [id(p) for p in model.output_layer.parameters()] |
| | base_params = [p for p in model.parameters() |
| | if id(p) not in embed_params_id and id(p) not in output_layer_params_id and p.requires_grad] |
| | embed_params = [p for p in model.parameters() if id(p) in embed_params_id and p.requires_grad] |
| | output_layer_params = [p for p in model.parameters() if id(p) in output_layer_params_id and p.requires_grad] |
| | adam = optim.Adam([ |
| | {'params': base_params}, |
| | {'params': embed_params, 'weight_decay': embed_l2}, |
| | {'params': output_layer_params, 'lr': 0.001}, |
| | ], lr=lr) |
| |
|
| | |
| | if verbose: |
| | print('Method: {}'.format(method)) |
| | print('Metric: {}'.format(metric)) |
| | print('Classes: {}'.format(nb_classes)) |
| |
|
| | if method == 'chain-thaw': |
| | result = chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, embed_l2=embed_l2, |
| | evaluate=metric, verbose=verbose) |
| | else: |
| | result = tune_trainable(model, loss_op, adam, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, |
| | evaluate=metric, verbose=verbose) |
| | return model, result |
| |
|
| |
|
| | def tune_trainable(model, loss_op, optim_op, train_gen, val_gen, test_gen, |
| | nb_epochs, checkpoint_path, patience=5, evaluate='acc', |
| | verbose=2): |
| | """ Finetunes the given model using the accuracy measure. |
| | |
| | # Arguments: |
| | model: Model to be finetuned. |
| | nb_classes: Number of classes in the given dataset. |
| | train: Training data, given as a tuple of (inputs, outputs) |
| | val: Validation data, given as a tuple of (inputs, outputs) |
| | test: Testing data, given as a tuple of (inputs, outputs) |
| | epoch_size: Number of samples in an epoch. |
| | nb_epochs: Number of epochs. |
| | batch_size: Batch size. |
| | checkpoint_weight_path: Filepath where weights will be checkpointed to |
| | during training. This file will be rewritten by the function. |
| | patience: Patience for callback methods. |
| | evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'. |
| | verbose: Verbosity flag. |
| | |
| | # Returns: |
| | Accuracy of the trained model, ONLY if 'evaluate' is set. |
| | """ |
| | if verbose: |
| | print("Trainable weights: {}".format([n for n, p in model.named_parameters() if p.requires_grad])) |
| | print("Training...") |
| | if evaluate == 'acc': |
| | print("Evaluation on test set prior training:", evaluate_using_acc(model, test_gen)) |
| | elif evaluate == 'weighted_f1': |
| | print("Evaluation on test set prior training:", evaluate_using_weighted_f1(model, test_gen, val_gen)) |
| |
|
| | fit_model(model, loss_op, optim_op, train_gen, val_gen, nb_epochs, checkpoint_path, patience) |
| |
|
| | |
| | |
| | sleep(1) |
| | model.load_state_dict(torch.load(checkpoint_path)) |
| | if verbose >= 2: |
| | print("Loaded weights from {}".format(checkpoint_path)) |
| |
|
| | if evaluate == 'acc': |
| | return evaluate_using_acc(model, test_gen) |
| | elif evaluate == 'weighted_f1': |
| | return evaluate_using_weighted_f1(model, test_gen, val_gen) |
| |
|
| |
|
| | def evaluate_using_weighted_f1(model, test_gen, val_gen): |
| | """ Evaluation function using macro weighted F1 score. |
| | |
| | # Arguments: |
| | model: Model to be evaluated. |
| | X_test: Inputs of the testing set. |
| | y_test: Outputs of the testing set. |
| | X_val: Inputs of the validation set. |
| | y_val: Outputs of the validation set. |
| | batch_size: Batch size. |
| | |
| | # Returns: |
| | Weighted F1 score of the given model. |
| | """ |
| | |
| | f1_test, _ = find_f1_threshold(model, test_gen, val_gen, average='weighted_f1') |
| | return f1_test |
| |
|
| |
|
| | def evaluate_using_acc(model, test_gen): |
| | """ Evaluation function using accuracy. |
| | |
| | # Arguments: |
| | model: Model to be evaluated. |
| | test_gen: Testing data iterator (DataLoader) |
| | |
| | # Returns: |
| | Accuracy of the given model. |
| | """ |
| |
|
| | |
| | model.eval() |
| | accs = [] |
| | for i, data in enumerate(test_gen): |
| | x, y = data |
| | outs = model(x) |
| | if model.nb_classes > 2: |
| | pred = torch.max(outs, 1)[1] |
| | acc = accuracy_score(y.squeeze().numpy(), pred.squeeze().numpy()) |
| | else: |
| | pred = (outs >= 0).long() |
| | acc = (pred == y).double().sum() / len(pred) |
| | accs.append(acc) |
| | return np.mean(accs) |
| |
|
| |
|
| | def chain_thaw(model, train_gen, val_gen, test_gen, nb_epochs, checkpoint_path, loss_op, |
| | patience=5, initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, evaluate='acc', verbose=1): |
| | """ Finetunes given model using chain-thaw and evaluates using accuracy. |
| | |
| | # Arguments: |
| | model: Model to be finetuned. |
| | train: Training data, given as a tuple of (inputs, outputs) |
| | val: Validation data, given as a tuple of (inputs, outputs) |
| | test: Testing data, given as a tuple of (inputs, outputs) |
| | batch_size: Batch size. |
| | loss: Loss function to be used during training. |
| | epoch_size: Number of samples in an epoch. |
| | nb_epochs: Number of epochs. |
| | checkpoint_weight_path: Filepath where weights will be checkpointed to |
| | during training. This file will be rewritten by the function. |
| | initial_lr: Initial learning rate. Will only be used for the first |
| | training step (i.e. the output_layer layer) |
| | next_lr: Learning rate for every subsequent step. |
| | seed: Random number generator seed. |
| | verbose: Verbosity flag. |
| | evaluate: Evaluation method to use. Can be 'acc' or 'weighted_f1'. |
| | |
| | # Returns: |
| | Accuracy of the finetuned model. |
| | """ |
| | if verbose: |
| | print('Training..') |
| |
|
| | |
| | train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path, |
| | initial_lr, next_lr, embed_l2, verbose) |
| |
|
| | if evaluate == 'acc': |
| | return evaluate_using_acc(model, test_gen) |
| | elif evaluate == 'weighted_f1': |
| | return evaluate_using_weighted_f1(model, test_gen, val_gen) |
| |
|
| |
|
| | def train_by_chain_thaw(model, train_gen, val_gen, loss_op, patience, nb_epochs, checkpoint_path, |
| | initial_lr=0.001, next_lr=0.0001, embed_l2=1E-6, verbose=1): |
| | """ Finetunes model using the chain-thaw method. |
| | |
| | This is done as follows: |
| | 1) Freeze every layer except the last (output_layer) layer and train it. |
| | 2) Freeze every layer except the first layer and train it. |
| | 3) Freeze every layer except the second etc., until the second last layer. |
| | 4) Unfreeze all layers and train entire model. |
| | |
| | # Arguments: |
| | model: Model to be trained. |
| | train_gen: Training sample generator. |
| | val_data: Validation data. |
| | loss: Loss function to be used. |
| | finetuning_args: Training early stopping and checkpoint saving parameters |
| | epoch_size: Number of samples in an epoch. |
| | nb_epochs: Number of epochs. |
| | checkpoint_weight_path: Where weight checkpoints should be saved. |
| | batch_size: Batch size. |
| | initial_lr: Initial learning rate. Will only be used for the first |
| | training step (i.e. the output_layer layer) |
| | next_lr: Learning rate for every subsequent step. |
| | verbose: Verbosity flag. |
| | """ |
| | |
| | layers = [m for m in model.children() if len([id(p) for p in m.parameters()]) != 0] |
| |
|
| | |
| | layers.insert(0, layers.pop(len(layers) - 1)) |
| |
|
| | |
| | layers.append(None) |
| |
|
| | lr = None |
| | |
| | |
| | for layer in layers: |
| | if lr is None: |
| | lr = initial_lr |
| | elif lr == initial_lr: |
| | lr = next_lr |
| |
|
| | |
| | for _layer in layers: |
| | if _layer is not None: |
| | trainable = _layer == layer or layer is None |
| | change_trainable(_layer, trainable=trainable, verbose=False) |
| |
|
| | |
| | for _layer in model.children(): |
| | assert all(p.requires_grad == (_layer == layer) for p in _layer.parameters()) or layer is None |
| |
|
| | if verbose: |
| | if layer is None: |
| | print('Finetuning all layers') |
| | else: |
| | print('Finetuning {}'.format(layer)) |
| |
|
| | special_params = [id(p) for p in model.embed.parameters()] |
| | base_params = [p for p in model.parameters() if id(p) not in special_params and p.requires_grad] |
| | embed_parameters = [p for p in model.parameters() if id(p) in special_params and p.requires_grad] |
| | adam = optim.Adam([ |
| | {'params': base_params}, |
| | {'params': embed_parameters, 'weight_decay': embed_l2}, |
| | ], lr=lr) |
| |
|
| | fit_model(model, loss_op, adam, train_gen, val_gen, nb_epochs, |
| | checkpoint_path, patience) |
| |
|
| | |
| | |
| | sleep(1) |
| | model.load_state_dict(torch.load(checkpoint_path)) |
| | if verbose >= 2: |
| | print("Loaded weights from {}".format(checkpoint_path)) |
| |
|
| |
|
| | def calc_loss(loss_op, pred, yv): |
| | if type(loss_op) is nn.CrossEntropyLoss: |
| | return loss_op(pred.squeeze(), yv.squeeze()) |
| | else: |
| | return loss_op(pred.squeeze(), yv.squeeze().float()) |
| |
|
| |
|
| | def fit_model(model, loss_op, optim_op, train_gen, val_gen, epochs, |
| | checkpoint_path, patience): |
| | """ Analog to Keras fit_generator function. |
| | |
| | # Arguments: |
| | model: Model to be finetuned. |
| | loss_op: loss operation (BCEWithLogitsLoss or CrossEntropy for e.g.) |
| | optim_op: optimization operation (Adam e.g.) |
| | train_gen: Training data iterator (DataLoader) |
| | val_gen: Validation data iterator (DataLoader) |
| | epochs: Number of epochs. |
| | checkpoint_path: Filepath where weights will be checkpointed to |
| | during training. This file will be rewritten by the function. |
| | patience: Patience for callback methods. |
| | verbose: Verbosity flag. |
| | |
| | # Returns: |
| | Accuracy of the trained model, ONLY if 'evaluate' is set. |
| | """ |
| | |
| | torch.save(model.state_dict(), checkpoint_path) |
| |
|
| | model.eval() |
| | best_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen]) |
| | print("original val loss", best_loss) |
| |
|
| | epoch_without_impr = 0 |
| | for epoch in range(epochs): |
| | for i, data in enumerate(train_gen): |
| | X_train, y_train = data |
| | X_train = Variable(X_train, requires_grad=False) |
| | y_train = Variable(y_train, requires_grad=False) |
| | model.train() |
| | optim_op.zero_grad() |
| | output = model(X_train) |
| | loss = calc_loss(loss_op, output, y_train) |
| | loss.backward() |
| | clip_grad_norm(model.parameters(), 1) |
| | optim_op.step() |
| |
|
| | acc = evaluate_using_acc(model, [(X_train.data, y_train.data)]) |
| | print("== Epoch", epoch, "step", i, "train loss", loss.data.cpu().numpy()[0], "train acc", acc) |
| |
|
| | model.eval() |
| | acc = evaluate_using_acc(model, val_gen) |
| | print("val acc", acc) |
| |
|
| | val_loss = np.mean([calc_loss(loss_op, model(Variable(xv)), Variable(yv)).data.cpu().numpy()[0] for xv, yv in val_gen]) |
| | print("val loss", val_loss) |
| | if best_loss is not None and val_loss >= best_loss: |
| | epoch_without_impr += 1 |
| | print('No improvement over previous best loss: ', best_loss) |
| |
|
| | |
| | if best_loss is None or val_loss < best_loss: |
| | best_loss = val_loss |
| | torch.save(model.state_dict(), checkpoint_path) |
| | print('Saving model at', checkpoint_path) |
| |
|
| | |
| | if epoch_without_impr >= patience: |
| | break |
| |
|
| | def get_data_loader(X_in, y_in, batch_size, extended_batch_sampler=True, epoch_size=25000, upsample=False, seed=42): |
| | """ Returns a dataloader that enables larger epochs on small datasets and |
| | has upsampling functionality. |
| | |
| | # Arguments: |
| | X_in: Inputs of the given dataset. |
| | y_in: Outputs of the given dataset. |
| | batch_size: Batch size. |
| | epoch_size: Number of samples in an epoch. |
| | upsample: Whether upsampling should be done. This flag should only be |
| | set on binary class problems. |
| | |
| | # Returns: |
| | DataLoader. |
| | """ |
| | dataset = DeepMojiDataset(X_in, y_in) |
| |
|
| | if extended_batch_sampler: |
| | batch_sampler = DeepMojiBatchSampler(y_in, batch_size, epoch_size=epoch_size, upsample=upsample, seed=seed) |
| | else: |
| | batch_sampler = BatchSampler(SequentialSampler(y_in), batch_size, drop_last=False) |
| |
|
| | return DataLoader(dataset, batch_sampler=batch_sampler, num_workers=0) |
| |
|
| | class DeepMojiDataset(Dataset): |
| | """ A simple Dataset class. |
| | |
| | # Arguments: |
| | X_in: Inputs of the given dataset. |
| | y_in: Outputs of the given dataset. |
| | |
| | # __getitem__ output: |
| | (torch.LongTensor, torch.LongTensor) |
| | """ |
| | def __init__(self, X_in, y_in): |
| | |
| | if not isinstance(X_in, torch.LongTensor): |
| | X_in = torch.from_numpy(X_in.astype('int64')).long() |
| | if not isinstance(y_in, torch.LongTensor): |
| | y_in = torch.from_numpy(y_in.astype('int64')).long() |
| |
|
| | self.X_in = torch.split(X_in, 1, dim=0) |
| | self.y_in = torch.split(y_in, 1, dim=0) |
| |
|
| | def __len__(self): |
| | return len(self.X_in) |
| |
|
| | def __getitem__(self, idx): |
| | return self.X_in[idx].squeeze(), self.y_in[idx].squeeze() |
| |
|
| | class DeepMojiBatchSampler(object): |
| | """A Batch sampler that enables larger epochs on small datasets and |
| | has upsampling functionality. |
| | |
| | # Arguments: |
| | y_in: Labels of the dataset. |
| | batch_size: Batch size. |
| | epoch_size: Number of samples in an epoch. |
| | upsample: Whether upsampling should be done. This flag should only be |
| | set on binary class problems. |
| | seed: Random number generator seed. |
| | |
| | # __iter__ output: |
| | iterator of lists (batches) of indices in the dataset |
| | """ |
| |
|
| | def __init__(self, y_in, batch_size, epoch_size, upsample, seed): |
| | self.batch_size = batch_size |
| | self.epoch_size = epoch_size |
| | self.upsample = upsample |
| |
|
| | np.random.seed(seed) |
| |
|
| | if upsample: |
| | |
| | assert len(y_in.shape) == 1 |
| | neg = np.where(y_in.numpy() == 0)[0] |
| | pos = np.where(y_in.numpy() == 1)[0] |
| | assert epoch_size % 2 == 0 |
| | samples_pr_class = int(epoch_size / 2) |
| | else: |
| | ind = range(len(y_in)) |
| |
|
| | if not upsample: |
| | |
| | self.sample_ind = np.random.choice(ind, epoch_size, replace=True) |
| | else: |
| | |
| | sample_neg = np.random.choice(neg, samples_pr_class, replace=True) |
| | sample_pos = np.random.choice(pos, samples_pr_class, replace=True) |
| | concat_ind = np.concatenate((sample_neg, sample_pos), axis=0) |
| |
|
| | |
| | |
| | p = np.random.permutation(len(concat_ind)) |
| | self.sample_ind = concat_ind[p] |
| |
|
| | label_dist = np.mean(y_in.numpy()[self.sample_ind]) |
| | assert(label_dist > 0.45) |
| | assert(label_dist < 0.55) |
| |
|
| | def __iter__(self): |
| | |
| | for i in range(int(self.epoch_size/self.batch_size)): |
| | start = i * self.batch_size |
| | end = min(start + self.batch_size, self.epoch_size) |
| | yield self.sample_ind[start:end] |
| |
|
| | def __len__(self): |
| | |
| | return (self.epoch_size + self.batch_size - 1) // self.batch_size |
| |
|