from typing import Callable, NamedTuple from typing import List, Tuple, Type, Union import torch from nltk import Tree from torch import Tensor from torch import nn from torch.distributions.utils import lazy_property from torchrua import C, segment_mean, L, Z from transformers.models.roberta.modeling_roberta import PreTrainedModel, RobertaModel from tmp.configuration_parserker import ParserkerConfig Frames = Union[List[Tensor], Tuple[Tensor, ...]] def diag(tensor: Tensor, offset: int) -> Tensor: return tensor.diagonal(offset=offset, dim1=1, dim2=2) def diag_scatter(chart: Tensor, score: Tensor, offset: int) -> None: chart.diagonal(offset=offset, dim1=1, dim2=2)[::] = score def left(chart: Tensor, offset: int) -> Tensor: b, t, _, *size = chart.size() c, n, m, *stride = chart.stride() return chart.as_strided( size=(b, t - offset, offset, *size), stride=(c, n + m, m, *stride), ) def right(chart: Tensor, offset: int) -> Tensor: b, t, _, *size = chart.size() c, n, m, *stride = chart.stride() return chart[:, 1:, offset:].as_strided( size=(b, t - offset, offset, *size), stride=(c, n + m, n, *stride), ) def to_hex(x: int, num_bits: int) -> str: return f'{x:0{(num_bits + 3) // 4}X}' def bits_to_long(tensor: Tensor) -> Tensor: *_, num_bits = tensor.size() index = torch.arange(num_bits, dtype=torch.long, device=tensor.device) return (tensor << index).sum(dim=-1) def long_to_bits(tensor: Tensor, num_bits: int) -> Tensor: index = torch.arange(num_bits, dtype=torch.long, device=tensor.device) return (tensor[..., None] >> index) & 1 def max(tensor: Tensor, dim: int, keepdim: bool = False) -> Tensor: return torch.max(tensor, dim=dim, keepdim=keepdim).values class Semiring(NamedTuple): zero: float one: float add: Callable mul: Callable sum: Callable prod: Callable Log = Semiring( zero=-float('inf'), one=0., add=torch.logaddexp, mul=torch.add, sum=torch.logsumexp, prod=torch.sum, ) Max = Semiring( zero=-float('inf'), one=0., add=torch.maximum, mul=torch.add, sum=max, prod=torch.sum, ) def cumsum(tensor: Tensor) -> Tensor: b, t1, t2, k = tensor.size() assert t1 == t2, f'{t1} != {t2}' p1 = tensor.permute(0, 3, 1, 2).triu() c1 = p1.cumsum(dim=-1) c2 = c1.flip(dims=[-2]).cumsum(dim=-2).flip(dims=[-2]) p2 = c2.permute(0, 2, 3, 1) return p2 def cky_partitions(logits: Tensor, token_sizes: Tensor, semiring: Type[Semiring]): logits = cumsum(logits) logits = torch.stack([torch.zeros_like(logits), logits], dim=-1) b, t, _, k, _ = logits.size() chart = torch.full_like(logits[..., 0, 0], fill_value=semiring.zero, requires_grad=False) z = diag(logits, offset=0)[..., None].permute([0, 3, 4, 1, 2]) frames = [z] z = semiring.sum(z, dim=-1) z = semiring.prod(z, dim=-1) diag_scatter(chart, z[..., 0], offset=0) index = torch.arange(t, dtype=chart.dtype, device=chart.device) for w in range(1, t): z = diag(logits, offset=w)[..., None].permute([0, 3, 4, 1, 2]) z = z - left(logits, offset=w) - right(logits, offset=w) z = z / ((1 + index[:w]) * (w - index[:w]))[:, None, None] frames.append(z) z = semiring.sum(z, dim=-1) z = semiring.prod(z, dim=-1) xyz = semiring.mul(z, semiring.mul(left(chart, offset=w), right(chart, offset=w))) score = semiring.sum(xyz, dim=-1) diag_scatter(chart, score, offset=w) index = torch.arange(b, dtype=torch.long, device=chart.device) return chart[index, 0, token_sizes - 1], frames class Distrubition(object): def __init__(self, logits: Tensor, token_sizes: Tensor) -> None: super(Distrubition, self).__init__() self.logits = logits self.token_sizes = token_sizes @lazy_property def log_partitions(self): partitions, frames = cky_partitions( logits=self.logits, token_sizes=self.token_sizes, semiring=Log, ) return partitions, frames @lazy_property def max(self): partitions, frames = cky_partitions( logits=self.logits, token_sizes=self.token_sizes, semiring=Max, ) return partitions, frames @lazy_property def marginals(self) -> Frames: partitions, frames = self.log_partitions return torch.autograd.grad( partitions, frames, torch.ones_like(partitions), create_graph=True, retain_graph=True, only_inputs=True, allow_unused=True, ) @lazy_property def grads(self) -> Frames: partitions, frames = self.max return torch.autograd.grad( partitions, frames, torch.ones_like(partitions), create_graph=False, retain_graph=False, only_inputs=True, allow_unused=True, ) @staticmethod def gather(marginals: Frames, grads: Frames, spans: Tensor): b, _, _, k, _ = marginals[0].size() xs, ys, zs = [], [], [] for w, (x, grad) in enumerate(zip(marginals, grads)): mask, y = grad.max(dim=-1, keepdim=True) mask = mask.sum(dim=-2, keepdim=True) > 0 z = diag(spans, offset=w)[..., None, None, None] xs.append(torch.masked_select(x, mask)) ys.append(torch.masked_select(y, mask)) zs.append(torch.masked_select(z, mask)) xs = torch.cat(xs, dim=0).view((-1, k, 2)) ys = torch.cat(ys, dim=0).view((-1, k)) zs = torch.cat(zs, dim=0) return xs, ys, zs @lazy_property def argmax(self) -> C: b, t, _, _, _ = self.grads[0].size() b = torch.arange(b, dtype=torch.long, device=self.grads[0].device) x = torch.arange(t, dtype=torch.long, device=self.grads[0].device) y = torch.arange(t, dtype=torch.long, device=self.grads[0].device) b, x, y = torch.broadcast_tensors(b[:, None, None], x[None, :, None], y[None, None, :]) data = [] for w, grad in enumerate(self.grads): mask, z = grad.max(dim=-1, keepdim=False) mask = mask.sum(dim=-1, keepdim=False) > 0 data.append(torch.stack([ torch.masked_select(diag(b, offset=w)[..., None], mask), torch.masked_select(diag(x, offset=w)[..., None], mask), torch.masked_select(diag(y, offset=w)[..., None], mask), torch.masked_select(bits_to_long(z), mask), ], dim=-1)) data = torch.cat(data, dim=0) b = torch.argsort(data[..., 0], dim=0, descending=False) return C(data=data[b, 1:], token_sizes=self.token_sizes * 2 - 1) class HashLayer(nn.Module): def __init__(self, config: ParserkerConfig) -> None: super(HashLayer, self).__init__() self.num_bits = config.num_bits self.bit_size = (config.hidden_size + config.num_bits - 1) // config.num_bits self.scale = self.bit_size ** -0.5 self.q_proj = nn.Linear(config.hidden_size, self.num_bits * self.bit_size, bias=True) self.k_proj = nn.Linear(config.hidden_size, self.num_bits * self.bit_size, bias=True) def forward(self, q: Tensor, k: Tensor): q = self.q_proj(q).unflatten(dim=-1, sizes=(self.num_bits, 1, self.bit_size)) k = self.k_proj(k).unflatten(dim=-1, sizes=(self.num_bits, self.bit_size, 1)) return (q[:, :, None] @ k[:, None, :]).flatten(start_dim=-3).transpose(1, 2) * self.scale class ParserkerModel(PreTrainedModel): config_class = ParserkerConfig base_model_prefix = "backbone" _tied_weights_keys = {} def __init__(self, config: ParserkerConfig, **kwargs): super(ParserkerModel, self).__init__(config=config, **kwargs) self.pad_token_id = config.pad_token_id self.num_bits = config.num_bits self.backbone = RobertaModel(config, add_pooling_layer=False) self.hash_layer = HashLayer(config) @property def all_tied_weights_keys(self): return getattr(self, "_tied_weights_keys", []) def forward(self, input_ids: Z, duration: Z) -> Tensor: out = self.backbone.forward( input_ids=input_ids.left(self.pad_token_id).data, attention_mask=input_ids.bmask(), return_dict=True, ) tensor = L(data=out.last_hidden_state, token_sizes=input_ids.cat().token_sizes) tensor, token_sizes = tensor.seg(duration, segment_mean).trunc((1, 1)) logits = self.hash_layer(tensor, tensor) return L(data=logits, token_sizes=token_sizes) def parse(self, input_ids: Z, duration: C): logits, token_sizes = self(input_ids, duration) logits = logits.clone().requires_grad_(True) dist = Distrubition(logits=logits, token_sizes=token_sizes) return dist.argmax def to_tree(self, words, spans) -> Tree: stack = [] for x, y, z in sorted(spans, key=lambda item: (item[0], -item[1]), reverse=True): children = [] while len(stack) > 0: xx, yy, zz = stack.pop() if x <= xx and yy <= y: children.append(zz) else: stack.append((xx, yy, zz)) break if len(children) == 0: children = ['__tok'] stack.append((x, y, Tree(to_hex(z, self.num_bits), children))) [(_, _, tree)] = stack for index in range(len(tree.leaves())): position = tree.leaf_treeposition(index) tree[position] = words[index] return tree