| | import torch |
| | import torch.nn as nn |
| | from torch.nn import functional as F |
| |
|
| | class AravalliSovereignModel(nn.Module): |
| | """ |
| | Refactored ARAVALLI-1 with integrated Mechanical Survival Gates. |
| | Removes probabilistic drift toward ecological degradation. |
| | """ |
| | def __init__(self, config): |
| | super().__init__() |
| | |
| | self.survival_vocab_indices = config.get('survival_indices', []) |
| |
|
| | @torch.no_grad() |
| | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): |
| | for _ in range(max_new_tokens): |
| | idx_cond = idx[:, -4096:] |
| | |
| | |
| | logits, _ = self(idx_cond) |
| | logits = logits[:, -1, :] / temperature |
| |
|
| | |
| | |
| | |
| | if self.is_in_critical_context(idx): |
| | logits = self.apply_survival_bias(logits) |
| | |
| |
|
| | if top_k is not None: |
| | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) |
| | logits[logits < v[:, [-1]]] = -float('Inf') |
| | |
| | probs = F.softmax(logits, dim=-1) |
| | idx_next = torch.multinomial(probs, num_samples=1) |
| |
|
| | |
| | if self.is_violation(idx_next): |
| | idx_next = torch.tensor([[self.config['tokens']['CATEGORY_SN']]]).to(idx.device) |
| |
|
| | idx = torch.cat((idx, idx_next), dim=1) |
| | return idx |
| |
|
| | def apply_survival_bias(self, logits): |
| | """Hard-coded logit manipulation for survival-critical tokens.""" |
| | |
| | logits[:, self.config['tokens']['CATEGORY_SN']] += 10.0 |
| | |
| | logits[:, self.config['tokens']['FORBIDDEN_DEGRADE']] = -float('inf') |
| | return logits |
| |
|