Transformers
English
Hindi
Sanskrit
sovereign-ai
ecological-intelligence
indian-llm
environmental-protection
iamkoder001 commited on
Commit
eb7f1e0
·
verified ·
1 Parent(s): e65bc91

Update src/architecture/transformer.py

Browse files
Files changed (1) hide show
  1. src/architecture/transformer.py +28 -144
src/architecture/transformer.py CHANGED
@@ -1,167 +1,51 @@
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn import functional as F
4
- import math
5
 
6
- # --- Sovereign Components ---
7
-
8
- class RMSNorm(nn.Module):
9
- """Faster and more stable normalization for Sovereign AI."""
10
- def __init__(self, dim, eps=1e-6):
11
- super().__init__()
12
- self.eps = eps
13
- self.weight = nn.Parameter(torch.ones(dim))
14
-
15
- def _norm(self, x):
16
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
17
-
18
- def forward(self, x):
19
- output = self._norm(x.float()).type_as(x)
20
- return output * self.weight
21
-
22
- class SwiGLU(nn.Module):
23
- """Advanced activation for deep ecological reasoning."""
24
- def __init__(self, dim):
25
- super().__init__()
26
- self.w1 = nn.Linear(dim, dim * 4, bias=False)
27
- self.w2 = nn.Linear(dim, dim * 4, bias=False)
28
- self.w3 = nn.Linear(dim * 4, dim, bias=False)
29
-
30
- def forward(self, x):
31
- return self.w3(F.silu(self.w1(x)) * self.w2(x))
32
-
33
- # --- The Core Block ---
34
-
35
- class AravalliBlock(nn.Module):
36
- """
37
- The fundamental unit of ARAVALLI-1 logic.
38
- Each block processes the survival-context of the previous tokens.
39
- """
40
- def __init__(self, config):
41
- super().__init__()
42
- self.n_head = config['model_params']['n_head']
43
- self.n_embd = config['model_params']['n_embd']
44
-
45
- # Norms
46
- self.attention_norm = RMSNorm(self.n_embd)
47
- self.ffn_norm = RMSNorm(self.n_embd)
48
-
49
- # Self-Attention (Simplified for MVP structure)
50
- self.wq = nn.Linear(self.n_embd, self.n_embd, bias=False)
51
- self.wk = nn.Linear(self.n_embd, self.n_embd, bias=False)
52
- self.wv = nn.Linear(self.n_embd, self.n_embd, bias=False)
53
- self.wo = nn.Linear(self.n_embd, self.n_embd, bias=False)
54
-
55
- # Feed Forward Network
56
- self.feed_forward = SwiGLU(self.n_embd)
57
-
58
- def forward(self, x):
59
- # 1. Attention with Residual Connection
60
- h = x + self.wo(self._self_attention(self.attention_norm(x)))
61
- # 2. Feed Forward with Residual Connection
62
- out = h + self.feed_forward(self.ffn_norm(h))
63
- return out
64
-
65
- def _self_attention(self, x):
66
- # Optimized Multi-Head Attention Logic
67
- B, T, C = x.size()
68
- q = self.wq(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
69
- k = self.wk(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
70
- v = self.wv(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
71
-
72
- # Scaled Dot-Product Attention
73
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
74
- # Apply causal mask (The model cannot see the future)
75
- mask = torch.tril(torch.ones(T, T)).to(x.device)
76
- att = att.masked_fill(mask == 0, float('-inf'))
77
- att = F.softmax(att, dim=-1)
78
-
79
- y = att @ v
80
- y = y.transpose(1, 2).contiguous().view(B, T, C)
81
- return y
82
- class AravalliModel(nn.Module):
83
  """
84
- The full Sovereign AI Model: ARAVALLI-1.
85
- A decoder-only transformer built from scratch for GOEC ecosystem mandates.
86
  """
87
  def __init__(self, config):
88
  super().__init__()
89
- self.config = config
90
- params = config['model_params']
91
-
92
- # 1. Token & Positional Embeddings
93
- # We use a standard Embedding layer for tokens
94
- self.token_embedding = nn.Embedding(params['vocab_size'], params['n_embd'])
95
-
96
- # 2. Transformer Blocks (The 'Brain' Layers)
97
- self.blocks = nn.ModuleList([
98
- AravalliBlock(config) for _ in range(params['n_layer'])
99
- ])
100
-
101
- # 3. Final Normalization
102
- self.final_norm = RMSNorm(params['n_embd'])
103
-
104
- # 4. Language Modeling Head
105
- # Projects the 2048-dim embedding back to the 50,257-dim vocab
106
- self.lm_head = nn.Linear(params['n_embd'], params['vocab_size'], bias=False)
107
-
108
- # Weight Tying (Optional but recommended for efficiency)
109
- # This shares weights between embedding and lm_head
110
- self.token_embedding.weight = self.lm_head.weight
111
-
112
- # Initialize all weights
113
- self.apply(self._init_weights)
114
-
115
- def _init_weights(self, module):
116
- if isinstance(module, nn.Linear):
117
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
118
- if module.bias is not None:
119
- torch.nn.init.zeros_(module.bias)
120
- elif isinstance(module, nn.Embedding):
121
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
122
-
123
- def forward(self, idx, targets=None):
124
- B, T = idx.size()
125
-
126
- # Token Embeddings
127
- x = self.token_embedding(idx) # Shape (B, T, n_embd)
128
-
129
- # Pass through the stack of AravalliBlocks
130
- for block in self.blocks:
131
- x = block(x)
132
-
133
- # Final Norm
134
- x = self.final_norm(x)
135
-
136
- # Compute Logits
137
- logits = self.lm_head(x) # Shape (B, T, vocab_size)
138
-
139
- loss = None
140
- if targets is not None:
141
- # Flatten for CrossEntropyLoss
142
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
143
-
144
- return logits, loss
145
 
146
  @torch.no_grad()
147
  def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
148
- """Simple greedy/sampled generation for the Secretariat Node."""
149
  for _ in range(max_new_tokens):
150
- # Crop index if it exceeds context window
151
- idx_cond = idx if idx.size(1) <= self.config['model_params']['n_positions'] else idx[:, -self.config['model_params']['n_positions']:]
152
 
153
- # Get logits
154
  logits, _ = self(idx_cond)
155
- # Focus only on the last time step
156
  logits = logits[:, -1, :] / temperature
157
-
 
 
 
 
 
 
 
158
  if top_k is not None:
159
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
160
  logits[logits < v[:, [-1]]] = -float('Inf')
161
 
162
  probs = F.softmax(logits, dim=-1)
163
  idx_next = torch.multinomial(probs, num_samples=1)
164
-
 
 
 
 
165
  idx = torch.cat((idx, idx_next), dim=1)
166
-
167
  return idx
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from torch.nn import functional as F
 
4
 
5
+ class AravalliSovereignModel(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
+ Refactored ARAVALLI-1 with integrated Mechanical Survival Gates.
8
+ Removes probabilistic drift toward ecological degradation.
9
  """
10
  def __init__(self, config):
11
  super().__init__()
12
+ # ... (Previous embedding and block definitions) ...
13
+ self.survival_vocab_indices = config.get('survival_indices', [])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  @torch.no_grad()
16
  def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
 
17
  for _ in range(max_new_tokens):
18
+ idx_cond = idx[:, -4096:] # Context window adherence
 
19
 
20
+ # Forward pass to get logits
21
  logits, _ = self(idx_cond)
 
22
  logits = logits[:, -1, :] / temperature
23
+
24
+ # --- MECHANICAL SURVIVAL GATE (Refactor Start) ---
25
+ # We apply a 'Negative Logit Bias' to tokens that imply degradation
26
+ # and a 'Sovereign Priority' to survival-aligned tokens.
27
+ if self.is_in_critical_context(idx):
28
+ logits = self.apply_survival_bias(logits)
29
+ # --- MECHANICAL SURVIVAL GATE (Refactor End) ---
30
+
31
  if top_k is not None:
32
  v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
33
  logits[logits < v[:, [-1]]] = -float('Inf')
34
 
35
  probs = F.softmax(logits, dim=-1)
36
  idx_next = torch.multinomial(probs, num_samples=1)
37
+
38
+ # FINAL DETERMINISTIC CHECK: Reject token if it violates SN status
39
+ if self.is_violation(idx_next):
40
+ idx_next = torch.tensor([[self.config['tokens']['CATEGORY_SN']]]).to(idx.device)
41
+
42
  idx = torch.cat((idx, idx_next), dim=1)
 
43
  return idx
44
+
45
+ def apply_survival_bias(self, logits):
46
+ """Hard-coded logit manipulation for survival-critical tokens."""
47
+ # Force high probability for Category SN/IPN terms
48
+ logits[:, self.config['tokens']['CATEGORY_SN']] += 10.0
49
+ # Zero out 'Permit Mining' or 'Degrade' related tokens
50
+ logits[:, self.config['tokens']['FORBIDDEN_DEGRADE']] = -float('inf')
51
+ return logits