Spaces:
Paused
Paused
Upload 6 files
Browse files- app.py +320 -0
- config.py +149 -0
- layers.py +449 -0
- model.py +228 -0
- oheng_moe.py +292 -0
- requirements.txt +5 -0
app.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
AETHER-Net 0.8B โ Inference Test Space
|
| 4 |
+
|
| 5 |
+
Private ๋ชจ๋ธ์ ๋ก๋ํ์ฌ ํ
์คํธ ์์ฑ์ ํ
์คํธํฉ๋๋ค.
|
| 6 |
+
HF Space: T4 GPU, HF_TOKEN secret ํ์
|
| 7 |
+
|
| 8 |
+
Deploy: FINAL-Bench/aether-net-test
|
| 9 |
+
"""
|
| 10 |
+
import os
|
| 11 |
+
import sys
|
| 12 |
+
import time
|
| 13 |
+
import json
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
import gradio as gr
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 19 |
+
|
| 20 |
+
# โโ Config โโ
|
| 21 |
+
MODEL_REPO = "FINAL-Bench/AETHER-Net-0.8B"
|
| 22 |
+
DONOR_REPO = "Qwen/Qwen3.5-0.8B" # For tokenizer
|
| 23 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 24 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
+
|
| 26 |
+
print(f"Device: {DEVICE}")
|
| 27 |
+
print(f"HF_TOKEN: {'set' if HF_TOKEN else 'NOT SET'}")
|
| 28 |
+
|
| 29 |
+
# โโ Download model weights from private repo โโ
|
| 30 |
+
print(f"Downloading AETHER-Net weights from {MODEL_REPO}...")
|
| 31 |
+
|
| 32 |
+
model_dir = None
|
| 33 |
+
try:
|
| 34 |
+
model_dir = snapshot_download(
|
| 35 |
+
MODEL_REPO, token=HF_TOKEN,
|
| 36 |
+
allow_patterns=["model.safetensors", "config.json"],
|
| 37 |
+
)
|
| 38 |
+
print(f" Model downloaded to: {model_dir}")
|
| 39 |
+
except Exception as e:
|
| 40 |
+
print(f" Download failed: {e}")
|
| 41 |
+
|
| 42 |
+
# Source files are co-located in the same directory
|
| 43 |
+
APP_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 44 |
+
sys.path.insert(0, APP_DIR)
|
| 45 |
+
|
| 46 |
+
# โโ Load model โโ
|
| 47 |
+
MODEL = None
|
| 48 |
+
TOKENIZER = None
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def load_model():
|
| 52 |
+
global MODEL, TOKENIZER
|
| 53 |
+
|
| 54 |
+
if MODEL is not None:
|
| 55 |
+
return True
|
| 56 |
+
|
| 57 |
+
# Load tokenizer from donor
|
| 58 |
+
print("Loading tokenizer...")
|
| 59 |
+
from transformers import AutoTokenizer
|
| 60 |
+
try:
|
| 61 |
+
TOKENIZER = AutoTokenizer.from_pretrained(
|
| 62 |
+
DONOR_REPO, trust_remote_code=True, token=HF_TOKEN
|
| 63 |
+
)
|
| 64 |
+
print(f" Tokenizer loaded: vocab_size={TOKENIZER.vocab_size}")
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f" Tokenizer failed: {e}")
|
| 67 |
+
return False
|
| 68 |
+
|
| 69 |
+
# Load AETHER-Net
|
| 70 |
+
print("Loading AETHER-Net model...")
|
| 71 |
+
try:
|
| 72 |
+
from config import AetherNetConfig
|
| 73 |
+
from model import AetherNetModel
|
| 74 |
+
|
| 75 |
+
# Load config
|
| 76 |
+
config_path = Path(model_dir) / "config.json" if model_dir else None
|
| 77 |
+
if config_path and config_path.exists():
|
| 78 |
+
with open(config_path) as f:
|
| 79 |
+
cfg_dict = json.load(f)
|
| 80 |
+
# Filter valid fields
|
| 81 |
+
valid_fields = {k for k in AetherNetConfig.__dataclass_fields__}
|
| 82 |
+
filtered = {k: v for k, v in cfg_dict.items() if k in valid_fields}
|
| 83 |
+
config = AetherNetConfig(**filtered)
|
| 84 |
+
print(f" Config loaded: hidden={config.hidden_size}, layers={config.num_layers}")
|
| 85 |
+
else:
|
| 86 |
+
print(" No config.json, using defaults")
|
| 87 |
+
config = AetherNetConfig(
|
| 88 |
+
hidden_size=1024, intermediate_size=3584,
|
| 89 |
+
num_layers=25, num_attention_heads=16, num_kv_heads=2,
|
| 90 |
+
head_dim=64, vocab_size=248320,
|
| 91 |
+
max_position_embeddings=4096,
|
| 92 |
+
expert_intermediate_size=716,
|
| 93 |
+
overcome_gate_hidden=64,
|
| 94 |
+
sliding_window_size=1024,
|
| 95 |
+
gdn_state_size=64, mamba2_state_size=64,
|
| 96 |
+
tie_word_embeddings=True,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
model = AetherNetModel(config)
|
| 100 |
+
|
| 101 |
+
# Load weights
|
| 102 |
+
weights_path = Path(model_dir) / "model.safetensors" if model_dir else None
|
| 103 |
+
if weights_path and weights_path.exists():
|
| 104 |
+
from safetensors.torch import load_file
|
| 105 |
+
state = load_file(str(weights_path), device="cpu")
|
| 106 |
+
model.load_state_dict(state, strict=False)
|
| 107 |
+
print(f" Weights loaded: {len(state)} tensors")
|
| 108 |
+
else:
|
| 109 |
+
print(" โ ๏ธ No weights found, using random init")
|
| 110 |
+
|
| 111 |
+
model = model.to(DEVICE).eval()
|
| 112 |
+
MODEL = model
|
| 113 |
+
|
| 114 |
+
params = sum(p.numel() for p in model.parameters())
|
| 115 |
+
mem = params * 2 / 1e9 # BF16 estimate
|
| 116 |
+
print(f" Model ready: {params:,} params (~{mem:.1f}GB)")
|
| 117 |
+
return True
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
import traceback
|
| 121 |
+
print(f" Model load failed: {e}")
|
| 122 |
+
traceback.print_exc()
|
| 123 |
+
return False
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# โโ Generation โโ
|
| 127 |
+
@torch.no_grad()
|
| 128 |
+
def generate(prompt, max_tokens=128, temperature=0.8, top_k=50, top_p=0.9):
|
| 129 |
+
"""Generate text from prompt."""
|
| 130 |
+
if MODEL is None:
|
| 131 |
+
success = load_model()
|
| 132 |
+
if not success:
|
| 133 |
+
return "โ Model failed to load. Check logs."
|
| 134 |
+
|
| 135 |
+
# Tokenize
|
| 136 |
+
input_ids = TOKENIZER.encode(prompt, return_tensors="pt").to(DEVICE)
|
| 137 |
+
generated = input_ids.clone()
|
| 138 |
+
|
| 139 |
+
t0 = time.time()
|
| 140 |
+
|
| 141 |
+
for i in range(max_tokens):
|
| 142 |
+
# Truncate to max position
|
| 143 |
+
if generated.shape[1] > 4096:
|
| 144 |
+
generated = generated[:, -4096:]
|
| 145 |
+
|
| 146 |
+
outputs = MODEL(input_ids=generated)
|
| 147 |
+
logits = outputs["logits"][:, -1, :]
|
| 148 |
+
|
| 149 |
+
# Temperature
|
| 150 |
+
if temperature > 0:
|
| 151 |
+
logits = logits / temperature
|
| 152 |
+
|
| 153 |
+
# Top-k
|
| 154 |
+
if top_k > 0:
|
| 155 |
+
values, _ = torch.topk(logits, top_k)
|
| 156 |
+
min_val = values[:, -1].unsqueeze(-1)
|
| 157 |
+
logits = torch.where(logits < min_val, torch.full_like(logits, -float('inf')), logits)
|
| 158 |
+
|
| 159 |
+
# Top-p (nucleus)
|
| 160 |
+
if top_p < 1.0:
|
| 161 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 162 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 163 |
+
mask = cum_probs - F.softmax(sorted_logits, dim=-1) > top_p
|
| 164 |
+
sorted_logits[mask] = -float('inf')
|
| 165 |
+
logits = sorted_logits.scatter(1, sorted_indices, sorted_logits)
|
| 166 |
+
|
| 167 |
+
probs = F.softmax(logits, dim=-1)
|
| 168 |
+
next_token = torch.multinomial(probs, num_samples=1)
|
| 169 |
+
else:
|
| 170 |
+
next_token = logits.argmax(dim=-1, keepdim=True)
|
| 171 |
+
|
| 172 |
+
generated = torch.cat([generated, next_token], dim=-1)
|
| 173 |
+
|
| 174 |
+
# EOS check
|
| 175 |
+
if next_token.item() == TOKENIZER.eos_token_id:
|
| 176 |
+
break
|
| 177 |
+
|
| 178 |
+
elapsed = time.time() - t0
|
| 179 |
+
tokens_generated = generated.shape[1] - input_ids.shape[1]
|
| 180 |
+
tps = tokens_generated / elapsed if elapsed > 0 else 0
|
| 181 |
+
|
| 182 |
+
output_text = TOKENIZER.decode(generated[0], skip_special_tokens=True)
|
| 183 |
+
stats = f"\n\n---\n๐ {tokens_generated} tokens | {tps:.1f} tok/s | {elapsed:.2f}s"
|
| 184 |
+
|
| 185 |
+
return output_text + stats
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_model_info():
|
| 189 |
+
"""Return model architecture info."""
|
| 190 |
+
if MODEL is None:
|
| 191 |
+
load_model()
|
| 192 |
+
|
| 193 |
+
if MODEL is None:
|
| 194 |
+
return "Model not loaded"
|
| 195 |
+
|
| 196 |
+
info = "## AETHER-Net 0.8B โ Architecture Info\n\n"
|
| 197 |
+
info += f"| Item | Value |\n|---|---|\n"
|
| 198 |
+
info += f"| Device | {DEVICE} |\n"
|
| 199 |
+
info += f"| Parameters | {sum(p.numel() for p in MODEL.parameters()):,} |\n"
|
| 200 |
+
info += f"| Layers | {len(MODEL.layers)} |\n"
|
| 201 |
+
info += f"| Vocab | {MODEL.config.vocab_size:,} |\n"
|
| 202 |
+
info += f"| Hidden | {MODEL.config.hidden_size} |\n"
|
| 203 |
+
|
| 204 |
+
# Layer types
|
| 205 |
+
from config import LAYER_TYPES, LAYER_TO_ELEMENT, ELEMENTS
|
| 206 |
+
info += f"\n### Layer Map\n\n"
|
| 207 |
+
info += "| Layer | Type | Element |\n|---|---|---|\n"
|
| 208 |
+
for i in range(len(MODEL.layers)):
|
| 209 |
+
lt = LAYER_TYPES[i]
|
| 210 |
+
elem = LAYER_TO_ELEMENT[i]
|
| 211 |
+
info += f"| {i} | {lt.upper()} | {elem} |\n"
|
| 212 |
+
|
| 213 |
+
# Oheng status
|
| 214 |
+
info += f"\n### Oheng Status\n\n"
|
| 215 |
+
for elem in ELEMENTS:
|
| 216 |
+
layers = [i for i in range(25) if LAYER_TO_ELEMENT[i] == elem]
|
| 217 |
+
alphas = []
|
| 218 |
+
for li in layers:
|
| 219 |
+
gb = MODEL.layers[li].moe.generate_boost
|
| 220 |
+
if gb is not None:
|
| 221 |
+
a = torch.sigmoid(gb.alpha).detach()
|
| 222 |
+
eidx = ELEMENTS.index(elem)
|
| 223 |
+
if eidx < a.shape[0]:
|
| 224 |
+
alphas.append(a[eidx].item())
|
| 225 |
+
avg = sum(alphas) / len(alphas) if alphas else 0
|
| 226 |
+
info += f"- {elem}: ฮฑ={avg:.4f}\n"
|
| 227 |
+
|
| 228 |
+
return info
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
# โโ Gradio UI โโ
|
| 232 |
+
TITLE = """
|
| 233 |
+
<div style="text-align:center; padding:15px 0;">
|
| 234 |
+
<h1>๐ AETHER-Net 0.8B โ Inference Test</h1>
|
| 235 |
+
<p style="color:#666;">Cross-Architecture Knowledge Distillation from Qwen3.5-0.8B</p>
|
| 236 |
+
<p style="color:#999; font-size:0.9em;">5ร5 Magic Square | Oheng MoE | 5 Attention Types</p>
|
| 237 |
+
</div>
|
| 238 |
+
"""
|
| 239 |
+
|
| 240 |
+
with gr.Blocks(title="AETHER-Net Test") as app:
|
| 241 |
+
gr.HTML(TITLE)
|
| 242 |
+
|
| 243 |
+
with gr.Tabs():
|
| 244 |
+
with gr.Tab("๐ฌ Generate"):
|
| 245 |
+
gr.Markdown("ํ๋กฌํํธ๋ฅผ ์
๋ ฅํ๋ฉด AETHER-Net์ด ํ
์คํธ๋ฅผ ์์ฑํฉ๋๋ค.")
|
| 246 |
+
|
| 247 |
+
with gr.Row():
|
| 248 |
+
with gr.Column(scale=3):
|
| 249 |
+
prompt = gr.Textbox(
|
| 250 |
+
label="Prompt",
|
| 251 |
+
placeholder="Enter your prompt here...",
|
| 252 |
+
lines=3,
|
| 253 |
+
value="The theory of relativity explains that"
|
| 254 |
+
)
|
| 255 |
+
with gr.Column(scale=1):
|
| 256 |
+
max_tokens = gr.Slider(16, 512, value=128, step=16, label="Max Tokens")
|
| 257 |
+
temperature = gr.Slider(0.0, 2.0, value=0.8, step=0.1, label="Temperature")
|
| 258 |
+
top_k = gr.Slider(0, 100, value=50, step=5, label="Top-K")
|
| 259 |
+
top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
|
| 260 |
+
|
| 261 |
+
gen_btn = gr.Button("๐ Generate", variant="primary", size="lg")
|
| 262 |
+
output = gr.Textbox(label="Output", lines=12, interactive=False)
|
| 263 |
+
|
| 264 |
+
gen_btn.click(
|
| 265 |
+
fn=generate,
|
| 266 |
+
inputs=[prompt, max_tokens, temperature, top_k, top_p],
|
| 267 |
+
outputs=output,
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
gr.Markdown("### Quick Prompts")
|
| 271 |
+
examples = gr.Examples(
|
| 272 |
+
examples=[
|
| 273 |
+
["The theory of relativity explains that"],
|
| 274 |
+
["In Python, the most efficient way to sort a list is"],
|
| 275 |
+
["The five elements of nature are"],
|
| 276 |
+
["Artificial general intelligence requires"],
|
| 277 |
+
["ํ๊ตญ์ ์๋๋"],
|
| 278 |
+
["def fibonacci(n):"],
|
| 279 |
+
],
|
| 280 |
+
inputs=prompt,
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
with gr.Tab("๐ Model Info"):
|
| 284 |
+
info_btn = gr.Button("Load Model Info", variant="primary")
|
| 285 |
+
info_output = gr.Markdown()
|
| 286 |
+
info_btn.click(fn=get_model_info, outputs=info_output)
|
| 287 |
+
|
| 288 |
+
with gr.Tab("โน๏ธ About"):
|
| 289 |
+
gr.Markdown("""
|
| 290 |
+
## AETHER-Net 0.8B
|
| 291 |
+
|
| 292 |
+
**Cross-Architecture Knowledge Distillation from Qwen3.5-0.8B**
|
| 293 |
+
|
| 294 |
+
### Method
|
| 295 |
+
- **Weight Transplant**: Qwen3.5-0.8B โ AETHER-Net (5ร5 Magic Square layout)
|
| 296 |
+
- **3-Stage MOHAWK Distillation**: KLD โ Hidden Alignment โ Oheng Regularization
|
| 297 |
+
- **Cost**: ~$0 (CPU-only, 100 steps demo)
|
| 298 |
+
|
| 299 |
+
### Architecture
|
| 300 |
+
- 25 Layers: 5 attention types ร 5 elements
|
| 301 |
+
- GDN, Full, Mamba2, Sliding Window, Cross Attention
|
| 302 |
+
- Oheng MoE: 25 experts, ์์(Generate) + ์๊ทน(Overcome)
|
| 303 |
+
|
| 304 |
+
### Source
|
| 305 |
+
- Model: [FINAL-Bench/AETHER-Net-0.8B](https://huggingface.co/FINAL-Bench/AETHER-Net-0.8B) (private)
|
| 306 |
+
- Space: [FINAL-Bench/agi-model-gen](https://huggingface.co/spaces/FINAL-Bench/agi-model-gen)
|
| 307 |
+
|
| 308 |
+
---
|
| 309 |
+
ยฉ 2026 VIDRAFT / Ginigen AI
|
| 310 |
+
""")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
# โโ Preload model on startup โโ
|
| 314 |
+
print("\n=== Pre-loading model ===")
|
| 315 |
+
load_model()
|
| 316 |
+
print("=== Ready ===\n")
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
if __name__ == "__main__":
|
| 320 |
+
app.launch()
|
config.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AETHER-Net Configuration
|
| 3 |
+
Adaptive Elemental Transformer-Hybrid Efficient Recurrent Network
|
| 4 |
+
|
| 5 |
+
5ร5 Latin Orthogonal Magic Square Layout + Oheng(ไบ่ก) MoE Routing
|
| 6 |
+
"""
|
| 7 |
+
from dataclasses import dataclass, field
|
| 8 |
+
from typing import List, Tuple
|
| 9 |
+
|
| 10 |
+
# โโ 5ร5 Latin Orthogonal Magic Square โโ
|
| 11 |
+
# Each row (element group) and each column (phase) contains
|
| 12 |
+
# exactly one of each attention type โ zero carry-over bias.
|
| 13 |
+
MAGIC_SQUARE = [
|
| 14 |
+
# Phase1 Phase2 Phase3 Phase4 Phase5
|
| 15 |
+
["gdn", "full", "mamba2", "slide", "cross"], # ๆจ Wood
|
| 16 |
+
["slide", "gdn", "full", "cross", "mamba2"], # ็ซ Fire
|
| 17 |
+
["full", "cross", "slide", "mamba2", "gdn"], # ๅ Earth
|
| 18 |
+
["mamba2", "slide", "cross", "gdn", "full"], # ้ Metal
|
| 19 |
+
["cross", "mamba2", "gdn", "full", "slide"], # ๆฐด Water
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
# Flatten to 25-layer sequence (row-major)
|
| 23 |
+
LAYER_TYPES = [t for row in MAGIC_SQUARE for t in row]
|
| 24 |
+
|
| 25 |
+
# โโ Oheng (ไบ่ก) Element System โโ
|
| 26 |
+
ELEMENTS = ["wood", "fire", "earth", "metal", "water"]
|
| 27 |
+
|
| 28 |
+
# ์์ (Generate): ๆจโ็ซโๅโ้โๆฐดโๆจ
|
| 29 |
+
GENERATE = {"wood": "fire", "fire": "earth", "earth": "metal", "metal": "water", "water": "wood"}
|
| 30 |
+
GENERATE_REVERSE = {v: k for k, v in GENERATE.items()}
|
| 31 |
+
|
| 32 |
+
# ์๊ทน (Overcome): ๆจโฃๅ, ๅโฃๆฐด, ๆฐดโฃ็ซ, ็ซโฃ้, ้โฃๆจ
|
| 33 |
+
OVERCOME = {"wood": "earth", "earth": "water", "water": "fire", "fire": "metal", "metal": "wood"}
|
| 34 |
+
OVERCOME_REVERSE = {v: k for k, v in OVERCOME.items()}
|
| 35 |
+
|
| 36 |
+
# Element โ Layer indices (0-based)
|
| 37 |
+
ELEMENT_LAYERS = {
|
| 38 |
+
"wood": [0, 1, 2, 3, 4],
|
| 39 |
+
"fire": [5, 6, 7, 8, 9],
|
| 40 |
+
"earth": [10, 11, 12, 13, 14],
|
| 41 |
+
"metal": [15, 16, 17, 18, 19],
|
| 42 |
+
"water": [20, 21, 22, 23, 24],
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
# Element โ Expert indices (0-based, 5 experts per element)
|
| 46 |
+
ELEMENT_EXPERTS = {
|
| 47 |
+
"wood": [0, 1, 2, 3, 4],
|
| 48 |
+
"fire": [5, 6, 7, 8, 9],
|
| 49 |
+
"earth": [10, 11, 12, 13, 14],
|
| 50 |
+
"metal": [15, 16, 17, 18, 19],
|
| 51 |
+
"water": [20, 21, 22, 23, 24],
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
# Layer index โ element name
|
| 55 |
+
LAYER_TO_ELEMENT = {}
|
| 56 |
+
for elem, indices in ELEMENT_LAYERS.items():
|
| 57 |
+
for idx in indices:
|
| 58 |
+
LAYER_TO_ELEMENT[idx] = elem
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class AetherNetConfig:
|
| 63 |
+
"""Configuration for AETHER-Net model."""
|
| 64 |
+
|
| 65 |
+
# โโ Model dimensions โโ
|
| 66 |
+
hidden_size: int = 4096
|
| 67 |
+
intermediate_size: int = 11008 # FFN intermediate (SwiGLU)
|
| 68 |
+
num_layers: int = 25
|
| 69 |
+
num_attention_heads: int = 32
|
| 70 |
+
num_kv_heads: int = 8 # GQA for Full Attention layers
|
| 71 |
+
head_dim: int = 128 # hidden_size // num_attention_heads
|
| 72 |
+
vocab_size: int = 151936 # Qwen tokenizer
|
| 73 |
+
max_position_embeddings: int = 262144
|
| 74 |
+
rope_theta: float = 10000000.0
|
| 75 |
+
|
| 76 |
+
# โโ Layer schedule (from magic square) โโ
|
| 77 |
+
layer_types: List[str] = field(default_factory=lambda: LAYER_TYPES)
|
| 78 |
+
|
| 79 |
+
# โโ MoE Configuration โโ
|
| 80 |
+
num_experts: int = 25
|
| 81 |
+
num_experts_per_group: int = 5
|
| 82 |
+
num_element_groups: int = 5
|
| 83 |
+
top_k: int = 2
|
| 84 |
+
num_shared_experts: int = 1
|
| 85 |
+
expert_intermediate_size: int = 2752 # intermediate_size // 4 (per expert)
|
| 86 |
+
moe_jitter_eps: float = 0.01
|
| 87 |
+
|
| 88 |
+
# โโ Oheng (ไบ่ก) routing โโ
|
| 89 |
+
use_generate_boost: bool = True
|
| 90 |
+
use_overcome_gate: bool = True
|
| 91 |
+
generate_alpha_init: float = 0.1 # learnable soft scalar
|
| 92 |
+
overcome_gate_hidden: int = 256 # critic head hidden dim
|
| 93 |
+
|
| 94 |
+
# โโ Attention-specific โโ
|
| 95 |
+
sliding_window_size: int = 4096
|
| 96 |
+
gdn_state_size: int = 128 # Gated DeltaNet state dimension
|
| 97 |
+
mamba2_state_size: int = 128
|
| 98 |
+
mamba2_conv_size: int = 4
|
| 99 |
+
mamba2_expand: int = 2
|
| 100 |
+
|
| 101 |
+
# โโ Training / Inference โโ
|
| 102 |
+
rms_norm_eps: float = 1e-6
|
| 103 |
+
initializer_range: float = 0.02
|
| 104 |
+
tie_word_embeddings: bool = False
|
| 105 |
+
use_cache: bool = True
|
| 106 |
+
torch_dtype: str = "bfloat16"
|
| 107 |
+
|
| 108 |
+
# โโ Donor transplant info (metadata) โโ
|
| 109 |
+
primary_donor: str = "Qwen/Qwen3.5-27B"
|
| 110 |
+
secondary_donor: str = "meta-llama/Llama-3.1-8B"
|
| 111 |
+
|
| 112 |
+
def get_layer_type(self, layer_idx: int) -> str:
|
| 113 |
+
return self.layer_types[layer_idx]
|
| 114 |
+
|
| 115 |
+
def get_layer_element(self, layer_idx: int) -> str:
|
| 116 |
+
return LAYER_TO_ELEMENT[layer_idx]
|
| 117 |
+
|
| 118 |
+
def get_element_expert_range(self, element: str) -> Tuple[int, int]:
|
| 119 |
+
indices = ELEMENT_EXPERTS[element]
|
| 120 |
+
return (indices[0], indices[-1] + 1)
|
| 121 |
+
|
| 122 |
+
def summary(self) -> str:
|
| 123 |
+
type_counts = {}
|
| 124 |
+
for t in self.layer_types:
|
| 125 |
+
type_counts[t] = type_counts.get(t, 0) + 1
|
| 126 |
+
total_params_b = (
|
| 127 |
+
self.num_experts * self.expert_intermediate_size * self.hidden_size * 3 * 2 # experts
|
| 128 |
+
+ self.num_layers * self.hidden_size * self.hidden_size * 4 # attention projections
|
| 129 |
+
+ self.vocab_size * self.hidden_size * 2 # embeddings
|
| 130 |
+
) / 1e9
|
| 131 |
+
active_params_b = total_params_b * (self.top_k + self.num_shared_experts) / self.num_experts_per_group
|
| 132 |
+
lines = [
|
| 133 |
+
"โ" * 60,
|
| 134 |
+
" AETHER-Net Architecture Summary",
|
| 135 |
+
"โ" * 60,
|
| 136 |
+
f" Layers: {self.num_layers} (5ร5 magic square)",
|
| 137 |
+
f" Hidden dim: {self.hidden_size}",
|
| 138 |
+
f" Attention mix: {type_counts}",
|
| 139 |
+
f" MoE: {self.num_experts} experts / {self.num_element_groups} groups / top-{self.top_k}",
|
| 140 |
+
f" Est. total: ~{total_params_b:.1f}B params",
|
| 141 |
+
f" Est. active: ~{active_params_b:.1f}B params",
|
| 142 |
+
f" Context: {self.max_position_embeddings:,} tokens",
|
| 143 |
+
f" Oheng generate: {self.use_generate_boost} (ฮฑ={self.generate_alpha_init})",
|
| 144 |
+
f" Oheng overcome: {self.use_overcome_gate}",
|
| 145 |
+
f" Primary donor: {self.primary_donor}",
|
| 146 |
+
f" Secondary donor:{self.secondary_donor}",
|
| 147 |
+
"โ" * 60,
|
| 148 |
+
]
|
| 149 |
+
return "\n".join(lines)
|
layers.py
ADDED
|
@@ -0,0 +1,449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AETHER-Net Attention Layers
|
| 3 |
+
5 types: GDN, Full, Mamba2, Sliding Window, Cross Attention
|
| 4 |
+
|
| 5 |
+
Each layer follows the same interface:
|
| 6 |
+
forward(hidden_states, attention_mask=None, position_ids=None, **kwargs) -> hidden_states
|
| 7 |
+
"""
|
| 8 |
+
import math
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from typing import Optional, Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class RMSNorm(nn.Module):
|
| 16 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
| 17 |
+
super().__init__()
|
| 18 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 19 |
+
self.eps = eps
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
variance = x.float().pow(2).mean(-1, keepdim=True)
|
| 23 |
+
x = x * torch.rsqrt(variance + self.eps)
|
| 24 |
+
return (self.weight * x).to(x.dtype)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def rotate_half(x):
|
| 28 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 29 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def apply_rotary_pos_emb(q, k, cos, sin):
|
| 33 |
+
q_embed = (q * cos) + (rotate_half(q) * sin)
|
| 34 |
+
k_embed = (k * cos) + (rotate_half(k) * sin)
|
| 35 |
+
return q_embed, k_embed
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class RotaryEmbedding(nn.Module):
|
| 39 |
+
def __init__(self, dim: int, max_seq_len: int = 262144, theta: float = 10000000.0):
|
| 40 |
+
super().__init__()
|
| 41 |
+
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
|
| 42 |
+
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
| 43 |
+
self.max_seq_len = max_seq_len
|
| 44 |
+
|
| 45 |
+
def forward(self, x, position_ids):
|
| 46 |
+
# position_ids: [B, L] โ take first batch (all same for standard positions)
|
| 47 |
+
pos = position_ids[0] if position_ids.dim() == 2 else position_ids
|
| 48 |
+
freqs = torch.outer(pos.float(), self.inv_freq.to(pos.device))
|
| 49 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 50 |
+
return emb.cos().unsqueeze(0), emb.sin().unsqueeze(0)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 54 |
+
# 1. FULL ATTENTION (Softmax, GQA, RoPE) โ O(nยฒ)
|
| 55 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 56 |
+
class FullAttention(nn.Module):
|
| 57 |
+
"""Standard grouped-query attention with RoPE.
|
| 58 |
+
Kept for 5 layers โ provides precise token-to-token reasoning.
|
| 59 |
+
These layers maintain KV cache."""
|
| 60 |
+
|
| 61 |
+
def __init__(self, config):
|
| 62 |
+
super().__init__()
|
| 63 |
+
self.num_heads = config.num_attention_heads
|
| 64 |
+
self.num_kv_heads = config.num_kv_heads
|
| 65 |
+
self.head_dim = config.head_dim
|
| 66 |
+
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
| 67 |
+
|
| 68 |
+
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 69 |
+
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| 70 |
+
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| 71 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
| 72 |
+
|
| 73 |
+
# Output gate (Qwen3.5 style gated attention)
|
| 74 |
+
self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 75 |
+
|
| 76 |
+
self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
|
| 77 |
+
|
| 78 |
+
def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
|
| 79 |
+
B, L, _ = hidden_states.shape
|
| 80 |
+
|
| 81 |
+
q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 82 |
+
k = self.k_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 83 |
+
v = self.v_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 84 |
+
|
| 85 |
+
# RoPE
|
| 86 |
+
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
| 87 |
+
cos = cos.unsqueeze(1)
|
| 88 |
+
sin = sin.unsqueeze(1)
|
| 89 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| 90 |
+
|
| 91 |
+
# GQA: expand KV heads
|
| 92 |
+
if self.num_kv_groups > 1:
|
| 93 |
+
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
| 94 |
+
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
| 95 |
+
|
| 96 |
+
# Scaled dot-product attention
|
| 97 |
+
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 98 |
+
|
| 99 |
+
# Causal mask
|
| 100 |
+
causal = torch.triu(torch.full((L, L), float('-inf'), device=attn.device), diagonal=1)
|
| 101 |
+
attn = attn + causal.unsqueeze(0).unsqueeze(0)
|
| 102 |
+
if attention_mask is not None:
|
| 103 |
+
attn = attn + attention_mask
|
| 104 |
+
|
| 105 |
+
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 106 |
+
out = torch.matmul(attn, v)
|
| 107 |
+
out = out.transpose(1, 2).contiguous().view(B, L, -1)
|
| 108 |
+
|
| 109 |
+
# Output gating
|
| 110 |
+
gate = torch.sigmoid(self.gate(hidden_states))
|
| 111 |
+
out = out * gate
|
| 112 |
+
|
| 113 |
+
return self.o_proj(out)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 117 |
+
# 2. GATED DELTANET (GDN) โ O(n) linear time
|
| 118 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 119 |
+
class GatedDeltaNet(nn.Module):
|
| 120 |
+
"""Gated DeltaNet: Mamba-style gating + DeltaNet fast-weight update.
|
| 121 |
+
Core linear attention mechanism โ 10 layers (40% of model).
|
| 122 |
+
|
| 123 |
+
Implements: M_t = ฮฑ_t * M_{t-1} * (I - k_t * q_t^T) + k_t * v_t^T
|
| 124 |
+
with SiLU output gating for gradient flow stability.
|
| 125 |
+
|
| 126 |
+
Weight transplant: Q,K,V projections map directly from Qwen3.5 GDN layers.
|
| 127 |
+
"""
|
| 128 |
+
|
| 129 |
+
def __init__(self, config):
|
| 130 |
+
super().__init__()
|
| 131 |
+
self.hidden_size = config.hidden_size
|
| 132 |
+
self.num_heads = config.num_attention_heads
|
| 133 |
+
self.head_dim = config.head_dim
|
| 134 |
+
self.state_size = config.gdn_state_size
|
| 135 |
+
|
| 136 |
+
# Input projections (transplantable from Qwen3.5 GDN)
|
| 137 |
+
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 138 |
+
self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 139 |
+
self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 140 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
| 141 |
+
|
| 142 |
+
# Decay gate (ฮฑ): controls memory decay speed
|
| 143 |
+
self.decay_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True)
|
| 144 |
+
|
| 145 |
+
# Update gate (ฮฒ): controls state update strength
|
| 146 |
+
self.beta_proj = nn.Linear(config.hidden_size, self.num_heads, bias=True)
|
| 147 |
+
|
| 148 |
+
# Output gate (SiLU activation for gradient stability)
|
| 149 |
+
self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 150 |
+
|
| 151 |
+
# Short convolution for local context (replaces positional encoding)
|
| 152 |
+
self.conv1d = nn.Conv1d(
|
| 153 |
+
in_channels=config.hidden_size,
|
| 154 |
+
out_channels=config.hidden_size,
|
| 155 |
+
kernel_size=4, padding=3, groups=config.hidden_size, bias=True
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
|
| 159 |
+
B, L, D = hidden_states.shape
|
| 160 |
+
|
| 161 |
+
# Local context mixing via causal conv1d
|
| 162 |
+
conv_out = self.conv1d(hidden_states.transpose(1, 2))[..., :L].transpose(1, 2)
|
| 163 |
+
|
| 164 |
+
q = self.q_proj(conv_out).view(B, L, self.num_heads, self.head_dim)
|
| 165 |
+
k = self.k_proj(conv_out).view(B, L, self.num_heads, self.head_dim)
|
| 166 |
+
v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim)
|
| 167 |
+
|
| 168 |
+
# L2 normalize Q, K (replaces softmax normalization)
|
| 169 |
+
q = F.normalize(q, p=2, dim=-1)
|
| 170 |
+
k = F.normalize(k, p=2, dim=-1)
|
| 171 |
+
|
| 172 |
+
# Decay and update gates
|
| 173 |
+
alpha = torch.sigmoid(self.decay_proj(hidden_states)).unsqueeze(-1) # [B, L, H, 1]
|
| 174 |
+
beta = torch.sigmoid(self.beta_proj(hidden_states)).unsqueeze(-1)
|
| 175 |
+
|
| 176 |
+
# Recurrent scan with delta rule
|
| 177 |
+
# M_t = ฮฑ * M_{t-1} * (I - ฮฒ * k * q^T) + ฮฒ * k * v^T
|
| 178 |
+
# For efficiency, compute as: o_t = q^T @ M_t
|
| 179 |
+
outputs = []
|
| 180 |
+
state = torch.zeros(B, self.num_heads, self.head_dim, self.head_dim,
|
| 181 |
+
device=hidden_states.device, dtype=hidden_states.dtype)
|
| 182 |
+
|
| 183 |
+
for t in range(L):
|
| 184 |
+
q_t = q[:, t] # [B, H, D]
|
| 185 |
+
k_t = k[:, t]
|
| 186 |
+
v_t = v[:, t]
|
| 187 |
+
a_t = alpha[:, t] # [B, H, 1]
|
| 188 |
+
b_t = beta[:, t]
|
| 189 |
+
|
| 190 |
+
# Delta rule update
|
| 191 |
+
# Erase: state = ฮฑ * state * (I - ฮฒ * k * q^T)
|
| 192 |
+
# Write: state += ฮฒ * k * v^T
|
| 193 |
+
erase = torch.einsum('bhd,bhe->bhde', k_t * b_t, q_t)
|
| 194 |
+
write = torch.einsum('bhd,bhe->bhde', k_t * b_t, v_t)
|
| 195 |
+
state = a_t.unsqueeze(-1) * (state - state * erase) + write
|
| 196 |
+
|
| 197 |
+
# Read: o_t = q^T @ state
|
| 198 |
+
o_t = torch.einsum('bhd,bhde->bhe', q_t, state)
|
| 199 |
+
outputs.append(o_t)
|
| 200 |
+
|
| 201 |
+
out = torch.stack(outputs, dim=1) # [B, L, H, D]
|
| 202 |
+
out = out.reshape(B, L, -1)
|
| 203 |
+
|
| 204 |
+
# Output gating with SiLU
|
| 205 |
+
gate = F.silu(self.gate(hidden_states))
|
| 206 |
+
out = out * gate
|
| 207 |
+
|
| 208 |
+
return self.o_proj(out)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 212 |
+
# 3. MAMBA2 โ O(n) with SSM state-space duality
|
| 213 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 214 |
+
class Mamba2Block(nn.Module):
|
| 215 |
+
"""Mamba-2 block with Structured State Space Duality.
|
| 216 |
+
5 layers โ provides state compression for memory efficiency.
|
| 217 |
+
|
| 218 |
+
Weight transplant: Via MOHAWK SSD duality from Llama-3.1 Q,K,V โ C,B,X.
|
| 219 |
+
"""
|
| 220 |
+
|
| 221 |
+
def __init__(self, config):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.hidden_size = config.hidden_size
|
| 224 |
+
expand = config.mamba2_expand
|
| 225 |
+
self.inner_size = config.hidden_size * expand
|
| 226 |
+
self.state_size = config.mamba2_state_size
|
| 227 |
+
self.conv_size = config.mamba2_conv_size
|
| 228 |
+
self.num_heads = config.num_attention_heads
|
| 229 |
+
|
| 230 |
+
# Input projection: x โ (z, x_ssm) split
|
| 231 |
+
self.in_proj = nn.Linear(config.hidden_size, self.inner_size * 2, bias=False)
|
| 232 |
+
|
| 233 |
+
# Causal conv1d
|
| 234 |
+
self.conv1d = nn.Conv1d(
|
| 235 |
+
self.inner_size, self.inner_size,
|
| 236 |
+
kernel_size=self.conv_size, padding=self.conv_size - 1,
|
| 237 |
+
groups=self.inner_size, bias=True
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
# SSM parameters
|
| 241 |
+
self.dt_proj = nn.Linear(self.inner_size, self.num_heads, bias=True)
|
| 242 |
+
self.A_log = nn.Parameter(torch.log(torch.arange(1, self.num_heads + 1, dtype=torch.float32)))
|
| 243 |
+
self.D = nn.Parameter(torch.ones(self.num_heads))
|
| 244 |
+
|
| 245 |
+
# B, C projections (state-space)
|
| 246 |
+
head_dim_ssm = self.inner_size // self.num_heads
|
| 247 |
+
self.B_proj = nn.Linear(self.inner_size, self.state_size * self.num_heads, bias=False)
|
| 248 |
+
self.C_proj = nn.Linear(self.inner_size, self.state_size * self.num_heads, bias=False)
|
| 249 |
+
|
| 250 |
+
# Output
|
| 251 |
+
self.out_proj = nn.Linear(self.inner_size, config.hidden_size, bias=False)
|
| 252 |
+
self.norm = RMSNorm(self.inner_size)
|
| 253 |
+
|
| 254 |
+
def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
|
| 255 |
+
B, L, _ = hidden_states.shape
|
| 256 |
+
|
| 257 |
+
# Input split
|
| 258 |
+
zx = self.in_proj(hidden_states)
|
| 259 |
+
z, x = zx.chunk(2, dim=-1)
|
| 260 |
+
|
| 261 |
+
# Causal conv
|
| 262 |
+
x = self.conv1d(x.transpose(1, 2))[..., :L].transpose(1, 2)
|
| 263 |
+
x = F.silu(x)
|
| 264 |
+
|
| 265 |
+
# SSM parameters
|
| 266 |
+
A = -torch.exp(self.A_log) # [H]
|
| 267 |
+
dt = F.softplus(self.dt_proj(x)) # [B, L, H]
|
| 268 |
+
|
| 269 |
+
B_state = self.B_proj(x).view(B, L, self.num_heads, self.state_size)
|
| 270 |
+
C_state = self.C_proj(x).view(B, L, self.num_heads, self.state_size)
|
| 271 |
+
|
| 272 |
+
# Discretize: A_bar = exp(dt * A), B_bar = dt * B
|
| 273 |
+
dt_A = dt.unsqueeze(-1) * A.view(1, 1, -1, 1) # [B, L, H, 1]
|
| 274 |
+
A_bar = torch.exp(dt_A)
|
| 275 |
+
B_bar = dt.unsqueeze(-1) * B_state # [B, L, H, N]
|
| 276 |
+
|
| 277 |
+
# Selective scan (sequential for correctness; replace with FLA parallel kernel)
|
| 278 |
+
head_dim = self.inner_size // self.num_heads
|
| 279 |
+
x_heads = x.view(B, L, self.num_heads, head_dim)
|
| 280 |
+
|
| 281 |
+
outputs = []
|
| 282 |
+
state = torch.zeros(B, self.num_heads, self.state_size, device=x.device, dtype=x.dtype)
|
| 283 |
+
|
| 284 |
+
for t in range(L):
|
| 285 |
+
state = A_bar[:, t] * state + B_bar[:, t] * x_heads[:, t, :, :1].expand_as(B_bar[:, t])
|
| 286 |
+
y_t = torch.sum(state * C_state[:, t], dim=-1) # [B, H]
|
| 287 |
+
outputs.append(y_t)
|
| 288 |
+
|
| 289 |
+
y = torch.stack(outputs, dim=1) # [B, L, H]
|
| 290 |
+
|
| 291 |
+
# Skip connection with D
|
| 292 |
+
y = y + self.D.view(1, 1, -1) * x.view(B, L, self.num_heads, head_dim).mean(-1)
|
| 293 |
+
|
| 294 |
+
# Expand back and gate with z
|
| 295 |
+
y = y.unsqueeze(-1).expand(-1, -1, -1, head_dim).reshape(B, L, self.inner_size)
|
| 296 |
+
y = self.norm(y)
|
| 297 |
+
y = y * F.silu(z)
|
| 298 |
+
|
| 299 |
+
return self.out_proj(y)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 303 |
+
# 4. SLIDING WINDOW ATTENTION โ O(n * w)
|
| 304 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 305 |
+
class SlidingWindowAttention(nn.Module):
|
| 306 |
+
"""Sliding window attention for local pattern capture.
|
| 307 |
+
5 layers โ complements GDN's global view with fine-grained local context."""
|
| 308 |
+
|
| 309 |
+
def __init__(self, config):
|
| 310 |
+
super().__init__()
|
| 311 |
+
self.num_heads = config.num_attention_heads
|
| 312 |
+
self.num_kv_heads = config.num_kv_heads
|
| 313 |
+
self.head_dim = config.head_dim
|
| 314 |
+
self.window_size = config.sliding_window_size
|
| 315 |
+
self.num_kv_groups = self.num_heads // self.num_kv_heads
|
| 316 |
+
|
| 317 |
+
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 318 |
+
self.k_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| 319 |
+
self.v_proj = nn.Linear(config.hidden_size, self.num_kv_heads * self.head_dim, bias=False)
|
| 320 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
| 321 |
+
self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 322 |
+
|
| 323 |
+
self.rotary_emb = RotaryEmbedding(self.head_dim, config.max_position_embeddings, config.rope_theta)
|
| 324 |
+
|
| 325 |
+
def forward(self, hidden_states, attention_mask=None, position_ids=None, **kwargs):
|
| 326 |
+
B, L, _ = hidden_states.shape
|
| 327 |
+
|
| 328 |
+
q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 329 |
+
k = self.k_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 330 |
+
v = self.v_proj(hidden_states).view(B, L, self.num_kv_heads, self.head_dim).transpose(1, 2)
|
| 331 |
+
|
| 332 |
+
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
| 333 |
+
cos = cos.unsqueeze(1)
|
| 334 |
+
sin = sin.unsqueeze(1)
|
| 335 |
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
| 336 |
+
|
| 337 |
+
if self.num_kv_groups > 1:
|
| 338 |
+
k = k.repeat_interleave(self.num_kv_groups, dim=1)
|
| 339 |
+
v = v.repeat_interleave(self.num_kv_groups, dim=1)
|
| 340 |
+
|
| 341 |
+
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 342 |
+
|
| 343 |
+
# Sliding window + causal mask
|
| 344 |
+
mask = torch.ones(L, L, device=attn.device, dtype=torch.bool)
|
| 345 |
+
mask = torch.triu(mask, diagonal=1) # causal
|
| 346 |
+
mask = mask | torch.tril(torch.ones_like(mask), diagonal=-self.window_size) # window
|
| 347 |
+
attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
|
| 348 |
+
|
| 349 |
+
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 350 |
+
out = torch.matmul(attn, v)
|
| 351 |
+
out = out.transpose(1, 2).contiguous().view(B, L, -1)
|
| 352 |
+
|
| 353 |
+
gate = torch.sigmoid(self.gate(hidden_states))
|
| 354 |
+
out = out * gate
|
| 355 |
+
|
| 356 |
+
return self.o_proj(out)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 360 |
+
# 5. CROSS ATTENTION โ for multimodal / tool bridging
|
| 361 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 362 |
+
class CrossAttention(nn.Module):
|
| 363 |
+
"""Cross attention for PROMETHEUS (world model) and HEPHAESTUS (embodiment) connection.
|
| 364 |
+
5 layers โ bridges AETHER-Net to external modalities.
|
| 365 |
+
When no external context: falls back to self-attention with gating."""
|
| 366 |
+
|
| 367 |
+
def __init__(self, config):
|
| 368 |
+
super().__init__()
|
| 369 |
+
self.num_heads = config.num_attention_heads
|
| 370 |
+
self.head_dim = config.head_dim
|
| 371 |
+
|
| 372 |
+
# Self-attention path (default when no external context)
|
| 373 |
+
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 374 |
+
self.k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 375 |
+
self.v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 376 |
+
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=False)
|
| 377 |
+
|
| 378 |
+
# Cross-attention path (when external context available)
|
| 379 |
+
self.cross_k_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 380 |
+
self.cross_v_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 381 |
+
|
| 382 |
+
# Modality gate: lerp between self and cross
|
| 383 |
+
self.modality_gate = nn.Linear(config.hidden_size, 1, bias=True)
|
| 384 |
+
nn.init.constant_(self.modality_gate.bias, -2.0) # default: mostly self-attention
|
| 385 |
+
|
| 386 |
+
self.gate = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=False)
|
| 387 |
+
|
| 388 |
+
def forward(self, hidden_states, attention_mask=None, position_ids=None,
|
| 389 |
+
encoder_hidden_states=None, **kwargs):
|
| 390 |
+
B, L, _ = hidden_states.shape
|
| 391 |
+
|
| 392 |
+
q = self.q_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 393 |
+
|
| 394 |
+
if encoder_hidden_states is not None:
|
| 395 |
+
# Cross-attention mode
|
| 396 |
+
k_cross = self.cross_k_proj(encoder_hidden_states).view(
|
| 397 |
+
B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 398 |
+
v_cross = self.cross_v_proj(encoder_hidden_states).view(
|
| 399 |
+
B, -1, self.num_heads, self.head_dim).transpose(1, 2)
|
| 400 |
+
|
| 401 |
+
attn_cross = torch.matmul(q, k_cross.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 402 |
+
attn_cross = F.softmax(attn_cross, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 403 |
+
out_cross = torch.matmul(attn_cross, v_cross)
|
| 404 |
+
out_cross = out_cross.transpose(1, 2).contiguous().view(B, L, -1)
|
| 405 |
+
|
| 406 |
+
# Self-attention path (always runs)
|
| 407 |
+
k_self = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 408 |
+
v_self = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 409 |
+
attn_self = torch.matmul(q, k_self.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 410 |
+
causal = torch.triu(torch.full((L, L), float('-inf'), device=attn_self.device), diagonal=1)
|
| 411 |
+
attn_self = attn_self + causal.unsqueeze(0).unsqueeze(0)
|
| 412 |
+
attn_self = F.softmax(attn_self, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 413 |
+
out_self = torch.matmul(attn_self, v_self).transpose(1, 2).contiguous().view(B, L, -1)
|
| 414 |
+
|
| 415 |
+
# Blend via modality gate
|
| 416 |
+
mg = torch.sigmoid(self.modality_gate(hidden_states))
|
| 417 |
+
out = mg * out_cross + (1 - mg) * out_self
|
| 418 |
+
else:
|
| 419 |
+
# Pure self-attention fallback
|
| 420 |
+
k = self.k_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 421 |
+
v = self.v_proj(hidden_states).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
|
| 422 |
+
attn = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
|
| 423 |
+
causal = torch.triu(torch.full((L, L), float('-inf'), device=attn.device), diagonal=1)
|
| 424 |
+
attn = attn + causal.unsqueeze(0).unsqueeze(0)
|
| 425 |
+
attn = F.softmax(attn, dim=-1, dtype=torch.float32).to(q.dtype)
|
| 426 |
+
out = torch.matmul(attn, v).transpose(1, 2).contiguous().view(B, L, -1)
|
| 427 |
+
|
| 428 |
+
gate = torch.sigmoid(self.gate(hidden_states))
|
| 429 |
+
out = out * gate
|
| 430 |
+
|
| 431 |
+
return self.o_proj(out)
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 435 |
+
# Factory
|
| 436 |
+
# โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
|
| 437 |
+
ATTENTION_CLASSES = {
|
| 438 |
+
"gdn": GatedDeltaNet,
|
| 439 |
+
"full": FullAttention,
|
| 440 |
+
"mamba2": Mamba2Block,
|
| 441 |
+
"slide": SlidingWindowAttention,
|
| 442 |
+
"cross": CrossAttention,
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
def build_attention(layer_type: str, config):
|
| 446 |
+
cls = ATTENTION_CLASSES.get(layer_type)
|
| 447 |
+
if cls is None:
|
| 448 |
+
raise ValueError(f"Unknown attention type: {layer_type}. Choose from {list(ATTENTION_CLASSES.keys())}")
|
| 449 |
+
return cls(config)
|
model.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AETHER-Net: Main Model
|
| 3 |
+
Adaptive Elemental Transformer-Hybrid Efficient Recurrent Network
|
| 4 |
+
|
| 5 |
+
25-layer hybrid LLM with 5ร5 Latin orthogonal magic square layout
|
| 6 |
+
and Oheng (ไบ่ก) MoE routing.
|
| 7 |
+
"""
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
from typing import Dict, List, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
from config import AetherNetConfig, ELEMENTS, LAYER_TO_ELEMENT, ELEMENT_LAYERS
|
| 13 |
+
from layers import RMSNorm, build_attention
|
| 14 |
+
from oheng_moe import OhengMoE
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class AetherNetBlock(nn.Module):
|
| 18 |
+
"""Single AETHER-Net transformer block.
|
| 19 |
+
|
| 20 |
+
Structure:
|
| 21 |
+
x โ RMSNorm โ Attention โ residual โ RMSNorm โ OhengMoE โ residual โ out
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(self, config: AetherNetConfig, layer_idx: int):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.layer_idx = layer_idx
|
| 27 |
+
self.layer_type = config.get_layer_type(layer_idx)
|
| 28 |
+
self.element = config.get_layer_element(layer_idx)
|
| 29 |
+
|
| 30 |
+
# Pre-norm
|
| 31 |
+
self.input_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 32 |
+
self.post_attention_layernorm = RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 33 |
+
|
| 34 |
+
# Attention (type determined by magic square)
|
| 35 |
+
self.attention = build_attention(self.layer_type, config)
|
| 36 |
+
|
| 37 |
+
# MoE FFN with Oheng routing
|
| 38 |
+
self.moe = OhengMoE(config, layer_idx)
|
| 39 |
+
|
| 40 |
+
def forward(
|
| 41 |
+
self,
|
| 42 |
+
hidden_states: torch.Tensor,
|
| 43 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 44 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 45 |
+
element_states: Optional[Dict[str, torch.Tensor]] = None,
|
| 46 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 47 |
+
) -> torch.Tensor:
|
| 48 |
+
# Attention block with residual
|
| 49 |
+
residual = hidden_states
|
| 50 |
+
hidden_states = self.input_layernorm(hidden_states)
|
| 51 |
+
hidden_states = self.attention(
|
| 52 |
+
hidden_states,
|
| 53 |
+
attention_mask=attention_mask,
|
| 54 |
+
position_ids=position_ids,
|
| 55 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 56 |
+
)
|
| 57 |
+
hidden_states = residual + hidden_states
|
| 58 |
+
|
| 59 |
+
# MoE FFN block with residual
|
| 60 |
+
residual = hidden_states
|
| 61 |
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
| 62 |
+
hidden_states = self.moe(hidden_states, element_states=element_states)
|
| 63 |
+
hidden_states = residual + hidden_states
|
| 64 |
+
|
| 65 |
+
return hidden_states
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class AetherNetModel(nn.Module):
|
| 69 |
+
"""AETHER-Net Language Model.
|
| 70 |
+
|
| 71 |
+
Architecture:
|
| 72 |
+
- Embedding โ 25 ร AetherNetBlock โ RMSNorm โ LM Head
|
| 73 |
+
- Blocks arranged in 5ร5 Latin orthogonal magic square
|
| 74 |
+
- Oheng MoE with ์์ generate and ์๊ทน overcome connections
|
| 75 |
+
- Element states flow between element groups for structural self-verification
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self, config: AetherNetConfig):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.config = config
|
| 81 |
+
|
| 82 |
+
# Token embedding
|
| 83 |
+
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
|
| 84 |
+
|
| 85 |
+
# 25 transformer blocks
|
| 86 |
+
self.layers = nn.ModuleList([
|
| 87 |
+
AetherNetBlock(config, layer_idx=i)
|
| 88 |
+
for i in range(config.num_layers)
|
| 89 |
+
])
|
| 90 |
+
|
| 91 |
+
# Final norm
|
| 92 |
+
self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps)
|
| 93 |
+
|
| 94 |
+
# LM Head
|
| 95 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
| 96 |
+
|
| 97 |
+
# Weight tying
|
| 98 |
+
if config.tie_word_embeddings:
|
| 99 |
+
self.lm_head.weight = self.embed_tokens.weight
|
| 100 |
+
|
| 101 |
+
# Initialize
|
| 102 |
+
self.apply(self._init_weights)
|
| 103 |
+
|
| 104 |
+
def _init_weights(self, module):
|
| 105 |
+
if isinstance(module, nn.Linear):
|
| 106 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
| 107 |
+
if module.bias is not None:
|
| 108 |
+
nn.init.zeros_(module.bias)
|
| 109 |
+
elif isinstance(module, nn.Embedding):
|
| 110 |
+
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
|
| 111 |
+
|
| 112 |
+
def forward(
|
| 113 |
+
self,
|
| 114 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 115 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 116 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 117 |
+
labels: Optional[torch.LongTensor] = None,
|
| 118 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 119 |
+
) -> Dict[str, torch.Tensor]:
|
| 120 |
+
B, L = input_ids.shape
|
| 121 |
+
|
| 122 |
+
# Position IDs
|
| 123 |
+
if position_ids is None:
|
| 124 |
+
position_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
|
| 125 |
+
|
| 126 |
+
# Embed
|
| 127 |
+
hidden_states = self.embed_tokens(input_ids)
|
| 128 |
+
|
| 129 |
+
# โโ Element state tracking for Oheng connections โโ
|
| 130 |
+
# Each element group accumulates its output for ์์/์๊ทน routing
|
| 131 |
+
element_states: Dict[str, torch.Tensor] = {}
|
| 132 |
+
element_layer_counts: Dict[str, int] = {e: 0 for e in ELEMENTS}
|
| 133 |
+
|
| 134 |
+
# โโ Forward through 25 layers โโ
|
| 135 |
+
for i, layer in enumerate(self.layers):
|
| 136 |
+
element = LAYER_TO_ELEMENT[i]
|
| 137 |
+
|
| 138 |
+
hidden_states = layer(
|
| 139 |
+
hidden_states,
|
| 140 |
+
attention_mask=attention_mask,
|
| 141 |
+
position_ids=position_ids,
|
| 142 |
+
element_states=element_states,
|
| 143 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
# Update element state (running average of this element's layer outputs)
|
| 147 |
+
element_layer_counts[element] += 1
|
| 148 |
+
count = element_layer_counts[element]
|
| 149 |
+
if element in element_states:
|
| 150 |
+
# Exponential moving average of element's outputs
|
| 151 |
+
element_states[element] = (
|
| 152 |
+
element_states[element] * (count - 1) / count
|
| 153 |
+
+ hidden_states.detach() / count
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
element_states[element] = hidden_states.detach()
|
| 157 |
+
|
| 158 |
+
# Final norm
|
| 159 |
+
hidden_states = self.norm(hidden_states)
|
| 160 |
+
|
| 161 |
+
# LM Head
|
| 162 |
+
logits = self.lm_head(hidden_states)
|
| 163 |
+
|
| 164 |
+
# Loss
|
| 165 |
+
loss = None
|
| 166 |
+
if labels is not None:
|
| 167 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
| 168 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 169 |
+
loss = nn.functional.cross_entropy(
|
| 170 |
+
shift_logits.view(-1, self.config.vocab_size),
|
| 171 |
+
shift_labels.view(-1),
|
| 172 |
+
ignore_index=-100,
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
return {
|
| 176 |
+
"loss": loss,
|
| 177 |
+
"logits": logits,
|
| 178 |
+
"element_states": element_states,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
def count_parameters(self) -> Dict[str, int]:
|
| 182 |
+
"""Count parameters by component."""
|
| 183 |
+
counts = {
|
| 184 |
+
"embedding": sum(p.numel() for p in self.embed_tokens.parameters()),
|
| 185 |
+
"lm_head": sum(p.numel() for p in self.lm_head.parameters()),
|
| 186 |
+
"norm": sum(p.numel() for p in self.norm.parameters()),
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
attn_total = 0
|
| 190 |
+
moe_total = 0
|
| 191 |
+
generate_total = 0
|
| 192 |
+
overcome_total = 0
|
| 193 |
+
|
| 194 |
+
for layer in self.layers:
|
| 195 |
+
attn_total += sum(p.numel() for p in layer.attention.parameters())
|
| 196 |
+
attn_total += sum(p.numel() for p in layer.input_layernorm.parameters())
|
| 197 |
+
attn_total += sum(p.numel() for p in layer.post_attention_layernorm.parameters())
|
| 198 |
+
|
| 199 |
+
moe_total += sum(p.numel() for p in layer.moe.experts.parameters())
|
| 200 |
+
moe_total += sum(p.numel() for p in layer.moe.shared_expert.parameters())
|
| 201 |
+
moe_total += sum(p.numel() for p in layer.moe.router.parameters())
|
| 202 |
+
|
| 203 |
+
if layer.moe.generate_boost is not None:
|
| 204 |
+
generate_total += sum(p.numel() for p in layer.moe.generate_boost.parameters())
|
| 205 |
+
if layer.moe.overcome_gate is not None:
|
| 206 |
+
overcome_total += sum(p.numel() for p in layer.moe.overcome_gate.parameters())
|
| 207 |
+
|
| 208 |
+
counts["attention_layers"] = attn_total
|
| 209 |
+
counts["moe_experts"] = moe_total
|
| 210 |
+
counts["oheng_generate"] = generate_total
|
| 211 |
+
counts["oheng_overcome"] = overcome_total
|
| 212 |
+
counts["total"] = sum(counts.values())
|
| 213 |
+
|
| 214 |
+
return counts
|
| 215 |
+
|
| 216 |
+
def get_layer_map(self) -> List[Dict]:
|
| 217 |
+
"""Return human-readable layer map for diagnostics."""
|
| 218 |
+
result = []
|
| 219 |
+
for i, layer in enumerate(self.layers):
|
| 220 |
+
result.append({
|
| 221 |
+
"layer": i,
|
| 222 |
+
"type": layer.layer_type,
|
| 223 |
+
"element": layer.element,
|
| 224 |
+
"element_idx": ELEMENTS.index(layer.element),
|
| 225 |
+
"phase": i % 5,
|
| 226 |
+
"attn_class": layer.attention.__class__.__name__,
|
| 227 |
+
})
|
| 228 |
+
return result
|
oheng_moe.py
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Oheng (ไบ่ก) Mixture-of-Experts Router
|
| 3 |
+
|
| 4 |
+
Core innovation: 25 experts organized in 5 element groups with:
|
| 5 |
+
- ์์ (Generate) cycle: WoodโFireโEarthโMetalโWaterโWood
|
| 6 |
+
Previous element's output provides residual boost to next element.
|
| 7 |
+
- ์๊ทน (Overcome) cycle: WoodโฃEarth, EarthโฃWater, WaterโฃFire, FireโฃMetal, MetalโฃWood
|
| 8 |
+
Opposing element provides critic gating to suppress hallucinations.
|
| 9 |
+
- Loss-Free Balancing via dynamic expert bias (DeepSeek-style)
|
| 10 |
+
"""
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from typing import Dict, Optional, Tuple
|
| 15 |
+
|
| 16 |
+
from config import (
|
| 17 |
+
ELEMENTS, GENERATE, GENERATE_REVERSE, OVERCOME, OVERCOME_REVERSE,
|
| 18 |
+
ELEMENT_EXPERTS, LAYER_TO_ELEMENT,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Expert(nn.Module):
|
| 23 |
+
"""Single SwiGLU expert (split from donor MLP)."""
|
| 24 |
+
|
| 25 |
+
def __init__(self, hidden_size: int, intermediate_size: int):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 28 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 29 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SharedExpert(nn.Module):
|
| 36 |
+
"""Shared expert that processes all tokens (always active)."""
|
| 37 |
+
|
| 38 |
+
def __init__(self, hidden_size: int, intermediate_size: int):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 41 |
+
self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
|
| 42 |
+
self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class GenerateBoost(nn.Module):
|
| 49 |
+
"""์์ (Generate) mechanism: Previous element boosts current element.
|
| 50 |
+
|
| 51 |
+
WoodโFireโEarthโMetalโWaterโWood
|
| 52 |
+
Implemented as learnable soft scalar ฮฑ gating on the previous
|
| 53 |
+
element group's pooled expert state.
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(self, hidden_size: int, num_elements: int = 5):
|
| 57 |
+
super().__init__()
|
| 58 |
+
# One learnable ฮฑ per element
|
| 59 |
+
self.alpha = nn.Parameter(torch.full((num_elements,), 0.1))
|
| 60 |
+
# Lightweight projection for sourceโtarget mapping
|
| 61 |
+
self.proj = nn.Linear(hidden_size, hidden_size, bias=False)
|
| 62 |
+
nn.init.zeros_(self.proj.weight) # Start with zero boost
|
| 63 |
+
|
| 64 |
+
def forward(self, hidden: torch.Tensor, source_state: Optional[torch.Tensor],
|
| 65 |
+
element_idx: int) -> torch.Tensor:
|
| 66 |
+
"""
|
| 67 |
+
Args:
|
| 68 |
+
hidden: Current hidden states [B, L, D]
|
| 69 |
+
source_state: Previous element group's output [B, L, D] or None
|
| 70 |
+
element_idx: Index of current element (0=wood, 1=fire, ...)
|
| 71 |
+
Returns:
|
| 72 |
+
Boosted hidden states
|
| 73 |
+
"""
|
| 74 |
+
if source_state is None:
|
| 75 |
+
return hidden
|
| 76 |
+
|
| 77 |
+
alpha = torch.sigmoid(self.alpha[element_idx])
|
| 78 |
+
boost = self.proj(source_state)
|
| 79 |
+
return hidden + alpha * boost
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class OvercomeGate(nn.Module):
|
| 83 |
+
"""์๊ทน (Overcome) mechanism: Opposing element provides critic gating.
|
| 84 |
+
|
| 85 |
+
WoodโฃEarth, EarthโฃWater, WaterโฃFire, FireโฃMetal, MetalโฃWood
|
| 86 |
+
|
| 87 |
+
A lightweight critic head from the opposing element group produces a
|
| 88 |
+
gate that suppresses potentially erroneous activations. This is the
|
| 89 |
+
structural self-verification mechanism that reduces hallucination.
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, hidden_size: int, critic_hidden: int = 256, num_elements: int = 5):
|
| 93 |
+
super().__init__()
|
| 94 |
+
# One critic head per element pair
|
| 95 |
+
self.critics = nn.ModuleList([
|
| 96 |
+
nn.Sequential(
|
| 97 |
+
nn.Linear(hidden_size, critic_hidden, bias=False),
|
| 98 |
+
nn.SiLU(),
|
| 99 |
+
nn.Linear(critic_hidden, hidden_size, bias=False),
|
| 100 |
+
)
|
| 101 |
+
for _ in range(num_elements)
|
| 102 |
+
])
|
| 103 |
+
# Initialize to near-identity (gate โ 1.0 at start)
|
| 104 |
+
for critic in self.critics:
|
| 105 |
+
nn.init.zeros_(critic[-1].weight)
|
| 106 |
+
|
| 107 |
+
def forward(self, hidden: torch.Tensor, critic_source: Optional[torch.Tensor],
|
| 108 |
+
element_idx: int) -> torch.Tensor:
|
| 109 |
+
"""
|
| 110 |
+
Args:
|
| 111 |
+
hidden: Current hidden states [B, L, D]
|
| 112 |
+
critic_source: Opposing element's output [B, L, D] or None
|
| 113 |
+
element_idx: Index of current element
|
| 114 |
+
Returns:
|
| 115 |
+
Gated hidden states
|
| 116 |
+
"""
|
| 117 |
+
if critic_source is None:
|
| 118 |
+
return hidden
|
| 119 |
+
|
| 120 |
+
gate = torch.sigmoid(self.critics[element_idx](critic_source))
|
| 121 |
+
return hidden * gate
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class OhengRouter(nn.Module):
|
| 125 |
+
"""Top-K router with Loss-Free Balancing.
|
| 126 |
+
|
| 127 |
+
Routes tokens to experts within the current element group first,
|
| 128 |
+
then allows overflow to adjacent groups via generate connections.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def __init__(self, config):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.num_experts = config.num_experts
|
| 134 |
+
self.top_k = config.top_k
|
| 135 |
+
self.jitter_eps = config.moe_jitter_eps
|
| 136 |
+
|
| 137 |
+
# Router: hidden โ expert scores
|
| 138 |
+
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False)
|
| 139 |
+
|
| 140 |
+
# Loss-Free Balancing bias (DeepSeek-style, not trained by gradient)
|
| 141 |
+
self.register_buffer(
|
| 142 |
+
"expert_bias",
|
| 143 |
+
torch.zeros(config.num_experts),
|
| 144 |
+
persistent=True
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Running load tracker for bias update
|
| 148 |
+
self.register_buffer(
|
| 149 |
+
"expert_load_ema",
|
| 150 |
+
torch.ones(config.num_experts) / config.num_experts,
|
| 151 |
+
persistent=False
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 155 |
+
"""
|
| 156 |
+
Args:
|
| 157 |
+
hidden_states: [B*L, D]
|
| 158 |
+
Returns:
|
| 159 |
+
expert_indices: [B*L, top_k] โ indices of selected experts
|
| 160 |
+
expert_weights: [B*L, top_k] โ softmax weights (from unbiased scores)
|
| 161 |
+
router_logits: [B*L, num_experts] โ raw logits for auxiliary logging
|
| 162 |
+
"""
|
| 163 |
+
# Raw scores
|
| 164 |
+
logits = self.gate(hidden_states) # [B*L, E]
|
| 165 |
+
|
| 166 |
+
# Add jitter during training for exploration
|
| 167 |
+
if self.training and self.jitter_eps > 0:
|
| 168 |
+
noise = torch.empty_like(logits).uniform_(1.0 - self.jitter_eps, 1.0 + self.jitter_eps)
|
| 169 |
+
logits = logits * noise
|
| 170 |
+
|
| 171 |
+
# Biased scores for selection (Loss-Free Balancing)
|
| 172 |
+
biased_logits = logits + self.expert_bias.unsqueeze(0)
|
| 173 |
+
|
| 174 |
+
# Top-K selection on biased scores
|
| 175 |
+
topk_biased, indices = torch.topk(biased_logits, self.top_k, dim=-1)
|
| 176 |
+
|
| 177 |
+
# Weights from UNBIASED scores (clean gradients)
|
| 178 |
+
topk_logits = torch.gather(logits, 1, indices)
|
| 179 |
+
weights = F.softmax(topk_logits, dim=-1, dtype=torch.float32).to(hidden_states.dtype)
|
| 180 |
+
|
| 181 |
+
# Update bias (outside gradient, after each batch)
|
| 182 |
+
if self.training:
|
| 183 |
+
self._update_bias(indices)
|
| 184 |
+
|
| 185 |
+
return indices, weights, logits
|
| 186 |
+
|
| 187 |
+
@torch.no_grad()
|
| 188 |
+
def _update_bias(self, indices: torch.Tensor, momentum: float = 0.99, step: float = 0.001):
|
| 189 |
+
"""Update expert bias based on current batch load."""
|
| 190 |
+
flat = indices.view(-1)
|
| 191 |
+
counts = torch.bincount(flat, minlength=self.num_experts).float()
|
| 192 |
+
load = counts / max(counts.sum().item(), 1.0)
|
| 193 |
+
|
| 194 |
+
self.expert_load_ema.mul_(momentum).add_(load, alpha=1 - momentum)
|
| 195 |
+
|
| 196 |
+
# Increase bias for underloaded experts, decrease for overloaded
|
| 197 |
+
target = 1.0 / self.num_experts
|
| 198 |
+
self.expert_bias.add_((target - self.expert_load_ema) * step)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class OhengMoE(nn.Module):
|
| 202 |
+
"""Complete Oheng MoE layer with Generate, Overcome, and expert computation.
|
| 203 |
+
|
| 204 |
+
Architecture per layer:
|
| 205 |
+
1. Router selects top-K experts
|
| 206 |
+
2. Selected experts process tokens
|
| 207 |
+
3. Shared expert processes all tokens
|
| 208 |
+
4. Generate boost from previous element group
|
| 209 |
+
5. Overcome gate from opposing element group
|
| 210 |
+
6. Sum all outputs
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(self, config, layer_idx: int):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.layer_idx = layer_idx
|
| 216 |
+
self.element = LAYER_TO_ELEMENT[layer_idx]
|
| 217 |
+
self.element_idx = ELEMENTS.index(self.element)
|
| 218 |
+
self.hidden_size = config.hidden_size
|
| 219 |
+
self.top_k = config.top_k
|
| 220 |
+
|
| 221 |
+
# 25 routed experts
|
| 222 |
+
self.experts = nn.ModuleList([
|
| 223 |
+
Expert(config.hidden_size, config.expert_intermediate_size)
|
| 224 |
+
for _ in range(config.num_experts)
|
| 225 |
+
])
|
| 226 |
+
|
| 227 |
+
# Shared expert (always active)
|
| 228 |
+
self.shared_expert = SharedExpert(config.hidden_size, config.expert_intermediate_size)
|
| 229 |
+
|
| 230 |
+
# Router
|
| 231 |
+
self.router = OhengRouter(config)
|
| 232 |
+
|
| 233 |
+
# Generate boost (์์)
|
| 234 |
+
if config.use_generate_boost:
|
| 235 |
+
self.generate_boost = GenerateBoost(config.hidden_size)
|
| 236 |
+
else:
|
| 237 |
+
self.generate_boost = None
|
| 238 |
+
|
| 239 |
+
# Overcome gate (์๊ทน)
|
| 240 |
+
if config.use_overcome_gate:
|
| 241 |
+
self.overcome_gate = OvercomeGate(config.hidden_size, config.overcome_gate_hidden)
|
| 242 |
+
else:
|
| 243 |
+
self.overcome_gate = None
|
| 244 |
+
|
| 245 |
+
def forward(self, hidden_states: torch.Tensor,
|
| 246 |
+
element_states: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
|
| 247 |
+
"""
|
| 248 |
+
Args:
|
| 249 |
+
hidden_states: [B, L, D]
|
| 250 |
+
element_states: dict mapping element names to their latest output
|
| 251 |
+
Returns:
|
| 252 |
+
output: [B, L, D]
|
| 253 |
+
"""
|
| 254 |
+
B, L, D = hidden_states.shape
|
| 255 |
+
flat = hidden_states.view(-1, D) # [B*L, D]
|
| 256 |
+
|
| 257 |
+
# Route
|
| 258 |
+
indices, weights, _ = self.router(flat) # [B*L, K], [B*L, K]
|
| 259 |
+
|
| 260 |
+
# Expert computation
|
| 261 |
+
expert_out = torch.zeros_like(flat)
|
| 262 |
+
for k in range(self.top_k):
|
| 263 |
+
expert_idx = indices[:, k] # [B*L]
|
| 264 |
+
expert_w = weights[:, k].unsqueeze(-1) # [B*L, 1]
|
| 265 |
+
|
| 266 |
+
for e_id in range(len(self.experts)):
|
| 267 |
+
mask = (expert_idx == e_id)
|
| 268 |
+
if mask.any():
|
| 269 |
+
token_input = flat[mask]
|
| 270 |
+
token_output = self.experts[e_id](token_input)
|
| 271 |
+
expert_out[mask] += expert_w[mask] * token_output
|
| 272 |
+
|
| 273 |
+
# Shared expert (always active)
|
| 274 |
+
shared_out = self.shared_expert(flat)
|
| 275 |
+
|
| 276 |
+
output = (expert_out + shared_out).view(B, L, D)
|
| 277 |
+
|
| 278 |
+
# Apply Oheng connections if element states available
|
| 279 |
+
if element_states is not None:
|
| 280 |
+
# ์์ Generate boost
|
| 281 |
+
if self.generate_boost is not None:
|
| 282 |
+
gen_source_elem = GENERATE_REVERSE.get(self.element)
|
| 283 |
+
gen_source = element_states.get(gen_source_elem)
|
| 284 |
+
output = self.generate_boost(output, gen_source, self.element_idx)
|
| 285 |
+
|
| 286 |
+
# ์๊ทน Overcome gate
|
| 287 |
+
if self.overcome_gate is not None:
|
| 288 |
+
overcome_source_elem = OVERCOME_REVERSE.get(self.element)
|
| 289 |
+
overcome_source = element_states.get(overcome_source_elem)
|
| 290 |
+
output = self.overcome_gate(output, overcome_source, self.element_idx)
|
| 291 |
+
|
| 292 |
+
return output
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.4.0
|
| 2 |
+
safetensors>=0.4.0
|
| 3 |
+
gradio>=5.0.0
|
| 4 |
+
transformers>=4.45.0
|
| 5 |
+
huggingface-hub>=0.25.0
|