File size: 2,169 Bytes
288c6c2 eb7f1e0 288c6c2 eb7f1e0 288c6c2 eb7f1e0 288c6c2 eb7f1e0 288c6c2 eb7f1e0 288c6c2 eb7f1e0 288c6c2 eb7f1e0 288c6c2 eb7f1e0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import torch
import torch.nn as nn
from torch.nn import functional as F
class AravalliSovereignModel(nn.Module):
"""
Refactored ARAVALLI-1 with integrated Mechanical Survival Gates.
Removes probabilistic drift toward ecological degradation.
"""
def __init__(self, config):
super().__init__()
# ... (Previous embedding and block definitions) ...
self.survival_vocab_indices = config.get('survival_indices', [])
@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
for _ in range(max_new_tokens):
idx_cond = idx[:, -4096:] # Context window adherence
# Forward pass to get logits
logits, _ = self(idx_cond)
logits = logits[:, -1, :] / temperature
# --- MECHANICAL SURVIVAL GATE (Refactor Start) ---
# We apply a 'Negative Logit Bias' to tokens that imply degradation
# and a 'Sovereign Priority' to survival-aligned tokens.
if self.is_in_critical_context(idx):
logits = self.apply_survival_bias(logits)
# --- MECHANICAL SURVIVAL GATE (Refactor End) ---
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
# FINAL DETERMINISTIC CHECK: Reject token if it violates SN status
if self.is_violation(idx_next):
idx_next = torch.tensor([[self.config['tokens']['CATEGORY_SN']]]).to(idx.device)
idx = torch.cat((idx, idx_next), dim=1)
return idx
def apply_survival_bias(self, logits):
"""Hard-coded logit manipulation for survival-critical tokens."""
# Force high probability for Category SN/IPN terms
logits[:, self.config['tokens']['CATEGORY_SN']] += 10.0
# Zero out 'Permit Mining' or 'Degrade' related tokens
logits[:, self.config['tokens']['FORBIDDEN_DEGRADE']] = -float('inf')
return logits
|