|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class ByteLatentEncoder(nn.Module): |
|
|
""" |
|
|
Encodes raw byte sequences into latent patch representations. |
|
|
|
|
|
This module replaces traditional tokenizers by learning to compress |
|
|
raw bytes directly into a higher-dimensional latent space. |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
patch_size: int = 4, |
|
|
dropout: float = 0.1, |
|
|
max_len: int = 4096 |
|
|
): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.patch_size = patch_size |
|
|
|
|
|
|
|
|
self.byte_embedding = nn.Embedding(256, d_model) |
|
|
|
|
|
|
|
|
self.patch_conv = nn.Conv1d( |
|
|
in_channels=d_model, |
|
|
out_channels=d_model, |
|
|
kernel_size=patch_size, |
|
|
stride=patch_size, |
|
|
padding=0 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.register_buffer("inv_freq", 1.0 / (10000 ** (torch.arange(0, d_model, 2).float() / d_model))) |
|
|
|
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def apply_rope(self, x: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
B, N, D = x.shape |
|
|
|
|
|
|
|
|
t = torch.arange(N, device=x.device).type_as(self.inv_freq) |
|
|
freqs = torch.einsum('i,j->ij', t, self.inv_freq) |
|
|
emb = torch.cat((freqs, freqs), dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x1 = x[..., :D//2] |
|
|
x2 = x[..., D//2:] |
|
|
rotate_half_x = torch.cat((-x2, x1), dim=-1) |
|
|
|
|
|
return x * emb.cos() + rotate_half_x * emb.sin() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: (Batch, Seq_Len) tensor of uint8 bytes (0-255) |
|
|
|
|
|
Returns: |
|
|
latents: (Batch, Seq_Len // patch_size, d_model) |
|
|
""" |
|
|
|
|
|
x = self.byte_embedding(x.long()) |
|
|
|
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
|
|
|
x = self.patch_conv(x) |
|
|
|
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
|
|
|
x = self.apply_rope(x) |
|
|
|
|
|
|
|
|
x = self.norm(x) |
|
|
x = self.dropout(x) |
|
|
|
|
|
return x |
|
|
|