""" Utility functions for data transformations. """ import logging import random import torch import stanza.models.common.seq2seq_constant as constant from stanza.models.common.doc import HEAD, ID, UPOS logger = logging.getLogger('stanza') def map_to_ids(tokens, vocab): ids = [vocab[t] if t in vocab else constant.UNK_ID for t in tokens] return ids def get_long_tensor(tokens_list, batch_size, pad_id=constant.PAD_ID): """ Convert (list of )+ tokens to a padded LongTensor. """ sizes = [] x = tokens_list while isinstance(x[0], list): sizes.append(max(len(y) for y in x)) x = [z for y in x for z in y] # TODO: pass in a device parameter and put it directly on the relevant device? # that might be faster than creating it and then moving it tokens = torch.LongTensor(batch_size, *sizes).fill_(pad_id) for i, s in enumerate(tokens_list): tokens[i, :len(s)] = torch.LongTensor(s) return tokens def get_float_tensor(features_list, batch_size): if features_list is None or features_list[0] is None: return None seq_len = max(len(x) for x in features_list) feature_len = len(features_list[0][0]) features = torch.FloatTensor(batch_size, seq_len, feature_len).zero_() for i,f in enumerate(features_list): features[i,:len(f),:] = torch.FloatTensor(f) return features def sort_all(batch, lens): """ Sort all fields by descending order of lens, and return the original indices. """ if batch == [[]]: return [[]], [] unsorted_all = [lens] + [range(len(lens))] + list(batch) sorted_all = [list(t) for t in zip(*sorted(zip(*unsorted_all), reverse=True))] return sorted_all[2:], sorted_all[1] def get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate, desired_ratio=0.1, max_ratio=0.5): """ Returns X so that if you randomly select X * N sentences, you get 10% The ratio will be chosen in the assumption that the final dataset is of size N rather than N + X * N. should_augment_predicate: returns True if the sentence has some feature which we may want to change occasionally. for example, depparse sentences which end in punct can_augment_predicate: in the depparse sentences example, it is technically possible for the punct at the end to be the parent of some other word in the sentence. in that case, the sentence should not be chosen. should be at least as restrictive as should_augment_predicate """ n_data = len(train_data) n_should_augment = sum(should_augment_predicate(sentence) for sentence in train_data) n_can_augment = sum(can_augment_predicate(sentence) for sentence in train_data) n_error = sum(can_augment_predicate(sentence) and not should_augment_predicate(sentence) for sentence in train_data) if n_error > 0: raise AssertionError("can_augment_predicate allowed sentences not allowed by should_augment_predicate") if n_can_augment == 0: logger.warning("Found no sentences which matched can_augment_predicate {}".format(can_augment_predicate)) return 0.0 n_needed = n_data * desired_ratio - (n_data - n_should_augment) # if we want 10%, for example, and more than 10% already matches, we can skip if n_needed < 0: return 0.0 ratio = n_needed / n_can_augment if ratio > max_ratio: return max_ratio return ratio def should_augment_nopunct_predicate(sentence): last_word = sentence[-1] return last_word.get(UPOS, None) == 'PUNCT' def can_augment_nopunct_predicate(sentence): """ Check that the sentence ends with PUNCT and also doesn't have any words which depend on the last word """ last_word = sentence[-1] if last_word.get(UPOS, None) != 'PUNCT': return False # don't cut off MWT if len(last_word[ID]) > 1: return False if any(len(word[ID]) == 1 and word[HEAD] == last_word[ID][0] for word in sentence): return False return True def augment_punct(train_data, augment_ratio, should_augment_predicate=should_augment_nopunct_predicate, can_augment_predicate=can_augment_nopunct_predicate, keep_original_sentences=True): """ Adds extra training data to compensate for some models having all sentences end with PUNCT Some of the models (for example, UD_Hebrew-HTB) have the flaw that all of the training sentences end with PUNCT. The model therefore learns to finish every sentence with punctuation, even if it is given a sentence with non-punct at the end. One simple way to fix this is to train on some fraction of training data with punct. Params: train_data: list of list of dicts, eg a conll doc augment_ratio: the fraction to augment. if None, a best guess is made to get to 10% should_augment_predicate: a function which returns T/F if a sentence already ends with not PUNCT can_augment_predicate: a function which returns T/F if it makes sense to remove the last PUNCT TODO: do this dynamically, as part of the DataLoader or elsewhere? One complication is the data comes back from the DataLoader as tensors & indices, so it is much more complicated to manipulate """ if len(train_data) == 0: return [] if augment_ratio is None: augment_ratio = get_augment_ratio(train_data, should_augment_predicate, can_augment_predicate) if augment_ratio <= 0: if keep_original_sentences: return list(train_data) else: return [] new_data = [] for sentence in train_data: if can_augment_predicate(sentence): if random.random() < augment_ratio and len(sentence) > 1: # todo: could deep copy the words # or not deep copy any of this new_sentence = list(sentence[:-1]) new_data.append(new_sentence) elif keep_original_sentences: new_data.append(new_sentence) return new_data