HuBrainV5-Preview / model /HungarianEncoder.py
Braien's picture
Upload folder using huggingface_hub
900b898 verified
import torch
import torch.nn as nn
import math
class RotaryPositionalEmbedding(nn.Module):
def __init__(self, dim, max_len=5000):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.max_len = max_len
def forward(self, seq_len, device):
t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos(), emb.sin()
def apply_rope(x, cos, sin):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
x_rotated = torch.cat((-x2, x1), dim=-1)
return (x * cos) + (x_rotated * sin)
class HungarianEncoder(nn.Module):
def __init__(
self, d_model=1536, nhead=24, num_layers=12, dim_feedforward=4096, dropout=0.1
):
super().__init__()
self.d_model = d_model
# Horgonyok (POS + 63 Word) = 256 dimenzió (4 fej * 64)
self.anchor_dim = 256
self.context_dim = d_model - self.anchor_dim # 1280 dimenzió (20 fej * 64)
self.rope = RotaryPositionalEmbedding(self.context_dim)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
batch_first=True,
norm_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.norm = nn.LayerNorm(self.context_dim)
def forward(self, src, src_key_padding_mask=None, predict=None):
B, S, D = src.shape
src_to_process = src.clone()
if predict is not None:
# -1.0 a maszk, mert a 0.0 egy érvényes adat (00 bit)
src_to_process[predict] = -1.0
cos, sin = self.rope(S, src.device)
cos = cos.unsqueeze(0)
sin = sin.unsqueeze(0)
anchors = src_to_process[:, :, : self.anchor_dim]
context = src_to_process[:, :, self.anchor_dim :]
context_rotated = apply_rope(context, cos, sin)
src_ready = torch.cat([anchors, context_rotated], dim=-1)
output = self.transformer(src_ready, src_key_padding_mask=src_key_padding_mask)
# Split output: Anchor (raw) vs Context (normed)
out_anchor = output[:, :, : self.anchor_dim]
out_context = output[:, :, self.anchor_dim :]
out_context = self.norm(out_context)
output = torch.cat([out_anchor, out_context], dim=-1)
if self.training:
return output
else:
# Eval módban: ha predict van (MLM), akkor a kimenetet adjuk vissza!
if predict is not None:
return output
# Ha sima embedding (nincs predict), akkor az eredeti anchort fűzzük vissza
return torch.cat(
[src[:, :, : self.anchor_dim], output[:, :, self.anchor_dim :]], dim=-1
)