File size: 3,687 Bytes
4b2aac8
40c390b
4b2aac8
 
 
 
 
 
c095d2e
4b2aac8
 
 
 
 
 
 
 
 
 
c095d2e
4b2aac8
 
 
 
 
 
 
40c390b
 
4b2aac8
40c390b
4b2aac8
40c390b
4b2aac8
 
40c390b
 
 
c095d2e
40c390b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b2aac8
 
 
 
 
 
 
 
 
 
c095d2e
 
4b2aac8
 
 
40c390b
4b2aac8
40c390b
4b2aac8
40c390b
4b2aac8
 
 
 
c095d2e
4b2aac8
 
40c390b
 
4b2aac8
 
 
c095d2e
 
4b2aac8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
## Developer: inkbytefo
## Modified: 2025-11-23

import torch
import torch.nn as nn
from typing import Optional
from .encoder import ByteLatentEncoder
from .layers import HybridBlock
from .reasoning import RecurrentReasoningBlock

class LocalAutoregressiveHead(nn.Module):
    def __init__(self, d_model, patch_size, hidden_dim=256):
        super().__init__()
        self.patch_size = patch_size
        self.proj_latent = nn.Linear(d_model, hidden_dim)
        self.byte_emb = nn.Embedding(256, hidden_dim)
        self.rnn = nn.GRU(hidden_dim * 2, hidden_dim, batch_first=True)
        self.head = nn.Linear(hidden_dim, 256)

    def forward(self, latents, target_bytes=None, temperature=0.0):
        B, N, D = latents.shape
        latent_context = self.proj_latent(latents).view(B * N, 1, -1)
        
        if target_bytes is not None:
            targets = target_bytes.view(B, N, self.patch_size)
            flat_targets = targets.contiguous().view(B * N, self.patch_size)
            sos = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device)
            rnn_inputs_bytes = torch.cat([sos, flat_targets[:, :-1]], dim=1)
            emb = self.byte_emb(rnn_inputs_bytes)
            latent_expanded = latent_context.expand(-1, self.patch_size, -1)
            rnn_input = torch.cat([emb, latent_expanded], dim=-1)
            out, _ = self.rnn(rnn_input)
            logits = self.head(out)
            return logits.view(B, N, self.patch_size, 256)
        else:
            # Inference logic (omitted for brevity, same as before)
            # ...
            return self._inference(latents, latent_context, temperature)

    def _inference(self, latents, latent_context, temperature):
        # Helper for inference to keep code clean
        B, N, _ = latents.shape
        pred_bytes = []
        current_input = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device)
        hidden = None 
        for i in range(self.patch_size):
            emb = self.byte_emb(current_input)
            rnn_in = torch.cat([emb, latent_context], dim=-1)
            out, hidden = self.rnn(rnn_in, hidden)
            logit = self.head(out)
            if temperature > 0:
                probs = torch.nn.functional.softmax(logit / temperature, dim=-1)
                next_byte = torch.multinomial(probs.squeeze(1), 1)
            else:
                next_byte = torch.argmax(logit, dim=-1)
            pred_bytes.append(next_byte)
            current_input = next_byte
        return torch.cat(pred_bytes, dim=1).view(B, N, self.patch_size)

class AGIFORMER(nn.Module):
    def __init__(
        self,
        d_model: int = 512,
        n_layers: int = 6,
        num_heads: int = 8,
        patch_size: int = 4,
        window_size: int = 128,
        vocab_size: int = 256,
        dropout: float = 0.1,
        thinking_steps: int = 3
    ):
        super().__init__()
        
        self.encoder = ByteLatentEncoder(d_model, patch_size, dropout)
        
        # Hybrid Blocks now use Hebbian Memory
        self.layers = nn.ModuleList([
            HybridBlock(d_model, num_heads, window_size, dropout)
            for _ in range(n_layers)
        ])
        
        self.norm_f = nn.LayerNorm(d_model)
        self.reasoning = RecurrentReasoningBlock(d_model, thinking_steps, dropout)
        self.head = LocalAutoregressiveHead(d_model, patch_size)

    def forward(self, x, target_bytes=None, temperature=0.0):
        x = self.encoder(x)
        for layer in self.layers:
            x = layer(x)
        x = self.norm_f(x)
        x = self.reasoning(x)
        logits = self.head(x, target_bytes, temperature=temperature)
        return logits