File size: 3,687 Bytes
4b2aac8 40c390b 4b2aac8 c095d2e 4b2aac8 c095d2e 4b2aac8 40c390b 4b2aac8 40c390b 4b2aac8 40c390b 4b2aac8 40c390b c095d2e 40c390b 4b2aac8 c095d2e 4b2aac8 40c390b 4b2aac8 40c390b 4b2aac8 40c390b 4b2aac8 c095d2e 4b2aac8 40c390b 4b2aac8 c095d2e 4b2aac8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
## Developer: inkbytefo
## Modified: 2025-11-23
import torch
import torch.nn as nn
from typing import Optional
from .encoder import ByteLatentEncoder
from .layers import HybridBlock
from .reasoning import RecurrentReasoningBlock
class LocalAutoregressiveHead(nn.Module):
def __init__(self, d_model, patch_size, hidden_dim=256):
super().__init__()
self.patch_size = patch_size
self.proj_latent = nn.Linear(d_model, hidden_dim)
self.byte_emb = nn.Embedding(256, hidden_dim)
self.rnn = nn.GRU(hidden_dim * 2, hidden_dim, batch_first=True)
self.head = nn.Linear(hidden_dim, 256)
def forward(self, latents, target_bytes=None, temperature=0.0):
B, N, D = latents.shape
latent_context = self.proj_latent(latents).view(B * N, 1, -1)
if target_bytes is not None:
targets = target_bytes.view(B, N, self.patch_size)
flat_targets = targets.contiguous().view(B * N, self.patch_size)
sos = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device)
rnn_inputs_bytes = torch.cat([sos, flat_targets[:, :-1]], dim=1)
emb = self.byte_emb(rnn_inputs_bytes)
latent_expanded = latent_context.expand(-1, self.patch_size, -1)
rnn_input = torch.cat([emb, latent_expanded], dim=-1)
out, _ = self.rnn(rnn_input)
logits = self.head(out)
return logits.view(B, N, self.patch_size, 256)
else:
# Inference logic (omitted for brevity, same as before)
# ...
return self._inference(latents, latent_context, temperature)
def _inference(self, latents, latent_context, temperature):
# Helper for inference to keep code clean
B, N, _ = latents.shape
pred_bytes = []
current_input = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device)
hidden = None
for i in range(self.patch_size):
emb = self.byte_emb(current_input)
rnn_in = torch.cat([emb, latent_context], dim=-1)
out, hidden = self.rnn(rnn_in, hidden)
logit = self.head(out)
if temperature > 0:
probs = torch.nn.functional.softmax(logit / temperature, dim=-1)
next_byte = torch.multinomial(probs.squeeze(1), 1)
else:
next_byte = torch.argmax(logit, dim=-1)
pred_bytes.append(next_byte)
current_input = next_byte
return torch.cat(pred_bytes, dim=1).view(B, N, self.patch_size)
class AGIFORMER(nn.Module):
def __init__(
self,
d_model: int = 512,
n_layers: int = 6,
num_heads: int = 8,
patch_size: int = 4,
window_size: int = 128,
vocab_size: int = 256,
dropout: float = 0.1,
thinking_steps: int = 3
):
super().__init__()
self.encoder = ByteLatentEncoder(d_model, patch_size, dropout)
# Hybrid Blocks now use Hebbian Memory
self.layers = nn.ModuleList([
HybridBlock(d_model, num_heads, window_size, dropout)
for _ in range(n_layers)
])
self.norm_f = nn.LayerNorm(d_model)
self.reasoning = RecurrentReasoningBlock(d_model, thinking_steps, dropout)
self.head = LocalAutoregressiveHead(d_model, patch_size)
def forward(self, x, target_bytes=None, temperature=0.0):
x = self.encoder(x)
for layer in self.layers:
x = layer(x)
x = self.norm_f(x)
x = self.reasoning(x)
logits = self.head(x, target_bytes, temperature=temperature)
return logits
|