tefoteknik commited on
Commit
40c390b
·
verified ·
1 Parent(s): 756c95a

Phase 7: Curriculum Learning (20K steps, BPC 1.78)

Browse files
Files changed (1) hide show
  1. src/models/agiformer.py +32 -103
src/models/agiformer.py CHANGED
@@ -1,103 +1,63 @@
1
  ## Developer: inkbytefo
2
- ## Modified: 2025-11-22
3
 
4
  import torch
5
  import torch.nn as nn
6
- import torch.nn.functional as F
7
  from typing import Optional
8
  from .encoder import ByteLatentEncoder
9
  from .layers import HybridBlock
10
  from .reasoning import RecurrentReasoningBlock
11
 
12
  class LocalAutoregressiveHead(nn.Module):
13
- """
14
- Latent vector -> Bytes (Autoregressive).
15
- Global Model -> Latent -> Local Model -> Bytes
16
- """
17
  def __init__(self, d_model, patch_size, hidden_dim=256):
18
  super().__init__()
19
  self.patch_size = patch_size
20
-
21
- # Project latent to be the initial state or context
22
  self.proj_latent = nn.Linear(d_model, hidden_dim)
23
-
24
- # Byte embedding for the local decoder
25
  self.byte_emb = nn.Embedding(256, hidden_dim)
26
-
27
- # Small, fast RNN (GRU) for local decoding
28
- # Input size is now hidden_dim (embedding) + hidden_dim (latent context)
29
  self.rnn = nn.GRU(hidden_dim * 2, hidden_dim, batch_first=True)
30
-
31
  self.head = nn.Linear(hidden_dim, 256)
32
 
33
  def forward(self, latents, target_bytes=None, temperature=0.0):
34
  B, N, D = latents.shape
35
- # (B * N, 1, Hidden)
36
  latent_context = self.proj_latent(latents).view(B * N, 1, -1)
37
 
38
  if target_bytes is not None:
39
- # --- TRAINING MODE ---
40
- # Reshape targets to (B, N, Patch_Size)
41
  targets = target_bytes.view(B, N, self.patch_size)
42
-
43
- # Flatten: (B*N, Patch_Size)
44
  flat_targets = targets.contiguous().view(B * N, self.patch_size)
45
-
46
- # Shift targets right to get inputs
47
  sos = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device)
48
- rnn_inputs_bytes = torch.cat([sos, flat_targets[:, :-1]], dim=1) # (B*N, P)
49
-
50
- emb = self.byte_emb(rnn_inputs_bytes) # (B*N, P, Hidden)
51
-
52
- # Concatenate latent context to every step
53
  latent_expanded = latent_context.expand(-1, self.patch_size, -1)
54
-
55
- # Concatenation instead of addition to preserve signal
56
- rnn_input = torch.cat([emb, latent_expanded], dim=-1) # (B*N, P, Hidden * 2)
57
-
58
  out, _ = self.rnn(rnn_input)
59
- logits = self.head(out) # (B*N, P, 256)
60
-
61
  return logits.view(B, N, self.patch_size, 256)
62
-
63
  else:
64
- # INFERENCE MODE
65
- pred_bytes = []
66
- # Start with SOS (0)
67
- current_input = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device)
68
-
69
- # Initialize hidden state
70
- hidden = None
71
 
72
- for i in range(self.patch_size):
73
- emb = self.byte_emb(current_input) # (B*N, 1, H)
74
-
75
- # Concatenate latent
76
- rnn_in = torch.cat([emb, latent_context], dim=-1) # (B*N, 1, H*2)
77
-
78
- # GRU State Preservation
79
- out, hidden = self.rnn(rnn_in, hidden)
80
- logit = self.head(out) # (B*N, 1, 256)
81
-
82
- # SAMPLING LOGIC
83
- if temperature > 0:
84
- # Apply temperature
85
- probs = F.softmax(logit / temperature, dim=-1)
86
- # Sample from distribution
87
- next_byte = torch.multinomial(probs.squeeze(1), 1)
88
- else:
89
- # Greedy
90
- next_byte = torch.argmax(logit, dim=-1)
91
-
92
- pred_bytes.append(next_byte)
93
- current_input = next_byte
94
-
95
- return torch.cat(pred_bytes, dim=1).view(B, N, self.patch_size)
96
 
97
  class AGIFORMER(nn.Module):
98
- """
99
- AGIFORMER Phase 3: System 2 Enabled
100
- """
101
  def __init__(
102
  self,
103
  d_model: int = 512,
@@ -111,54 +71,23 @@ class AGIFORMER(nn.Module):
111
  ):
112
  super().__init__()
113
 
114
- self.encoder = ByteLatentEncoder(
115
- d_model=d_model,
116
- patch_size=patch_size,
117
- dropout=dropout
118
- )
119
 
 
120
  self.layers = nn.ModuleList([
121
- HybridBlock(
122
- d_model=d_model,
123
- num_heads=num_heads,
124
- window_size=window_size,
125
- dropout=dropout
126
- )
127
  for _ in range(n_layers)
128
  ])
129
 
130
  self.norm_f = nn.LayerNorm(d_model)
131
-
132
- # SYSTEM 2 MODULE
133
  self.reasoning = RecurrentReasoningBlock(d_model, thinking_steps, dropout)
134
-
135
- # Local Autoregressive Head
136
  self.head = LocalAutoregressiveHead(d_model, patch_size)
137
 
