import torch import torch.nn as nn from torch.nn import functional as F class AravalliSovereignModel(nn.Module): """ Refactored ARAVALLI-1 with integrated Mechanical Survival Gates. Removes probabilistic drift toward ecological degradation. """ def __init__(self, config): super().__init__() # ... (Previous embedding and block definitions) ... self.survival_vocab_indices = config.get('survival_indices', []) @torch.no_grad() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): for _ in range(max_new_tokens): idx_cond = idx[:, -4096:] # Context window adherence # Forward pass to get logits logits, _ = self(idx_cond) logits = logits[:, -1, :] / temperature # --- MECHANICAL SURVIVAL GATE (Refactor Start) --- # We apply a 'Negative Logit Bias' to tokens that imply degradation # and a 'Sovereign Priority' to survival-aligned tokens. if self.is_in_critical_context(idx): logits = self.apply_survival_bias(logits) # --- MECHANICAL SURVIVAL GATE (Refactor End) --- if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) # FINAL DETERMINISTIC CHECK: Reject token if it violates SN status if self.is_violation(idx_next): idx_next = torch.tensor([[self.config['tokens']['CATEGORY_SN']]]).to(idx.device) idx = torch.cat((idx, idx_next), dim=1) return idx def apply_survival_bias(self, logits): """Hard-coded logit manipulation for survival-critical tokens.""" # Force high probability for Category SN/IPN terms logits[:, self.config['tokens']['CATEGORY_SN']] += 10.0 # Zero out 'Permit Mining' or 'Degrade' related tokens logits[:, self.config['tokens']['FORBIDDEN_DEGRADE']] = -float('inf') return logits