File size: 2,744 Bytes
4322ff0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
from torch import Tensor


def pad_sequences(sequences: list[Tensor], padding_value: int) -> Tensor:
    """
    Stack 1d tensors (sequences) into a single 2d tensor so that each sequence is padded on the
    right.
    """
    return torch.nn.utils.rnn.pad_sequence(sequences, padding_value=padding_value, batch_first=True)


def _build_condition_mask(sentences: list[list[str]], condition_fn: callable, device) -> Tensor:
    masks = [
        torch.tensor([condition_fn(word) for word in sentence], dtype=bool, device=device)
        for sentence in sentences
    ]
    return pad_sequences(masks, padding_value=False)

def build_padding_mask(sentences: list[list[str]], device) -> Tensor:
    return _build_condition_mask(sentences, condition_fn=lambda word: True, device=device)

def build_null_mask(sentences: list[list[str]], device) -> Tensor:
    return _build_condition_mask(sentences, condition_fn=lambda word: word != "#NULL", device=device)


def pairwise_mask(masks1d: Tensor) -> Tensor:
    """
    Calculate an outer product of a mask, i.e. masks2d[:, i, j] = masks1d[:, i] & masks1d[:, j].
    """
    return masks1d[:, None, :] & masks1d[:, :, None]


# Credits: https://docs.allennlp.org/main/api/nn/util/#replace_masked_values
def replace_masked_values(tensor: Tensor, mask: Tensor, replace_with: float):
    """
    Replace all masked values in tensor with `replace_with`.
    """
    assert tensor.dim() == mask.dim(), "tensor.dim() of {tensor.dim()} != mask.dim() of {mask.dim()}"
    tensor.masked_fill_(~mask, replace_with)


def prepend_cls(sentences: list[list[str]]) -> list[list[str]]:
    """
    Return a copy of sentences with [CLS] token prepended.
    """
    return [["[CLS]", *sentence] for sentence in sentences]

def remove_nulls(sentences: list[list[str]]) -> list[list[str]]:
    """
    Return a copy of sentences with nulls removed.
    """
    return [[word for word in sentence if word != "#NULL"] for sentence in sentences]

def add_nulls(sentences: list[list[str]], counting_mask) -> list[list[str]]:
    """
    Return a copy of sentences with nulls restored according to counting masks.
    """
    sentences_with_nulls = []
    for sentence, counting_mask in zip(sentences, counting_mask, strict=True):
        sentence_with_nulls = []
        assert 0 < len(counting_mask)
        # Account for leading (CLS) auxiliary token. 
        sentence_with_nulls.extend(["#NULL"] * counting_mask[0])
        for word, n_nulls_to_insert in zip(sentence, counting_mask[1:], strict=True):
            sentence_with_nulls.append(word)
            sentence_with_nulls.extend(["#NULL"] * n_nulls_to_insert)
        sentences_with_nulls.append(sentence_with_nulls)
    return sentences_with_nulls