138
- def forward(self, x: torch.Tensor, target_bytes: Optional[torch.Tensor] = None, temperature: float = 0.0) -> torch.Tensor:
139
- """
140
- Args:
141
- x: (Batch, Seq_Len) uint8 - Input Context
142
- target_bytes: (Batch, Seq_Len_Target) - Required for training the local head
143
- temperature: float - Sampling temperature (0.0 = Greedy)
144
-
145
- Returns:
146
- logits: (Batch, Num_Patches, Patch_Size, 256)
147
- """
148
- # 1. System 1 (Intuition / Perception)
149
- x = self.encoder(x) # (B, N_Patches, D)
150
-
151
- # 2. Backbone
152
  for layer in self.layers:
153
  x = layer(x)
154
-
155
  x = self.norm_f(x)
156
-
157
- # 3. System 2 (Reasoning / Thinking Loop)
158
- # Refine the latent state before speaking
159
  x = self.reasoning(x)
160
-
161
- # 4. Output (Articulation)
162
  logits = self.head(x, target_bytes, temperature=temperature)
163
-
164
  return logits
 
1
  ## Developer: inkbytefo
2
+ ## Modified: 2025-11-23
3
 
4
  import torch
5
  import torch.nn as nn
 
6
  from typing import Optional
7
  from .encoder import ByteLatentEncoder
8
  from .layers import HybridBlock
9
  from .reasoning import RecurrentReasoningBlock
10
 
11
  class LocalAutoregressiveHead(nn.Module):
 
 
 
 
12
  def __init__(self, d_model, patch_size, hidden_dim=256):
13
  super().__init__()
14
  self.patch_size = patch_size
 
 
15
  self.proj_latent = nn.Linear(d_model, hidden_dim)
 
 
16
  self.byte_emb = nn.Embedding(256, hidden_dim)
 
 
 
17
  self.rnn = nn.GRU(hidden_dim * 2, hidden_dim, batch_first=True)
 
18
  self.head = nn.Linear(hidden_dim, 256)
19
 
20
  def forward(self, latents, target_bytes=None, temperature=0.0):
21
  B, N, D = latents.shape
 
22
  latent_context = self.proj_latent(latents).view(B * N, 1, -1)
23
 
24
  if target_bytes is not None:
 
 
25
  targets = target_bytes.view(B, N, self.patch_size)
 
 
26
  flat_targets = targets.contiguous().view(B * N, self.patch_size)
 
 
27
  sos = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device)
28
+ rnn_inputs_bytes = torch.cat([sos, flat_targets[:, :-1]], dim=1)
29
+ emb = self.byte_emb(rnn_inputs_bytes)
 
 
 
30
  latent_expanded = latent_context.expand(-1, self.patch_size, -1)
31
+ rnn_input = torch.cat([emb, latent_expanded], dim=-1)
 
 
 
32
  out, _ = self.rnn(rnn_input)
33
+ logits = self.head(out)
 
34
  return logits.view(B, N, self.patch_size, 256)
 
35
  else:
36
+ # Inference logic (omitted for brevity, same as before)
37
+ # ...
38
+ return self._inference(latents, latent_context, temperature)
 
 
 
 
39
 
40
+ def _inference(self, latents, latent_context, temperature):
41
+ # Helper for inference to keep code clean
42
+ B, N, _ = latents.shape
43
+ pred_bytes = []
44
+ current_input = torch.zeros(B * N, 1, dtype=torch.long, device=latents.device)
45
+ hidden = None
46
+ for i in range(self.patch_size):
47
+ emb = self.byte_emb(current_input)
48
+ rnn_in = torch.cat([emb, latent_context], dim=-1)
49
+ out, hidden = self.rnn(rnn_in, hidden)
50
+ logit = self.head(out)
51
+ if temperature > 0:
52
+ probs = torch.nn.functional.softmax(logit / temperature, dim=-1)
53
+ next_byte = torch.multinomial(probs.squeeze(1), 1)
54
+ else:
55
+ next_byte = torch.argmax(logit, dim=-1)
56
+ pred_bytes.append(next_byte)
57
+ current_input = next_byte
58
+ return torch.cat(pred_bytes, dim=1).view(B, N, self.patch_size)
 
 
 
 
 
59
 
60
  class AGIFORMER(nn.Module):
 
 
 
61
  def __init__(
62
  self,
63
  d_model: int = 512,
 
71
  ):
72
  super().__init__()
73
 
74
+ self.encoder = ByteLatentEncoder(d_model, patch_size, dropout)
 
 
 
 
75
 
76
+ # Hybrid Blocks now use Hebbian Memory
77
  self.layers = nn.ModuleList([
78
+ HybridBlock(d_model, num_heads, window_size, dropout)
 
 
 
 
 
79
  for _ in range(n_layers)
80
  ])
81
 
82
  self.norm_f = nn.LayerNorm(d_model)
 
 
83
  self.reasoning = RecurrentReasoningBlock(d_model, thinking_steps, dropout)
 
 
84
  self.head = LocalAutoregressiveHead(d_model, patch_size)
85
 
86
+ def forward(self, x, target_bytes=None, temperature=0.0):
87
+ x = self.encoder(x)
 
 
 
 
 
 
 
 
 
 
 
 
88
  for layer in self.layers:
89
  x = layer(x)
 
90
  x = self.norm_f(x)
 
 
 
91
  x = self.reasoning(x)
 
 
92
  logits = self.head(x, target_bytes, temperature=temperature)
 
93
  return logits