tefoteknik commited on
Commit
4b2aac8
·
verified ·
1 Parent(s): a304b46

Upload src/models/agiformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/models/agiformer.py +150 -0
src/models/agiformer.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Developer: inkbytefo
2
+ ## Modified: 2025-11-22
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Optional
7
+ from .encoder import ByteLatentEncoder
8
+ from .layers import HybridBlock
9
+
10
+ class LocalAutoregressiveHead(nn.Module):
11
+ """
12
+ Latent vector -> Bytes (Autoregressive).
13
+ Global Model -> Latent -> Local Model -> Bytes
14
+ """
15
+ def __init__(self, d_model, patch_size, hidden_dim=256):
16
+ super().__init__()
17
+ self.patch_size = patch_size
18
+
19
+ # Project latent to be the initial state or context
20
+ self.proj_latent = nn.Linear(d_model, hidden_dim)
21
+
22
+ # Byte embedding for the local decoder
23
+ self.byte_emb = nn.Embedding(256, hidden_dim)
24
+
25
+ # Small, fast RNN (GRU) for local decoding
26
+ # Input size is now hidden_dim (embedding) + hidden_dim (latent context)
27
+ self.rnn = nn.GRU(hidden_dim * 2, hidden_dim, batch_first=True)
28
+
29
+ self.head = nn.Linear(hidden_dim, 256)
30
+
31
+ def forward(self, latents, target_bytes=None):
32
+ """
33
+ Args:
34
+ latents: (B, N_Patches, D_Model)
35
+ target_bytes: (B, L) - Required for training (Teacher Forcing)
36
+ """
37
+ B, N, D = latents.shape
38
+
39
+ # (B * N, 1, Hidden)
40
+ latent_context = self.proj_latent(latents).view(B * N, 1, -1)
41
+
42
+ if target_bytes is not None:
43
+ # TRAINING MODE (Teacher Forcing)
44
+ # Reshape targets to (B, N, Patch_Size)
45
+ targets = target_bytes.view(B, N, self.patch_size)
46
+
47
+ # Flatten: (B*N, Patch_Size)
48
+ flat_targets = targets.contiguous().view(B * N, self.patch_size)
49
+
50
+ # Shift targets right to get inputs
51
+ sos = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device)
52
+ rnn_inputs_bytes = torch.cat([sos, flat_targets[:, :-1]], dim=1) # (B*N, P)
53
+
54
+ emb = self.byte_emb(rnn_inputs_bytes) # (B*N, P, Hidden)
55
+
56
+ # Concatenate latent context to every step
57
+ latent_expanded = latent_context.expand(-1, self.patch_size, -1)
58
+
59
+ # Concatenation instead of addition to preserve signal
60
+ rnn_input = torch.cat([emb, latent_expanded], dim=-1) # (B*N, P, Hidden * 2)
61
+
62
+ out, _ = self.rnn(rnn_input)
63
+ logits = self.head(out) # (B*N, P, 256)
64
+
65
+ return logits.view(B, N, self.patch_size, 256)
66
+
67
+ else:
68
+ # INFERENCE MODE
69
+ pred_bytes = []
70
+ # Start with SOS (0)
71
+ current_input = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device)
72
+
73
+ # Initialize hidden state
74
+ hidden = None # Let GRU initialize to 0 or we could use latent as initial state if mapped correctly
75
+
76
+ for i in range(self.patch_size):
77
+ emb = self.byte_emb(current_input) # (B*N, 1, H)
78
+
79
+ # Concatenate latent
80
+ rnn_in = torch.cat([emb, latent_context], dim=-1) # (B*N, 1, H*2)
81
+
82
+ out, hidden = self.rnn(rnn_in, hidden)
83
+ logit = self.head(out) # (B*N, 1, 256)
84
+
85
+ # Greedy decode
86
+ next_byte = torch.argmax(logit, dim=-1)
87
+ pred_bytes.append(next_byte)
88
+ current_input = next_byte
89
+
90
+ return torch.cat(pred_bytes, dim=1).view(B, N, self.patch_size)
91
+
92
+ class AGIFORMER(nn.Module):
93
+ """
94
+ AGIFORMER: A Byte-Latent Hybrid Architecture.
95
+ """
96
+ def __init__(
97
+ self,
98
+ d_model: int = 512,
99
+ n_layers: int = 6,
100
+ num_heads: int = 8,
101
+ patch_size: int = 4,
102
+ window_size: int = 128,
103
+ vocab_size: int = 256,
104
+ dropout: float = 0.1
105
+ ):
106
+ super().__init__()
107
+
108
+ self.encoder = ByteLatentEncoder(
109
+ d_model=d_model,
110
+ patch_size=patch_size,
111
+ dropout=dropout
112
+ )
113
+
114
+ self.layers = nn.ModuleList([
115
+ HybridBlock(
116
+ d_model=d_model,
117
+ num_heads=num_heads,
118
+ window_size=window_size,
119
+ dropout=dropout
120
+ )
121
+ for _ in range(n_layers)
122
+ ])
123
+
124
+ self.norm_f = nn.LayerNorm(d_model)
125
+
126
+ # Local Autoregressive Head
127
+ self.head = LocalAutoregressiveHead(d_model, patch_size)
128
+
129
+ def forward(self, x: torch.Tensor, target_bytes: Optional[torch.Tensor] = None) -> torch.Tensor:
130
+ """
131
+ Args:
132
+ x: (Batch, Seq_Len) uint8 - Input Context
133
+ target_bytes: (Batch, Seq_Len_Target) - Required for training the local head
134
+
135
+ Returns:
136
+ logits: (Batch, Num_Patches, Patch_Size, 256)
137
+ """
138
+ # 1. Encode
139
+ x = self.encoder(x) # (B, N_Patches, D)
140
+
141
+ # 2. Backbone
142
+ for layer in self.layers:
143
+ x = layer(x)
144
+
145
+ x = self.norm_f(x)
146
+
147
+ # 3. Head (Local Autoregressive)
148
+ logits = self.head(x, target_bytes)
149
+
150
+ return logits