| | import torch
|
| | import torch.nn as nn
|
| | import math
|
| |
|
| |
|
| | class RotaryPositionalEmbedding(nn.Module):
|
| | def __init__(self, dim, max_len=5000):
|
| | super().__init__()
|
| | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
| | self.register_buffer("inv_freq", inv_freq)
|
| | self.max_len = max_len
|
| |
|
| | def forward(self, seq_len, device):
|
| | t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
|
| | freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| | emb = torch.cat((freqs, freqs), dim=-1)
|
| | return emb.cos(), emb.sin()
|
| |
|
| |
|
| | def apply_rope(x, cos, sin):
|
| | x1 = x[..., : x.shape[-1] // 2]
|
| | x2 = x[..., x.shape[-1] // 2 :]
|
| | x_rotated = torch.cat((-x2, x1), dim=-1)
|
| | return (x * cos) + (x_rotated * sin)
|
| |
|
| |
|
| | class HungarianEncoder(nn.Module):
|
| | def __init__(
|
| | self, d_model=1536, nhead=24, num_layers=12, dim_feedforward=4096, dropout=0.1
|
| | ):
|
| | super().__init__()
|
| | self.d_model = d_model
|
| |
|
| | self.anchor_dim = 256
|
| | self.context_dim = d_model - self.anchor_dim
|
| |
|
| | self.rope = RotaryPositionalEmbedding(self.context_dim)
|
| |
|
| | encoder_layer = nn.TransformerEncoderLayer(
|
| | d_model=d_model,
|
| | nhead=nhead,
|
| | dim_feedforward=dim_feedforward,
|
| | dropout=dropout,
|
| | batch_first=True,
|
| | norm_first=True,
|
| | )
|
| | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| | self.norm = nn.LayerNorm(self.context_dim)
|
| |
|
| | def forward(self, src, src_key_padding_mask=None, predict=None):
|
| | B, S, D = src.shape
|
| | src_to_process = src.clone()
|
| |
|
| | if predict is not None:
|
| |
|
| | src_to_process[predict] = -1.0
|
| |
|
| | cos, sin = self.rope(S, src.device)
|
| | cos = cos.unsqueeze(0)
|
| | sin = sin.unsqueeze(0)
|
| |
|
| | anchors = src_to_process[:, :, : self.anchor_dim]
|
| | context = src_to_process[:, :, self.anchor_dim :]
|
| |
|
| | context_rotated = apply_rope(context, cos, sin)
|
| | src_ready = torch.cat([anchors, context_rotated], dim=-1)
|
| |
|
| | output = self.transformer(src_ready, src_key_padding_mask=src_key_padding_mask)
|
| |
|
| |
|
| | out_anchor = output[:, :, : self.anchor_dim]
|
| | out_context = output[:, :, self.anchor_dim :]
|
| |
|
| | out_context = self.norm(out_context)
|
| |
|
| | output = torch.cat([out_anchor, out_context], dim=-1)
|
| |
|
| | if self.training:
|
| | return output
|
| | else:
|
| |
|
| | if predict is not None:
|
| | return output
|
| |
|
| |
|
| | return torch.cat(
|
| | [src[:, :, : self.anchor_dim], output[:, :, self.anchor_dim :]], dim=-1
|
| | )
|
| |
|