File size: 5,731 Bytes
ca19627
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
AETHER-Net Configuration
Adaptive Elemental Transformer-Hybrid Efficient Recurrent Network

5Γ—5 Latin Orthogonal Magic Square Layout + Oheng(δΊ”θ‘Œ) MoE Routing
"""
from dataclasses import dataclass, field
from typing import List, Tuple

# ── 5Γ—5 Latin Orthogonal Magic Square ──
# Each row (element group) and each column (phase) contains
# exactly one of each attention type β†’ zero carry-over bias.
MAGIC_SQUARE = [
    # Phase1    Phase2    Phase3    Phase4    Phase5
    ["gdn",    "full",   "mamba2", "slide",  "cross"],   # 木 Wood
    ["slide",  "gdn",    "full",   "cross",  "mamba2"],  # 火 Fire
    ["full",   "cross",  "slide",  "mamba2", "gdn"],     # 土 Earth
    ["mamba2", "slide",  "cross",  "gdn",    "full"],    # 金 Metal
    ["cross",  "mamba2", "gdn",    "full",   "slide"],   # ζ°΄ Water
]

# Flatten to 25-layer sequence (row-major)
LAYER_TYPES = [t for row in MAGIC_SQUARE for t in row]

# ── Oheng (δΊ”θ‘Œ) Element System ──
ELEMENTS = ["wood", "fire", "earth", "metal", "water"]

# 상생 (Generate): ζœ¨β†’η«β†’εœŸβ†’ι‡‘β†’ζ°΄β†’ζœ¨
GENERATE = {"wood": "fire", "fire": "earth", "earth": "metal", "metal": "water", "water": "wood"}
GENERATE_REVERSE = {v: k for k, v in GENERATE.items()}

# 상극 (Overcome): 木⊣土, 土⊣水, 水⊣火, η«βŠ£ι‡‘, ι‡‘βŠ£ζœ¨
OVERCOME = {"wood": "earth", "earth": "water", "water": "fire", "fire": "metal", "metal": "wood"}
OVERCOME_REVERSE = {v: k for k, v in OVERCOME.items()}

# Element β†’ Layer indices (0-based)
ELEMENT_LAYERS = {
    "wood":  [0, 1, 2, 3, 4],
    "fire":  [5, 6, 7, 8, 9],
    "earth": [10, 11, 12, 13, 14],
    "metal": [15, 16, 17, 18, 19],
    "water": [20, 21, 22, 23, 24],
}

# Element β†’ Expert indices (0-based, 5 experts per element)
ELEMENT_EXPERTS = {
    "wood":  [0, 1, 2, 3, 4],
    "fire":  [5, 6, 7, 8, 9],
    "earth": [10, 11, 12, 13, 14],
    "metal": [15, 16, 17, 18, 19],
    "water": [20, 21, 22, 23, 24],
}

# Layer index β†’ element name
LAYER_TO_ELEMENT = {}
for elem, indices in ELEMENT_LAYERS.items():
    for idx in indices:
        LAYER_TO_ELEMENT[idx] = elem


@dataclass
class AetherNetConfig:
    """Configuration for AETHER-Net model."""

    # ── Model dimensions ──
    hidden_size: int = 4096
    intermediate_size: int = 11008  # FFN intermediate (SwiGLU)
    num_layers: int = 25
    num_attention_heads: int = 32
    num_kv_heads: int = 8  # GQA for Full Attention layers
    head_dim: int = 128  # hidden_size // num_attention_heads
    vocab_size: int = 151936  # Qwen tokenizer
    max_position_embeddings: int = 262144
    rope_theta: float = 10000000.0

    # ── Layer schedule (from magic square) ──
    layer_types: List[str] = field(default_factory=lambda: LAYER_TYPES)

    # ── MoE Configuration ──
    num_experts: int = 25
    num_experts_per_group: int = 5
    num_element_groups: int = 5
    top_k: int = 2
    num_shared_experts: int = 1
    expert_intermediate_size: int = 2752  # intermediate_size // 4 (per expert)
    moe_jitter_eps: float = 0.01

    # ── Oheng (δΊ”θ‘Œ) routing ──
    use_generate_boost: bool = True
    use_overcome_gate: bool = True
    generate_alpha_init: float = 0.1  # learnable soft scalar
    overcome_gate_hidden: int = 256  # critic head hidden dim

    # ── Attention-specific ──
    sliding_window_size: int = 4096
    gdn_state_size: int = 128  # Gated DeltaNet state dimension
    mamba2_state_size: int = 128
    mamba2_conv_size: int = 4
    mamba2_expand: int = 2

    # ── Training / Inference ──
    rms_norm_eps: float = 1e-6
    initializer_range: float = 0.02
    tie_word_embeddings: bool = False
    use_cache: bool = True
    torch_dtype: str = "bfloat16"

    # ── Donor transplant info (metadata) ──
    primary_donor: str = "Qwen/Qwen3.5-27B"
    secondary_donor: str = "meta-llama/Llama-3.1-8B"

    def get_layer_type(self, layer_idx: int) -> str:
        return self.layer_types[layer_idx]

    def get_layer_element(self, layer_idx: int) -> str:
        return LAYER_TO_ELEMENT[layer_idx]

    def get_element_expert_range(self, element: str) -> Tuple[int, int]:
        indices = ELEMENT_EXPERTS[element]
        return (indices[0], indices[-1] + 1)

    def summary(self) -> str:
        type_counts = {}
        for t in self.layer_types:
            type_counts[t] = type_counts.get(t, 0) + 1
        total_params_b = (
            self.num_experts * self.expert_intermediate_size * self.hidden_size * 3 * 2  # experts
            + self.num_layers * self.hidden_size * self.hidden_size * 4  # attention projections
            + self.vocab_size * self.hidden_size * 2  # embeddings
        ) / 1e9
        active_params_b = total_params_b * (self.top_k + self.num_shared_experts) / self.num_experts_per_group
        lines = [
            "═" * 60,
            "  AETHER-Net Architecture Summary",
            "═" * 60,
            f"  Layers:         {self.num_layers} (5Γ—5 magic square)",
            f"  Hidden dim:     {self.hidden_size}",
            f"  Attention mix:  {type_counts}",
            f"  MoE:            {self.num_experts} experts / {self.num_element_groups} groups / top-{self.top_k}",
            f"  Est. total:     ~{total_params_b:.1f}B params",
            f"  Est. active:    ~{active_params_b:.1f}B params",
            f"  Context:        {self.max_position_embeddings:,} tokens",
            f"  Oheng generate: {self.use_generate_boost} (Ξ±={self.generate_alpha_init})",
            f"  Oheng overcome: {self.use_overcome_gate}",
            f"  Primary donor:  {self.primary_donor}",
            f"  Secondary donor:{self.secondary_donor}",
            "═" * 60,
        ]
        return "\n".join(lines)