PatrickHaller's picture
Upload ngme.py
af87020
import math
from typing import Optional, List
from functools import lru_cache
import unittest
import torch
import torch.nn.functional as F
n_dists = {
0: [1],
1: [0.4, 0.6],
2: [0.2, 0.3, 0.5],
3: [0.1, 0.2, 0.3, 0.4],
4: [0.1, 0.15, 0.2, 0.25, 0.3],
}
strats = {"linear": lambda x: x, "log": lambda x: math.log(x + 1), "exp": lambda x: x**2}
@lru_cache(maxsize=5)
def soft_dist(n):
return [1 / n] * n
@lru_cache(maxsize=5)
def n_dist(n: int, strategy: str) -> list[float]:
"""dist of ngram weight is logarithmic"""
ns = list(range(1, n + 1))
xs = list(map(strats[strategy], ns))
result = list(map(lambda x: x / sum(xs), xs))
return result
def soft_n_hot(
input,
num_classes: int,
strategy: Optional[str],
):
shape = list(input.size())[1:]
shape.append(num_classes)
ret = torch.zeros(shape).to(input.device)
if strategy:
soft_labels = n_dist(input.size(0), strategy)
else:
soft_labels = [1] * input.size(0)
for i, t in enumerate(input):
ret.scatter_(-1, t.unsqueeze(-1), soft_labels[i])
return ret
def n_hot(t, num_clases, ngram_sequences: Optional[torch.Tensor] = None, unk_idx: Optional[int] = None):
shape = list(t.size())
if ngram_sequences is not None:
shape.append(num_clases)
ret = torch.zeros(shape).to(t.device)
ret.scatter_(-1, t.unsqueeze(-1), 1)
for seq in ngram_sequences:
if unk_idx is not None:
mask = torch.eq(seq, unk_idx)
seq[mask] = t[mask]
ret.scatter_(-1, seq.unsqueeze(-1), 1)
return ret
elif len(shape) == 2:
return F.one_hot(t, num_classes=num_clases).float()
else:
shape = shape[1:]
shape.append(num_clases)
ret = torch.zeros(shape).to(t.device)
# Expect that first dimension is for all n-grams
for seq in t:
ret.scatter_(-1, seq.unsqueeze(-1), 1)
return ret
class NGramsEmbedding(torch.nn.Embedding):
"""N-Hot encoder"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2,
scale_grad_by_freq: bool = False,
sparse: bool = False,
_weight: Optional[torch.Tensor] = None,
device=None,
dtype=None,
unk_idx: Optional[int] = None
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse,
_weight=_weight,
device=device,
dtype=dtype,
)
self.num_classes = num_embeddings
self.unk_idx = unk_idx
def forward(self, input: torch.Tensor, ngram_sequences: Optional[torch.Tensor] = None):
return self._forward(
n_hot(input, self.num_classes, ngram_sequences, self.unk_idx)
)
def _forward(self, n_hot: torch.Tensor) -> torch.Tensor:
return F.linear(n_hot, self.weight.t())
def collect_n_gram_sequences(**kwargs) -> List[torch.Tensor]:
sequences = []
for n in range(2, len(kwargs)+2):
s = kwargs[f"gram_{n}_sequence"]
if s is not None:
sequences.append(s)
else:
break
return sequences
def shift_with_pad(target_tensor, n, from_tensor):
shifted = target_tensor[:, n:]
seq_size = target_tensor.size(1) - 1
missing_idxs = torch.arange(seq_size - (n-1), seq_size).to(target_tensor.device)
# Pad with missing idxs from unigram tensor
shifted = torch.concat(
(shifted, from_tensor.index_select(1, missing_idxs)), dim=1
)
return shifted
class TestNGME(unittest.TestCase):
def test_one_hot(self):
t = torch.tensor([[0, 1, 2]])
ret = n_hot(t, 3)
expected = torch.eye(3)
assert torch.all(torch.eq(ret, expected))
def test_multi_hot1(self):
t = torch.tensor([[0, 1, 2]])
# [batch, ngram, seq]
two_grams = torch.tensor([[[0, 1, 2]]])
ret = n_hot(t, 3, two_grams)
expected = torch.eye(3)
assert torch.all(torch.eq(ret, expected))
def test_multi_hot2(self):
t = torch.tensor([[0, 1, 2]])
two_three_grams = torch.tensor([[[1, 2, 0]], [[2, 0, 1]]])
ret = n_hot(t, 3, two_three_grams)
expected = torch.ones(3, 3)
assert torch.all(torch.eq(ret, expected))
class TestShifting(unittest.TestCase):
def test_two_gram(self):
two_gram_batch = torch.tensor([[0, 1, 2, 3, 4]])
from_tensor = torch.tensor([[-4, -3, -2, -1]])
shifted = _shift_with_pad(two_gram_batch, 2, from_tensor)
expected = torch.tensor([[2, 3, 4, -1]])
assert torch.all(torch.eq(shifted, expected))
def test_three_gram(self):
three_gram_batch = torch.tensor([[0, 1, 2, 3, 4, 5, 6]])
from_tensor = torch.tensor([[-6, -5, -4, -3, -2, -1]])
shifted = _shift_with_pad(three_gram_batch, 3, from_tensor)
expected = torch.tensor([[3, 4, 5, 6, -2, -1]])
assert torch.all(torch.eq(shifted, expected))
def test_three_gram_2(self):
three_gram_batch = torch.tensor([[0, 1, 2, 3, 4, 5, 6], [0, 1, 2, 3, 4, 5, 6]])
from_tensor = torch.tensor([[-6, -5, -4, -3, -2, -1], [-6, -5, -4, -3, -2, -1]])
shifted = _shift_with_pad(three_gram_batch, 3, from_tensor)
expected = torch.tensor([[3, 4, 5, 6, -2, -1], [3, 4, 5, 6, -2, -1]])
assert torch.all(torch.eq(shifted, expected))
class TestNGramEmbeddings(unittest.TestCase):
def test(self):
emb = NGramsEmbedding(10, 10)
emb.weight = torch.nn.Parameter(torch.eye(10))
emb1 = emb(torch.tensor([[1, 2, 3]]))
emb2 = emb(torch.tensor([[4, 5, 6]]))
emb3 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[4, 5, 6]])])
assert torch.all(torch.eq(torch.add(emb1, emb2), emb3))
def test_2(self):
emb = NGramsEmbedding(10, 10)
emb.weight = torch.nn.Parameter(torch.eye(10))
emb1 = emb(torch.tensor([[1, 2, 3]]))
emb2 = emb(torch.tensor([[1, 2, 3]]))
emb3 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[1, 2, 3]])])
assert torch.all(torch.eq(emb1, emb3))
assert torch.all(torch.eq(emb2, emb3))
def test_3_gram(self):
emb = NGramsEmbedding(10, 10)
emb.weight = torch.nn.Parameter(torch.eye(10))
emb1 = emb(torch.tensor([[1, 2, 3]]))
emb2 = emb(torch.tensor([[4, 5, 6]]))
emb3 = emb(torch.tensor([[7, 8, 9]]))
emb4 = emb(torch.tensor([[1, 2, 3]]), [torch.tensor([[4, 5, 6]]), torch.tensor([[7, 8, 9]])])
assert torch.all(torch.eq(torch.add(torch.add(emb1, emb2), emb3), emb4))
def test_ignore_indx(self):
emb = NGramsEmbedding(10, 10, unk_idx=0)
emb.weight = torch.nn.Parameter(torch.eye(10))
unigram = torch.tensor([[1, 2, 3]])
bigram = torch.tensor([[0, 0, 0]])
emb1 = emb(unigram, [bigram])
emb2 = emb(unigram)
assert torch.all(torch.eq(emb1, emb2))
def test_ignore_indx_2(self):
emb = NGramsEmbedding(10, 10, unk_idx=0)
emb.weight = torch.nn.Parameter(torch.eye(10))
unigram = torch.tensor([[0, 2, 3]])
bigram = torch.tensor([[0, 0, 0]])
emb1 = emb(unigram, [bigram])
emb2 = emb(unigram)
assert torch.all(torch.eq(emb1, emb2))
if __name__ == '__main__':
unittest.main()