|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from typing import Optional |
|
|
from .encoder import ByteLatentEncoder |
|
|
from .layers import HybridBlock |
|
|
from .reasoning import RecurrentReasoningBlock |
|
|
|
|
|
class LocalAutoregressiveHead(nn.Module): |
|
|
def __init__(self, d_model, patch_size, hidden_dim=256): |
|
|
super().__init__() |
|
|
self.patch_size = patch_size |
|
|
self.proj_latent = nn.Linear(d_model, hidden_dim) |
|
|
self.byte_emb = nn.Embedding(256, hidden_dim) |
|
|
self.rnn = nn.GRU(hidden_dim * 2, hidden_dim, batch_first=True) |
|
|
self.head = nn.Linear(hidden_dim, 256) |
|
|
|
|
|
def forward(self, latents, target_bytes=None, temperature=0.0): |
|
|
B, N, D = latents.shape |
|
|
latent_context = self.proj_latent(latents).view(B * N, 1, -1) |
|
|
|
|
|
if target_bytes is not None: |
|
|
targets = target_bytes.view(B, N, self.patch_size) |
|
|
flat_targets = targets.contiguous().view(B * N, self.patch_size) |
|
|
sos = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device) |
|
|
rnn_inputs_bytes = torch.cat([sos, flat_targets[:, :-1]], dim=1) |
|
|
emb = self.byte_emb(rnn_inputs_bytes) |
|
|
latent_expanded = latent_context.expand(-1, self.patch_size, -1) |
|
|
rnn_input = torch.cat([emb, latent_expanded], dim=-1) |
|
|
out, _ = self.rnn(rnn_input) |
|
|
logits = self.head(out) |
|
|
return logits.view(B, N, self.patch_size, 256) |
|
|
else: |
|
|
|
|
|
|
|
|
return self._inference(latents, latent_context, temperature) |
|
|
|
|
|
def _inference(self, latents, latent_context, temperature): |
|
|
|
|
|
B, N, _ = latents.shape |
|
|
pred_bytes = [] |
|
|
current_input = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device) |
|
|
hidden = None |
|
|
for i in range(self.patch_size): |
|
|
emb = self.byte_emb(current_input) |
|
|
rnn_in = torch.cat([emb, latent_context], dim=-1) |
|
|
out, hidden = self.rnn(rnn_in, hidden) |
|
|
logit = self.head(out) |
|
|
if temperature > 0: |
|
|
probs = torch.nn.functional.softmax(logit / temperature, dim=-1) |
|
|
next_byte = torch.multinomial(probs.squeeze(1), 1) |
|
|
else: |
|
|
next_byte = torch.argmax(logit, dim=-1) |
|
|
pred_bytes.append(next_byte) |
|
|
current_input = next_byte |
|
|
return torch.cat(pred_bytes, dim=1).view(B, N, self.patch_size) |
|
|
|
|
|
class AGIFORMER(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int = 512, |
|
|
n_layers: int = 6, |
|
|
num_heads: int = 8, |
|
|
patch_size: int = 4, |
|
|
window_size: int = 128, |
|
|
vocab_size: int = 256, |
|
|
dropout: float = 0.1, |
|
|
thinking_steps: int = 3 |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.encoder = ByteLatentEncoder(d_model, patch_size, dropout) |
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList([ |
|
|
HybridBlock(d_model, num_heads, window_size, dropout) |
|
|
for _ in range(n_layers) |
|
|
]) |
|
|
|
|
|
self.norm_f = nn.LayerNorm(d_model) |
|
|
self.reasoning = RecurrentReasoningBlock(d_model, thinking_steps, dropout) |
|
|
self.head = LocalAutoregressiveHead(d_model, patch_size) |
|
|
|
|
|
def forward(self, x, target_bytes=None, temperature=0.0): |
|
|
x = self.encoder(x) |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
x = self.norm_f(x) |
|
|
x = self.reasoning(x) |
|
|
logits = self.head(x, target_bytes, temperature=temperature) |
|
|
return logits |
|
|
|