| | import torch |
| | from torch import Tensor |
| |
|
| |
|
| | def pad_sequences(sequences: list[Tensor], padding_value: int) -> Tensor: |
| | """ |
| | Stack 1d tensors (sequences) into a single 2d tensor so that each sequence is padded on the |
| | right. |
| | """ |
| | return torch.nn.utils.rnn.pad_sequence(sequences, padding_value=padding_value, batch_first=True) |
| |
|
| |
|
| | def _build_condition_mask(sentences: list[list[str]], condition_fn: callable, device) -> Tensor: |
| | masks = [ |
| | torch.tensor([condition_fn(word) for word in sentence], dtype=bool, device=device) |
| | for sentence in sentences |
| | ] |
| | return pad_sequences(masks, padding_value=False) |
| |
|
| | def build_padding_mask(sentences: list[list[str]], device) -> Tensor: |
| | return _build_condition_mask(sentences, condition_fn=lambda word: True, device=device) |
| |
|
| | def build_null_mask(sentences: list[list[str]], device) -> Tensor: |
| | return _build_condition_mask(sentences, condition_fn=lambda word: word != "#NULL", device=device) |
| |
|
| |
|
| | def pairwise_mask(masks1d: Tensor) -> Tensor: |
| | """ |
| | Calculate an outer product of a mask, i.e. masks2d[:, i, j] = masks1d[:, i] & masks1d[:, j]. |
| | """ |
| | return masks1d[:, None, :] & masks1d[:, :, None] |
| |
|
| |
|
| | |
| | def replace_masked_values(tensor: Tensor, mask: Tensor, replace_with: float): |
| | """ |
| | Replace all masked values in tensor with `replace_with`. |
| | """ |
| | assert tensor.dim() == mask.dim(), "tensor.dim() of {tensor.dim()} != mask.dim() of {mask.dim()}" |
| | tensor.masked_fill_(~mask, replace_with) |
| |
|
| |
|
| | def prepend_cls(sentences: list[list[str]]) -> list[list[str]]: |
| | """ |
| | Return a copy of sentences with [CLS] token prepended. |
| | """ |
| | return [["[CLS]", *sentence] for sentence in sentences] |
| |
|
| | def remove_nulls(sentences: list[list[str]]) -> list[list[str]]: |
| | """ |
| | Return a copy of sentences with nulls removed. |
| | """ |
| | return [[word for word in sentence if word != "#NULL"] for sentence in sentences] |
| |
|
| | def add_nulls(sentences: list[list[str]], counting_mask) -> list[list[str]]: |
| | """ |
| | Return a copy of sentences with nulls restored according to counting masks. |
| | """ |
| | sentences_with_nulls = [] |
| | for sentence, counting_mask in zip(sentences, counting_mask, strict=True): |
| | sentence_with_nulls = [] |
| | assert 0 < len(counting_mask) |
| | |
| | sentence_with_nulls.extend(["#NULL"] * counting_mask[0]) |
| | for word, n_nulls_to_insert in zip(sentence, counting_mask[1:], strict=True): |
| | sentence_with_nulls.append(word) |
| | sentence_with_nulls.extend(["#NULL"] * n_nulls_to_insert) |
| | sentences_with_nulls.append(sentence_with_nulls) |
| | return sentences_with_nulls |