import torch import torch.nn as nn import math class RotaryPositionalEmbedding(nn.Module): def __init__(self, dim, max_len=5000): super().__init__() inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("inv_freq", inv_freq) self.max_len = max_len def forward(self, seq_len, device): t = torch.arange(seq_len, device=device).type_as(self.inv_freq) freqs = torch.einsum("i,j->ij", t, self.inv_freq) emb = torch.cat((freqs, freqs), dim=-1) return emb.cos(), emb.sin() def apply_rope(x, cos, sin): x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] x_rotated = torch.cat((-x2, x1), dim=-1) return (x * cos) + (x_rotated * sin) class HungarianEncoder(nn.Module): def __init__( self, d_model=1536, nhead=24, num_layers=12, dim_feedforward=4096, dropout=0.1 ): super().__init__() self.d_model = d_model # Horgonyok (POS + 63 Word) = 256 dimenzió (4 fej * 64) self.anchor_dim = 256 self.context_dim = d_model - self.anchor_dim # 1280 dimenzió (20 fej * 64) self.rope = RotaryPositionalEmbedding(self.context_dim) encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True, norm_first=True, ) self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) self.norm = nn.LayerNorm(self.context_dim) def forward(self, src, src_key_padding_mask=None, predict=None): B, S, D = src.shape src_to_process = src.clone() if predict is not None: # -1.0 a maszk, mert a 0.0 egy érvényes adat (00 bit) src_to_process[predict] = -1.0 cos, sin = self.rope(S, src.device) cos = cos.unsqueeze(0) sin = sin.unsqueeze(0) anchors = src_to_process[:, :, : self.anchor_dim] context = src_to_process[:, :, self.anchor_dim :] context_rotated = apply_rope(context, cos, sin) src_ready = torch.cat([anchors, context_rotated], dim=-1) output = self.transformer(src_ready, src_key_padding_mask=src_key_padding_mask) # Split output: Anchor (raw) vs Context (normed) out_anchor = output[:, :, : self.anchor_dim] out_context = output[:, :, self.anchor_dim :] out_context = self.norm(out_context) output = torch.cat([out_anchor, out_context], dim=-1) if self.training: return output else: # Eval módban: ha predict van (MLM), akkor a kimenetet adjuk vissza! if predict is not None: return output # Ha sima embedding (nincs predict), akkor az eredeti anchort fűzzük vissza return torch.cat( [src[:, :, : self.anchor_dim], output[:, :, self.anchor_dim :]], dim=-1 )