Upload genuine_model.py with huggingface_hub
Browse files- genuine_model.py +10 -10
genuine_model.py
CHANGED
|
@@ -109,13 +109,13 @@ class GenuinenessGate(nn.Module):
|
|
| 109 |
return self.gate_fc(feat)
|
| 110 |
|
| 111 |
class GenuineTransformer(nn.Module):
|
| 112 |
-
def __init__(self, d_model=
|
| 113 |
super().__init__()
|
| 114 |
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 115 |
self.layers = nn.ModuleList([GenuineLayer(d_model, n_heads) for _ in range(n_layers)])
|
| 116 |
self.gate = GenuinenessGate(d_model, n_heads)
|
| 117 |
self.fc_out = nn.Linear(d_model, vocab_size)
|
| 118 |
-
self.register_buffer("freqs_cis", precompute_freqs_cis(d_model // n_heads,
|
| 119 |
self.n_heads = n_heads
|
| 120 |
|
| 121 |
def forward(self, x, g_budget=12):
|
|
@@ -125,26 +125,26 @@ class GenuineTransformer(nn.Module):
|
|
| 125 |
x = self.embedding(x)
|
| 126 |
all_entropies = []
|
| 127 |
reasoning_layers = len(self.layers) // 2
|
| 128 |
-
|
| 129 |
# V2.2: Global G-Budgeting
|
| 130 |
total_steps = 0
|
| 131 |
-
|
| 132 |
while total_steps < g_budget:
|
| 133 |
loop_entropies = []
|
| 134 |
for i in range(reasoning_layers):
|
| 135 |
x, attn, entropies = self.layers[i](x, freqs_cis)
|
| 136 |
loop_entropies.append(entropies)
|
| 137 |
total_steps += 1
|
| 138 |
-
|
| 139 |
all_entropies.extend(loop_entropies)
|
| 140 |
-
|
| 141 |
# Gating check
|
| 142 |
last_entropy = loop_entropies[-1]
|
| 143 |
gate_signal = self.gate(x, last_entropy) # [batch, 1]
|
| 144 |
-
|
| 145 |
if gate_signal.mean() < 0.5:
|
| 146 |
break
|
| 147 |
-
|
| 148 |
# Final decoding layers
|
| 149 |
for i in range(reasoning_layers, len(self.layers)):
|
| 150 |
x, attn, entropies = self.layers[i](x, freqs_cis)
|
|
@@ -173,11 +173,11 @@ class ThermodynamicRegularizer:
|
|
| 173 |
# 1. Variance Reward (Internal Complexity)
|
| 174 |
# Reward high entropy variance across heads for each token/layer
|
| 175 |
var_h = torch.var(stack, dim=-1) # [layers, batch, seq]
|
| 176 |
-
|
| 177 |
# Apply layer-wise decay to prioritize early-layer complexity
|
| 178 |
layer_weights = torch.pow(self.layer_decay, torch.arange(len(entropies), device=stack.device, dtype=torch.float)).view(-1, 1, 1)
|
| 179 |
weighted_var = (var_h * layer_weights).mean()
|
| 180 |
-
|
| 181 |
total_loss = -self.variance_weight * weighted_var
|
| 182 |
|
| 183 |
# 2. Mechanical Penalty (Anti-Pattern Matching)
|
|
|
|
| 109 |
return self.gate_fc(feat)
|
| 110 |
|
| 111 |
class GenuineTransformer(nn.Module):
|
| 112 |
+
def __init__(self, d_model=512, n_heads=8, n_layers=12, vocab_size=1000):
|
| 113 |
super().__init__()
|
| 114 |
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 115 |
self.layers = nn.ModuleList([GenuineLayer(d_model, n_heads) for _ in range(n_layers)])
|
| 116 |
self.gate = GenuinenessGate(d_model, n_heads)
|
| 117 |
self.fc_out = nn.Linear(d_model, vocab_size)
|
| 118 |
+
self.register_buffer("freqs_cis", precompute_freqs_cis(d_model // n_heads, 256))
|
| 119 |
self.n_heads = n_heads
|
| 120 |
|
| 121 |
def forward(self, x, g_budget=12):
|
|
|
|
| 125 |
x = self.embedding(x)
|
| 126 |
all_entropies = []
|
| 127 |
reasoning_layers = len(self.layers) // 2
|
| 128 |
+
|
| 129 |
# V2.2: Global G-Budgeting
|
| 130 |
total_steps = 0
|
| 131 |
+
|
| 132 |
while total_steps < g_budget:
|
| 133 |
loop_entropies = []
|
| 134 |
for i in range(reasoning_layers):
|
| 135 |
x, attn, entropies = self.layers[i](x, freqs_cis)
|
| 136 |
loop_entropies.append(entropies)
|
| 137 |
total_steps += 1
|
| 138 |
+
|
| 139 |
all_entropies.extend(loop_entropies)
|
| 140 |
+
|
| 141 |
# Gating check
|
| 142 |
last_entropy = loop_entropies[-1]
|
| 143 |
gate_signal = self.gate(x, last_entropy) # [batch, 1]
|
| 144 |
+
|
| 145 |
if gate_signal.mean() < 0.5:
|
| 146 |
break
|
| 147 |
+
|
| 148 |
# Final decoding layers
|
| 149 |
for i in range(reasoning_layers, len(self.layers)):
|
| 150 |
x, attn, entropies = self.layers[i](x, freqs_cis)
|
|
|
|
| 173 |
# 1. Variance Reward (Internal Complexity)
|
| 174 |
# Reward high entropy variance across heads for each token/layer
|
| 175 |
var_h = torch.var(stack, dim=-1) # [layers, batch, seq]
|
| 176 |
+
|
| 177 |
# Apply layer-wise decay to prioritize early-layer complexity
|
| 178 |
layer_weights = torch.pow(self.layer_decay, torch.arange(len(entropies), device=stack.device, dtype=torch.float)).view(-1, 1, 1)
|
| 179 |
weighted_var = (var_h * layer_weights).mean()
|
| 180 |
+
|
| 181 |
total_loss = -self.variance_weight * weighted_var
|
| 182 |
|
| 183 |
# 2. Mechanical Penalty (Anti-Pattern Matching)
|