File size: 8,373 Bytes
ca19627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
"""
AETHER-Net: Main Model
Adaptive Elemental Transformer-Hybrid Efficient Recurrent Network

25-layer hybrid LLM with 5Γ—5 Latin orthogonal magic square layout
and Oheng (δΊ”θ‘Œ) MoE routing.
"""
import torch
import torch.nn as nn
from typing import Dict, List, Optional, Tuple

from config import AetherNetConfig, ELEMENTS, LAYER_TO_ELEMENT, ELEMENT_LAYERS
from layers import RMSNorm, build_attention
from oheng_moe import OhengMoE


class AetherNetBlock(nn.Module):
    """Single AETHER-Net transformer block.

    Structure:
        x β†’ RMSNorm β†’ Attention β†’ residual β†’ RMSNorm β†’ OhengMoE β†’ residual β†’ out
    """

    def __init__(self, config: AetherNetConfig, layer_idx: int):
        super().__init__()
        self.layer_idx = layer_idx
        self.layer_type = config.get_layer_type(layer_idx)
        self.element = config.get_layer_element(layer_idx)

        # Pre-norm
        self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)

        # Attention (type determined by magic square)
        self.attention = build_attention(self.layer_type, config)

        # MoE FFN with Oheng routing
        self.moe = OhengMoE(config, layer_idx)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        element_states: Optional[Dict[str, torch.Tensor]] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        # Attention block with residual
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.attention(
            hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            encoder_hidden_states=encoder_hidden_states,
        )
        hidden_states = residual + hidden_states

        # MoE FFN block with residual
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.moe(hidden_states, element_states=element_states)
        hidden_states = residual + hidden_states

        return hidden_states


class AetherNetModel(nn.Module):
    """AETHER-Net Language Model.

    Architecture:
    - Embedding β†’ 25 Γ— AetherNetBlock β†’ RMSNorm β†’ LM Head
    - Blocks arranged in 5Γ—5 Latin orthogonal magic square
    - Oheng MoE with 상생 generate and 상극 overcome connections
    - Element states flow between element groups for structural self-verification
    """

    def __init__(self, config: AetherNetConfig):
        super().__init__()
        self.config = config

        # Token embedding
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)

        # 25 transformer blocks
        self.layers = nn.ModuleList([
            AetherNetBlock(config, layer_idx=i)
            for i in range(config.num_layers)
        ])

        # Final norm
        self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)

        # LM Head
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Weight tying
        if config.tie_word_embeddings:
            self.lm_head.weight = self.embed_tokens.weight

        # Initialize
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        B, L = input_ids.shape

        # Position IDs
        if position_ids is None:
            position_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)

        # Embed
        hidden_states = self.embed_tokens(input_ids)

        # ── Element state tracking for Oheng connections ──
        # Each element group accumulates its output for 상생/상극 routing
        element_states: Dict[str, torch.Tensor] = {}
        element_layer_counts: Dict[str, int] = {e: 0 for e in ELEMENTS}

        # ── Forward through 25 layers ──
        for i, layer in enumerate(self.layers):
            element = LAYER_TO_ELEMENT[i]

            hidden_states = layer(
                hidden_states,
                attention_mask=attention_mask,
                position_ids=position_ids,
                element_states=element_states,
                encoder_hidden_states=encoder_hidden_states,
            )

            # Update element state (running average of this element's layer outputs)
            element_layer_counts[element] += 1
            count = element_layer_counts[element]
            if element in element_states:
                # Exponential moving average of element's outputs
                element_states[element] = (
                    element_states[element] * (count - 1) / count
                    + hidden_states.detach() / count
                )
            else:
                element_states[element] = hidden_states.detach()

        # Final norm
        hidden_states = self.norm(hidden_states)

        # LM Head
        logits = self.lm_head(hidden_states)

        # Loss
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = nn.functional.cross_entropy(
                shift_logits.view(-1, self.config.vocab_size),
                shift_labels.view(-1),
                ignore_index=-100,
            )

        return {
            "loss": loss,
            "logits": logits,
            "element_states": element_states,
        }

    def count_parameters(self) -> Dict[str, int]:
        """Count parameters by component."""
        counts = {
            "embedding": sum(p.numel() for p in self.embed_tokens.parameters()),
            "lm_head": sum(p.numel() for p in self.lm_head.parameters()),
            "norm": sum(p.numel() for p in self.norm.parameters()),
        }

        attn_total = 0
        moe_total = 0
        generate_total = 0
        overcome_total = 0

        for layer in self.layers:
            attn_total += sum(p.numel() for p in layer.attention.parameters())
            attn_total += sum(p.numel() for p in layer.input_layernorm.parameters())
            attn_total += sum(p.numel() for p in layer.post_attention_layernorm.parameters())

            moe_total += sum(p.numel() for p in layer.moe.experts.parameters())
            moe_total += sum(p.numel() for p in layer.moe.shared_expert.parameters())
            moe_total += sum(p.numel() for p in layer.moe.router.parameters())

            if layer.moe.generate_boost is not None:
                generate_total += sum(p.numel() for p in layer.moe.generate_boost.parameters())
            if layer.moe.overcome_gate is not None:
                overcome_total += sum(p.numel() for p in layer.moe.overcome_gate.parameters())

        counts["attention_layers"] = attn_total
        counts["moe_experts"] = moe_total
        counts["oheng_generate"] = generate_total
        counts["oheng_overcome"] = overcome_total
        counts["total"] = sum(counts.values())

        return counts

    def get_layer_map(self) -> List[Dict]:
        """Return human-readable layer map for diagnostics."""
        result = []
        for i, layer in enumerate(self.layers):
            result.append({
                "layer": i,
                "type": layer.layer_type,
                "element": layer.element,
                "element_idx": ELEMENTS.index(layer.element),
                "phase": i % 5,
                "attn_class": layer.attention.__class__.__name__,
            })
        return result