Upload 13 files
Browse files- 158b_train_sample/__init__.py +0 -0
- 158b_train_sample/drna_restore_train.py +281 -0
- 158b_train_sample/drna_swi_mount.py +447 -0
- drna/__init__.py +0 -0
- drna/drna.py +179 -0
- drna/drna_moe.py +226 -0
- drna/drna_restore.py +137 -0
- drna/drna_swi_restore.py +142 -0
- drna/drna_swi_triox.py +323 -0
- drna/drna_swiglu.py +192 -0
- drna/drna_triox.py +310 -0
- drna/drna_triox_log.txt +433 -0
- drna/drna_vlayer.py +173 -0
158b_train_sample/__init__.py
ADDED
|
File without changes
|
158b_train_sample/drna_restore_train.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import sys
|
| 4 |
+
import signal
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.utils.checkpoint as checkpoint
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from datasets import load_dataset
|
| 11 |
+
from safetensors.torch import save_file, load_file
|
| 12 |
+
from torch.utils.data import IterableDataset, DataLoader
|
| 13 |
+
|
| 14 |
+
'''
|
| 15 |
+
こちらは学習コードです、これはモデルコードから機能等を参照して読みだして実行します
|
| 16 |
+
次元などの変更もこの学習コード側で行います、モデルコード側は基準として触らずに保持します
|
| 17 |
+
---
|
| 18 |
+
utf16をトークナイザ代替にする
|
| 19 |
+
世界中のさまざまな言語、絵文字、特殊記号、ソースコードのインデントまで、あらゆる文字を100%表現できる
|
| 20 |
+
(事前に決めた3万〜10万語の「辞書」にない言葉で[UNK]を生じることがなくなる)
|
| 21 |
+
1つのアーキテクチャで全メディアを等価に処理できる究極のマルチモーダルが理論上可能になる
|
| 22 |
+
Vocab Sizeを「65,536」(16bit境界)にジャストフィットさせることによるVRAM効率と計算効率を最大化します
|
| 23 |
+
ハードウェア的な特性(メモリのビット幅、アライメント、並列計算の仕組み)に合致しオーバーヘッドを生じません
|
| 24 |
+
(トークナイザやVAEなどの外付けはパディング処理などで空白を埋めるようなムダを生じる)
|
| 25 |
+
スマートフォンやエッジデバイスでも標準的なutf16なら確実に動作可能です
|
| 26 |
+
欠点は以下のみ、単純に少し長く学習するだけで解消します
|
| 27 |
+
コンテキスト長(トークン効率)の悪化、意味的抽象化(セマンティクス)をゼロから自力で学習しなければならない
|
| 28 |
+
'''
|
| 29 |
+
|
| 30 |
+
# モデル定義ファイルから必要なクラスをインポート
|
| 31 |
+
from drna_swi_mount import (
|
| 32 |
+
DRNA_Model,
|
| 33 |
+
DRNA_Block,
|
| 34 |
+
TernaryTrainingManager,
|
| 35 |
+
get_ternary_schedule
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# グローバル変数:Ctrl+C ハンドラから安全にアクセスするため
|
| 39 |
+
training_interrupted = False
|
| 40 |
+
current_step_global = 0
|
| 41 |
+
|
| 42 |
+
def sigint_handler(signum, frame):
|
| 43 |
+
'''Ctrl+C (SIGINT) をエレガントにキャッチするハンドラ'''
|
| 44 |
+
global training_interrupted
|
| 45 |
+
print("\n\n[!!] Ctrl+C (SIGINT) を検知しました。現在のステップで安全に緊急保存処理へ移行します...")
|
| 46 |
+
training_interrupted = True
|
| 47 |
+
|
| 48 |
+
# SIGINT ハンドラを登録
|
| 49 |
+
signal.signal(signal.SIGINT, sigint_handler)
|
| 50 |
+
|
| 51 |
+
# UTF16 トークナイザ
|
| 52 |
+
def encode_utf16(text: str, seq_len: int) -> torch.Tensor:
|
| 53 |
+
tokens = [ord(char) for char in text if ord(char) < 65536]
|
| 54 |
+
if len(tokens) < seq_len + 1:
|
| 55 |
+
tokens = tokens + [0] * (seq_len + 1 - len(tokens))
|
| 56 |
+
else:
|
| 57 |
+
tokens = tokens[:seq_len + 1]
|
| 58 |
+
return torch.tensor(tokens, dtype=torch.long)
|
| 59 |
+
|
| 60 |
+
# Webテキスト逐次読み込み(ストリーミング)
|
| 61 |
+
class StreamingWebTextDataset(IterableDataset):
|
| 62 |
+
def __init__(self, seq_len: int):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.seq_len = seq_len
|
| 65 |
+
print(">>> OpenWebText からデータをストリーミング接続中... (初回は数分かかる場合があります)")
|
| 66 |
+
self.dataset = load_dataset("openwebtext", split="train", streaming=True, trust_remote_code=True)
|
| 67 |
+
|
| 68 |
+
def __iter__(self):
|
| 69 |
+
buffer = ""
|
| 70 |
+
for item in self.dataset:
|
| 71 |
+
text = item["text"].strip()
|
| 72 |
+
if not text:
|
| 73 |
+
continue
|
| 74 |
+
buffer += text + " "
|
| 75 |
+
while len(buffer) >= self.seq_len + 1:
|
| 76 |
+
chunk = buffer[:self.seq_len + 1]
|
| 77 |
+
buffer = buffer[self.seq_len:]
|
| 78 |
+
token_tensor = encode_utf16(chunk, self.seq_len)
|
| 79 |
+
yield token_tensor[:-1], token_tensor[1:]
|
| 80 |
+
|
| 81 |
+
# 💎 Gradient Checkpointing マウント関数
|
| 82 |
+
def apply_gradient_checkpointing(model: nn.Module):
|
| 83 |
+
'''
|
| 84 |
+
フック構造確定後に呼び出し、DRNA_Block全体をGCで包む
|
| 85 |
+
これにより、再計算時にもフック(3値ブレンド)が正常に働き、w2の勾配消失を防ぐ
|
| 86 |
+
'''
|
| 87 |
+
print(">>> [成功] Gradient Checkpointing を各 DRNA_Block (レイヤー全体) に適用中...")
|
| 88 |
+
|
| 89 |
+
def make_checkpoint_forward(block_module):
|
| 90 |
+
original_forward = block_module.forward
|
| 91 |
+
|
| 92 |
+
def checkpoint_forward(*args, **kwargs):
|
| 93 |
+
x, cos, sin = args[0], args[1], args[2]
|
| 94 |
+
mask = kwargs.get('mask', None)
|
| 95 |
+
# GCでスキップされてしまう、そのフックを直接呼び出す
|
| 96 |
+
for name, module in block_module.named_modules():
|
| 97 |
+
if isinstance(module, nn.Linear):
|
| 98 |
+
# モジュールに登録されている forward_pre_hook を探して直接実行する
|
| 99 |
+
for hook in module._forward_pre_hooks.values():
|
| 100 |
+
# フックを偽装実行(module と inputs を渡せば、内部で正しく3値化���れる)
|
| 101 |
+
hook(module, (x,))
|
| 102 |
+
return checkpoint.checkpoint(
|
| 103 |
+
original_forward, x, cos, sin, mask,
|
| 104 |
+
use_reentrant=False
|
| 105 |
+
)
|
| 106 |
+
return checkpoint_forward
|
| 107 |
+
|
| 108 |
+
for name, module in model.named_modules():
|
| 109 |
+
if isinstance(module, DRNA_Block):
|
| 110 |
+
module.forward = make_checkpoint_forward(module)
|
| 111 |
+
print(f" -> {name} (DRNA_Block) の計算空間を GC で保護しました。")
|
| 112 |
+
|
| 113 |
+
# 対話型セットアップ関数
|
| 114 |
+
def select_model_setup(vocab_size, d_model, n_layers, n_heads, d_ff):
|
| 115 |
+
print(" D-RNA Trio 3値学習 セットアップモードの選択")
|
| 116 |
+
print("1: 新規モデルを初期化して作成し学習開始")
|
| 117 |
+
print("2: 既存の通常モデルや3値モデル、途中保存をマウントして学習開始")
|
| 118 |
+
choice = input("選択してください (1 or 2): ").strip()
|
| 119 |
+
|
| 120 |
+
model = DRNA_Model(
|
| 121 |
+
vocab_size=vocab_size, d_model=d_model,
|
| 122 |
+
n_layers=n_layers, n_heads=n_heads, d_ff=d_ff
|
| 123 |
+
).cuda()
|
| 124 |
+
|
| 125 |
+
if choice == "2":
|
| 126 |
+
checkpoint_path = input("読み込む重みファイルのパスを入力してください ").strip().strip("'\"")
|
| 127 |
+
if not os.path.exists(checkpoint_path):
|
| 128 |
+
print(f"エラー: パスが見つかりません。新規作成します")
|
| 129 |
+
return model
|
| 130 |
+
|
| 131 |
+
print("\n>>> [理論通りに復元開始] 通常/3値/途中保存の重みから連続勾配空間を復元中...")
|
| 132 |
+
_ = TernaryTrainingManager(model, warmup_steps=10, max_lambda=1.0)
|
| 133 |
+
checkpoint = load_file(checkpoint_path, device='cuda')
|
| 134 |
+
|
| 135 |
+
with torch.no_grad():
|
| 136 |
+
for name, module in model.named_modules():
|
| 137 |
+
if isinstance(module, nn.Linear):
|
| 138 |
+
if "embed" in name or "output_head" in name:
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
# 統一された単一の「.weight」から実数聖域「raw_weight」へバトンを戻す
|
| 142 |
+
target_key = f"{name}.weight" if f"{name}.weight" in checkpoint else f"{name}.raw_weight"
|
| 143 |
+
if target_key in checkpoint:
|
| 144 |
+
module.raw_weight.copy_(checkpoint[target_key].cuda())
|
| 145 |
+
print(">>> マウントおよび連続勾配空間の再展開が完了しました \n")
|
| 146 |
+
model.is_already_mounted = True
|
| 147 |
+
else:
|
| 148 |
+
model.is_already_mounted = False
|
| 149 |
+
|
| 150 |
+
return model
|
| 151 |
+
|
| 152 |
+
def select_save_precision():
|
| 153 |
+
print(" エクスポート精度の選択")
|
| 154 |
+
print("1: bfloat16 (推奨)")
|
| 155 |
+
print("2: float16")
|
| 156 |
+
print("3: float32")
|
| 157 |
+
choice = input("選択してください (1, 2, 3): ").strip()
|
| 158 |
+
if choice == "2": return torch.float16, "fp16"
|
| 159 |
+
if choice == "3": return torch.float32, "fp32"
|
| 160 |
+
return torch.bfloat16, "bf16"
|
| 161 |
+
|
| 162 |
+
# メイン学習ループ
|
| 163 |
+
def main():
|
| 164 |
+
global training_interrupted, current_step_global
|
| 165 |
+
'''次元などの変更はこちらで行います'''
|
| 166 |
+
# パラメータ設定
|
| 167 |
+
vocab_size = 65536 # UTF-16 全域 (BMP)
|
| 168 |
+
d_model = 256
|
| 169 |
+
n_layers = 16 # GCにより、12GB VRAM環境でも16層でも軽快に回ります
|
| 170 |
+
n_heads = 8
|
| 171 |
+
d_ff = 1024 # (d_model * 4)
|
| 172 |
+
|
| 173 |
+
seq_len = 256
|
| 174 |
+
batch_size = 8
|
| 175 |
+
max_train_steps = 3001
|
| 176 |
+
|
| 177 |
+
# モデル初期化 / ロード
|
| 178 |
+
model = select_model_setup(vocab_size, d_model, n_layers, n_heads, d_ff)
|
| 179 |
+
|
| 180 |
+
# 3値誘導マネージャー外付け
|
| 181 |
+
warmup_steps = 300
|
| 182 |
+
max_lambda = 1.0
|
| 183 |
+
manager = TernaryTrainingManager(model, warmup_steps=warmup_steps, max_lambda=max_lambda)
|
| 184 |
+
|
| 185 |
+
# 🎯フック構造確定の「後」にGCをマウント
|
| 186 |
+
apply_gradient_checkpointing(model)
|
| 187 |
+
|
| 188 |
+
print(">>> 英語Webテキストをストリーミング中...")
|
| 189 |
+
dataset = StreamingWebTextDataset(seq_len=seq_len)
|
| 190 |
+
dataloader = DataLoader(dataset, batch_size=batch_size)
|
| 191 |
+
|
| 192 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
|
| 193 |
+
criterion = nn.CrossEntropyLoss()
|
| 194 |
+
|
| 195 |
+
print(f">>> 学習を開始します (最大 {max_train_steps} ステップ) ...")
|
| 196 |
+
print("※ 学習途中で安全に終了して実数保存したい場合は [Ctrl+C] を押してください")
|
| 197 |
+
model.train()
|
| 198 |
+
|
| 199 |
+
current_step = 0
|
| 200 |
+
save_mode = "crystallized"
|
| 201 |
+
timestamp = datetime.now().strftime("%y%m%d_%H%M")
|
| 202 |
+
|
| 203 |
+
for inputs, targets in dataloader:
|
| 204 |
+
if current_step >= max_train_steps or training_interrupted:
|
| 205 |
+
if training_interrupted:
|
| 206 |
+
save_mode = "emergency"
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
current_step_global = current_step
|
| 210 |
+
inputs = inputs.cuda()
|
| 211 |
+
targets = targets.cuda()
|
| 212 |
+
|
| 213 |
+
optimizer.zero_grad()
|
| 214 |
+
|
| 215 |
+
# 順伝播
|
| 216 |
+
outputs = model(inputs, pad_id=0)
|
| 217 |
+
|
| 218 |
+
# 損失計算
|
| 219 |
+
task_loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))
|
| 220 |
+
loss = manager.amend_loss(task_loss, step=current_step, total_steps=max_train_steps)
|
| 221 |
+
|
| 222 |
+
loss.backward()
|
| 223 |
+
optimizer.step()
|
| 224 |
+
|
| 225 |
+
if current_step % 100 == 0:
|
| 226 |
+
blend_ratio = get_ternary_schedule(current_step, max_train_steps, warmup_steps)
|
| 227 |
+
|
| 228 |
+
print(f"Step {current_step:3d}/{max_train_steps} | "
|
| 229 |
+
f"CE-Loss: {task_loss.item():.4f} | "
|
| 230 |
+
f"Trio Blend: {blend_ratio * 100:6.2f}%")
|
| 231 |
+
|
| 232 |
+
current_step += 1
|
| 233 |
+
|
| 234 |
+
# 💾 修正箇所: 2重保存バグを根絶するスマートなキー管理保存
|
| 235 |
+
# まずベースとなる全パラメータの辞書を取得(※この時点では PyTorch の仕様で両方入っている)
|
| 236 |
+
state_dict = model.state_dict()
|
| 237 |
+
output_state_dict = {}
|
| 238 |
+
|
| 239 |
+
if save_mode == "crystallized":
|
| 240 |
+
print("\n>>> [通常終了] 指定ステップに到達したため、3値重みの結晶化(export_ternary)を実行中...")
|
| 241 |
+
# モデル側のコアロジックを呼び出し、raw_weight と weight を完全に同じ3値で同期化
|
| 242 |
+
crystallized_model = manager.export_ternary()
|
| 243 |
+
state_dict = crystallized_model.state_dict()
|
| 244 |
+
|
| 245 |
+
# 保存フェーズ: 中身は同一なので、余計な '.raw_weight' キーだけを完全に排除
|
| 246 |
+
for k, v in state_dict.items():
|
| 247 |
+
if ".raw_weight" in k:
|
| 248 |
+
continue # 重複保存しない
|
| 249 |
+
output_state_dict[k] = v
|
| 250 |
+
|
| 251 |
+
output_filename = f"drna_pure_158_{timestamp}.safetensors"
|
| 252 |
+
else:
|
| 253 |
+
print(f"\n>>> [緊急停止] Step {current_step_global} で中断されました。結晶化はせず、実数勾配空間を抽出中...")
|
| 254 |
+
|
| 255 |
+
# 保存フェーズ: フックの残骸 '.weight' を捨て、実数本体 '.raw_weight' の名前を '.weight' にリネームして詰め替える
|
| 256 |
+
for k, v in state_dict.items():
|
| 257 |
+
if ".weight" in k and not any(x in k for x in ["embed", "output_head", "final_norm"]):
|
| 258 |
+
# 3値化対象レイヤーの通常の '.weight' キー(バッファの残骸)は無視
|
| 259 |
+
continue
|
| 260 |
+
if ".raw_weight" in k:
|
| 261 |
+
# 実数の本体である '.raw_weight' を通常の '.weight' という名前にリネームして昇格
|
| 262 |
+
new_key = k.replace(".raw_weight", ".weight")
|
| 263 |
+
output_state_dict[new_key] = v
|
| 264 |
+
else:
|
| 265 |
+
# それ以外の全レイヤー(Embedding、Norm等)はそのまま保持
|
| 266 |
+
output_state_dict[k] = v
|
| 267 |
+
|
| 268 |
+
output_filename = f"drna_int_step{current_step_global}_{timestamp}_realw.safetensors"
|
| 269 |
+
|
| 270 |
+
# ユーザー指定の精度に一括キャストして safetensors でクリーンに書き出し
|
| 271 |
+
target_dtype, dtype_str = select_save_precision()
|
| 272 |
+
final_state_dict = {k: v.to(target_dtype).cpu() for k, v in output_state_dict.items()}
|
| 273 |
+
|
| 274 |
+
output_path = os.path.join(os.path.dirname(__file__ if __file__ else "."), f"{dtype_str}_{output_filename}")
|
| 275 |
+
|
| 276 |
+
save_file(final_state_dict, output_path)
|
| 277 |
+
print(f" ==> [完了] 2重保存を完全に排除したクリーンな1倍サイズモデルを保存しました: {output_path}")
|
| 278 |
+
print("プロセスを正常に終了します")
|
| 279 |
+
|
| 280 |
+
if __name__ == "__main__":
|
| 281 |
+
main()
|
158b_train_sample/drna_swi_mount.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
'''
|
| 8 |
+
このコードは純粋3値学習を行うためのモデルコード(定義)です、通常はこちらを参照し呼び出すだけでOKです
|
| 9 |
+
既存の重みをもつモデルを読み込んで3値へ転換も可能です(モデルの次元などは学習コード側で適切に合わせてください)
|
| 10 |
+
---
|
| 11 |
+
# チェックポイント読み込み(D-RNA-Trio の学習済みモデル) / fp8/16/32 相当の連続重みの復元
|
| 12 |
+
2〜3 レイヤ:fp8相当、4〜6 レイヤ:bf16相当、7+ レイヤ:fp16/32 相当
|
| 13 |
+
固有位相(K-energy × RoPE 周波数成分)を重ね合わせ+正規化(二重らせんの共鳴収縮により歪みを抑制)
|
| 14 |
+
特に K-for-phase の「逐次カスケード的変化」が位相シフトとなり自然な凸凹パターンを生成
|
| 15 |
+
---
|
| 16 |
+
D‑RNA: Dual‑Helix Resonance Neural Architecture (DRNA) Pre-Norm・Kv-RoPE 版
|
| 17 |
+
仕様:Pre-Norm(RMSNorm)、GELU(Activation)、Kv-RoPE(head_dim)、mask(padding + causal)
|
| 18 |
+
Transformerの全接続性を継承しつつ、二重らせん(Dual-Helix)構造による
|
| 19 |
+
「共鳴収縮」(Resonant Contraction)を物理的に再現したニューラルアーキテクチャです
|
| 20 |
+
D-RNA の位相設計と Trio Induction system により3値学習を STE に頼らず安定的に行えます
|
| 21 |
+
---
|
| 22 |
+
これは STE で機能しない optimiser などを3値学習へ活用できるようになります
|
| 23 |
+
将来的に3値モデルを位相差の重ねによる疑似重みをつくり、これを学習対象にして3値学習もおこなえるはずです
|
| 24 |
+
つまり学習元も3値モデルにできるはずです、推論も学習も3値で済むようになる最初の1歩です
|
| 25 |
+
'''
|
| 26 |
+
|
| 27 |
+
#---restore mount START---
|
| 28 |
+
class DRNAWeightRestorer:
|
| 29 |
+
'''
|
| 30 |
+
D-RNA の固有位相を利用した fp8/16/32 重み復元器
|
| 31 |
+
チェックポイント読み込み → 位相抽出 → 重ね合わせ → 正規化 で連続分布を生成
|
| 32 |
+
'''
|
| 33 |
+
|
| 34 |
+
def __init__(self, d_model: int = 256, num_layers: int = 16):
|
| 35 |
+
self.d_model = d_model
|
| 36 |
+
self.num_layers = num_layers
|
| 37 |
+
|
| 38 |
+
# チェックポイントから -/0/+ 重みを読み出す
|
| 39 |
+
def load_ternary_weights(self, checkpoint_path: str) -> List[torch.Tensor]:
|
| 40 |
+
'''ロードされたチェックポイントから各レイヤの重みを抽出する'''
|
| 41 |
+
checkpoint = torch.load(checkpoint_path, map_location='cuda')
|
| 42 |
+
|
| 43 |
+
layer_weights = []
|
| 44 |
+
for i in range(self.num_layers):
|
| 45 |
+
# QKV レイヤー(norm1 ブランチ)
|
| 46 |
+
qkv_weight = checkpoint[f'layers.{i}.qkv.weight']
|
| 47 |
+
|
| 48 |
+
# MLP 入力/出力レイヤー(norm2 ブランチ)
|
| 49 |
+
mlp_0_weight = checkpoint[f'layers.{i}.mlp.0.weight']
|
| 50 |
+
mlp_3_weight = checkpoint[f'layers.{i}.mlp.3.weight']
|
| 51 |
+
|
| 52 |
+
layer_weights.extend([qkv_weight, mlp_0_weight, mlp_3_weight])
|
| 53 |
+
|
| 54 |
+
return layer_weights
|
| 55 |
+
|
| 56 |
+
# K-for-phase のカスケード効果から各レイヤの固有位相を抽出する
|
| 57 |
+
def extract_phase_offsets(self, max_seq_len: int = 256) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
| 58 |
+
'''DRNA の固有位相(K-energy × RoPE 周波数成分)を計算し返す'''
|
| 59 |
+
|
| 60 |
+
num_layer_groups = self.num_layers * 3
|
| 61 |
+
|
| 62 |
+
phase_data_list = []
|
| 63 |
+
base_energy = torch.arange(128, dtype=torch.float32, device='cuda') / self.d_model
|
| 64 |
+
|
| 65 |
+
for i in range(num_layer_groups):
|
| 66 |
+
# K-energy のカスケード効果:レイヤごとに滑らかにシフト(K-for-phase 同様に)
|
| 67 |
+
k_energy = base_energy * (i % self.num_layers + 1)
|
| 68 |
+
|
| 69 |
+
# DRNA の動的位相変調: tanh(k_energy) × π
|
| 70 |
+
rt_phase = math.tanh(k_energy) * math.pi
|
| 71 |
+
|
| 72 |
+
# RoPE 的な周波数成分の抽出(各レイヤ固有の周波数パターン)
|
| 73 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, 128, 2).float() / self.d_model))
|
| 74 |
+
|
| 75 |
+
freqs = torch.einsum("i,j->ij",
|
| 76 |
+
torch.arange(max_seq_len, device='cuda'),
|
| 77 |
+
inv_freq)
|
| 78 |
+
|
| 79 |
+
# cos / sin の位相シフトを生成
|
| 80 |
+
cos_shift = torch.cos(freqs * rt_phase.unsqueeze(-1))
|
| 81 |
+
sin_shift = torch.sin(freqs * rt_phase.unsqueeze(-1))
|
| 82 |
+
|
| 83 |
+
phase_data_list.append((cos_shift, sin_shift))
|
| 84 |
+
|
| 85 |
+
return phase_data_list
|
| 86 |
+
|
| 87 |
+
# 固有位相を用いて重ね合わせ(RoPE 的な位相変調を適用)
|
| 88 |
+
def superimpose_weights(
|
| 89 |
+
self,
|
| 90 |
+
layer_weights: List[torch.Tensor],
|
| 91 |
+
phase_offsets: List[Tuple[torch.Tensor, torch.Tensor]]
|
| 92 |
+
) -> torch.Tensor:
|
| 93 |
+
'''各レイヤの -/0/+ 重みに固有位相を適用し重ね合わせる'''
|
| 94 |
+
|
| 95 |
+
reconstructed = torch.zeros_like(layer_weights[0])
|
| 96 |
+
|
| 97 |
+
for i, w_i in enumerate(layer_weights):
|
| 98 |
+
cos_shift, sin_shift = phase_offsets[i]
|
| 99 |
+
|
| 100 |
+
# DRNA の K-for-phase 同様の位相変調: d_cos / d_sin を用いる重ね合わせ
|
| 101 |
+
contribution = w_i * (cos_shift - rotate_half(w_i) * sin_shift)
|
| 102 |
+
reconstructed += contribution
|
| 103 |
+
|
| 104 |
+
return reconstructed
|
| 105 |
+
|
| 106 |
+
# RMSNorm 的な正規化(3 値歪みのリセット)
|
| 107 |
+
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 108 |
+
'''重ね合わせ後の分布を RMSNorm で安定化する'''
|
| 109 |
+
|
| 110 |
+
mean = x.mean(-1, keepdim=True)
|
| 111 |
+
var = x.var(-1, keepdim=True, unbiased=False)
|
| 112 |
+
|
| 113 |
+
# 中心化・標準化(3 値歪みのリセット)
|
| 114 |
+
return (x - mean) * torch.rsqrt(var + 1e-8)
|
| 115 |
+
|
| 116 |
+
# 完全復元パイプライン
|
| 117 |
+
def restore_from_checkpoint(
|
| 118 |
+
self,
|
| 119 |
+
checkpoint_path: str,
|
| 120 |
+
max_seq_len: int = 256
|
| 121 |
+
) -> torch.Tensor:
|
| 122 |
+
'''チェックポイントを読み込み、fp8/16 相当の連続重みを生成する'''
|
| 123 |
+
|
| 124 |
+
# -1.0/0.0/1.0 重みの読み出し
|
| 125 |
+
layer_weights = self.load_ternary_weights(checkpoint_path)
|
| 126 |
+
|
| 127 |
+
# 固有位相(K-energy × RoPE)の抽出
|
| 128 |
+
phase_offsets = self.extract_phase_offsets(max_seq_len)
|
| 129 |
+
|
| 130 |
+
# 重ね合わせ(RoPE 的な位相変調を適用)
|
| 131 |
+
continuous_weight = self.superimpose_weights(
|
| 132 |
+
layer_weights,
|
| 133 |
+
phase_offsets)
|
| 134 |
+
|
| 135 |
+
# RMSNorm 的な正規化
|
| 136 |
+
normalized_weight = self.normalize(continuous_weight)
|
| 137 |
+
|
| 138 |
+
return normalized_weight
|
| 139 |
+
|
| 140 |
+
# 補助関数: DRNA rotate_half 再現
|
| 141 |
+
def rotate_half(x):
|
| 142 |
+
'''D-RNA の K-for-phase 同様の回転操作(位相変調の右辺項)'''
|
| 143 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 144 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 145 |
+
#---restore mount END---
|
| 146 |
+
|
| 147 |
+
#---drna-swi-triox START---
|
| 148 |
+
# 3値誘導制御コア(モデルの書き換え、3値ブレンド、ペナルティ計算、結晶化)
|
| 149 |
+
class TernaryTrainingManager:
|
| 150 |
+
'''
|
| 151 |
+
D-RNAのコードに触れることなく、3値誘導の全ライフサイクルを統括する抽象化マネージャー
|
| 152 |
+
'''
|
| 153 |
+
def __init__(self, model, warmup_steps=100, max_lambda=1.0):
|
| 154 |
+
self.model = model
|
| 155 |
+
self.warmup_steps = warmup_steps
|
| 156 |
+
self.max_lambda = max_lambda
|
| 157 |
+
self.current_step = 0
|
| 158 |
+
self.total_steps = 0
|
| 159 |
+
|
| 160 |
+
# 内部で利用するステッププロバイダー関数
|
| 161 |
+
def step_provider():
|
| 162 |
+
return self.current_step, self.total_steps
|
| 163 |
+
|
| 164 |
+
# モデル内の全2次元重みにフックを外付け
|
| 165 |
+
for name, module in self.model.named_modules():
|
| 166 |
+
if isinstance(module, nn.Linear):
|
| 167 |
+
# ★ 拡張性:LLMの生命線である「embed」と「output_head」は絶対に3値化しない
|
| 168 |
+
if "embed" in name or "output_head" in name:
|
| 169 |
+
continue
|
| 170 |
+
if not hasattr(module, "raw_weight"):
|
| 171 |
+
# 元の実数重みを聖域(raw_weight)に退避
|
| 172 |
+
module.register_parameter("raw_weight", nn.Parameter(module.weight.data.clone()))
|
| 173 |
+
delattr(module, "weight")
|
| 174 |
+
module.register_buffer("weight", module.raw_weight.data.clone())
|
| 175 |
+
module.register_forward_pre_hook(TernaryWeightHook(step_provider, self.warmup_steps))
|
| 176 |
+
|
| 177 |
+
def amend_loss(self, task_loss, step, total_steps):
|
| 178 |
+
'''
|
| 179 |
+
【ループ内抽象化用】 メインのタスク損失(CrossEntropy等)を受け取り、
|
| 180 |
+
現在のステップに応じた3値結晶化ペナルティを自動計算して合算した損失を返す
|
| 181 |
+
'''
|
| 182 |
+
self.current_step = step
|
| 183 |
+
self.total_steps = total_steps
|
| 184 |
+
|
| 185 |
+
blend_ratio = get_ternary_schedule(step, total_steps, self.warmup_steps)
|
| 186 |
+
current_lambda = blend_ratio * self.max_lambda
|
| 187 |
+
|
| 188 |
+
if current_lambda == 0.0:
|
| 189 |
+
return task_loss
|
| 190 |
+
|
| 191 |
+
# 3値誘導トリプレット・ペナルティ自動計算
|
| 192 |
+
ternary_penalty = 0.0
|
| 193 |
+
for name, param in self.model.named_parameters():
|
| 194 |
+
# 2次元以上の重み(LinearやEmbeddingの実数聖域)のみを対象とする
|
| 195 |
+
if "raw_weight" in name and param.dim() >= 2:
|
| 196 |
+
# W * (W - 1) * (W + 1) = W^3 - W を0に近づける(-1, 0, 1への収束強制力)
|
| 197 |
+
ternary_penalty += torch.mean(param * (param - 1.0) * (param + 1.0)) ** 2
|
| 198 |
+
|
| 199 |
+
return task_loss + current_lambda * ternary_penalty
|
| 200 |
+
|
| 201 |
+
def export_ternary(self):
|
| 202 |
+
'''学習終了後、モデルの全2次元重みを完全な[-1.0, 0.0, 1.0]へ固定(結晶化)する'''
|
| 203 |
+
with torch.no_grad():
|
| 204 |
+
for name, module in self.model.named_modules():
|
| 205 |
+
if isinstance(module, nn.Linear):
|
| 206 |
+
# 拡張性:LLMの生命線である「embed」と「output_head」は絶対に3値化しない
|
| 207 |
+
if "embed" in name or "output_head" in name:
|
| 208 |
+
continue
|
| 209 |
+
# raw_weight があればそっちを本体として使う
|
| 210 |
+
if hasattr(module, "raw_weight"):
|
| 211 |
+
param = module.raw_weight
|
| 212 |
+
else:
|
| 213 |
+
param = module.weight
|
| 214 |
+
# 学習中と同じ写像:tanh(3w) で soft 3値ターゲットを作る
|
| 215 |
+
soft = torch.tanh(param * 3.0)
|
| 216 |
+
# soft を hard 3値に潰す
|
| 217 |
+
ternary = torch.zeros_like(soft)
|
| 218 |
+
ternary[soft > 0.08] = 1.0
|
| 219 |
+
ternary[soft < -0.08] = -1.0
|
| 220 |
+
# 実数パラメータ自体を3値に上書きして完全固定
|
| 221 |
+
if hasattr(module, "raw_weight"):
|
| 222 |
+
module.raw_weight.copy_(ternary)
|
| 223 |
+
module.weight.copy_(ternary)
|
| 224 |
+
|
| 225 |
+
return self.model
|
| 226 |
+
|
| 227 |
+
# 3値誘導外付けフックシステム(2次元重みを3値ブレンド)
|
| 228 |
+
def get_ternary_schedule(step, total_steps, warmup_steps=100):
|
| 229 |
+
'''逆転コサインアニーリングスケジューラ'''
|
| 230 |
+
if step < warmup_steps:
|
| 231 |
+
return 0.0
|
| 232 |
+
anneal_steps = total_steps - warmup_steps
|
| 233 |
+
progress = (step - warmup_steps) / anneal_steps
|
| 234 |
+
return 1.0 - (0.5 * (1.0 + math.cos(progress * math.pi)))
|
| 235 |
+
|
| 236 |
+
def get_soft_ternary_weight(param, step, total_steps, warmup_steps=100):
|
| 237 |
+
'''勾配直通バイパス型 3値ブレンド関数'''
|
| 238 |
+
if param.dim() < 2: # 1次元パラメータは保護
|
| 239 |
+
return param
|
| 240 |
+
blend_ratio = get_ternary_schedule(step, total_steps, warmup_steps)
|
| 241 |
+
if blend_ratio == 0.0:
|
| 242 |
+
return param
|
| 243 |
+
with torch.no_grad():
|
| 244 |
+
ternary_target = torch.tanh(param * 3.0)
|
| 245 |
+
# 【最重要】生の勾配を裏に直通させるバイパス構造
|
| 246 |
+
return param + blend_ratio * (ternary_target - param)
|
| 247 |
+
|
| 248 |
+
class TernaryWeightHook:
|
| 249 |
+
def __init__(self, step_provider, warmup_steps=100):
|
| 250 |
+
self.step_provider = step_provider
|
| 251 |
+
self.warmup_steps = warmup_steps
|
| 252 |
+
|
| 253 |
+
def __call__(self, module, inputs):
|
| 254 |
+
step, total_steps = self.step_provider()
|
| 255 |
+
if step is not None and total_steps is not None:
|
| 256 |
+
# バックアップした実数重み(raw_weight)から疑似3値重みを計算し一時的に上書き
|
| 257 |
+
module.weight.data = get_soft_ternary_weight(module.raw_weight, step, total_steps, self.warmup_steps)
|
| 258 |
+
|
| 259 |
+
def apply_trio_induction(model, step_provider, warmup_steps=100):
|
| 260 |
+
'''モデル側ではなく外側から3値化プラグインを刺す関数'''
|
| 261 |
+
for name, module in model.named_modules():
|
| 262 |
+
if isinstance(module, nn.Linear):
|
| 263 |
+
# ★ 拡張性:LLMの生命線である「embed」と「output_head」は絶対に3値化しない
|
| 264 |
+
if "embed" in name or "output_head" in name:
|
| 265 |
+
continue
|
| 266 |
+
if not hasattr(module, "raw_weight"):
|
| 267 |
+
# 元の実数重みを聖域(raw_weight)に退避
|
| 268 |
+
module.register_parameter("raw_weight", nn.Parameter(module.weight.data.clone()))
|
| 269 |
+
delattr(module, "weight")
|
| 270 |
+
module.register_buffer("weight", module.raw_weight.data.clone())
|
| 271 |
+
module.register_forward_pre_hook(TernaryWeightHook(step_provider, warmup_steps))
|
| 272 |
+
|
| 273 |
+
# D-RNA 3値化のパニックを100%いなすコンサルタント(RMSNorm 修正)
|
| 274 |
+
class RMSNorm(nn.Module):
|
| 275 |
+
'''【修正版】3値化の歪みをリセットする中心化・標準化型防波堤'''
|
| 276 |
+
def __init__(self, d_model, eps=1e-8):
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.eps = eps
|
| 279 |
+
# 3値化の歪みによってズレた「音量の軸」を、フル精度の実数で強制的に中心に戻すバイアス
|
| 280 |
+
self.bias = nn.Parameter(torch.zeros(d_model))
|
| 281 |
+
# 3値化のせいで極端にインフレ・デフレした次元ごとの音量を、個別にジャストフィットさせる実数スケール
|
| 282 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
| 283 |
+
|
| 284 |
+
def forward(self, x):
|
| 285 |
+
# 各次元ごとの平均と分散を計算
|
| 286 |
+
mean = x.mean(-1, keepdim=True)
|
| 287 |
+
var = x.var(-1, keepdim=True, unbiased=False)
|
| 288 |
+
|
| 289 |
+
# 3値化による「歪み・偏り」を、ここで完全にリセット(標準化)する
|
| 290 |
+
x_normed = (x - mean) * torch.rsqrt(var + self.eps)
|
| 291 |
+
|
| 292 |
+
# リセットされた綺麗な状態に対して、フル精度のオプティマイザが最適なスケールとオフセットを施す
|
| 293 |
+
return self.weight * x_normed + self.bias
|
| 294 |
+
|
| 295 |
+
class DRNA_RoPE(nn.Module):
|
| 296 |
+
'''二重らせんの位相を決定する回転場'''
|
| 297 |
+
def __init__(self, head_dim, base=10000):
|
| 298 |
+
super().__init__()
|
| 299 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 300 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 301 |
+
|
| 302 |
+
def forward(self, x, seq_len):
|
| 303 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 304 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 305 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 306 |
+
return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
|
| 307 |
+
|
| 308 |
+
def apply_drna_rope(q, k, cos, sin):
|
| 309 |
+
'''Kによる動的位相変調済み cos/sin を受け取る'''
|
| 310 |
+
def rotate_half(x):
|
| 311 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 312 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 313 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 314 |
+
|
| 315 |
+
class DRNA_Block(nn.Module):
|
| 316 |
+
'''DRNA共鳴ブロック:安定性を高めたPre-Norm直列共鳴構造'''
|
| 317 |
+
def __init__(self, d_model, n_heads, head_dim, d_ff=None, dropout=0.1):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.n_heads = n_heads
|
| 320 |
+
self.head_dim = head_dim
|
| 321 |
+
|
| 322 |
+
# らせんA:回想系 (Attention)
|
| 323 |
+
self.norm1 = RMSNorm(d_model) # 演算の前に配置
|
| 324 |
+
self.qkv = nn.Linear(d_model, d_model * 3)
|
| 325 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 326 |
+
|
| 327 |
+
# らせんB:記憶系 (MLP)
|
| 328 |
+
self.norm2 = RMSNorm(d_model) # 演算の前に配置
|
| 329 |
+
|
| 330 |
+
# SwiGLUの定石:パラメータ量をGELU版(×4倍)と合わせるため、約2.67倍(8/3)にする
|
| 331 |
+
if d_ff is None:
|
| 332 |
+
d_ff = int(2 * (d_model * 4) / 3)
|
| 333 |
+
|
| 334 |
+
# SwiGLUは入力に対して2つのLinear(Gate・Up)を並列に走らせます
|
| 335 |
+
self.w1 = nn.Linear(d_model, d_ff, bias=False) # ゲート用
|
| 336 |
+
self.w3 = nn.Linear(d_model, d_ff, bias=False) # 値用(アッププロジェクション)
|
| 337 |
+
self.w2 = nn.Linear(d_ff, d_model, bias=False) # ダウンプロジェクション
|
| 338 |
+
|
| 339 |
+
self.dropout = nn.Dropout(dropout)
|
| 340 |
+
|
| 341 |
+
def forward(self, x, cos, sin, mask=None):
|
| 342 |
+
b, s, d = x.shape
|
| 343 |
+
|
| 344 |
+
# 共通の残差(ベースとなる螺旋の軸)
|
| 345 |
+
residual = x
|
| 346 |
+
|
| 347 |
+
# らせんA (Attention) 並列方式
|
| 348 |
+
x_norm1 = self.norm1(x)
|
| 349 |
+
|
| 350 |
+
# QKV生成 (3倍のまま)
|
| 351 |
+
# ※ self.qkv(x_norm) が x_norm1 になっているか確認してください
|
| 352 |
+
qkv = self.qkv(x_norm1).reshape(b, s, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 353 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 354 |
+
|
| 355 |
+
# K因果的動的回転:一つ前の単語の K が今の単語の座標を決める
|
| 356 |
+
# 自分の情報で自分を回さないよう、Kを1つ未来にシフトさせる
|
| 357 |
+
# これにより「赤い(K)」が「猫(Q,K)」の位相を決定する構造になる
|
| 358 |
+
k_for_phase = torch.cat([torch.zeros_like(k[:, :, :1, :]), k[:, :, :-1, :]], dim=2)
|
| 359 |
+
|
| 360 |
+
# Kのエネルギーを位相(回転角)に変換
|
| 361 |
+
rt_phase = torch.tanh(k_for_phase) * math.pi
|
| 362 |
+
|
| 363 |
+
# 静的RoPE (cos, sin) を動的位相 (rt_phase) で加法定理により変調
|
| 364 |
+
# ※ apply_drna_rope に rt_phase を渡せるように関数側を調整するか、
|
| 365 |
+
# ここで dynamic_cos / sin を作って渡します
|
| 366 |
+
d_cos = (cos * torch.cos(rt_phase)) - (sin * torch.sin(rt_phase))
|
| 367 |
+
d_sin = (sin * torch.cos(rt_phase)) + (cos * torch.sin(rt_phase))
|
| 368 |
+
|
| 369 |
+
# 2重らせんをつくる (変調された座標で回転)
|
| 370 |
+
q, k = apply_drna_rope(q, k, d_cos, d_sin)
|
| 371 |
+
|
| 372 |
+
# Attention計算
|
| 373 |
+
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 374 |
+
if mask is not None:
|
| 375 |
+
attn = attn + mask
|
| 376 |
+
|
| 377 |
+
attn = F.softmax(attn, dim=-1)
|
| 378 |
+
a_out_raw = (attn @ v).transpose(1, 2).reshape(b, s, d)
|
| 379 |
+
a_out = self.out_proj(a_out_raw)
|
| 380 |
+
|
| 381 |
+
# らせんB (MLP)
|
| 382 |
+
x_norm2 = self.norm2(x)
|
| 383 |
+
|
| 384 |
+
# SwiGLUのコア数式: Swish(xW1) * xW3(F.silu:PyTorchのSwish関数)
|
| 385 |
+
gate = F.silu(self.w1(x_norm2))
|
| 386 |
+
current_value = self.w3(x_norm2)
|
| 387 |
+
|
| 388 |
+
# 要素積(共鳴収縮の表現としても非常に相性が良いです)
|
| 389 |
+
swiglu_out = gate * current_value
|
| 390 |
+
|
| 391 |
+
# 最終投影
|
| 392 |
+
m_out = self.w2(swiglu_out)
|
| 393 |
+
|
| 394 |
+
# 並列方式
|
| 395 |
+
x = residual + self.dropout(a_out) + self.dropout(m_out)
|
| 396 |
+
|
| 397 |
+
return x
|
| 398 |
+
|
| 399 |
+
class DRNA_Model(nn.Module):
|
| 400 |
+
'''汎用 DRNA モデルコンテナ(安定化 Pre-Norm 版)'''
|
| 401 |
+
def __init__(self, vocab_size, d_model=256, n_layers=16, n_heads=8, d_ff=1024):
|
| 402 |
+
super().__init__()
|
| 403 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
| 404 |
+
self.head_dim = d_model // n_heads
|
| 405 |
+
self.rope = DRNA_RoPE(self.head_dim)
|
| 406 |
+
|
| 407 |
+
self.layers = nn.ModuleList([
|
| 408 |
+
DRNA_Block(d_model, n_heads, self.head_dim, d_ff) for _ in range(n_layers)
|
| 409 |
+
])
|
| 410 |
+
|
| 411 |
+
# Pre-Norm構造の場合、最終レイヤーの後に全体のNormを置くのが一般的
|
| 412 |
+
self.final_norm = RMSNorm(d_model)
|
| 413 |
+
self.output_head = nn.Linear(d_model, vocab_size)
|
| 414 |
+
|
| 415 |
+
def forward(self, x, mask=None, pad_id=None):
|
| 416 |
+
b, s = x.shape
|
| 417 |
+
device = x.device
|
| 418 |
+
inputs = x
|
| 419 |
+
x = self.embed(x)
|
| 420 |
+
|
| 421 |
+
if mask is None or mask.sum() == 0:
|
| 422 |
+
# pad_id が整数(int/long)として有効な場合のみ pad_mask を作成
|
| 423 |
+
# 退避させた inputs を使ってパディングを判定し因果マスクを準備
|
| 424 |
+
p_id = pad_id.item() if isinstance(pad_id, torch.Tensor) else pad_id
|
| 425 |
+
pad_mask = (inputs != p_id).unsqueeze(1).unsqueeze(2) if isinstance(p_id, (int, float)) else torch.ones((1, 1, 1, s), device=device, dtype=torch.bool)
|
| 426 |
+
causal = torch.triu(torch.ones(s, s, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
|
| 427 |
+
|
| 428 |
+
# 環境(fp16/32)に応じた最小値を安全に自動計算
|
| 429 |
+
inf_value = torch.finfo(x.dtype).min if x.dtype != torch.float16 else -65500.0
|
| 430 |
+
|
| 431 |
+
# ゼロ初期化テンソルに無効領域をインプレースで直接埋める(カンニングの完全遮断)
|
| 432 |
+
mask = torch.zeros((b, 1, s, s), device=device, dtype=x.dtype).masked_fill_(causal | (~pad_mask), inf_value)
|
| 433 |
+
|
| 434 |
+
cos, sin = self.rope(x, x.size(1))
|
| 435 |
+
|
| 436 |
+
for layer in self.layers:
|
| 437 |
+
x = layer(x, cos, sin, mask=mask)
|
| 438 |
+
|
| 439 |
+
x = self.final_norm(x) # 出力前の最終同期
|
| 440 |
+
return self.output_head(x)
|
| 441 |
+
#---drna-swi-triox END---
|
| 442 |
+
|
| 443 |
+
'''
|
| 444 |
+
汎用型 D-RNA (Pre-Norm) License: Apache License 2.0 https://github.com/muooon/DRNA
|
| 445 |
+
Attention is all you need_started, Resonance is all you need_endure,
|
| 446 |
+
Neocognitron ― Transformer ― D‑RNA Dream Resonance Never Adjourns — it goes on...
|
| 447 |
+
'''
|
drna/__init__.py
ADDED
|
File without changes
|
drna/drna.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
'''
|
| 7 |
+
D‑RNA: Dual‑Helix Resonance Neural Architecture (DRNA) Pre-Norm・Kv-RoPE 版
|
| 8 |
+
仕様:Pre-Norm(RMSNorm)、GELU(Activation)、Kv-RoPE(head_dim)、mask(padding + causal)
|
| 9 |
+
Transformerの全接続性を継承しつつ、二重らせん(Dual-Helix)構造による
|
| 10 |
+
「共鳴収縮」(Resonant Contraction)を物理的に再現したニューラルアーキテクチャです
|
| 11 |
+
螺旋の同期:Attention(文脈の回想)とMLP(知識の定着)を並列配置し RoPE で情報を同期
|
| 12 |
+
位相の保持:RoPE(Phase Field)を回転場として利用し、安定した相対位置を保ち早期収束を両立
|
| 13 |
+
高密度圧縮:Pre-Norm により、各らせんを安定的に収縮させ、全結合により記憶を定着させる
|
| 14 |
+
'''
|
| 15 |
+
|
| 16 |
+
class RMSNorm(nn.Module):
|
| 17 |
+
def __init__(self, d_model, eps=1e-6):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.eps = eps
|
| 20 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
# 2乗平均の平方根で割る(平均を減算しない・中心化をしない)
|
| 24 |
+
norm = x.pow(2).mean(-1, keepdim=True)
|
| 25 |
+
x_normed = x * torch.rsqrt(norm + self.eps)
|
| 26 |
+
return self.weight * x_normed
|
| 27 |
+
|
| 28 |
+
class DRNA_RoPE(nn.Module):
|
| 29 |
+
"""二重らせんの位相を決定する回転場"""
|
| 30 |
+
def __init__(self, head_dim, base=10000):
|
| 31 |
+
super().__init__()
|
| 32 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 33 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 34 |
+
|
| 35 |
+
def forward(self, x, seq_len):
|
| 36 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 37 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 38 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 39 |
+
return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
|
| 40 |
+
|
| 41 |
+
def apply_drna_rope(q, k, cos, sin):
|
| 42 |
+
"""Kによる動的位相変調済み cos/sin を受け取る"""
|
| 43 |
+
def rotate_half(x):
|
| 44 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 45 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 46 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 47 |
+
|
| 48 |
+
class DRNA_Block(nn.Module):
|
| 49 |
+
"""DRNA共鳴ブロック:安定性を高めたPre-Norm直列共鳴構造"""
|
| 50 |
+
def __init__(self, d_model, n_heads, head_dim, d_ff=None, dropout=0.1):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.n_heads = n_heads
|
| 53 |
+
self.head_dim = head_dim
|
| 54 |
+
|
| 55 |
+
# らせんA: 回想系 (Attention)
|
| 56 |
+
self.norm1 = RMSNorm(d_model) # 演算の前に配置
|
| 57 |
+
self.qkv = nn.Linear(d_model, d_model * 3)
|
| 58 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 59 |
+
|
| 60 |
+
# らせんB: 記憶系 (MLP)
|
| 61 |
+
self.norm2 = RMSNorm(d_model) # 演算の前に配置
|
| 62 |
+
d_ff = d_ff or d_model * 4
|
| 63 |
+
self.mlp = nn.Sequential(
|
| 64 |
+
nn.Linear(d_model, d_ff),
|
| 65 |
+
nn.GELU(), # VRAM抑制は ReLU (別レイヤの干渉で0勾配にならない「可能性」あり)
|
| 66 |
+
nn.Dropout(dropout),
|
| 67 |
+
nn.Linear(d_ff, d_model)
|
| 68 |
+
)
|
| 69 |
+
self.dropout = nn.Dropout(dropout)
|
| 70 |
+
|
| 71 |
+
def forward(self, x, cos, sin, mask=None):
|
| 72 |
+
b, s, d = x.shape
|
| 73 |
+
|
| 74 |
+
# 1. 共通の残差(ベースとなる螺旋の軸)
|
| 75 |
+
residual = x
|
| 76 |
+
|
| 77 |
+
# 2. らせんA (Attention) 並列方式
|
| 78 |
+
x_norm1 = self.norm1(x)
|
| 79 |
+
|
| 80 |
+
# QKV生成 (3倍のまま)
|
| 81 |
+
# ※ self.qkv(x_norm) が x_norm1 になっているか確認してください
|
| 82 |
+
qkv = self.qkv(x_norm1).reshape(b, s, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 83 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 84 |
+
|
| 85 |
+
# K因果的動的回転:一つ前の単語の K が今の単語の座標を決める
|
| 86 |
+
# 自分の情報で自分を回さないよう、Kを1つ未来にシフトさせる
|
| 87 |
+
# これにより「赤い(K)」が「猫(Q,K)」の位相を決定する構造になる
|
| 88 |
+
k_for_phase = torch.cat([torch.zeros_like(k[:, :, :1, :]), k[:, :, :-1, :]], dim=2)
|
| 89 |
+
|
| 90 |
+
# Kのエネルギーを位相(回転角)に変換
|
| 91 |
+
rt_phase = torch.tanh(k_for_phase) * math.pi
|
| 92 |
+
|
| 93 |
+
# 静的RoPE (cos, sin) を動的位相 (rt_phase) で加法定理により変調
|
| 94 |
+
# ※ apply_drna_rope に rt_phase を渡せるように関数側を調整するか、
|
| 95 |
+
# ここで dynamic_cos / sin を作って渡します
|
| 96 |
+
d_cos = (cos * torch.cos(rt_phase)) - (sin * torch.sin(rt_phase))
|
| 97 |
+
d_sin = (sin * torch.cos(rt_phase)) + (cos * torch.sin(rt_phase))
|
| 98 |
+
|
| 99 |
+
# 2重らせんをつくる (変調された座標で回転)
|
| 100 |
+
q, k = apply_drna_rope(q, k, d_cos, d_sin)
|
| 101 |
+
|
| 102 |
+
# Attention計算
|
| 103 |
+
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 104 |
+
if mask is not None:
|
| 105 |
+
attn = attn + mask
|
| 106 |
+
|
| 107 |
+
attn = F.softmax(attn, dim=-1)
|
| 108 |
+
a_out_raw = (attn @ v).transpose(1, 2).reshape(b, s, d)
|
| 109 |
+
a_out = self.out_proj(a_out_raw)
|
| 110 |
+
|
| 111 |
+
# 3. らせんB (MLP)
|
| 112 |
+
x_norm2 = self.norm2(x)
|
| 113 |
+
m_out = self.mlp(x_norm2)
|
| 114 |
+
|
| 115 |
+
# 4. 並列方式
|
| 116 |
+
x = residual + self.dropout(a_out) + self.dropout(m_out)
|
| 117 |
+
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
class DRNA_Model(nn.Module):
|
| 121 |
+
"""汎用 DRNA モデルコンテナ(安定化 Pre-Norm 版)"""
|
| 122 |
+
def __init__(self, vocab_size, d_model=256, n_layers=16, n_heads=8, d_ff=1024):
|
| 123 |
+
super().__init__()
|
| 124 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
| 125 |
+
self.head_dim = d_model // n_heads
|
| 126 |
+
self.rope = DRNA_RoPE(self.head_dim)
|
| 127 |
+
|
| 128 |
+
self.layers = nn.ModuleList([
|
| 129 |
+
DRNA_Block(d_model, n_heads, self.head_dim, d_ff) for _ in range(n_layers)
|
| 130 |
+
])
|
| 131 |
+
|
| 132 |
+
# Pre-Norm構造の場合、最終レイヤーの後に全体のNormを置くのが一般的
|
| 133 |
+
self.final_norm = RMSNorm(d_model)
|
| 134 |
+
self.output_head = nn.Linear(d_model, vocab_size)
|
| 135 |
+
|
| 136 |
+
def forward(self, x, mask=None, pad_id=None):
|
| 137 |
+
b, s = x.shape
|
| 138 |
+
device = x.device
|
| 139 |
+
inputs = x
|
| 140 |
+
x = self.embed(x)
|
| 141 |
+
|
| 142 |
+
if mask is None or mask.sum() == 0:
|
| 143 |
+
# pad_id が整数(int/long)として有効な場合のみ pad_mask を作成
|
| 144 |
+
# 退避させた inputs を使ってパディングを判定し因果マスクを準備
|
| 145 |
+
p_id = pad_id.item() if isinstance(pad_id, torch.Tensor) else pad_id
|
| 146 |
+
pad_mask = (inputs != p_id).unsqueeze(1).unsqueeze(2) if isinstance(p_id, (int, float)) else torch.ones((1, 1, 1, s), device=device, dtype=torch.bool)
|
| 147 |
+
causal = torch.triu(torch.ones(s, s, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
|
| 148 |
+
|
| 149 |
+
# 環境(fp16/32)に応じた最小値を安全に自動計算
|
| 150 |
+
inf_value = torch.finfo(x.dtype).min if x.dtype != torch.float16 else -65500.0
|
| 151 |
+
|
| 152 |
+
# ゼロ初期化テンソルに無効領域をインプレースで直接埋める(カンニングの完全遮断)
|
| 153 |
+
mask = torch.zeros((b, 1, s, s), device=device, dtype=x.dtype).masked_fill_(causal | (~pad_mask), inf_value)
|
| 154 |
+
|
| 155 |
+
cos, sin = self.rope(x, x.size(1))
|
| 156 |
+
|
| 157 |
+
for layer in self.layers:
|
| 158 |
+
x = layer(x, cos, sin, mask=mask)
|
| 159 |
+
|
| 160 |
+
x = self.final_norm(x) # 出力前の最終同期
|
| 161 |
+
return self.output_head(x)
|
| 162 |
+
|
| 163 |
+
'''
|
| 164 |
+
260520:maskの微調整(AMP対応)/MoE-LoRA版、vlayer版、D-RNAの活用例を汎用コード化
|
| 165 |
+
260507:Kによる回転で文脈に単語を沿わせ2重らせんの干渉による取捨選択とホログラム合成を可能にする
|
| 166 |
+
260505:model構成から学習解像度を自動化、汎用 mask の精度への適正化、RMSNormへの移行
|
| 167 |
+
260503:padding を引数で指定できるよう変更
|
| 168 |
+
# 例:一般的な Tokenizer の pad_id が 0 の場合
|
| 169 |
+
output = model(input_ids, pad_id=0)
|
| 170 |
+
# 例:Hugging Face 等の tokenizer を使っている場合
|
| 171 |
+
output = model(input_ids, pad_id=tokenizer.pad_token_id)
|
| 172 |
+
260502:変数名を正確化(head_dim)、汎用 mask に変更し padding 等に対応可
|
| 173 |
+
'''
|
| 174 |
+
|
| 175 |
+
'''
|
| 176 |
+
汎用型 D-RNA (Pre-Norm) License: Apache License 2.0 https://github.com/muooon/DRNA
|
| 177 |
+
Attention is all you need_started, Resonance is all you need_endure,
|
| 178 |
+
Neocognitron ― Transformer ― D‑RNA Dream Resonance Never Adjourns — it goes on...
|
| 179 |
+
'''
|
drna/drna_moe.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
'''
|
| 7 |
+
D‑RNA: Dual‑Helix Resonance Neural Architecture (DRNA) Pre-Norm・Kv-RoPE・MoE‑LoRA 版
|
| 8 |
+
仕様:Pre-Norm(RMSNorm)、GELU(Activation)、Kv-RoPE(head_dim)、mask(padding + causal)
|
| 9 |
+
汎用コードのコア(Base Weight)を不動の岩盤(ランダム初期化・完全固定)としつつ、
|
| 10 |
+
Attention および MLP の全線形層に LoRA による複数 expert を配置する D-RNA 応用型です
|
| 11 |
+
入力トークン、あるいはシーケンス特性に応じて動的に expert を選択・融合する MoE的拡張版
|
| 12 |
+
'''
|
| 13 |
+
|
| 14 |
+
class RMSNorm(nn.Module):
|
| 15 |
+
def __init__(self, d_model, eps=1e-6):
|
| 16 |
+
super().__init__()
|
| 17 |
+
self.eps = eps
|
| 18 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
norm = x.pow(2).mean(-1, keepdim=True)
|
| 22 |
+
x_normed = x * torch.rsqrt(norm + self.eps)
|
| 23 |
+
return self.weight * x_normed
|
| 24 |
+
|
| 25 |
+
class DRNA_RoPE(nn.Module):
|
| 26 |
+
"""二重らせんの位相を決定する回転場"""
|
| 27 |
+
def __init__(self, head_dim, base=10000):
|
| 28 |
+
super().__init__()
|
| 29 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 30 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 31 |
+
|
| 32 |
+
def forward(self, x, seq_len):
|
| 33 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 34 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 35 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 36 |
+
return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
|
| 37 |
+
|
| 38 |
+
def apply_drna_rope(q, k, cos, sin):
|
| 39 |
+
def rotate_half(x):
|
| 40 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 41 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 42 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 43 |
+
|
| 44 |
+
class MoELoRALinear(nn.Module):
|
| 45 |
+
"""固定されたベース線形層に対して、複数のLoRA Expertを動的にルーティングするMoE構造"""
|
| 46 |
+
def __init__(self, in_features, out_features, r=16, lora_alpha=16, num_experts=4):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.in_features = in_features
|
| 49 |
+
self.out_features = out_features
|
| 50 |
+
self.r = r
|
| 51 |
+
self.scaling = lora_alpha / r
|
| 52 |
+
self.num_experts = num_experts
|
| 53 |
+
|
| 54 |
+
# 不動の岩盤 (ランダム初期化のまま勾配を固定)
|
| 55 |
+
self.base_weight = nn.Parameter(torch.empty(out_features, in_features), requires_grad=False)
|
| 56 |
+
self.base_bias = nn.Parameter(torch.empty(out_features), requires_grad=False)
|
| 57 |
+
nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5))
|
| 58 |
+
bound = 1 / math.sqrt(in_features) if in_features > 0 else 0
|
| 59 |
+
nn.init.uniform_(self.base_bias, -bound, bound)
|
| 60 |
+
|
| 61 |
+
# MoE-LoRA Experts (学習対象)
|
| 62 |
+
self.lora_A = nn.Parameter(torch.randn(num_experts, r, in_features) / math.sqrt(in_features))
|
| 63 |
+
self.lora_B = nn.Parameter(torch.zeros(num_experts, out_features, r))
|
| 64 |
+
|
| 65 |
+
# どのExpertをどれだけ使うかを決定するルーター (シーケンス/トークン単位)
|
| 66 |
+
self.router = nn.Linear(in_features, num_experts)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
# 共通のベース出力を計算
|
| 70 |
+
base_out = F.linear(x, self.base_weight, self.base_bias)
|
| 71 |
+
|
| 72 |
+
# ルーティング重みの計算 (Softmaxによるソフトな結合、またはTop-kへの拡張も可能)
|
| 73 |
+
# x: (B, S, in_features) -> router_logits: (B, S, num_experts)
|
| 74 |
+
router_logits = self.router(x)
|
| 75 |
+
router_weights = F.softmax(router_logits, dim=-1).unsqueeze(-1) # (B, S, num_experts)
|
| 76 |
+
|
| 77 |
+
# 全ExpertのLoRA出力をループなしで一斉に並列計算
|
| 78 |
+
# self.lora_A の形状: (num_experts, r, in_features)
|
| 79 |
+
lora_A_out = torch.einsum("bsi,ejr->bsej", x, self.lora_A)
|
| 80 |
+
|
| 81 |
+
# self.lora_B の形状: (num_experts, out_features, r)
|
| 82 |
+
# -> lora_out_all: (B, S, num_experts, out_features)
|
| 83 |
+
lora_out_all = torch.einsum("bsej,ekj->bsek", lora_A_out, self.lora_B) * self.scaling
|
| 84 |
+
# ルーターの重みで加重平均して合流
|
| 85 |
+
# (B, S, num_experts, out_features) * (B, S, num_experts, 1) -> sum over experts
|
| 86 |
+
lora_out_total = (lora_out_all * router_weights).sum(dim=2)
|
| 87 |
+
|
| 88 |
+
return base_out + lora_out_total
|
| 89 |
+
|
| 90 |
+
class DRNA_MoE_Block(nn.Module):
|
| 91 |
+
"""DRNA共鳴ブロック:MoE-LoRA並列構造"""
|
| 92 |
+
def __init__(self, d_model, n_heads, head_dim, d_ff=None, dropout=0.1, r=16, num_experts=4):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.n_heads = n_heads
|
| 95 |
+
self.head_dim = head_dim
|
| 96 |
+
|
| 97 |
+
# らせんA: 回想系 (Attention)
|
| 98 |
+
self.norm1 = RMSNorm(d_model)
|
| 99 |
+
self.qkv = MoELoRALinear(d_model, d_model * 3, r=r, num_experts=num_experts)
|
| 100 |
+
self.out_proj = MoELoRALinear(d_model, d_model, r=r, num_experts=num_experts)
|
| 101 |
+
|
| 102 |
+
# らせんB: 記憶系 (MLP)
|
| 103 |
+
self.norm2 = RMSNorm(d_model)
|
| 104 |
+
d_ff = d_ff or d_model * 4
|
| 105 |
+
|
| 106 |
+
# 汎用コードの直列���述を維持するため、Sequentialではなく個別にMoEレイヤを定義
|
| 107 |
+
self.mlp_in = MoELoRALinear(d_model, d_ff, r=r, num_experts=num_experts)
|
| 108 |
+
self.mlp_out = MoELoRALinear(d_ff, d_model, r=r, num_experts=num_experts)
|
| 109 |
+
self.activation = nn.GELU()
|
| 110 |
+
self.mlp_dropout = nn.Dropout(dropout)
|
| 111 |
+
|
| 112 |
+
self.dropout = nn.Dropout(dropout)
|
| 113 |
+
|
| 114 |
+
def forward(self, x, cos, sin, mask=None):
|
| 115 |
+
b, s, d = x.shape
|
| 116 |
+
residual = x
|
| 117 |
+
|
| 118 |
+
# らせんA (Attention)
|
| 119 |
+
x_norm1 = self.norm1(x)
|
| 120 |
+
qkv = self.qkv(x_norm1).reshape(b, s, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 121 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 122 |
+
|
| 123 |
+
k_for_phase = torch.cat([torch.zeros_like(k[:, :, :1, :]), k[:, :, :-1, :]], dim=2)
|
| 124 |
+
rt_phase = torch.tanh(k_for_phase) * math.pi
|
| 125 |
+
|
| 126 |
+
d_cos = (cos * torch.cos(rt_phase)) - (sin * torch.sin(rt_phase))
|
| 127 |
+
d_sin = (sin * torch.cos(rt_phase)) + (cos * torch.sin(rt_phase))
|
| 128 |
+
|
| 129 |
+
q, k = apply_drna_rope(q, k, d_cos, d_sin)
|
| 130 |
+
|
| 131 |
+
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 132 |
+
if mask is not None:
|
| 133 |
+
attn = attn + mask
|
| 134 |
+
|
| 135 |
+
attn = F.softmax(attn, dim=-1)
|
| 136 |
+
a_out_raw = (attn @ v).transpose(1, 2).reshape(b, s, d)
|
| 137 |
+
a_out = self.out_proj(a_out_raw)
|
| 138 |
+
|
| 139 |
+
# らせんB (MLP)
|
| 140 |
+
x_norm2 = self.norm2(x)
|
| 141 |
+
m_intermediate = self.activation(self.mlp_in(x_norm2))
|
| 142 |
+
m_out = self.mlp_out(self.mlp_dropout(m_intermediate))
|
| 143 |
+
|
| 144 |
+
# 並列方式
|
| 145 |
+
x = residual + self.dropout(a_out) + self.dropout(m_out)
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
class DRNA_MoE_Model(nn.Module):
|
| 149 |
+
"""汎用 DRNA モデルコンテナ(ランダム初期化ベース+MoE-LoRA拡張・UTF-8対応版)"""
|
| 150 |
+
def __init__(self, vocab_size=256, d_model=256, n_layers=16, n_heads=8, d_ff=1024, lora_r=16, num_experts=4):
|
| 151 |
+
super().__init__()
|
| 152 |
+
# UTF-8直受けのため、vocab_sizeは固定で256を指定可能
|
| 153 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
| 154 |
+
self.head_dim = d_model // n_heads
|
| 155 |
+
self.rope = DRNA_RoPE(self.head_dim)
|
| 156 |
+
|
| 157 |
+
self.layers = nn.ModuleList([
|
| 158 |
+
DRNA_MoE_Block(d_model, n_heads, self.head_dim, d_ff, r=lora_r, num_experts=num_experts) for _ in range(n_layers)
|
| 159 |
+
])
|
| 160 |
+
|
| 161 |
+
self.final_norm = RMSNorm(d_model)
|
| 162 |
+
self.output_head = nn.Linear(d_model, vocab_size)
|
| 163 |
+
|
| 164 |
+
def forward(self, x, mask=None, pad_id=None):
|
| 165 |
+
b, s = x.shape
|
| 166 |
+
device = x.device
|
| 167 |
+
inputs = x
|
| 168 |
+
x = self.embed(x)
|
| 169 |
+
|
| 170 |
+
if mask is None or mask.sum() == 0:
|
| 171 |
+
# 退避させた inputs でパディング位置を判定
|
| 172 |
+
# パッドマスクとコーザルマスクを判定(pad_idの型チェック&テンソルバグ修正)
|
| 173 |
+
p_id = pad_id.item() if isinstance(pad_id, torch.Tensor) else pad_id
|
| 174 |
+
pad_mask = (inputs != p_id).unsqueeze(1).unsqueeze(2) if isinstance(p_id, (int, float)) else torch.ones((1, 1, 1, s), device=device, dtype=torch.bool)
|
| 175 |
+
causal = torch.triu(torch.ones(s, s, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
|
| 176 |
+
|
| 177 |
+
# 環境(fp16/32)に応じた最小値を安全に自動計算
|
| 178 |
+
inf_value = torch.finfo(x.dtype).min if x.dtype != torch.float16 else -65500.0
|
| 179 |
+
|
| 180 |
+
# ゼロ初期化テンソルに無効領域をインプレースで直接埋める(カンニングの完全遮断)
|
| 181 |
+
mask = torch.zeros((b, 1, s, s), device=device, dtype=x.dtype).masked_fill_(causal | (~pad_mask), inf_value)
|
| 182 |
+
|
| 183 |
+
cos, sin = self.rope(x, x.size(1))
|
| 184 |
+
|
| 185 |
+
for layer in self.layers:
|
| 186 |
+
x = layer(x, cos, sin, mask=mask)
|
| 187 |
+
|
| 188 |
+
x = self.final_norm(x)
|
| 189 |
+
return self.output_head(x)
|
| 190 |
+
|
| 191 |
+
'''
|
| 192 |
+
260520:LoRA による MoE(Mixture of Experts) 多重拡張化をする Plugin型です
|
| 193 |
+
学習元モデルをランダム初期化による完全カオス状態(凍結運用)とすることで"公開鍵"的に扱うことが可能となる
|
| 194 |
+
これにより LoRA は"秘密鍵"的に扱える、つまりこの学習元モデルでしか機能しないためセキュリティを堅持できる
|
| 195 |
+
学習元モデルを非公開で秘匿することで情報流出などを抑止した運用も可能になる
|
| 196 |
+
---
|
| 197 |
+
MoE-LoRAによる多重知性拡張-LoRA形式は高効率な学習と推論を可能とします
|
| 198 |
+
学習元モデル単体では具体的な知識や偏りを持たない安全な共通基盤になります(ランダム行列表なので)
|
| 199 |
+
このLoRAは学習元モデルでしか使えないため、LoRA単体だけではなにもできません
|
| 200 |
+
完全非公開・秘匿とすることで、情報流出リスクを極小化したクローズド運用も実現可能です
|
| 201 |
+
学習元モデルのみを公開するオープン運用でMoE-LoRAをみんなで自由に作成し機能向上を図ることも可能です
|
| 202 |
+
---
|
| 203 |
+
マルチモーダル対応:画像・動画・音��などの学習も可能です(必要ならトークナイザ等の差し替えも可能です)
|
| 204 |
+
VAEなどの外付けをせず、UTF8 を用いたトークンで 16x16 パッド化などでViT的な学習もできます
|
| 205 |
+
事前学習(プレトレーニング)もLoRAで行うことで高速化効率化を果たします(応答専用LoRAも学習可)
|
| 206 |
+
純粋な古代語LoRAなどをつくることで現代語に浸食されたり現代語を破壊するような干渉も防げます
|
| 207 |
+
'''
|
| 208 |
+
|
| 209 |
+
'''
|
| 210 |
+
[ 9B D-RNA ランダムベース重み (共通基盤:完全固定) ] (Dim4096/Layer32などを想定)
|
| 211 |
+
├──► 【言語LoRA (各国語/英語)】 ────► 日常対話・創造的執筆
|
| 212 |
+
├──► 【視覚パッチLoRA (512px)】 ────► 画像認識・マルチモーダル
|
| 213 |
+
├──► 【時間補間LoRA (fps拡張)】 ──► 滑らかな動画生成
|
| 214 |
+
├──► 【超解像LoRA (4K拡大)】 ────► 究極のディテール補間
|
| 215 |
+
└──► 【聴覚パッチLoRA (音声波形)】 ─► 音声認識
|
| 216 |
+
画像:512px生成(主題と背景と物体等の対応関係)─►アップスケール(局所超解像MoE-LoRA)
|
| 217 |
+
動画:512px生成(8fps)─►フレーム補間(60fps)─►空間アップスケール(4K超解像)
|
| 218 |
+
このように位置関係LoRA、精細化LoRA、動体追従LoRA、などを付け足すだけで拡張可能です
|
| 219 |
+
言語も種類問わず、プログラム言語専用、国や地域の固有のもの、歴史的なもの、を干渉防止で学習可能です
|
| 220 |
+
'''
|
| 221 |
+
|
| 222 |
+
'''
|
| 223 |
+
汎用型 D-RNA (Pre-Norm) License: Apache License 2.0 https://github.com/muooon/DRNA
|
| 224 |
+
Attention is all you need_started, Resonance is all you need_endure,
|
| 225 |
+
Neocognitron ― Transformer ― D‑RNA Dream Resonance Never Adjourns — it goes on...
|
| 226 |
+
'''
|
drna/drna_restore.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
'''
|
| 6 |
+
# チェックポイント読み込み(D-RNA-Trio の学習済みモデル)
|
| 7 |
+
restorer = DRNAWeightRestorer(d_model=256, num_layers=16)
|
| 8 |
+
# fp8/16/32 相当の連続重みの復元
|
| 9 |
+
continuous_weights = restorer.restore_from_checkpoint(
|
| 10 |
+
'path/to/drna_trio_checkpoint.pth')
|
| 11 |
+
# 結果の確認(多峰性の凸凹分布が観測される)
|
| 12 |
+
print(f"重み分布の最小値: {continuous_weights.min():.4f}")
|
| 13 |
+
print(f"重み分布の最大値: {continuous_weights.max():.4f}")
|
| 14 |
+
print(f"重み分布の標準偏差: {continuous_weights.std():.4f}")
|
| 15 |
+
# 仕組み
|
| 16 |
+
2〜3 レイヤ:fp8相当、4〜6 レイヤ:bf16相当、7+ レイヤ:fp16/32 相当
|
| 17 |
+
固有位相(K-energy × RoPE 周波数成分)を重ね合わせ+正規化(二重らせんの共鳴収縮により歪みを抑制)
|
| 18 |
+
特に K-for-phase の「逐次カスケード的変化」が位相シフトとなり自然な凸凹パターンを生成
|
| 19 |
+
'''
|
| 20 |
+
|
| 21 |
+
class DRNAWeightRestorer:
|
| 22 |
+
"""
|
| 23 |
+
D-RNA の固有位相を利用した fp8/16/32 重み復元器
|
| 24 |
+
チェックポイント読み込み → 位相抽出 → 重ね合わせ → 正規化 で連続分布を生成
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, d_model: int = 256, num_layers: int = 16):
|
| 28 |
+
self.d_model = d_model
|
| 29 |
+
self.num_layers = num_layers
|
| 30 |
+
|
| 31 |
+
# チェックポイントから -/0/+ 重みを読み出す
|
| 32 |
+
def load_ternary_weights(self, checkpoint_path: str) -> List[torch.Tensor]:
|
| 33 |
+
"""ロードされたチェックポイントから各レイヤの重みを抽出する"""
|
| 34 |
+
checkpoint = torch.load(checkpoint_path, map_location='cuda')
|
| 35 |
+
|
| 36 |
+
layer_weights = []
|
| 37 |
+
for i in range(self.num_layers):
|
| 38 |
+
# QKV レイヤー(norm1 ブランチ)
|
| 39 |
+
qkv_weight = checkpoint[f'layers.{i}.qkv.weight']
|
| 40 |
+
|
| 41 |
+
# MLP 入力/出力レイヤー(norm2 ブランチ)
|
| 42 |
+
mlp_0_weight = checkpoint[f'layers.{i}.mlp.0.weight']
|
| 43 |
+
mlp_3_weight = checkpoint[f'layers.{i}.mlp.3.weight']
|
| 44 |
+
|
| 45 |
+
layer_weights.extend([qkv_weight, mlp_0_weight, mlp_3_weight])
|
| 46 |
+
|
| 47 |
+
return layer_weights
|
| 48 |
+
|
| 49 |
+
# K-for-phase のカスケード効果から各レイヤの固有位相を抽出する
|
| 50 |
+
def extract_phase_offsets(self, max_seq_len: int = 256) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
| 51 |
+
"""DRNA の固有位相(K-energy × RoPE 周波数成分)を計算し返す"""
|
| 52 |
+
|
| 53 |
+
num_layer_groups = self.num_layers * 3
|
| 54 |
+
|
| 55 |
+
phase_data_list = []
|
| 56 |
+
base_energy = torch.arange(128, dtype=torch.float32, device='cuda') / self.d_model
|
| 57 |
+
|
| 58 |
+
for i in range(num_layer_groups):
|
| 59 |
+
# K-energy のカスケード効果:レイヤごとに滑らかにシフト(K-for-phase 同様に)
|
| 60 |
+
k_energy = base_energy * (i % self.num_layers + 1)
|
| 61 |
+
|
| 62 |
+
# DRNA の動的位相変調: tanh(k_energy) × π
|
| 63 |
+
rt_phase = math.tanh(k_energy) * math.pi
|
| 64 |
+
|
| 65 |
+
# RoPE 的な周波数成分の抽出(各レイヤ固有の周波数パターン)
|
| 66 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, 128, 2).float() / self.d_model))
|
| 67 |
+
|
| 68 |
+
freqs = torch.einsum("i,j->ij",
|
| 69 |
+
torch.arange(max_seq_len, device='cuda'),
|
| 70 |
+
inv_freq)
|
| 71 |
+
|
| 72 |
+
# cos/sin の位相シフトを生成
|
| 73 |
+
cos_shift = torch.cos(freqs * rt_phase.unsqueeze(-1))
|
| 74 |
+
sin_shift = torch.sin(freqs * rt_phase.unsqueeze(-1))
|
| 75 |
+
|
| 76 |
+
phase_data_list.append((cos_shift, sin_shift))
|
| 77 |
+
|
| 78 |
+
return phase_data_list
|
| 79 |
+
|
| 80 |
+
# 固有位相を用いて重ね合わせ(RoPE 的な位相変調を適用)
|
| 81 |
+
def superimpose_weights(
|
| 82 |
+
self,
|
| 83 |
+
layer_weights: List[torch.Tensor],
|
| 84 |
+
phase_offsets: List[Tuple[torch.Tensor, torch.Tensor]]
|
| 85 |
+
) -> torch.Tensor:
|
| 86 |
+
"""各レイヤの -/0/+ 重みに固有位相を適用し重ね合わせる"""
|
| 87 |
+
|
| 88 |
+
reconstructed = torch.zeros_like(layer_weights[0])
|
| 89 |
+
|
| 90 |
+
for i, w_i in enumerate(layer_weights):
|
| 91 |
+
cos_shift, sin_shift = phase_offsets[i]
|
| 92 |
+
|
| 93 |
+
# DRNA の K-for-phase 同様の位相変調: d_cos / d_sin を用いる重ね合わせ
|
| 94 |
+
contribution = w_i * (cos_shift - rotate_half(w_i) * sin_shift)
|
| 95 |
+
reconstructed += contribution
|
| 96 |
+
|
| 97 |
+
return reconstructed
|
| 98 |
+
|
| 99 |
+
# RMSNorm 的な正規化(3 値歪みのリセット)
|
| 100 |
+
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
"""重ね合わせ後の分布を RMSNorm で安定化する"""
|
| 102 |
+
|
| 103 |
+
mean = x.mean(-1, keepdim=True)
|
| 104 |
+
var = x.var(-1, keepdim=True, unbiased=False)
|
| 105 |
+
|
| 106 |
+
# 中心化・標準化(3 値歪みのリセット)
|
| 107 |
+
return (x - mean) * torch.rsqrt(var + 1e-8)
|
| 108 |
+
|
| 109 |
+
# 完全復元パイプライン
|
| 110 |
+
def restore_from_checkpoint(
|
| 111 |
+
self,
|
| 112 |
+
checkpoint_path: str,
|
| 113 |
+
max_seq_len: int = 256
|
| 114 |
+
) -> torch.Tensor:
|
| 115 |
+
"""チェ���クポイントを読み込み、fp8/16 相当の連続重みを生成する"""
|
| 116 |
+
|
| 117 |
+
# -1.0/0.0/1.0 重みの読み出し
|
| 118 |
+
layer_weights = self.load_ternary_weights(checkpoint_path)
|
| 119 |
+
|
| 120 |
+
# 固有位相(K-energy × RoPE)の抽出
|
| 121 |
+
phase_offsets = self.extract_phase_offsets(max_seq_len)
|
| 122 |
+
|
| 123 |
+
# 重ね合わせ(RoPE 的な位相変調を適用)
|
| 124 |
+
continuous_weight = self.superimpose_weights(
|
| 125 |
+
layer_weights,
|
| 126 |
+
phase_offsets)
|
| 127 |
+
|
| 128 |
+
# RMSNorm 的な正規化
|
| 129 |
+
normalized_weight = self.normalize(continuous_weight)
|
| 130 |
+
|
| 131 |
+
return normalized_weight
|
| 132 |
+
|
| 133 |
+
# 補助関数: DRNA rotate_half 再現
|
| 134 |
+
def rotate_half(x):
|
| 135 |
+
"""D-RNA の K-for-phase 同様の回転操作(位相変調の右辺項)"""
|
| 136 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 137 |
+
return torch.cat((-x2, x1), dim=-1)
|
drna/drna_swi_restore.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import math
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
|
| 5 |
+
'''
|
| 6 |
+
# チェックポイント読み込み(D-RNA-Trio の学習済みモデル)
|
| 7 |
+
restorer = DRNAWeightRestorer(d_model=256, num_layers=16)
|
| 8 |
+
# fp8/16/32 相当の連続重みの復元
|
| 9 |
+
continuous_weights = restorer.restore_from_checkpoint(
|
| 10 |
+
'path/to/drna_trio_checkpoint.pth')
|
| 11 |
+
# 結果の確認(多峰性の凸凹分布が観測される)
|
| 12 |
+
print(f"重み分布の最小値: {continuous_weights.min():.4f}")
|
| 13 |
+
print(f"重み分布の最大値: {continuous_weights.max():.4f}")
|
| 14 |
+
print(f"重み分布の標準偏差: {continuous_weights.std():.4f}")
|
| 15 |
+
# 仕組み
|
| 16 |
+
2〜3 レイヤ:fp8相当、4〜6 レイヤ:bf16相当、7+ レイヤ:fp16/32 相当
|
| 17 |
+
固有位相(K-energy × RoPE 周波数成分)を重ね合わせ+正規化(二重らせんの共鳴収縮により歪みを抑制)
|
| 18 |
+
特に K-for-phase の「逐次カスケード的変化」が位相シフトとなり自然な凸凹パターンを生成
|
| 19 |
+
'''
|
| 20 |
+
|
| 21 |
+
class DRNAWeightRestorer:
|
| 22 |
+
"""
|
| 23 |
+
D-RNA の固有位相を利用した fp8/16/32 重み復元器
|
| 24 |
+
チェックポイント読み込み → 位相抽出 → 重ね合わせ → 正規化 で連続分布を生成
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, d_model: int = 256, num_layers: int = 16):
|
| 28 |
+
self.d_model = d_model
|
| 29 |
+
self.num_layers = num_layers
|
| 30 |
+
|
| 31 |
+
# チェックポイントから -/0/+ 重みを読み出す
|
| 32 |
+
def load_ternary_weights(self, checkpoint_path: str) -> List[torch.Tensor]:
|
| 33 |
+
"""ロードされたチェックポイントから各レイヤの重みを抽出する"""
|
| 34 |
+
checkpoint = torch.load(checkpoint_path, map_location='cuda')
|
| 35 |
+
|
| 36 |
+
layer_weights = []
|
| 37 |
+
for i in range(self.num_layers):
|
| 38 |
+
# QKV レイヤー(norm1 ブランチ)
|
| 39 |
+
qkv_weight = checkpoint[f'layers.{i}.qkv.weight']
|
| 40 |
+
|
| 41 |
+
# MLP 入力/出力レイヤー(norm2 ブランチ)w3 も取得する
|
| 42 |
+
# 旧 mlp.0.weight 相当 ──> ゲート用 w1
|
| 43 |
+
mlp_w1_weight = checkpoint[f'layers.{i}.w1.weight']
|
| 44 |
+
# SwiGLU対応による拡張部分 ──> アッププロジェクション w3
|
| 45 |
+
mlp_w3_weight = checkpoint[f'layers.{i}.w3.weight']
|
| 46 |
+
# 旧 mlp.3 (絞るLinear) ──> ダウンプロジェクション w2
|
| 47 |
+
mlp_w2_weight = checkpoint[f'layers.{i}.w2.weight']
|
| 48 |
+
|
| 49 |
+
# 4つの重みをすべてリストに格納(qkv, w1, w3, w2)
|
| 50 |
+
layer_weights.extend([qkv_weight, mlp_w1_weight, mlp_w3_weight, mlp_w2_weight])
|
| 51 |
+
|
| 52 |
+
return layer_weights
|
| 53 |
+
|
| 54 |
+
# K-for-phase のカスケード効果から各レイヤの固有位相を抽出する
|
| 55 |
+
def extract_phase_offsets(self, max_seq_len: int = 256) -> List[Tuple[torch.Tensor, torch.Tensor]]:
|
| 56 |
+
"""DRNA の固有位相(K-energy × RoPE 周波数成分)を計算し返す"""
|
| 57 |
+
|
| 58 |
+
num_layer_groups = self.num_layers * 4
|
| 59 |
+
|
| 60 |
+
phase_data_list = []
|
| 61 |
+
base_energy = torch.arange(128, dtype=torch.float32, device='cuda') / self.d_model
|
| 62 |
+
|
| 63 |
+
for i in range(num_layer_groups):
|
| 64 |
+
# K-energy のカスケード効果:レイヤごとに滑らかにシフト(K-for-phase 同様に)
|
| 65 |
+
k_energy = base_energy * (i % self.num_layers + 1)
|
| 66 |
+
|
| 67 |
+
# DRNA の動的位相変調: tanh(k_energy) × π
|
| 68 |
+
rt_phase = math.tanh(k_energy) * math.pi
|
| 69 |
+
|
| 70 |
+
# RoPE 的な周波数成分の抽出(各レイヤ固有の周波数パターン)
|
| 71 |
+
inv_freq = 1.0 / (10000 ** (torch.arange(0, 128, 2).float() / self.d_model))
|
| 72 |
+
|
| 73 |
+
freqs = torch.einsum("i,j->ij",
|
| 74 |
+
torch.arange(max_seq_len, device='cuda'),
|
| 75 |
+
inv_freq)
|
| 76 |
+
|
| 77 |
+
# cos/sin の位相シフトを生成
|
| 78 |
+
cos_shift = torch.cos(freqs * rt_phase.unsqueeze(-1))
|
| 79 |
+
sin_shift = torch.sin(freqs * rt_phase.unsqueeze(-1))
|
| 80 |
+
|
| 81 |
+
phase_data_list.append((cos_shift, sin_shift))
|
| 82 |
+
|
| 83 |
+
return phase_data_list
|
| 84 |
+
|
| 85 |
+
# 固有位相を用いて重ね合わせ(RoPE 的な位相変調を適用)
|
| 86 |
+
def superimpose_weights(
|
| 87 |
+
self,
|
| 88 |
+
layer_weights: List[torch.Tensor],
|
| 89 |
+
phase_offsets: List[Tuple[torch.Tensor, torch.Tensor]]
|
| 90 |
+
) -> torch.Tensor:
|
| 91 |
+
"""各レイヤの -/0/+ 重みに固有位相を適用し重ね合わせる"""
|
| 92 |
+
|
| 93 |
+
reconstructed = torch.zeros_like(layer_weights[0])
|
| 94 |
+
|
| 95 |
+
for i, w_i in enumerate(layer_weights):
|
| 96 |
+
cos_shift, sin_shift = phase_offsets[i]
|
| 97 |
+
|
| 98 |
+
# DRNA の K-for-phase 同様の位相変調: d_cos / d_sin を用いる重ね合わせ
|
| 99 |
+
contribution = w_i * (cos_shift - rotate_half(w_i) * sin_shift)
|
| 100 |
+
reconstructed += contribution
|
| 101 |
+
|
| 102 |
+
return reconstructed
|
| 103 |
+
|
| 104 |
+
# RMSNorm 的な正規化(3 値歪みのリセット)
|
| 105 |
+
def normalize(self, x: torch.Tensor) -> torch.Tensor:
|
| 106 |
+
"""重ね合わせ後の分布を RMSNorm で安定化する"""
|
| 107 |
+
|
| 108 |
+
mean = x.mean(-1, keepdim=True)
|
| 109 |
+
var = x.var(-1, keepdim=True, unbiased=False)
|
| 110 |
+
|
| 111 |
+
# 中心化・標準化(3 値歪みのリセット)
|
| 112 |
+
return (x - mean) * torch.rsqrt(var + 1e-8)
|
| 113 |
+
|
| 114 |
+
# 完全復元パイプライン
|
| 115 |
+
def restore_from_checkpoint(
|
| 116 |
+
self,
|
| 117 |
+
checkpoint_path: str,
|
| 118 |
+
max_seq_len: int = 256
|
| 119 |
+
) -> torch.Tensor:
|
| 120 |
+
"""チェックポイントを読み込み、fp8/16 相当の連続重みを生成する"""
|
| 121 |
+
|
| 122 |
+
# -1.0/0.0/1.0 重みの読み出し
|
| 123 |
+
layer_weights = self.load_ternary_weights(checkpoint_path)
|
| 124 |
+
|
| 125 |
+
# 固有位相(K-energy × RoPE)の抽出
|
| 126 |
+
phase_offsets = self.extract_phase_offsets(max_seq_len)
|
| 127 |
+
|
| 128 |
+
# 重ね合わせ(RoPE 的な位相変調を適用)
|
| 129 |
+
continuous_weight = self.superimpose_weights(
|
| 130 |
+
layer_weights,
|
| 131 |
+
phase_offsets)
|
| 132 |
+
|
| 133 |
+
# RMSNorm 的な正規化
|
| 134 |
+
normalized_weight = self.normalize(continuous_weight)
|
| 135 |
+
|
| 136 |
+
return normalized_weight
|
| 137 |
+
|
| 138 |
+
# 補助関数: DRNA rotate_half 再現
|
| 139 |
+
def rotate_half(x):
|
| 140 |
+
"""D-RNA の K-for-phase 同様の回転操作(位相変調の右辺項)"""
|
| 141 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 142 |
+
return torch.cat((-x2, x1), dim=-1)
|
drna/drna_swi_triox.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
'''
|
| 7 |
+
D‑RNA: Dual‑Helix Resonance Neural Architecture (DRNA) Pre-Norm・Kv-RoPE 版
|
| 8 |
+
仕様:Pre-Norm(RMSNorm)、GELU(Activation)、Kv-RoPE(head_dim)、mask(padding + causal)
|
| 9 |
+
Transformerの全接続性を継承しつつ、二重らせん(Dual-Helix)構造による
|
| 10 |
+
「共鳴収縮」(Resonant Contraction)を物理的に再現したニューラルアーキテクチャです
|
| 11 |
+
D-RNA の位相設計と Trio Induction system により3値学習を STE に頼らず安定的に行えます
|
| 12 |
+
これは STE で機能しない optimiser などを3値学習へ活用できるようになります
|
| 13 |
+
将来的に3値モデルを位相差の重ねによる疑似重みをつくり、これを学習対象にして3値学習もおこなえるはずです
|
| 14 |
+
つまり学習元も3値モデルにできるはずです、推論も学習も3値で済むようになる最初の1歩です
|
| 15 |
+
'''
|
| 16 |
+
|
| 17 |
+
# 3値誘導制御コア(モデルの書き換え、3値ブレンド、ペナルティ計算、結晶化)
|
| 18 |
+
class TernaryTrainingManager:
|
| 19 |
+
"""
|
| 20 |
+
D-RNAのコードに触れることなく、3値誘導の全ライフサイクルを統括する抽象化マネージャー。
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, model, warmup_steps=100, max_lambda=1.0):
|
| 23 |
+
self.model = model
|
| 24 |
+
self.warmup_steps = warmup_steps
|
| 25 |
+
self.max_lambda = max_lambda
|
| 26 |
+
self.current_step = 0
|
| 27 |
+
self.total_steps = 0
|
| 28 |
+
|
| 29 |
+
# 内部で利用するステッププロバイダー関数
|
| 30 |
+
def step_provider():
|
| 31 |
+
return self.current_step, self.total_steps
|
| 32 |
+
|
| 33 |
+
# 1. モデル内の全2次元重みにフックを外付け
|
| 34 |
+
for name, module in self.model.named_modules():
|
| 35 |
+
if isinstance(module, nn.Linear):
|
| 36 |
+
# ★ 拡張性:LLMの生命線である「embed」と「output_head」は絶対に3値化しない
|
| 37 |
+
if "embed" in name or "output_head" in name:
|
| 38 |
+
continue
|
| 39 |
+
if not hasattr(module, "raw_weight"):
|
| 40 |
+
# 元の実数重みを聖域(raw_weight)に退避
|
| 41 |
+
module.register_parameter("raw_weight", nn.Parameter(module.weight.data.clone()))
|
| 42 |
+
delattr(module, "weight")
|
| 43 |
+
module.register_buffer("weight", module.raw_weight.data.clone())
|
| 44 |
+
module.register_forward_pre_hook(TernaryWeightHook(step_provider, self.warmup_steps))
|
| 45 |
+
|
| 46 |
+
def amend_loss(self, task_loss, step, total_steps):
|
| 47 |
+
"""
|
| 48 |
+
【ループ内抽象化用】
|
| 49 |
+
メインのタスク損失(CrossEntropy等)を受け取り、
|
| 50 |
+
現在のステップに応じた3値結晶化ペナルティを自動計算して合算した損失を返します。
|
| 51 |
+
"""
|
| 52 |
+
self.current_step = step
|
| 53 |
+
self.total_steps = total_steps
|
| 54 |
+
|
| 55 |
+
blend_ratio = get_ternary_schedule(step, total_steps, self.warmup_steps)
|
| 56 |
+
current_lambda = blend_ratio * self.max_lambda
|
| 57 |
+
|
| 58 |
+
if current_lambda == 0.0:
|
| 59 |
+
return task_loss
|
| 60 |
+
|
| 61 |
+
# 3値誘導トリプレット・ペナルティ自動計算
|
| 62 |
+
ternary_penalty = 0.0
|
| 63 |
+
for name, param in self.model.named_parameters():
|
| 64 |
+
# 2次元以上の重み(LinearやEmbeddingの実数聖域)のみを対象とする
|
| 65 |
+
if "raw_weight" in name and param.dim() >= 2:
|
| 66 |
+
# W * (W - 1) * (W + 1) = W^3 - W を0に近づける(-1, 0, 1への収束強制力)
|
| 67 |
+
ternary_penalty += torch.mean(param * (param - 1.0) * (param + 1.0)) ** 2
|
| 68 |
+
|
| 69 |
+
return task_loss + current_lambda * ternary_penalty
|
| 70 |
+
|
| 71 |
+
def export_ternary(self):
|
| 72 |
+
"""学習終了後、モデルの全2次元重みを完全な[-1.0, 0.0, 1.0]へ固定(結晶化)する"""
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
for name, module in self.model.named_modules():
|
| 75 |
+
if isinstance(module, nn.Linear):
|
| 76 |
+
# 拡張性:LLMの生命線である「embed」と「output_head」は絶対に3値化しない
|
| 77 |
+
if "embed" in name or "output_head" in name:
|
| 78 |
+
continue
|
| 79 |
+
# raw_weight があればそっちを本体として使う
|
| 80 |
+
if hasattr(module, "raw_weight"):
|
| 81 |
+
param = module.raw_weight
|
| 82 |
+
else:
|
| 83 |
+
param = module.weight
|
| 84 |
+
# 学習中と同じ写像:tanh(3w) で soft 3値ターゲットを作る
|
| 85 |
+
soft = torch.tanh(param * 3.0)
|
| 86 |
+
# soft を hard 3値に潰す
|
| 87 |
+
ternary = torch.zeros_like(soft)
|
| 88 |
+
ternary[soft > 0.08] = 1.0
|
| 89 |
+
ternary[soft < -0.08] = -1.0
|
| 90 |
+
# 実数パラメータ自体を3値に上書きして完全固定
|
| 91 |
+
if hasattr(module, "raw_weight"):
|
| 92 |
+
module.raw_weight.copy_(ternary)
|
| 93 |
+
module.weight.copy_(ternary)
|
| 94 |
+
|
| 95 |
+
return self.model
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# 3値誘導外付けフックシステム(2次元重みを3値ブレンド)
|
| 99 |
+
def get_ternary_schedule(step, total_steps, warmup_steps=100):
|
| 100 |
+
"""逆転コサインアニーリングスケジューラ"""
|
| 101 |
+
if step < warmup_steps:
|
| 102 |
+
return 0.0
|
| 103 |
+
anneal_steps = total_steps - warmup_steps
|
| 104 |
+
progress = (step - warmup_steps) / anneal_steps
|
| 105 |
+
return 1.0 - (0.5 * (1.0 + math.cos(progress * math.pi)))
|
| 106 |
+
|
| 107 |
+
def get_soft_ternary_weight(param, step, total_steps, warmup_steps=100):
|
| 108 |
+
"""勾配直通バイパス型 3値ブレンド関数"""
|
| 109 |
+
if param.dim() < 2: # 1次元パラメータは保護
|
| 110 |
+
return param
|
| 111 |
+
blend_ratio = get_ternary_schedule(step, total_steps, warmup_steps)
|
| 112 |
+
if blend_ratio == 0.0:
|
| 113 |
+
return param
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
ternary_target = torch.tanh(param * 3.0)
|
| 116 |
+
# 【最重要】生の勾配を裏に直通させるバイパス構造
|
| 117 |
+
return param + blend_ratio * (ternary_target - param)
|
| 118 |
+
|
| 119 |
+
class TernaryWeightHook:
|
| 120 |
+
def __init__(self, step_provider, warmup_steps=100):
|
| 121 |
+
self.step_provider = step_provider
|
| 122 |
+
self.warmup_steps = warmup_steps
|
| 123 |
+
|
| 124 |
+
def __call__(self, module, inputs):
|
| 125 |
+
step, total_steps = self.step_provider()
|
| 126 |
+
if step is not None and total_steps is not None:
|
| 127 |
+
# バックアップした実数重み(raw_weight)から疑似3値重みを計算し、一時的に上書き
|
| 128 |
+
module.weight.data = get_soft_ternary_weight(module.raw_weight, step, total_steps, self.warmup_steps)
|
| 129 |
+
|
| 130 |
+
def apply_trio_induction(model, step_provider, warmup_steps=100):
|
| 131 |
+
"""モデルを汚さずに外側から3値化プラグインを刺す関数"""
|
| 132 |
+
for name, module in model.named_modules():
|
| 133 |
+
if isinstance(module, nn.Linear):
|
| 134 |
+
# ★ 拡張性:LLMの生命線である「embed」と「output_head」は絶対に3値化しない
|
| 135 |
+
if "embed" in name or "output_head" in name:
|
| 136 |
+
continue
|
| 137 |
+
if not hasattr(module, "raw_weight"):
|
| 138 |
+
# 元の実数重みを聖域(raw_weight)に退避
|
| 139 |
+
module.register_parameter("raw_weight", nn.Parameter(module.weight.data.clone()))
|
| 140 |
+
delattr(module, "weight")
|
| 141 |
+
module.register_buffer("weight", module.raw_weight.data.clone())
|
| 142 |
+
module.register_forward_pre_hook(TernaryWeightHook(step_provider, warmup_steps))
|
| 143 |
+
|
| 144 |
+
# D-RNA 3値化のパニックを100%いなすコンサルタント(RMSNorm 修正)
|
| 145 |
+
class RMSNorm(nn.Module):
|
| 146 |
+
"""【修正版】3値化の歪みをリセットする中心化・標準化型防波堤"""
|
| 147 |
+
def __init__(self, d_model, eps=1e-8):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.eps = eps
|
| 150 |
+
# 3値化の歪みによってズレた「音量の軸」を、フル精度の実数で強制的に中心に戻すバイアス
|
| 151 |
+
self.bias = nn.Parameter(torch.zeros(d_model))
|
| 152 |
+
# 3値化のせいで極端にインフレ・デフレした次元ごとの音量を、個別にジャストフィットさせる実数スケール
|
| 153 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
# 1. 各次元ごとの平均と分散を計算
|
| 157 |
+
mean = x.mean(-1, keepdim=True)
|
| 158 |
+
var = x.var(-1, keepdim=True, unbiased=False)
|
| 159 |
+
|
| 160 |
+
# 2. 3値化による「歪み・偏り」を、ここで完全にリセット(標準化)する
|
| 161 |
+
x_normed = (x - mean) * torch.rsqrt(var + self.eps)
|
| 162 |
+
|
| 163 |
+
# 3. リセットされた綺麗な状態に対して、フル精度のオプティマイザが最適なスケールとオフセットを施す
|
| 164 |
+
return self.weight * x_normed + self.bias
|
| 165 |
+
|
| 166 |
+
class DRNA_RoPE(nn.Module):
|
| 167 |
+
"""二重らせんの位相を決定する回転場"""
|
| 168 |
+
def __init__(self, head_dim, base=10000):
|
| 169 |
+
super().__init__()
|
| 170 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 171 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 172 |
+
|
| 173 |
+
def forward(self, x, seq_len):
|
| 174 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 175 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 176 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 177 |
+
return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
|
| 178 |
+
|
| 179 |
+
def apply_drna_rope(q, k, cos, sin):
|
| 180 |
+
"""Kによる動的位相変調済み cos/sin を受け取る"""
|
| 181 |
+
def rotate_half(x):
|
| 182 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 183 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 184 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 185 |
+
|
| 186 |
+
class DRNA_Block(nn.Module):
|
| 187 |
+
"""DRNA共鳴ブロック:安定性を高めたPre-Norm直列共鳴構造"""
|
| 188 |
+
def __init__(self, d_model, n_heads, head_dim, d_ff=None, dropout=0.1):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.n_heads = n_heads
|
| 191 |
+
self.head_dim = head_dim
|
| 192 |
+
|
| 193 |
+
# らせんA:回想系 (Attention)
|
| 194 |
+
self.norm1 = RMSNorm(d_model) # 演算の前に配置
|
| 195 |
+
self.qkv = nn.Linear(d_model, d_model * 3)
|
| 196 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 197 |
+
|
| 198 |
+
# らせんB:記憶系 (MLP)
|
| 199 |
+
self.norm2 = RMSNorm(d_model) # 演算の前に配置
|
| 200 |
+
|
| 201 |
+
# SwiGLUの定石:パラメータ量をGELU版(×4倍)と合わせるため、約2.67倍(8/3)にする
|
| 202 |
+
if d_ff is None:
|
| 203 |
+
d_ff = int(2 * (d_model * 4) / 3)
|
| 204 |
+
|
| 205 |
+
# SwiGLUは入力に対して2つのLinear(Gate・Up)を並列に走らせます
|
| 206 |
+
self.w1 = nn.Linear(d_model, d_ff, bias=False) # ゲート用
|
| 207 |
+
self.w3 = nn.Linear(d_model, d_ff, bias=False) # 値用(アッププロジェクション)
|
| 208 |
+
self.w2 = nn.Linear(d_ff, d_model, bias=False) # ダウンプロジェクション
|
| 209 |
+
|
| 210 |
+
self.dropout = nn.Dropout(dropout)
|
| 211 |
+
|
| 212 |
+
def forward(self, x, cos, sin, mask=None):
|
| 213 |
+
b, s, d = x.shape
|
| 214 |
+
|
| 215 |
+
# 共通の残差(ベースとなる螺旋の軸)
|
| 216 |
+
residual = x
|
| 217 |
+
|
| 218 |
+
# らせんA (Attention) 並列方式
|
| 219 |
+
x_norm1 = self.norm1(x)
|
| 220 |
+
|
| 221 |
+
# QKV生成 (3倍のまま)
|
| 222 |
+
# ※ self.qkv(x_norm) が x_norm1 になっているか確認してください
|
| 223 |
+
qkv = self.qkv(x_norm1).reshape(b, s, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 224 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 225 |
+
|
| 226 |
+
# K因果的動的回転:一つ前の単語の K が今の単語の座標を決める
|
| 227 |
+
# 自分の情報で自分を回さないよう、Kを1つ未来にシフトさせる
|
| 228 |
+
# これにより「赤い(K)」が「猫(Q,K)」の位相を決定する構造になる
|
| 229 |
+
k_for_phase = torch.cat([torch.zeros_like(k[:, :, :1, :]), k[:, :, :-1, :]], dim=2)
|
| 230 |
+
|
| 231 |
+
# Kのエネルギーを位相(回転角)に変換
|
| 232 |
+
rt_phase = torch.tanh(k_for_phase) * math.pi
|
| 233 |
+
|
| 234 |
+
# 静的RoPE (cos, sin) を動的位相 (rt_phase) で加法定理により変調
|
| 235 |
+
# ※ apply_drna_rope に rt_phase を渡せるように関数側を調整するか、
|
| 236 |
+
# ここで dynamic_cos / sin を作って渡します
|
| 237 |
+
d_cos = (cos * torch.cos(rt_phase)) - (sin * torch.sin(rt_phase))
|
| 238 |
+
d_sin = (sin * torch.cos(rt_phase)) + (cos * torch.sin(rt_phase))
|
| 239 |
+
|
| 240 |
+
# 2重らせんをつくる (変調された座標で回転)
|
| 241 |
+
q, k = apply_drna_rope(q, k, d_cos, d_sin)
|
| 242 |
+
|
| 243 |
+
# Attention計算
|
| 244 |
+
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 245 |
+
if mask is not None:
|
| 246 |
+
attn = attn + mask
|
| 247 |
+
|
| 248 |
+
attn = F.softmax(attn, dim=-1)
|
| 249 |
+
a_out_raw = (attn @ v).transpose(1, 2).reshape(b, s, d)
|
| 250 |
+
a_out = self.out_proj(a_out_raw)
|
| 251 |
+
|
| 252 |
+
# らせんB (MLP)
|
| 253 |
+
x_norm2 = self.norm2(x)
|
| 254 |
+
|
| 255 |
+
# SwiGLUのコア数式: Swish(xW1) * xW3(F.silu:PyTorchのSwish関数)
|
| 256 |
+
gate = F.silu(self.w1(x_norm2))
|
| 257 |
+
current_value = self.w3(x_norm2)
|
| 258 |
+
|
| 259 |
+
# 要素積(共鳴収縮の表現としても非常に相性が良いです)
|
| 260 |
+
swiglu_out = gate * current_value
|
| 261 |
+
|
| 262 |
+
# 最終投影
|
| 263 |
+
m_out = self.w2(swiglu_out)
|
| 264 |
+
|
| 265 |
+
# 並列方式
|
| 266 |
+
x = residual + self.dropout(a_out) + self.dropout(m_out)
|
| 267 |
+
|
| 268 |
+
return x
|
| 269 |
+
|
| 270 |
+
class DRNA_Model(nn.Module):
|
| 271 |
+
"""汎用 DRNA モデルコンテナ(安定化 Pre-Norm 版)"""
|
| 272 |
+
def __init__(self, vocab_size, d_model=256, n_layers=16, n_heads=8, d_ff=1024):
|
| 273 |
+
super().__init__()
|
| 274 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
| 275 |
+
self.head_dim = d_model // n_heads
|
| 276 |
+
self.rope = DRNA_RoPE(self.head_dim)
|
| 277 |
+
|
| 278 |
+
self.layers = nn.ModuleList([
|
| 279 |
+
DRNA_Block(d_model, n_heads, self.head_dim, d_ff) for _ in range(n_layers)
|
| 280 |
+
])
|
| 281 |
+
|
| 282 |
+
# Pre-Norm構造の場合、最終レイヤーの後に全体のNormを置くのが一般的
|
| 283 |
+
self.final_norm = RMSNorm(d_model)
|
| 284 |
+
self.output_head = nn.Linear(d_model, vocab_size)
|
| 285 |
+
|
| 286 |
+
def forward(self, x, mask=None, pad_id=None):
|
| 287 |
+
b, s = x.shape
|
| 288 |
+
device = x.device
|
| 289 |
+
inputs = x
|
| 290 |
+
x = self.embed(x)
|
| 291 |
+
|
| 292 |
+
if mask is None or mask.sum() == 0:
|
| 293 |
+
# pad_id が整数(int/long)として有効な場合のみ pad_mask を作成
|
| 294 |
+
# 退避させた inputs を使ってパディングを判定し因果マスクを準備
|
| 295 |
+
p_id = pad_id.item() if isinstance(pad_id, torch.Tensor) else pad_id
|
| 296 |
+
pad_mask = (inputs != p_id).unsqueeze(1).unsqueeze(2) if isinstance(p_id, (int, float)) else torch.ones((1, 1, 1, s), device=device, dtype=torch.bool)
|
| 297 |
+
causal = torch.triu(torch.ones(s, s, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
|
| 298 |
+
|
| 299 |
+
# 環境(fp16/32)に応じた最小値を安全に自動計算
|
| 300 |
+
inf_value = torch.finfo(x.dtype).min if x.dtype != torch.float16 else -65500.0
|
| 301 |
+
|
| 302 |
+
# ゼロ初期化テンソルに無効領域をインプレースで直接埋める(カンニングの完全遮断)
|
| 303 |
+
mask = torch.zeros((b, 1, s, s), device=device, dtype=x.dtype).masked_fill_(causal | (~pad_mask), inf_value)
|
| 304 |
+
|
| 305 |
+
cos, sin = self.rope(x, x.size(1))
|
| 306 |
+
|
| 307 |
+
for layer in self.layers:
|
| 308 |
+
x = layer(x, cos, sin, mask=mask)
|
| 309 |
+
|
| 310 |
+
x = self.final_norm(x) # 出力前の最終同期
|
| 311 |
+
return self.output_head(x)
|
| 312 |
+
|
| 313 |
+
'''
|
| 314 |
+
260528:3値モデル(-、0、+)学習対応/活用例"D-RNA-SwiGLU-Trio"版を追加(Trio Induction)
|
| 315 |
+
260528:3値モデル(-、0、+)学習対応/活用例"D-RNA-Trio"版を追加(Trio Induction)
|
| 316 |
+
260520:maskの微調整(AMP対応)/MoE-LoRA版、vlayer版、D-RNAの活用例を汎用コード化
|
| 317 |
+
'''
|
| 318 |
+
|
| 319 |
+
'''
|
| 320 |
+
汎用型 D-RNA (Pre-Norm) License: Apache License 2.0 https://github.com/muooon/DRNA
|
| 321 |
+
Attention is all you need_started, Resonance is all you need_endure,
|
| 322 |
+
Neocognitron ― Transformer ― D‑RNA Dream Resonance Never Adjourns — it goes on...
|
| 323 |
+
'''
|
drna/drna_swiglu.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
'''
|
| 7 |
+
D‑RNA: Dual‑Helix Resonance Neural Architecture (DRNA) Pre-Norm・Kv-RoPE 版
|
| 8 |
+
仕様:Pre-Norm(RMSNorm)、SwiGLU(Activation)、Kv-RoPE(head_dim)、mask(padding + causal)
|
| 9 |
+
Transformerの全接続性を継承しつつ、二重らせん(Dual-Helix)構造による
|
| 10 |
+
「共鳴収縮」(Resonant Contraction)を物理的に再現したニューラルアーキテクチャです
|
| 11 |
+
螺旋の同期:Attention(文脈の回想)とMLP(知識の定着)を並列配置し RoPE で情報を同期
|
| 12 |
+
位相の保持:RoPE(Phase Field)を回転場として利用し、安定した相対位置を保ち早期収束を両立
|
| 13 |
+
高密度圧縮:Pre-Norm により、各らせんを安定的に収縮させ、全結合により記憶を定着させる
|
| 14 |
+
'''
|
| 15 |
+
|
| 16 |
+
class RMSNorm(nn.Module):
|
| 17 |
+
def __init__(self, d_model, eps=1e-6):
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.eps = eps
|
| 20 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
# 2乗平均の平方根で割る(平均を減算しない・中心化をしない)
|
| 24 |
+
norm = x.pow(2).mean(-1, keepdim=True)
|
| 25 |
+
x_normed = x * torch.rsqrt(norm + self.eps)
|
| 26 |
+
return self.weight * x_normed
|
| 27 |
+
|
| 28 |
+
class DRNA_RoPE(nn.Module):
|
| 29 |
+
"""二重らせんの位相を決定する回転場"""
|
| 30 |
+
def __init__(self, head_dim, base=10000):
|
| 31 |
+
super().__init__()
|
| 32 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 33 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 34 |
+
|
| 35 |
+
def forward(self, x, seq_len):
|
| 36 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 37 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 38 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 39 |
+
return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
|
| 40 |
+
|
| 41 |
+
def apply_drna_rope(q, k, cos, sin):
|
| 42 |
+
"""Kによる動的位相変調済み cos/sin を受け取る"""
|
| 43 |
+
def rotate_half(x):
|
| 44 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 45 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 46 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 47 |
+
|
| 48 |
+
class DRNA_Block(nn.Module):
|
| 49 |
+
"""DRNA共鳴ブロック:安定性を高めたPre-Norm直列共鳴構造"""
|
| 50 |
+
def __init__(self, d_model, n_heads, head_dim, d_ff=None, dropout=0.1):
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.n_heads = n_heads
|
| 53 |
+
self.head_dim = head_dim
|
| 54 |
+
|
| 55 |
+
# らせんA:回想系 (Attention)
|
| 56 |
+
self.norm1 = RMSNorm(d_model) # 演算の前に配置
|
| 57 |
+
self.qkv = nn.Linear(d_model, d_model * 3)
|
| 58 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 59 |
+
|
| 60 |
+
# らせんB:記憶系 (MLP)
|
| 61 |
+
self.norm2 = RMSNorm(d_model) # 演算の前に配置
|
| 62 |
+
|
| 63 |
+
# SwiGLUの定石:パラメータ量をGELU版(×4倍)と合わせるため、約2.67倍(8/3)にする
|
| 64 |
+
if d_ff is None:
|
| 65 |
+
d_ff = int(2 * (d_model * 4) / 3)
|
| 66 |
+
|
| 67 |
+
# SwiGLUは入力に対して2つのLinear(Gate・Up)を並列に走らせます
|
| 68 |
+
self.w1 = nn.Linear(d_model, d_ff, bias=False) # ゲート用
|
| 69 |
+
self.w3 = nn.Linear(d_model, d_ff, bias=False) # 値用(アッププロジェクション)
|
| 70 |
+
self.w2 = nn.Linear(d_ff, d_model, bias=False) # ダウンプロジェクション
|
| 71 |
+
|
| 72 |
+
self.dropout = nn.Dropout(dropout)
|
| 73 |
+
|
| 74 |
+
def forward(self, x, cos, sin, mask=None):
|
| 75 |
+
b, s, d = x.shape
|
| 76 |
+
|
| 77 |
+
# 共通の残差(ベースとなる螺旋の軸)
|
| 78 |
+
residual = x
|
| 79 |
+
|
| 80 |
+
# らせんA (Attention) 並列方式
|
| 81 |
+
x_norm1 = self.norm1(x)
|
| 82 |
+
|
| 83 |
+
# QKV生成 (3倍のまま)
|
| 84 |
+
# ※ self.qkv(x_norm) が x_norm1 になっているか確認してください
|
| 85 |
+
qkv = self.qkv(x_norm1).reshape(b, s, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 86 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 87 |
+
|
| 88 |
+
# K因果的動的回転:一つ前の単語の K が今の単語の座標を決める
|
| 89 |
+
# 自分の情報で自分を回さないよう、Kを1つ未来にシフトさせる
|
| 90 |
+
# これにより「赤い(K)」が「猫(Q,K)」の位相を決定する構造になる
|
| 91 |
+
k_for_phase = torch.cat([torch.zeros_like(k[:, :, :1, :]), k[:, :, :-1, :]], dim=2)
|
| 92 |
+
|
| 93 |
+
# Kのエネルギーを位相(回転角)に変換
|
| 94 |
+
rt_phase = torch.tanh(k_for_phase) * math.pi
|
| 95 |
+
|
| 96 |
+
# 静的RoPE (cos, sin) を動的位相 (rt_phase) で加法定理により変調
|
| 97 |
+
# ※ apply_drna_rope に rt_phase を渡せるように関数側を調整するか、
|
| 98 |
+
# ここで dynamic_cos / sin を作って渡します
|
| 99 |
+
d_cos = (cos * torch.cos(rt_phase)) - (sin * torch.sin(rt_phase))
|
| 100 |
+
d_sin = (sin * torch.cos(rt_phase)) + (cos * torch.sin(rt_phase))
|
| 101 |
+
|
| 102 |
+
# 2重らせんをつくる (変調された座標で回転)
|
| 103 |
+
q, k = apply_drna_rope(q, k, d_cos, d_sin)
|
| 104 |
+
|
| 105 |
+
# Attention計算
|
| 106 |
+
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 107 |
+
if mask is not None:
|
| 108 |
+
attn = attn + mask
|
| 109 |
+
|
| 110 |
+
attn = F.softmax(attn, dim=-1)
|
| 111 |
+
a_out_raw = (attn @ v).transpose(1, 2).reshape(b, s, d)
|
| 112 |
+
a_out = self.out_proj(a_out_raw)
|
| 113 |
+
|
| 114 |
+
# らせんB (MLP)
|
| 115 |
+
x_norm2 = self.norm2(x)
|
| 116 |
+
|
| 117 |
+
# SwiGLUのコア数式: Swish(xW1) * xW3(F.silu:PyTorchのSwish関数)
|
| 118 |
+
gate = F.silu(self.w1(x_norm2))
|
| 119 |
+
current_value = self.w3(x_norm2)
|
| 120 |
+
|
| 121 |
+
# 要素積(共鳴収縮の表現としても非常に相性が良いです)
|
| 122 |
+
swiglu_out = gate * current_value
|
| 123 |
+
|
| 124 |
+
# 最終投影
|
| 125 |
+
m_out = self.w2(swiglu_out)
|
| 126 |
+
|
| 127 |
+
# 並列方式
|
| 128 |
+
x = residual + self.dropout(a_out) + self.dropout(m_out)
|
| 129 |
+
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
class DRNA_Model(nn.Module):
|
| 133 |
+
"""汎用 DRNA モデルコンテナ(安定化 Pre-Norm 版)"""
|
| 134 |
+
def __init__(self, vocab_size, d_model=256, n_layers=16, n_heads=8, d_ff=1024):
|
| 135 |
+
super().__init__()
|
| 136 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
| 137 |
+
self.head_dim = d_model // n_heads
|
| 138 |
+
self.rope = DRNA_RoPE(self.head_dim)
|
| 139 |
+
|
| 140 |
+
self.layers = nn.ModuleList([
|
| 141 |
+
DRNA_Block(d_model, n_heads, self.head_dim, d_ff) for _ in range(n_layers)
|
| 142 |
+
])
|
| 143 |
+
|
| 144 |
+
# Pre-Norm構造の場合、最終レイヤーの後に全体のNormを置くのが一般的
|
| 145 |
+
self.final_norm = RMSNorm(d_model)
|
| 146 |
+
self.output_head = nn.Linear(d_model, vocab_size)
|
| 147 |
+
|
| 148 |
+
def forward(self, x, mask=None, pad_id=None):
|
| 149 |
+
b, s = x.shape
|
| 150 |
+
device = x.device
|
| 151 |
+
inputs = x
|
| 152 |
+
x = self.embed(x)
|
| 153 |
+
|
| 154 |
+
if mask is None or mask.sum() == 0:
|
| 155 |
+
# pad_id が整数(int/long)として有効な場合のみ pad_mask を作成
|
| 156 |
+
# 退避させた inputs を使ってパディングを判定し因果マスクを準備
|
| 157 |
+
p_id = pad_id.item() if isinstance(pad_id, torch.Tensor) else pad_id
|
| 158 |
+
pad_mask = (inputs != p_id).unsqueeze(1).unsqueeze(2) if isinstance(p_id, (int, float)) else torch.ones((1, 1, 1, s), device=device, dtype=torch.bool)
|
| 159 |
+
causal = torch.triu(torch.ones(s, s, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
|
| 160 |
+
|
| 161 |
+
# 環境(fp16/32)に応じた最小値を安全に自動計算
|
| 162 |
+
inf_value = torch.finfo(x.dtype).min if x.dtype != torch.float16 else -65500.0
|
| 163 |
+
|
| 164 |
+
# ゼロ初期化テンソルに無効領域をインプレースで直接埋める(カンニングの完全遮断)
|
| 165 |
+
mask = torch.zeros((b, 1, s, s), device=device, dtype=x.dtype).masked_fill_(causal | (~pad_mask), inf_value)
|
| 166 |
+
|
| 167 |
+
cos, sin = self.rope(x, x.size(1))
|
| 168 |
+
|
| 169 |
+
for layer in self.layers:
|
| 170 |
+
x = layer(x, cos, sin, mask=mask)
|
| 171 |
+
|
| 172 |
+
x = self.final_norm(x) # 出力前の最終同期
|
| 173 |
+
return self.output_head(x)
|
| 174 |
+
|
| 175 |
+
'''
|
| 176 |
+
260528:SwiGLU構成/動的ゲート制御(d_ff で Dropout せず、w2 で d_model 凝縮直後で適用)
|
| 177 |
+
260520:maskの微調整(AMP対応)/MoE-LoRA版、vlayer版、D-RNAの活用例を汎用コード化
|
| 178 |
+
260507:Kによる回転で文脈に単語を沿わせ2重らせんの干渉による取捨選択とホログラム合成を可能にする
|
| 179 |
+
260505:model構成から学習解像度を自動化、汎用 mask の精度への適正化、RMSNormへの移行
|
| 180 |
+
260503:padding を引数で指定できるよう変更
|
| 181 |
+
# 例:一般的な Tokenizer の pad_id が 0 の場合
|
| 182 |
+
output = model(input_ids, pad_id=0)
|
| 183 |
+
# 例:Hugging Face 等の tokenizer を使っている場合
|
| 184 |
+
output = model(input_ids, pad_id=tokenizer.pad_token_id)
|
| 185 |
+
260502:変数名を正確化(head_dim)、汎用 mask に変更し padding 等に対応可
|
| 186 |
+
'''
|
| 187 |
+
|
| 188 |
+
'''
|
| 189 |
+
汎用型 D-RNA (Pre-Norm) License: Apache License 2.0 https://github.com/muooon/DRNA
|
| 190 |
+
Attention is all you need_started, Resonance is all you need_endure,
|
| 191 |
+
Neocognitron ― Transformer ― D‑RNA Dream Resonance Never Adjourns — it goes on...
|
| 192 |
+
'''
|
drna/drna_triox.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
'''
|
| 7 |
+
D‑RNA: Dual‑Helix Resonance Neural Architecture (DRNA) Pre-Norm・Kv-RoPE 版
|
| 8 |
+
仕様:Pre-Norm(RMSNorm)、GELU(Activation)、Kv-RoPE(head_dim)、mask(padding + causal)
|
| 9 |
+
Transformerの全接続性を継承しつつ、二重らせん(Dual-Helix)構造による
|
| 10 |
+
「共鳴収縮」(Resonant Contraction)を物理的に再現したニューラルアーキテクチャです
|
| 11 |
+
D-RNA の位相設計と Trio Induction system により3値学習を STE に頼らず安定的に行えます
|
| 12 |
+
これは STE で機能しない optimiser などを3値学習へ活用できるようになります
|
| 13 |
+
将来的に3値モデルを位相差の重ねによる疑似重みをつくり、これを学習対象にして3値学習もおこなえるはずです
|
| 14 |
+
つまり学習元も3値モデルにできるはずです、推論も学習も3値で済むようになる最初の1歩です
|
| 15 |
+
'''
|
| 16 |
+
|
| 17 |
+
# 3値誘導制御コア(モデルの書き換え、3値ブレンド、ペナルティ計算、結晶化)
|
| 18 |
+
class TernaryTrainingManager:
|
| 19 |
+
"""
|
| 20 |
+
D-RNAのコードに触れることなく、3値誘導の全ライフサイクルを統括する抽象化マネージャー。
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, model, warmup_steps=100, max_lambda=1.0):
|
| 23 |
+
self.model = model
|
| 24 |
+
self.warmup_steps = warmup_steps
|
| 25 |
+
self.max_lambda = max_lambda
|
| 26 |
+
self.current_step = 0
|
| 27 |
+
self.total_steps = 0
|
| 28 |
+
|
| 29 |
+
# 内部で利用するステッププロバイダー関数
|
| 30 |
+
def step_provider():
|
| 31 |
+
return self.current_step, self.total_steps
|
| 32 |
+
|
| 33 |
+
# 1. モデル内の全2次元重みにフックを外付け
|
| 34 |
+
for name, module in self.model.named_modules():
|
| 35 |
+
if isinstance(module, nn.Linear):
|
| 36 |
+
# ★ 拡張性:LLMの生命線である「embed」と「output_head」は絶対に3値化しない
|
| 37 |
+
if "embed" in name or "output_head" in name:
|
| 38 |
+
continue
|
| 39 |
+
if not hasattr(module, "raw_weight"):
|
| 40 |
+
# 元の実数重みを聖域(raw_weight)に退避
|
| 41 |
+
module.register_parameter("raw_weight", nn.Parameter(module.weight.data.clone()))
|
| 42 |
+
delattr(module, "weight")
|
| 43 |
+
module.register_buffer("weight", module.raw_weight.data.clone())
|
| 44 |
+
module.register_forward_pre_hook(TernaryWeightHook(step_provider, self.warmup_steps))
|
| 45 |
+
|
| 46 |
+
def amend_loss(self, task_loss, step, total_steps):
|
| 47 |
+
"""
|
| 48 |
+
【ループ内抽象化用】
|
| 49 |
+
メインのタスク損失(CrossEntropy等)を受け取り、
|
| 50 |
+
現在のステップに応じた3値結晶化ペナルティを自動計算して合算した損失を返します。
|
| 51 |
+
"""
|
| 52 |
+
self.current_step = step
|
| 53 |
+
self.total_steps = total_steps
|
| 54 |
+
|
| 55 |
+
blend_ratio = get_ternary_schedule(step, total_steps, self.warmup_steps)
|
| 56 |
+
current_lambda = blend_ratio * self.max_lambda
|
| 57 |
+
|
| 58 |
+
if current_lambda == 0.0:
|
| 59 |
+
return task_loss
|
| 60 |
+
|
| 61 |
+
# 3値誘導トリプレット・ペナルティ自動計算
|
| 62 |
+
ternary_penalty = 0.0
|
| 63 |
+
for name, param in self.model.named_parameters():
|
| 64 |
+
# 2次元以上の重み(LinearやEmbeddingの実数聖域)のみを対象とする
|
| 65 |
+
if "raw_weight" in name and param.dim() >= 2:
|
| 66 |
+
# W * (W - 1) * (W + 1) = W^3 - W を0に近づける(-1, 0, 1への収束強制力)
|
| 67 |
+
ternary_penalty += torch.mean(param * (param - 1.0) * (param + 1.0)) ** 2
|
| 68 |
+
|
| 69 |
+
return task_loss + current_lambda * ternary_penalty
|
| 70 |
+
|
| 71 |
+
def export_ternary(self):
|
| 72 |
+
"""学習終了後、モデルの全2次元重みを完全な[-1.0, 0.0, 1.0]へ固定(結晶化)する"""
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
for name, module in self.model.named_modules():
|
| 75 |
+
if isinstance(module, nn.Linear):
|
| 76 |
+
# 拡張性:LLMの生命線である「embed」と「output_head」は絶対に3値化しない
|
| 77 |
+
if "embed" in name or "output_head" in name:
|
| 78 |
+
continue
|
| 79 |
+
# raw_weight があればそっちを本体として使う
|
| 80 |
+
if hasattr(module, "raw_weight"):
|
| 81 |
+
param = module.raw_weight
|
| 82 |
+
else:
|
| 83 |
+
param = module.weight
|
| 84 |
+
# 学習中と同じ写像:tanh(3w) で soft 3値ターゲットを作る
|
| 85 |
+
soft = torch.tanh(param * 3.0)
|
| 86 |
+
# soft を hard 3値に潰す
|
| 87 |
+
ternary = torch.zeros_like(soft)
|
| 88 |
+
ternary[soft > 0.08] = 1.0
|
| 89 |
+
ternary[soft < -0.08] = -1.0
|
| 90 |
+
# 実数パラメータ自体を3値に上書きして完全固定
|
| 91 |
+
if hasattr(module, "raw_weight"):
|
| 92 |
+
module.raw_weight.copy_(ternary)
|
| 93 |
+
module.weight.copy_(ternary)
|
| 94 |
+
|
| 95 |
+
return self.model
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# 3値誘導外付けフックシステム(2次元重みを3値ブレンド)
|
| 99 |
+
def get_ternary_schedule(step, total_steps, warmup_steps=100):
|
| 100 |
+
"""逆転コサインアニーリングスケジューラ"""
|
| 101 |
+
if step < warmup_steps:
|
| 102 |
+
return 0.0
|
| 103 |
+
anneal_steps = total_steps - warmup_steps
|
| 104 |
+
progress = (step - warmup_steps) / anneal_steps
|
| 105 |
+
return 1.0 - (0.5 * (1.0 + math.cos(progress * math.pi)))
|
| 106 |
+
|
| 107 |
+
def get_soft_ternary_weight(param, step, total_steps, warmup_steps=100):
|
| 108 |
+
"""勾配直通バイパス型 3値ブレンド関数"""
|
| 109 |
+
if param.dim() < 2: # 1次元パラメータは保護
|
| 110 |
+
return param
|
| 111 |
+
blend_ratio = get_ternary_schedule(step, total_steps, warmup_steps)
|
| 112 |
+
if blend_ratio == 0.0:
|
| 113 |
+
return param
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
ternary_target = torch.tanh(param * 3.0)
|
| 116 |
+
# 【最重要】生の勾配を裏に直通させるバイパス構造
|
| 117 |
+
return param + blend_ratio * (ternary_target - param)
|
| 118 |
+
|
| 119 |
+
class TernaryWeightHook:
|
| 120 |
+
def __init__(self, step_provider, warmup_steps=100):
|
| 121 |
+
self.step_provider = step_provider
|
| 122 |
+
self.warmup_steps = warmup_steps
|
| 123 |
+
|
| 124 |
+
def __call__(self, module, inputs):
|
| 125 |
+
step, total_steps = self.step_provider()
|
| 126 |
+
if step is not None and total_steps is not None:
|
| 127 |
+
# バックアップした実数重み(raw_weight)から疑似3値重みを計算し、一時的に上書き
|
| 128 |
+
module.weight.data = get_soft_ternary_weight(module.raw_weight, step, total_steps, self.warmup_steps)
|
| 129 |
+
|
| 130 |
+
def apply_trio_induction(model, step_provider, warmup_steps=100):
|
| 131 |
+
"""モデルを汚さずに外側から3値化プラグインを刺す関数"""
|
| 132 |
+
for name, module in model.named_modules():
|
| 133 |
+
if isinstance(module, nn.Linear):
|
| 134 |
+
# ★ 拡張性:LLMの生命線である「embed」と「output_head」は絶対に3値化しない
|
| 135 |
+
if "embed" in name or "output_head" in name:
|
| 136 |
+
continue
|
| 137 |
+
if not hasattr(module, "raw_weight"):
|
| 138 |
+
# 元の実数重みを聖域(raw_weight)に退避
|
| 139 |
+
module.register_parameter("raw_weight", nn.Parameter(module.weight.data.clone()))
|
| 140 |
+
delattr(module, "weight")
|
| 141 |
+
module.register_buffer("weight", module.raw_weight.data.clone())
|
| 142 |
+
module.register_forward_pre_hook(TernaryWeightHook(step_provider, warmup_steps))
|
| 143 |
+
|
| 144 |
+
# D-RNA 3値化のパニックを100%いなすコンサルタント(RMSNorm 修正)
|
| 145 |
+
class RMSNorm(nn.Module):
|
| 146 |
+
"""【修正版】3値化の歪みをリセットする中心化・標準化型防波堤"""
|
| 147 |
+
def __init__(self, d_model, eps=1e-8):
|
| 148 |
+
super().__init__()
|
| 149 |
+
self.eps = eps
|
| 150 |
+
# 3値化の歪みによってズレた「音量の軸」を、フル精度の実数で強制的に中心に戻すバイアス
|
| 151 |
+
self.bias = nn.Parameter(torch.zeros(d_model))
|
| 152 |
+
# 3値化のせいで極端にインフレ・デフレした次元ごとの音量を、個別にジャストフィットさせる実数スケール
|
| 153 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
# 1. 各次元ごとの平均と分散を計算
|
| 157 |
+
mean = x.mean(-1, keepdim=True)
|
| 158 |
+
var = x.var(-1, keepdim=True, unbiased=False)
|
| 159 |
+
|
| 160 |
+
# 2. 3値化による「歪み・偏り」を、ここで完全にリセット(標準化)する
|
| 161 |
+
x_normed = (x - mean) * torch.rsqrt(var + self.eps)
|
| 162 |
+
|
| 163 |
+
# 3. リセットされた綺麗な状態に対して、フル精度のオプティマイザが最適なスケールとオフセットを施す
|
| 164 |
+
return self.weight * x_normed + self.bias
|
| 165 |
+
|
| 166 |
+
class DRNA_RoPE(nn.Module):
|
| 167 |
+
"""二重らせんの位相を決定する回転場"""
|
| 168 |
+
def __init__(self, head_dim, base=10000):
|
| 169 |
+
super().__init__()
|
| 170 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 171 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 172 |
+
|
| 173 |
+
def forward(self, x, seq_len):
|
| 174 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 175 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 176 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 177 |
+
return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
|
| 178 |
+
|
| 179 |
+
def apply_drna_rope(q, k, cos, sin):
|
| 180 |
+
"""Kによる動的位相変調済み cos/sin を受け取る"""
|
| 181 |
+
def rotate_half(x):
|
| 182 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 183 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 184 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 185 |
+
|
| 186 |
+
class DRNA_Block(nn.Module):
|
| 187 |
+
"""DRNA共鳴ブロック:安定性を高めたPre-Norm直列共鳴構造"""
|
| 188 |
+
def __init__(self, d_model, n_heads, head_dim, d_ff=None, dropout=0.1):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.n_heads = n_heads
|
| 191 |
+
self.head_dim = head_dim
|
| 192 |
+
|
| 193 |
+
# らせんA: 回想系 (Attention)
|
| 194 |
+
self.norm1 = RMSNorm(d_model) # 演算の前に配置
|
| 195 |
+
self.qkv = nn.Linear(d_model, d_model * 3)
|
| 196 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 197 |
+
|
| 198 |
+
# らせんB: 記憶系 (MLP)
|
| 199 |
+
self.norm2 = RMSNorm(d_model) # 演算の前に配置
|
| 200 |
+
d_ff = d_ff or d_model * 4
|
| 201 |
+
self.mlp = nn.Sequential(
|
| 202 |
+
nn.Linear(d_model, d_ff),
|
| 203 |
+
nn.GELU(), # VRAM抑制は ReLU (別レイヤの干渉で0勾配にならない「可能性」あり)
|
| 204 |
+
nn.Dropout(dropout),
|
| 205 |
+
nn.Linear(d_ff, d_model)
|
| 206 |
+
)
|
| 207 |
+
self.dropout = nn.Dropout(dropout)
|
| 208 |
+
|
| 209 |
+
def forward(self, x, cos, sin, mask=None):
|
| 210 |
+
b, s, d = x.shape
|
| 211 |
+
|
| 212 |
+
# 1. 共通の残差(ベースとなる螺旋の軸)
|
| 213 |
+
residual = x
|
| 214 |
+
|
| 215 |
+
# 2. らせんA (Attention) 並列方式
|
| 216 |
+
x_norm1 = self.norm1(x)
|
| 217 |
+
|
| 218 |
+
# QKV生成 (3倍のまま)
|
| 219 |
+
# ※ self.qkv(x_norm) が x_norm1 になっているか確認してください
|
| 220 |
+
qkv = self.qkv(x_norm1).reshape(b, s, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 221 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 222 |
+
|
| 223 |
+
# K因果的動的回転:一つ前の単語の K が今の単語の座標を決める
|
| 224 |
+
# 自分の情報で自分を回さないよう、Kを1つ未来にシフトさせる
|
| 225 |
+
# これにより「赤い(K)」が「猫(Q,K)」の位相を決定する構造になる
|
| 226 |
+
k_for_phase = torch.cat([torch.zeros_like(k[:, :, :1, :]), k[:, :, :-1, :]], dim=2)
|
| 227 |
+
|
| 228 |
+
# Kのエネルギーを位相(回転角)に変換
|
| 229 |
+
rt_phase = torch.tanh(k_for_phase) * math.pi
|
| 230 |
+
|
| 231 |
+
# 静的RoPE (cos, sin) を動的位相 (rt_phase) で加法定理により変調
|
| 232 |
+
# ※ apply_drna_rope に rt_phase を渡せるように関数側を調整するか、
|
| 233 |
+
# ここで dynamic_cos / sin を作って渡します
|
| 234 |
+
d_cos = (cos * torch.cos(rt_phase)) - (sin * torch.sin(rt_phase))
|
| 235 |
+
d_sin = (sin * torch.cos(rt_phase)) + (cos * torch.sin(rt_phase))
|
| 236 |
+
|
| 237 |
+
# 2重らせんをつくる (変調された座標で回転)
|
| 238 |
+
q, k = apply_drna_rope(q, k, d_cos, d_sin)
|
| 239 |
+
|
| 240 |
+
# Attention計算
|
| 241 |
+
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 242 |
+
if mask is not None:
|
| 243 |
+
attn = attn + mask
|
| 244 |
+
|
| 245 |
+
attn = F.softmax(attn, dim=-1)
|
| 246 |
+
a_out_raw = (attn @ v).transpose(1, 2).reshape(b, s, d)
|
| 247 |
+
a_out = self.out_proj(a_out_raw)
|
| 248 |
+
|
| 249 |
+
# 3. らせんB (MLP)
|
| 250 |
+
x_norm2 = self.norm2(x)
|
| 251 |
+
m_out = self.mlp(x_norm2)
|
| 252 |
+
|
| 253 |
+
# 4. 並列方式
|
| 254 |
+
x = residual + self.dropout(a_out) + self.dropout(m_out)
|
| 255 |
+
|
| 256 |
+
return x
|
| 257 |
+
|
| 258 |
+
class DRNA_Model(nn.Module):
|
| 259 |
+
"""汎用 DRNA モデルコンテナ(安定化 Pre-Norm 版)"""
|
| 260 |
+
def __init__(self, vocab_size, d_model=256, n_layers=16, n_heads=8, d_ff=1024):
|
| 261 |
+
super().__init__()
|
| 262 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
| 263 |
+
self.head_dim = d_model // n_heads
|
| 264 |
+
self.rope = DRNA_RoPE(self.head_dim)
|
| 265 |
+
|
| 266 |
+
self.layers = nn.ModuleList([
|
| 267 |
+
DRNA_Block(d_model, n_heads, self.head_dim, d_ff) for _ in range(n_layers)
|
| 268 |
+
])
|
| 269 |
+
|
| 270 |
+
# Pre-Norm構造の場合、最終レイヤーの後に全体のNormを置くのが一般的
|
| 271 |
+
self.final_norm = RMSNorm(d_model)
|
| 272 |
+
self.output_head = nn.Linear(d_model, vocab_size)
|
| 273 |
+
|
| 274 |
+
def forward(self, x, mask=None, pad_id=None):
|
| 275 |
+
b, s = x.shape
|
| 276 |
+
device = x.device
|
| 277 |
+
inputs = x
|
| 278 |
+
x = self.embed(x)
|
| 279 |
+
|
| 280 |
+
if mask is None or mask.sum() == 0:
|
| 281 |
+
# pad_id が整数(int/long)として有効な場合のみ pad_mask を作成
|
| 282 |
+
# 退避させた inputs を使ってパディングを判定し因果マスクを準備
|
| 283 |
+
p_id = pad_id.item() if isinstance(pad_id, torch.Tensor) else pad_id
|
| 284 |
+
pad_mask = (inputs != p_id).unsqueeze(1).unsqueeze(2) if isinstance(p_id, (int, float)) else torch.ones((1, 1, 1, s), device=device, dtype=torch.bool)
|
| 285 |
+
causal = torch.triu(torch.ones(s, s, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
|
| 286 |
+
|
| 287 |
+
# 環境(fp16/32)に応じた最小値を安全に自動計算
|
| 288 |
+
inf_value = torch.finfo(x.dtype).min if x.dtype != torch.float16 else -65500.0
|
| 289 |
+
|
| 290 |
+
# ゼロ初期化テンソルに無効領域をインプレースで直接埋める(カンニングの完全遮断)
|
| 291 |
+
mask = torch.zeros((b, 1, s, s), device=device, dtype=x.dtype).masked_fill_(causal | (~pad_mask), inf_value)
|
| 292 |
+
|
| 293 |
+
cos, sin = self.rope(x, x.size(1))
|
| 294 |
+
|
| 295 |
+
for layer in self.layers:
|
| 296 |
+
x = layer(x, cos, sin, mask=mask)
|
| 297 |
+
|
| 298 |
+
x = self.final_norm(x) # 出力前の最終同期
|
| 299 |
+
return self.output_head(x)
|
| 300 |
+
|
| 301 |
+
'''
|
| 302 |
+
260528:3値モデル(-、0、+)学習対応/活用例"D-RNA-Trio"版を追加(Trio Induction)
|
| 303 |
+
260520:maskの微調整(AMP対応)/MoE-LoRA版、vlayer版、D-RNAの活用例を汎用コード化
|
| 304 |
+
'''
|
| 305 |
+
|
| 306 |
+
'''
|
| 307 |
+
汎用型 D-RNA (Pre-Norm) License: Apache License 2.0 https://github.com/muooon/DRNA
|
| 308 |
+
Attention is all you need_started, Resonance is all you need_endure,
|
| 309 |
+
Neocognitron ― Transformer ― D‑RNA Dream Resonance Never Adjourns — it goes on...
|
| 310 |
+
'''
|
drna/drna_triox_log.txt
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
none-STE、ternary_schedule、amend_loss、
|
| 2 |
+
---
|
| 3 |
+
1.58-bit Ternary Model、
|
| 4 |
+
This directory contains the training logs for our STE-free Ternary (1.58-bit) Model
|
| 5 |
+
---
|
| 6 |
+
(英語OpenWebText・10万歩)を開始します...
|
| 7 |
+
Step 0/32000 | CE Loss: 12.2871 | Trio Blend: 0.0%
|
| 8 |
+
Step 100/32000 | CE Loss: 8.5651 | Trio Blend: 0.0%
|
| 9 |
+
Step 200/32000 | CE Loss: 5.9627 | Trio Blend: 0.0%
|
| 10 |
+
Step 300/32000 | CE Loss: 5.4165 | Trio Blend: 0.0%
|
| 11 |
+
Step 400/32000 | CE Loss: 5.2959 | Trio Blend: 0.0%
|
| 12 |
+
Step 500/32000 | CE Loss: 4.6366 | Trio Blend: 0.0%
|
| 13 |
+
Step 600/32000 | CE Loss: 5.0238 | Trio Blend: 0.0%
|
| 14 |
+
Step 700/32000 | CE Loss: 4.3707 | Trio Blend: 0.0%
|
| 15 |
+
Step 800/32000 | CE Loss: 7.5176 | Trio Blend: 0.0%
|
| 16 |
+
Step 900/32000 | CE Loss: 4.5658 | Trio Blend: 0.0%
|
| 17 |
+
Step 1000/32000 | CE Loss: 4.1014 | Trio Blend: 0.0%
|
| 18 |
+
Step 1100/32000 | CE Loss: 4.3125 | Trio Blend: 0.0%
|
| 19 |
+
Step 1200/32000 | CE Loss: 5.5312 | Trio Blend: 0.0%
|
| 20 |
+
Step 1300/32000 | CE Loss: 5.2188 | Trio Blend: 0.0%
|
| 21 |
+
Step 1400/32000 | CE Loss: 4.1875 | Trio Blend: 0.0%
|
| 22 |
+
Step 1500/32000 | CE Loss: 3.2500 | Trio Blend: 0.0%
|
| 23 |
+
Step 1600/32000 | CE Loss: 2.9844 | Trio Blend: 0.0%
|
| 24 |
+
Step 1700/32000 | CE Loss: 3.2812 | Trio Blend: 0.0%
|
| 25 |
+
Step 1800/32000 | CE Loss: 2.8438 | Trio Blend: 0.0%
|
| 26 |
+
Step 1900/32000 | CE Loss: 2.9062 | Trio Blend: 0.0%
|
| 27 |
+
Step 2000/32000 | CE Loss: 3.2812 | Trio Blend: 0.0%
|
| 28 |
+
Step 2100/32000 | CE Loss: 4.1875 | Trio Blend: 0.0%
|
| 29 |
+
Step 2200/32000 | CE Loss: 3.2344 | Trio Blend: 0.0%
|
| 30 |
+
Step 2300/32000 | CE Loss: 3.4844 | Trio Blend: 0.0%
|
| 31 |
+
Step 2400/32000 | CE Loss: 3.2344 | Trio Blend: 0.0%
|
| 32 |
+
Step 2500/32000 | CE Loss: 4.3438 | Trio Blend: 0.0%
|
| 33 |
+
Step 2600/32000 | CE Loss: 3.0781 | Trio Blend: 0.0%
|
| 34 |
+
Step 2700/32000 | CE Loss: 3.6250 | Trio Blend: 0.0%
|
| 35 |
+
Step 2800/32000 | CE Loss: 3.2188 | Trio Blend: 0.0%
|
| 36 |
+
Step 2900/32000 | CE Loss: 2.7188 | Trio Blend: 0.0%
|
| 37 |
+
Step 3000/32000 | CE Loss: 2.9375 | Trio Blend: 0.0%
|
| 38 |
+
Step 3100/32000 | CE Loss: 5.3438 | Trio Blend: 0.0%
|
| 39 |
+
Step 3200/32000 | CE Loss: 3.7656 | Trio Blend: 0.0%
|
| 40 |
+
Step 3300/32000 | CE Loss: 2.7031 | Trio Blend: 0.0%
|
| 41 |
+
Step 3400/32000 | CE Loss: 2.6094 | Trio Blend: 0.0%
|
| 42 |
+
Step 3500/32000 | CE Loss: 4.0938 | Trio Blend: 0.0%
|
| 43 |
+
Step 3600/32000 | CE Loss: 3.0000 | Trio Blend: 0.0%
|
| 44 |
+
Step 3700/32000 | CE Loss: 3.0000 | Trio Blend: 0.0%
|
| 45 |
+
Step 3800/32000 | CE Loss: 3.0938 | Trio Blend: 0.0%
|
| 46 |
+
Step 3900/32000 | CE Loss: 2.7656 | Trio Blend: 0.0%
|
| 47 |
+
Step 4000/32000 | CE Loss: 2.8281 | Trio Blend: 0.0%
|
| 48 |
+
Step 4100/32000 | CE Loss: 2.7500 | Trio Blend: 0.0%
|
| 49 |
+
Step 4200/32000 | CE Loss: 2.8906 | Trio Blend: 0.0%
|
| 50 |
+
Step 4300/32000 | CE Loss: 3.1250 | Trio Blend: 0.0%
|
| 51 |
+
Step 4400/32000 | CE Loss: 2.6094 | Trio Blend: 0.0%
|
| 52 |
+
Step 4500/32000 | CE Loss: 2.8750 | Trio Blend: 0.0%
|
| 53 |
+
Step 4600/32000 | CE Loss: 2.4688 | Trio Blend: 0.0%
|
| 54 |
+
Step 4700/32000 | CE Loss: 2.6875 | Trio Blend: 0.0%
|
| 55 |
+
Step 4800/32000 | CE Loss: 2.4531 | Trio Blend: 0.0%
|
| 56 |
+
Step 4900/32000 | CE Loss: 2.6250 | Trio Blend: 0.0%
|
| 57 |
+
Step 5000/32000 | CE Loss: 2.6875 | Trio Blend: 0.0%
|
| 58 |
+
Step 5100/32000 | CE Loss: 2.7344 | Trio Blend: 0.0%
|
| 59 |
+
Step 5200/32000 | CE Loss: 2.6406 | Trio Blend: 0.0%
|
| 60 |
+
Step 5300/32000 | CE Loss: 2.8281 | Trio Blend: 0.0%
|
| 61 |
+
Step 5400/32000 | CE Loss: 2.6094 | Trio Blend: 0.0%
|
| 62 |
+
Step 5500/32000 | CE Loss: 2.5938 | Trio Blend: 0.0%
|
| 63 |
+
Step 5600/32000 | CE Loss: 2.7031 | Trio Blend: 0.0%
|
| 64 |
+
Step 5700/32000 | CE Loss: 3.0156 | Trio Blend: 0.0%
|
| 65 |
+
Step 5800/32000 | CE Loss: 2.5781 | Trio Blend: 0.0%
|
| 66 |
+
Step 5900/32000 | CE Loss: 2.7188 | Trio Blend: 0.0%
|
| 67 |
+
Step 6000/32000 | CE Loss: 2.6250 | Trio Blend: 0.0%
|
| 68 |
+
Step 6100/32000 | CE Loss: 3.5312 | Trio Blend: 0.0%
|
| 69 |
+
Step 6200/32000 | CE Loss: 2.8125 | Trio Blend: 0.0%
|
| 70 |
+
Step 6300/32000 | CE Loss: 2.7656 | Trio Blend: 0.0%
|
| 71 |
+
Step 6400/32000 | CE Loss: 2.6875 | Trio Blend: 0.1%
|
| 72 |
+
Step 6500/32000 | CE Loss: 2.7031 | Trio Blend: 0.1%
|
| 73 |
+
Step 6600/32000 | CE Loss: 2.8594 | Trio Blend: 0.1%
|
| 74 |
+
Step 6700/32000 | CE Loss: 2.6562 | Trio Blend: 0.2%
|
| 75 |
+
Step 6800/32000 | CE Loss: 2.8906 | Trio Blend: 0.2%
|
| 76 |
+
Step 6900/32000 | CE Loss: 2.4844 | Trio Blend: 0.3%
|
| 77 |
+
Step 7000/32000 | CE Loss: 2.5469 | Trio Blend: 0.4%
|
| 78 |
+
Step 7100/32000 | CE Loss: 2.8906 | Trio Blend: 0.4%
|
| 79 |
+
Step 7200/32000 | CE Loss: 2.6094 | Trio Blend: 0.5%
|
| 80 |
+
Step 7300/32000 | CE Loss: 2.7188 | Trio Blend: 0.6%
|
| 81 |
+
Step 7400/32000 | CE Loss: 2.4531 | Trio Blend: 0.7%
|
| 82 |
+
Step 7500/32000 | CE Loss: 3.1094 | Trio Blend: 0.8%
|
| 83 |
+
Step 7600/32000 | CE Loss: 2.7500 | Trio Blend: 0.9%
|
| 84 |
+
Step 7700/32000 | CE Loss: 2.7656 | Trio Blend: 1.1%
|
| 85 |
+
Step 7800/32000 | CE Loss: 2.6094 | Trio Blend: 1.2%
|
| 86 |
+
Step 7900/32000 | CE Loss: 2.4844 | Trio Blend: 1.3%
|
| 87 |
+
Step 8000/32000 | CE Loss: 2.6094 | Trio Blend: 1.5%
|
| 88 |
+
Step 8100/32000 | CE Loss: 2.8438 | Trio Blend: 1.6%
|
| 89 |
+
Step 8200/32000 | CE Loss: 2.6250 | Trio Blend: 1.8%
|
| 90 |
+
Step 8300/32000 | CE Loss: 2.7656 | Trio Blend: 1.9%
|
| 91 |
+
Step 8400/32000 | CE Loss: 2.5781 | Trio Blend: 2.1%
|
| 92 |
+
Step 8500/32000 | CE Loss: 3.0312 | Trio Blend: 2.3%
|
| 93 |
+
Step 8600/32000 | CE Loss: 2.6719 | Trio Blend: 2.4%
|
| 94 |
+
Step 8700/32000 | CE Loss: 3.6250 | Trio Blend: 2.6%
|
| 95 |
+
Step 8800/32000 | CE Loss: 2.8125 | Trio Blend: 2.8%
|
| 96 |
+
Step 8900/32000 | CE Loss: 2.6094 | Trio Blend: 3.0%
|
| 97 |
+
Step 9000/32000 | CE Loss: 2.6562 | Trio Blend: 3.2%
|
| 98 |
+
Step 9100/32000 | CE Loss: 2.9062 | Trio Blend: 3.5%
|
| 99 |
+
Step 9200/32000 | CE Loss: 2.6875 | Trio Blend: 3.7%
|
| 100 |
+
Step 9300/32000 | CE Loss: 2.6875 | Trio Blend: 3.9%
|
| 101 |
+
Step 9400/32000 | CE Loss: 2.7188 | Trio Blend: 4.2%
|
| 102 |
+
Step 9500/32000 | CE Loss: 2.4219 | Trio Blend: 4.4%
|
| 103 |
+
Step 9600/32000 | CE Loss: 2.6562 | Trio Blend: 4.7%
|
| 104 |
+
Step 9700/32000 | CE Loss: 2.7500 | Trio Blend: 4.9%
|
| 105 |
+
Step 9800/32000 | CE Loss: 2.7500 | Trio Blend: 5.2%
|
| 106 |
+
Step 9900/32000 | CE Loss: 2.7188 | Trio Blend: 5.4%
|
| 107 |
+
Step 10000/32000 | CE Loss: 2.6562 | Trio Blend: 5.7%
|
| 108 |
+
Step 10100/32000 | CE Loss: 2.8594 | Trio Blend: 6.0%
|
| 109 |
+
Step 10200/32000 | CE Loss: 2.4531 | Trio Blend: 6.3%
|
| 110 |
+
Step 10300/32000 | CE Loss: 2.5781 | Trio Blend: 6.6%
|
| 111 |
+
Step 10400/32000 | CE Loss: 2.7344 | Trio Blend: 6.9%
|
| 112 |
+
Step 10500/32000 | CE Loss: 2.5156 | Trio Blend: 7.2%
|
| 113 |
+
Step 10600/32000 | CE Loss: 4.3438 | Trio Blend: 7.5%
|
| 114 |
+
Step 10700/32000 | CE Loss: 2.5781 | Trio Blend: 7.8%
|
| 115 |
+
Step 10800/32000 | CE Loss: 2.7344 | Trio Blend: 8.2%
|
| 116 |
+
Step 10900/32000 | CE Loss: 2.2812 | Trio Blend: 8.5%
|
| 117 |
+
Step 11000/32000 | CE Loss: 2.8438 | Trio Blend: 8.9%
|
| 118 |
+
Step 11100/32000 | CE Loss: 2.6562 | Trio Blend: 9.2%
|
| 119 |
+
Step 11200/32000 | CE Loss: 2.6875 | Trio Blend: 9.5%
|
| 120 |
+
Step 11300/32000 | CE Loss: 2.5312 | Trio Blend: 9.9%
|
| 121 |
+
Step 11400/32000 | CE Loss: 2.5000 | Trio Blend: 10.3%
|
| 122 |
+
Step 11500/32000 | CE Loss: 2.5156 | Trio Blend: 10.6%
|
| 123 |
+
Step 11600/32000 | CE Loss: 2.5156 | Trio Blend: 11.0%
|
| 124 |
+
Step 11700/32000 | CE Loss: 2.6562 | Trio Blend: 11.4%
|
| 125 |
+
Step 11800/32000 | CE Loss: 2.5625 | Trio Blend: 11.8%
|
| 126 |
+
Step 11900/32000 | CE Loss: 2.7188 | Trio Blend: 12.2%
|
| 127 |
+
Step 12000/32000 | CE Loss: 2.7344 | Trio Blend: 12.6%
|
| 128 |
+
Step 12100/32000 | CE Loss: 2.9688 | Trio Blend: 13.0%
|
| 129 |
+
Step 12200/32000 | CE Loss: 4.2500 | Trio Blend: 13.4%
|
| 130 |
+
Step 12300/32000 | CE Loss: 2.7656 | Trio Blend: 13.8%
|
| 131 |
+
Step 12400/32000 | CE Loss: 2.7031 | Trio Blend: 14.2%
|
| 132 |
+
Step 12500/32000 | CE Loss: 3.3125 | Trio Blend: 14.6%
|
| 133 |
+
Step 12600/32000 | CE Loss: 2.5625 | Trio Blend: 15.1%
|
| 134 |
+
Step 12700/32000 | CE Loss: 2.7031 | Trio Blend: 15.5%
|
| 135 |
+
Step 12800/32000 | CE Loss: 2.7500 | Trio Blend: 15.9%
|
| 136 |
+
Step 12900/32000 | CE Loss: 2.6562 | Trio Blend: 16.4%
|
| 137 |
+
Step 13000/32000 | CE Loss: 2.7656 | Trio Blend: 16.8%
|
| 138 |
+
Step 13100/32000 | CE Loss: 2.5312 | Trio Blend: 17.3%
|
| 139 |
+
Step 13200/32000 | CE Loss: 2.7188 | Trio Blend: 17.8%
|
| 140 |
+
Step 13300/32000 | CE Loss: 2.6250 | Trio Blend: 18.2%
|
| 141 |
+
Step 13400/32000 | CE Loss: 2.6875 | Trio Blend: 18.7%
|
| 142 |
+
Step 13500/32000 | CE Loss: 3.0625 | Trio Blend: 19.2%
|
| 143 |
+
Step 13600/32000 | CE Loss: 2.9219 | Trio Blend: 19.6%
|
| 144 |
+
Step 13700/32000 | CE Loss: 2.6562 | Trio Blend: 20.1%
|
| 145 |
+
Step 13800/32000 | CE Loss: 3.0469 | Trio Blend: 20.6%
|
| 146 |
+
Step 13900/32000 | CE Loss: 2.7188 | Trio Blend: 21.1%
|
| 147 |
+
Step 14000/32000 | CE Loss: 2.7500 | Trio Blend: 21.6%
|
| 148 |
+
Step 14100/32000 | CE Loss: 2.6562 | Trio Blend: 22.1%
|
| 149 |
+
Step 14200/32000 | CE Loss: 2.4531 | Trio Blend: 22.6%
|
| 150 |
+
Step 14300/32000 | CE Loss: 2.7812 | Trio Blend: 23.1%
|
| 151 |
+
Step 14400/32000 | CE Loss: 2.5469 | Trio Blend: 23.6%
|
| 152 |
+
Step 14500/32000 | CE Loss: 2.8438 | Trio Blend: 24.1%
|
| 153 |
+
Step 14600/32000 | CE Loss: 3.0781 | Trio Blend: 24.7%
|
| 154 |
+
Step 14700/32000 | CE Loss: 2.9062 | Trio Blend: 25.2%
|
| 155 |
+
Step 14800/32000 | CE Loss: 2.9844 | Trio Blend: 25.7%
|
| 156 |
+
Step 14900/32000 | CE Loss: 2.7031 | Trio Blend: 26.2%
|
| 157 |
+
Step 15000/32000 | CE Loss: 2.6562 | Trio Blend: 26.8%
|
| 158 |
+
Step 15100/32000 | CE Loss: 2.6250 | Trio Blend: 27.3%
|
| 159 |
+
Step 15200/32000 | CE Loss: 2.7500 | Trio Blend: 27.8%
|
| 160 |
+
Step 15300/32000 | CE Loss: 2.8594 | Trio Blend: 28.4%
|
| 161 |
+
Step 15400/32000 | CE Loss: 3.3750 | Trio Blend: 28.9%
|
| 162 |
+
Step 15500/32000 | CE Loss: 3.0938 | Trio Blend: 29.5%
|
| 163 |
+
Step 15600/32000 | CE Loss: 3.7188 | Trio Blend: 30.0%
|
| 164 |
+
Step 15700/32000 | CE Loss: 3.2188 | Trio Blend: 30.6%
|
| 165 |
+
Step 15800/32000 | CE Loss: 2.5156 | Trio Blend: 31.1%
|
| 166 |
+
Step 15900/32000 | CE Loss: 2.6562 | Trio Blend: 31.7%
|
| 167 |
+
Step 16000/32000 | CE Loss: 2.8906 | Trio Blend: 32.3%
|
| 168 |
+
Step 16100/32000 | CE Loss: 2.6250 | Trio Blend: 32.8%
|
| 169 |
+
Step 16200/32000 | CE Loss: 2.8594 | Trio Blend: 33.4%
|
| 170 |
+
Step 16300/32000 | CE Loss: 3.0625 | Trio Blend: 34.0%
|
| 171 |
+
Step 16400/32000 | CE Loss: 2.7656 | Trio Blend: 34.5%
|
| 172 |
+
Step 16500/32000 | CE Loss: 3.2969 | Trio Blend: 35.1%
|
| 173 |
+
Step 16600/32000 | CE Loss: 2.6406 | Trio Blend: 35.7%
|
| 174 |
+
Step 16700/32000 | CE Loss: 3.0781 | Trio Blend: 36.3%
|
| 175 |
+
Step 16800/32000 | CE Loss: 2.8438 | Trio Blend: 36.9%
|
| 176 |
+
Step 16900/32000 | CE Loss: 2.8438 | Trio Blend: 37.4%
|
| 177 |
+
Step 17000/32000 | CE Loss: 2.7031 | Trio Blend: 38.0%
|
| 178 |
+
Step 17100/32000 | CE Loss: 2.7969 | Trio Blend: 38.6%
|
| 179 |
+
Step 17200/32000 | CE Loss: 2.5938 | Trio Blend: 39.2%
|
| 180 |
+
Step 17300/32000 | CE Loss: 2.5938 | Trio Blend: 39.8%
|
| 181 |
+
Step 17400/32000 | CE Loss: 2.8438 | Trio Blend: 40.4%
|
| 182 |
+
Step 17500/32000 | CE Loss: 3.1875 | Trio Blend: 41.0%
|
| 183 |
+
Step 17600/32000 | CE Loss: 2.8438 | Trio Blend: 41.6%
|
| 184 |
+
Step 17700/32000 | CE Loss: 2.7656 | Trio Blend: 42.2%
|
| 185 |
+
Step 17800/32000 | CE Loss: 2.9062 | Trio Blend: 42.8%
|
| 186 |
+
Step 17900/32000 | CE Loss: 2.6875 | Trio Blend: 43.4%
|
| 187 |
+
Step 18000/32000 | CE Loss: 3.1406 | Trio Blend: 44.0%
|
| 188 |
+
Step 18100/32000 | CE Loss: 4.0312 | Trio Blend: 44.6%
|
| 189 |
+
Step 18200/32000 | CE Loss: 3.1406 | Trio Blend: 45.2%
|
| 190 |
+
Step 18300/32000 | CE Loss: 3.2031 | Trio Blend: 45.8%
|
| 191 |
+
Step 18400/32000 | CE Loss: 3.2969 | Trio Blend: 46.4%
|
| 192 |
+
Step 18500/32000 | CE Loss: 3.0625 | Trio Blend: 47.0%
|
| 193 |
+
Step 18600/32000 | CE Loss: 3.7031 | Trio Blend: 47.6%
|
| 194 |
+
Step 18700/32000 | CE Loss: 3.1719 | Trio Blend: 48.2%
|
| 195 |
+
Step 18800/32000 | CE Loss: 3.3906 | Trio Blend: 48.8%
|
| 196 |
+
Step 18900/32000 | CE Loss: 3.1406 | Trio Blend: 49.4%
|
| 197 |
+
Step 19000/32000 | CE Loss: 2.8594 | Trio Blend: 50.0%
|
| 198 |
+
Step 19100/32000 | CE Loss: 2.8906 | Trio Blend: 50.6%
|
| 199 |
+
Step 19200/32000 | CE Loss: 3.2344 | Trio Blend: 51.2%
|
| 200 |
+
Step 19300/32000 | CE Loss: 2.6562 | Trio Blend: 51.8%
|
| 201 |
+
Step 19400/32000 | CE Loss: 3.5312 | Trio Blend: 52.4%
|
| 202 |
+
Step 19500/32000 | CE Loss: 3.1719 | Trio Blend: 53.0%
|
| 203 |
+
Step 19600/32000 | CE Loss: 4.2500 | Trio Blend: 53.6%
|
| 204 |
+
Step 19700/32000 | CE Loss: 3.5625 | Trio Blend: 54.2%
|
| 205 |
+
Step 19800/32000 | CE Loss: 2.9688 | Trio Blend: 54.8%
|
| 206 |
+
Step 19900/32000 | CE Loss: 8.0000 | Trio Blend: 55.4%
|
| 207 |
+
Step 20000/32000 | CE Loss: 4.5312 | Trio Blend: 56.0%
|
| 208 |
+
Step 20100/32000 | CE Loss: 6.0000 | Trio Blend: 56.6%
|
| 209 |
+
Step 20200/32000 | CE Loss: 2.9531 | Trio Blend: 57.2%
|
| 210 |
+
Step 20300/32000 | CE Loss: 3.8281 | Trio Blend: 57.8%
|
| 211 |
+
Step 20400/32000 | CE Loss: 5.4688 | Trio Blend: 58.4%
|
| 212 |
+
Step 20500/32000 | CE Loss: 7.0625 | Trio Blend: 59.0%
|
| 213 |
+
Step 20600/32000 | CE Loss: 3.7500 | Trio Blend: 59.6%
|
| 214 |
+
Step 20700/32000 | CE Loss: 3.1875 | Trio Blend: 60.2%
|
| 215 |
+
Step 20800/32000 | CE Loss: 3.5312 | Trio Blend: 60.8%
|
| 216 |
+
Step 20900/32000 | CE Loss: 3.3594 | Trio Blend: 61.4%
|
| 217 |
+
Step 21000/32000 | CE Loss: 3.3125 | Trio Blend: 62.0%
|
| 218 |
+
Step 21100/32000 | CE Loss: 3.2031 | Trio Blend: 62.6%
|
| 219 |
+
Step 21200/32000 | CE Loss: 3.1406 | Trio Blend: 63.1%
|
| 220 |
+
Step 21300/32000 | CE Loss: 3.4531 | Trio Blend: 63.7%
|
| 221 |
+
Step 21400/32000 | CE Loss: 3.0625 | Trio Blend: 64.3%
|
| 222 |
+
Step 21500/32000 | CE Loss: 3.7031 | Trio Blend: 64.9%
|
| 223 |
+
Step 21600/32000 | CE Loss: 3.4062 | Trio Blend: 65.5%
|
| 224 |
+
Step 21700/32000 | CE Loss: 4.8438 | Trio Blend: 66.0%
|
| 225 |
+
Step 21800/32000 | CE Loss: 3.0000 | Trio Blend: 66.6%
|
| 226 |
+
Step 21900/32000 | CE Loss: 6.6250 | Trio Blend: 67.2%
|
| 227 |
+
Step 22000/32000 | CE Loss: 3.7344 | Trio Blend: 67.7%
|
| 228 |
+
Step 22100/32000 | CE Loss: 3.9844 | Trio Blend: 68.3%
|
| 229 |
+
Step 22200/32000 | CE Loss: 3.5625 | Trio Blend: 68.9%
|
| 230 |
+
Step 22300/32000 | CE Loss: 11.6875 | Trio Blend: 69.4%
|
| 231 |
+
Step 22400/32000 | CE Loss: 3.8906 | Trio Blend: 70.0%
|
| 232 |
+
Step 22500/32000 | CE Loss: 3.4688 | Trio Blend: 70.5%
|
| 233 |
+
Step 22600/32000 | CE Loss: 3.3125 | Trio Blend: 71.1%
|
| 234 |
+
Step 22700/32000 | CE Loss: 3.4062 | Trio Blend: 71.6%
|
| 235 |
+
Step 22800/32000 | CE Loss: 3.3906 | Trio Blend: 72.2%
|
| 236 |
+
Step 22900/32000 | CE Loss: 2.8438 | Trio Blend: 72.7%
|
| 237 |
+
Step 23000/32000 | CE Loss: 3.2031 | Trio Blend: 73.2%
|
| 238 |
+
Step 23100/32000 | CE Loss: 2.7031 | Trio Blend: 73.8%
|
| 239 |
+
Step 23200/32000 | CE Loss: 5.6562 | Trio Blend: 74.3%
|
| 240 |
+
Step 23300/32000 | CE Loss: 3.5469 | Trio Blend: 74.8%
|
| 241 |
+
Step 23400/32000 | CE Loss: 7.6875 | Trio Blend: 75.3%
|
| 242 |
+
Step 23500/32000 | CE Loss: 3.3594 | Trio Blend: 75.9%
|
| 243 |
+
Step 23600/32000 | CE Loss: 3.6406 | Trio Blend: 76.4%
|
| 244 |
+
Step 23700/32000 | CE Loss: 3.3750 | Trio Blend: 76.9%
|
| 245 |
+
Step 23800/32000 | CE Loss: 3.3125 | Trio Blend: 77.4%
|
| 246 |
+
Step 23900/32000 | CE Loss: 3.0625 | Trio Blend: 77.9%
|
| 247 |
+
Step 24000/32000 | CE Loss: 4.1875 | Trio Blend: 78.4%
|
| 248 |
+
Step 24100/32000 | CE Loss: 4.3438 | Trio Blend: 78.9%
|
| 249 |
+
Step 24200/32000 | CE Loss: 4.1250 | Trio Blend: 79.4%
|
| 250 |
+
Step 24300/32000 | CE Loss: 4.4062 | Trio Blend: 79.9%
|
| 251 |
+
Step 24400/32000 | CE Loss: 4.1875 | Trio Blend: 80.4%
|
| 252 |
+
Step 24500/32000 | CE Loss: 3.0312 | Trio Blend: 80.8%
|
| 253 |
+
Step 24600/32000 | CE Loss: 3.6406 | Trio Blend: 81.3%
|
| 254 |
+
Step 24700/32000 | CE Loss: 3.8906 | Trio Blend: 81.8%
|
| 255 |
+
Step 24800/32000 | CE Loss: 4.0312 | Trio Blend: 82.2%
|
| 256 |
+
Step 24900/32000 | CE Loss: 4.4062 | Trio Blend: 82.7%
|
| 257 |
+
Step 25000/32000 | CE Loss: 3.4688 | Trio Blend: 83.2%
|
| 258 |
+
Step 25100/32000 | CE Loss: 5.0938 | Trio Blend: 83.6%
|
| 259 |
+
Step 25200/32000 | CE Loss: 3.7344 | Trio Blend: 84.1%
|
| 260 |
+
Step 25300/32000 | CE Loss: 3.6719 | Trio Blend: 84.5%
|
| 261 |
+
Step 25400/32000 | CE Loss: 3.3438 | Trio Blend: 84.9%
|
| 262 |
+
Step 25500/32000 | CE Loss: 3.1094 | Trio Blend: 85.4%
|
| 263 |
+
Step 25600/32000 | CE Loss: 3.6094 | Trio Blend: 85.8%
|
| 264 |
+
Step 25700/32000 | CE Loss: 3.2656 | Trio Blend: 86.2%
|
| 265 |
+
Step 25800/32000 | CE Loss: 3.2344 | Trio Blend: 86.6%
|
| 266 |
+
Step 25900/32000 | CE Loss: 4.0625 | Trio Blend: 87.0%
|
| 267 |
+
Step 26000/32000 | CE Loss: 9.1250 | Trio Blend: 87.4%
|
| 268 |
+
Step 26100/32000 | CE Loss: 4.4688 | Trio Blend: 87.8%
|
| 269 |
+
Step 26200/32000 | CE Loss: 5.7500 | Trio Blend: 88.2%
|
| 270 |
+
Step 26300/32000 | CE Loss: 3.2188 | Trio Blend: 88.6%
|
| 271 |
+
Step 26400/32000 | CE Loss: 3.5625 | Trio Blend: 89.0%
|
| 272 |
+
Step 26500/32000 | CE Loss: 6.2812 | Trio Blend: 89.4%
|
| 273 |
+
Step 26600/32000 | CE Loss: 3.8438 | Trio Blend: 89.7%
|
| 274 |
+
Step 26700/32000 | CE Loss: 3.0938 | Trio Blend: 90.1%
|
| 275 |
+
Step 26800/32000 | CE Loss: 4.6875 | Trio Blend: 90.5%
|
| 276 |
+
Step 26900/32000 | CE Loss: 3.0469 | Trio Blend: 90.8%
|
| 277 |
+
Step 27000/32000 | CE Loss: 5.5000 | Trio Blend: 91.1%
|
| 278 |
+
Step 27100/32000 | CE Loss: 3.3594 | Trio Blend: 91.5%
|
| 279 |
+
Step 27200/32000 | CE Loss: 6.0312 | Trio Blend: 91.8%
|
| 280 |
+
Step 27300/32000 | CE Loss: 5.5938 | Trio Blend: 92.2%
|
| 281 |
+
Step 27400/32000 | CE Loss: 3.7656 | Trio Blend: 92.5%
|
| 282 |
+
Step 27500/32000 | CE Loss: 3.3594 | Trio Blend: 92.8%
|
| 283 |
+
Step 27600/32000 | CE Loss: 3.9062 | Trio Blend: 93.1%
|
| 284 |
+
Step 27700/32000 | CE Loss: 5.8750 | Trio Blend: 93.4%
|
| 285 |
+
Step 27800/32000 | CE Loss: 6.3125 | Trio Blend: 93.7%
|
| 286 |
+
Step 27900/32000 | CE Loss: 4.4062 | Trio Blend: 94.0%
|
| 287 |
+
Step 28000/32000 | CE Loss: 4.0625 | Trio Blend: 94.3%
|
| 288 |
+
Step 28100/32000 | CE Loss: 3.4531 | Trio Blend: 94.6%
|
| 289 |
+
Step 28200/32000 | CE Loss: 3.7031 | Trio Blend: 94.8%
|
| 290 |
+
Step 28300/32000 | CE Loss: 3.5000 | Trio Blend: 95.1%
|
| 291 |
+
Step 28400/32000 | CE Loss: 7.4688 | Trio Blend: 95.3%
|
| 292 |
+
Step 28500/32000 | CE Loss: 4.0312 | Trio Blend: 95.6%
|
| 293 |
+
Step 28600/32000 | CE Loss: 3.7812 | Trio Blend: 95.8%
|
| 294 |
+
Step 28700/32000 | CE Loss: 4.4062 | Trio Blend: 96.1%
|
| 295 |
+
Step 28800/32000 | CE Loss: 4.2500 | Trio Blend: 96.3%
|
| 296 |
+
Step 28900/32000 | CE Loss: 5.2500 | Trio Blend: 96.5%
|
| 297 |
+
Step 29000/32000 | CE Loss: 3.3281 | Trio Blend: 96.8%
|
| 298 |
+
Step 29100/32000 | CE Loss: 3.7031 | Trio Blend: 97.0%
|
| 299 |
+
Step 29200/32000 | CE Loss: 4.5312 | Trio Blend: 97.2%
|
| 300 |
+
Step 29300/32000 | CE Loss: 5.6250 | Trio Blend: 97.4%
|
| 301 |
+
Step 29400/32000 | CE Loss: 4.5625 | Trio Blend: 97.6%
|
| 302 |
+
Step 29500/32000 | CE Loss: 3.5938 | Trio Blend: 97.7%
|
| 303 |
+
Step 29600/32000 | CE Loss: 9.6250 | Trio Blend: 97.9%
|
| 304 |
+
Step 29700/32000 | CE Loss: 4.9375 | Trio Blend: 98.1%
|
| 305 |
+
Step 29800/32000 | CE Loss: 4.9062 | Trio Blend: 98.2%
|
| 306 |
+
Step 29900/32000 | CE Loss: 3.4531 | Trio Blend: 98.4%
|
| 307 |
+
Step 30000/32000 | CE Loss: 3.9062 | Trio Blend: 98.5%
|
| 308 |
+
Step 30100/32000 | CE Loss: 4.4688 | Trio Blend: 98.7%
|
| 309 |
+
Step 30200/32000 | CE Loss: 3.6250 | Trio Blend: 98.8%
|
| 310 |
+
Step 30300/32000 | CE Loss: 4.4688 | Trio Blend: 98.9%
|
| 311 |
+
Step 30400/32000 | CE Loss: 3.3125 | Trio Blend: 99.1%
|
| 312 |
+
Step 30500/32000 | CE Loss: 3.7500 | Trio Blend: 99.2%
|
| 313 |
+
Step 30600/32000 | CE Loss: 3.6094 | Trio Blend: 99.3%
|
| 314 |
+
Step 30700/32000 | CE Loss: 5.5000 | Trio Blend: 99.4%
|
| 315 |
+
Step 30800/32000 | CE Loss: 5.0000 | Trio Blend: 99.5%
|
| 316 |
+
Step 30900/32000 | CE Loss: 4.8125 | Trio Blend: 99.6%
|
| 317 |
+
Step 31000/32000 | CE Loss: 3.5625 | Trio Blend: 99.6%
|
| 318 |
+
Step 31100/32000 | CE Loss: 4.0000 | Trio Blend: 99.7%
|
| 319 |
+
Step 31200/32000 | CE Loss: 4.2188 | Trio Blend: 99.8%
|
| 320 |
+
Step 31300/32000 | CE Loss: 5.3438 | Trio Blend: 99.8%
|
| 321 |
+
Step 31400/32000 | CE Loss: 4.9062 | Trio Blend: 99.9%
|
| 322 |
+
Step 31500/32000 | CE Loss: 2.9062 | Trio Blend: 99.9%
|
| 323 |
+
Step 31600/32000 | CE Loss: 3.4375 | Trio Blend: 99.9%
|
| 324 |
+
Step 31700/32000 | CE Loss: 3.6250 | Trio Blend: 100.0%
|
| 325 |
+
Step 31800/32000 | CE Loss: 3.0469 | Trio Blend: 100.0%
|
| 326 |
+
Step 31900/32000 | CE Loss: 3.6875 | Trio Blend: 100.0%
|
| 327 |
+
|
| 328 |
+
--- フェーズ1(英語)完了。3値結晶化モデルを抽出・保存します ---
|
| 329 |
+
🎉 D-RNAの構造を完全に守った3値結晶化ファイルを出力しました: drna_trio_phase1_english.safetensors
|
| 330 |
+
|
| 331 |
+
---
|
| 332 |
+
drna_trio_phase1_english.safetensors
|
| 333 |
+
|
| 334 |
+
Layer Name | Mean | Std | Max | Min | SVD
|
| 335 |
+
embed.weight | 0.0016 | 1.0007 | 4.47 | -4.28 | 5.50
|
| 336 |
+
---------------------------------------------------------------------------------------------
|
| 337 |
+
layers.0.mlp.0.weight | 0.0084 | 0.4821 | 1.00 | -1.00 | 5.66
|
| 338 |
+
layers.0.mlp.3.weight | -0.0194 | 0.3737 | 1.00 | -1.00 | 5.66
|
| 339 |
+
layers.0.out_proj.weight | 0.0145 | 0.2285 | 1.00 | -1.00 | 5.16
|
| 340 |
+
layers.0.qkv.weight | -0.0396 | 0.6020 | 1.00 | -1.00 | 5.40
|
| 341 |
+
layers.1.mlp.0.weight | -0.0237 | 0.4926 | 1.00 | -1.00 | 5.67
|
| 342 |
+
layers.1.mlp.3.weight | -0.0035 | 0.3485 | 1.00 | -1.00 | 5.66
|
| 343 |
+
layers.1.out_proj.weight | 0.0151 | 0.2843 | 1.00 | -1.00 | 5.36
|
| 344 |
+
layers.1.qkv.weight | 0.0143 | 0.5119 | 1.00 | -1.00 | 5.52
|
| 345 |
+
layers.10.mlp.0.weight | -0.0294 | 0.4374 | 1.00 | -1.00 | 5.65
|
| 346 |
+
layers.10.mlp.3.weight | -0.0018 | 0.3587 | 1.00 | -1.00 | 5.58
|
| 347 |
+
layers.10.out_proj.weight | 0.0134 | 0.3689 | 1.00 | -1.00 | 5.34
|
| 348 |
+
layers.10.qkv.weight | -0.0128 | 0.4449 | 1.00 | -1.00 | 5.56
|
| 349 |
+
layers.11.mlp.0.weight | -0.0230 | 0.3980 | 1.00 | -1.00 | 5.66
|
| 350 |
+
layers.11.mlp.3.weight | 0.0114 | 0.3651 | 1.00 | -1.00 | 5.59
|
| 351 |
+
layers.11.out_proj.weight | -0.0027 | 0.3156 | 1.00 | -1.00 | 5.29
|
| 352 |
+
layers.11.qkv.weight | 0.0132 | 0.3869 | 1.00 | -1.00 | 5.52
|
| 353 |
+
layers.12.mlp.0.weight | -0.0023 | 0.4047 | 1.00 | -1.00 | 5.66
|
| 354 |
+
layers.12.mlp.3.weight | -0.0210 | 0.3773 | 1.00 | -1.00 | 5.60
|
| 355 |
+
layers.12.out_proj.weight | 0.0310 | 0.3670 | 1.00 | -1.00 | 5.34
|
| 356 |
+
layers.12.qkv.weight | -0.0145 | 0.4328 | 1.00 | -1.00 | 5.56
|
| 357 |
+
layers.13.mlp.0.weight | -0.0134 | 0.3875 | 1.00 | -1.00 | 5.66
|
| 358 |
+
layers.13.mlp.3.weight | -0.0131 | 0.3814 | 1.00 | -1.00 | 5.61
|
| 359 |
+
layers.13.out_proj.weight | -0.0035 | 0.3549 | 1.00 | -1.00 | 5.32
|
| 360 |
+
layers.13.qkv.weight | -0.0154 | 0.4175 | 1.00 | -1.00 | 5.54
|
| 361 |
+
layers.14.mlp.0.weight | -0.0096 | 0.4027 | 1.00 | -1.00 | 5.66
|
| 362 |
+
layers.14.mlp.3.weight | -0.0101 | 0.3817 | 1.00 | -1.00 | 5.60
|
| 363 |
+
layers.14.out_proj.weight | -0.0135 | 0.3729 | 1.00 | -1.00 | 5.35
|
| 364 |
+
layers.14.qkv.weight | 0.0009 | 0.4264 | 1.00 | -1.00 | 5.57
|
| 365 |
+
layers.15.mlp.0.weight | 0.0066 | 0.4065 | 1.00 | -1.00 | 5.66
|
| 366 |
+
layers.15.mlp.3.weight | 0.0256 | 0.3939 | 1.00 | -1.00 | 5.59
|
| 367 |
+
layers.15.out_proj.weight | 0.0044 | 0.3477 | 1.00 | -1.00 | 5.30
|
| 368 |
+
layers.15.qkv.weight | -0.0036 | 0.4371 | 1.00 | -1.00 | 5.55
|
| 369 |
+
layers.16.mlp.0.weight | 0.0119 | 0.4346 | 1.00 | -1.00 | 5.66
|
| 370 |
+
layers.16.mlp.3.weight | 0.0189 | 0.3736 | 1.00 | -1.00 | 5.66
|
| 371 |
+
layers.16.out_proj.weight | -0.0043 | 0.3657 | 1.00 | -1.00 | 5.36
|
| 372 |
+
layers.16.qkv.weight | 0.0089 | 0.4011 | 1.00 | -1.00 | 5.54
|
| 373 |
+
layers.17.mlp.0.weight | -0.0105 | 0.4586 | 1.00 | -1.00 | 5.67
|
| 374 |
+
layers.17.mlp.3.weight | 0.0027 | 0.3710 | 1.00 | -1.00 | 5.66
|
| 375 |
+
layers.17.out_proj.weight | -0.0062 | 0.3893 | 1.00 | -1.00 | 5.40
|
| 376 |
+
layers.17.qkv.weight | 0.0083 | 0.4309 | 1.00 | -1.00 | 5.56
|
| 377 |
+
layers.18.mlp.0.weight | 0.0088 | 0.4437 | 1.00 | -1.00 | 5.67
|
| 378 |
+
layers.18.mlp.3.weight | 0.0181 | 0.3938 | 1.00 | -1.00 | 5.64
|
| 379 |
+
layers.18.out_proj.weight | -0.0287 | 0.4066 | 1.00 | -1.00 | 5.39
|
| 380 |
+
layers.18.qkv.weight | 0.0100 | 0.4649 | 1.00 | -1.00 | 5.56
|
| 381 |
+
layers.19.mlp.0.weight | -0.0041 | 0.4362 | 1.00 | -1.00 | 5.66
|
| 382 |
+
layers.19.mlp.3.weight | -0.0261 | 0.4043 | 1.00 | -1.00 | 5.63
|
| 383 |
+
layers.19.out_proj.weight | -0.0105 | 0.3573 | 1.00 | -1.00 | 5.37
|
| 384 |
+
layers.19.qkv.weight | -0.0117 | 0.4573 | 1.00 | -1.00 | 5.56
|
| 385 |
+
layers.2.mlp.0.weight | 0.0038 | 0.5067 | 1.00 | -1.00 | 5.67
|
| 386 |
+
layers.2.mlp.3.weight | 0.0038 | 0.3484 | 1.00 | -1.00 | 5.65
|
| 387 |
+
layers.2.out_proj.weight | 0.0002 | 0.2741 | 1.00 | -1.00 | 5.38
|
| 388 |
+
layers.2.qkv.weight | 0.0065 | 0.4754 | 1.00 | -1.00 | 5.52
|
| 389 |
+
layers.20.mlp.0.weight | 0.0185 | 0.4301 | 1.00 | -1.00 | 5.66
|
| 390 |
+
layers.20.mlp.3.weight | 0.0076 | 0.4125 | 1.00 | -1.00 | 5.63
|
| 391 |
+
layers.20.out_proj.weight | -0.0074 | 0.4323 | 1.00 | -1.00 | 5.41
|
| 392 |
+
layers.20.qkv.weight | -0.0095 | 0.4289 | 1.00 | -1.00 | 5.56
|
| 393 |
+
layers.21.mlp.0.weight | 0.0185 | 0.4340 | 1.00 | -1.00 | 5.67
|
| 394 |
+
layers.21.mlp.3.weight | 0.0117 | 0.4229 | 1.00 | -1.00 | 5.63
|
| 395 |
+
layers.21.out_proj.weight | 0.0379 | 0.4586 | 1.00 | -1.00 | 5.43
|
| 396 |
+
layers.21.qkv.weight | 0.0181 | 0.4517 | 1.00 | -1.00 | 5.53
|
| 397 |
+
layers.22.mlp.0.weight | 0.0127 | 0.4215 | 1.00 | -1.00 | 5.67
|
| 398 |
+
layers.22.mlp.3.weight | -0.0142 | 0.4287 | 1.00 | -1.00 | 5.66
|
| 399 |
+
layers.22.out_proj.weight | 0.0090 | 0.4570 | 1.00 | -1.00 | 5.46
|
| 400 |
+
layers.22.qkv.weight | -0.0030 | 0.4303 | 1.00 | -1.00 | 5.53
|
| 401 |
+
layers.23.mlp.0.weight | 0.0242 | 0.4349 | 1.00 | -1.00 | 5.66
|
| 402 |
+
layers.23.mlp.3.weight | -0.0092 | 0.3675 | 1.00 | -1.00 | 5.67
|
| 403 |
+
layers.23.out_proj.weight | -0.0005 | 0.3852 | 1.00 | -1.00 | 5.38
|
| 404 |
+
layers.23.qkv.weight | 0.0115 | 0.4324 | 1.00 | -1.00 | 5.52
|
| 405 |
+
layers.3.mlp.0.weight | -0.0189 | 0.4833 | 1.00 | -1.00 | 5.67
|
| 406 |
+
layers.3.mlp.3.weight | 0.0102 | 0.3595 | 1.00 | -1.00 | 5.65
|
| 407 |
+
layers.3.out_proj.weight | -0.0025 | 0.3115 | 1.00 | -1.00 | 5.31
|
| 408 |
+
layers.3.qkv.weight | 0.0251 | 0.4530 | 1.00 | -1.00 | 5.53
|
| 409 |
+
layers.4.mlp.0.weight | 0.0041 | 0.4719 | 1.00 | -1.00 | 5.67
|
| 410 |
+
layers.4.mlp.3.weight | -0.0088 | 0.3623 | 1.00 | -1.00 | 5.65
|
| 411 |
+
layers.4.out_proj.weight | 0.0069 | 0.3526 | 1.00 | -1.00 | 5.35
|
| 412 |
+
layers.4.qkv.weight | 0.0226 | 0.4593 | 1.00 | -1.00 | 5.56
|
| 413 |
+
layers.5.mlp.0.weight | 0.0029 | 0.5050 | 1.00 | -1.00 | 5.66
|
| 414 |
+
layers.5.mlp.3.weight | 0.0079 | 0.3396 | 1.00 | -1.00 | 5.62
|
| 415 |
+
layers.5.out_proj.weight | 0.0014 | 0.3697 | 1.00 | -1.00 | 5.38
|
| 416 |
+
layers.5.qkv.weight | -0.0267 | 0.4430 | 1.00 | -1.00 | 5.56
|
| 417 |
+
layers.6.mlp.0.weight | 0.0226 | 0.4660 | 1.00 | -1.00 | 5.67
|
| 418 |
+
layers.6.mlp.3.weight | 0.0037 | 0.3739 | 1.00 | -1.00 | 5.63
|
| 419 |
+
layers.6.out_proj.weight | -0.0061 | 0.3420 | 1.00 | -1.00 | 5.34
|
| 420 |
+
layers.6.qkv.weight | 0.0197 | 0.4442 | 1.00 | -1.00 | 5.53
|
| 421 |
+
layers.7.mlp.0.weight | -0.0398 | 0.4449 | 1.00 | -1.00 | 5.66
|
| 422 |
+
layers.7.mlp.3.weight | 0.0068 | 0.3722 | 1.00 | -1.00 | 5.64
|
| 423 |
+
layers.7.out_proj.weight | -0.0282 | 0.3675 | 1.00 | -1.00 | 5.37
|
| 424 |
+
layers.7.qkv.weight | 0.0120 | 0.4651 | 1.00 | -1.00 | 5.56
|
| 425 |
+
layers.8.mlp.0.weight | -0.0240 | 0.4524 | 1.00 | -1.00 | 5.66
|
| 426 |
+
layers.8.mlp.3.weight | 0.0231 | 0.3776 | 1.00 | -1.00 | 5.60
|
| 427 |
+
layers.8.out_proj.weight | 0.0170 | 0.3391 | 1.00 | -1.00 | 5.33
|
| 428 |
+
layers.8.qkv.weight | -0.0096 | 0.4836 | 1.00 | -1.00 | 5.55
|
| 429 |
+
layers.9.mlp.0.weight | 0.0147 | 0.4235 | 1.00 | -1.00 | 5.66
|
| 430 |
+
layers.9.mlp.3.weight | 0.0186 | 0.3676 | 1.00 | -1.00 | 5.60
|
| 431 |
+
layers.9.out_proj.weight | -0.0043 | 0.3399 | 1.00 | -1.00 | 5.31
|
| 432 |
+
layers.9.qkv.weight | 0.0182 | 0.4266 | 1.00 | -1.00 | 5.55
|
| 433 |
+
output_head.weight | 0.0031 | 0.0523 | 0.06 | -0.06 | 2.68
|
drna/drna_vlayer.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torch.utils.checkpoint as checkpoint
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
'''
|
| 8 |
+
D‑RNA: Dual‑Helix Resonance Neural Architecture (DRNA) Pre-Norm・Kv-RoPE・vlayer版
|
| 9 |
+
仕様:Pre-Norm(RMSNorm)、GELU(Activation)、Kv-RoPE(head_dim)、mask(padding + causal)
|
| 10 |
+
仮想レイヤ再帰(p_layers)、Gradient Checkpointing対応
|
| 11 |
+
'''
|
| 12 |
+
|
| 13 |
+
class RMSNorm(nn.Module):
|
| 14 |
+
def __init__(self, d_model, eps=1e-6):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.eps = eps
|
| 17 |
+
self.weight = nn.Parameter(torch.ones(d_model))
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
norm = x.pow(2).mean(-1, keepdim=True)
|
| 21 |
+
x_normed = x * torch.rsqrt(norm + self.eps)
|
| 22 |
+
return self.weight * x_normed
|
| 23 |
+
|
| 24 |
+
class DRNA_RoPE(nn.Module):
|
| 25 |
+
"""二重らせんの位相を決定する回転場"""
|
| 26 |
+
def __init__(self, head_dim, base=10000):
|
| 27 |
+
super().__init__()
|
| 28 |
+
inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
|
| 29 |
+
self.register_buffer("inv_freq", inv_freq)
|
| 30 |
+
|
| 31 |
+
def forward(self, x, seq_len):
|
| 32 |
+
t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
|
| 33 |
+
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
|
| 34 |
+
emb = torch.cat((freqs, freqs), dim=-1)
|
| 35 |
+
return emb.cos()[None, None, :, :], emb.sin()[None, None, :, :]
|
| 36 |
+
|
| 37 |
+
def apply_drna_rope(q, k, cos, sin):
|
| 38 |
+
def rotate_half(x):
|
| 39 |
+
x1, x2 = x.chunk(2, dim=-1)
|
| 40 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 41 |
+
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
|
| 42 |
+
|
| 43 |
+
class DRNA_Block(nn.Module):
|
| 44 |
+
"""DRNA共鳴ブロック:安定性を高めたPre-Norm直列共鳴構造"""
|
| 45 |
+
def __init__(self, d_model, n_heads, head_dim, d_ff=None, dropout=0.1):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.n_heads = n_heads
|
| 48 |
+
self.head_dim = head_dim
|
| 49 |
+
|
| 50 |
+
# らせんA: 回想系 (Attention)
|
| 51 |
+
self.norm1 = RMSNorm(d_model)
|
| 52 |
+
self.qkv = nn.Linear(d_model, d_model * 3)
|
| 53 |
+
self.out_proj = nn.Linear(d_model, d_model)
|
| 54 |
+
|
| 55 |
+
# らせんB: 記憶系 (MLP)
|
| 56 |
+
self.norm2 = RMSNorm(d_model)
|
| 57 |
+
d_ff = d_ff or d_model * 4
|
| 58 |
+
self.mlp = nn.Sequential(
|
| 59 |
+
nn.Linear(d_model, d_ff),
|
| 60 |
+
nn.GELU(),
|
| 61 |
+
nn.Dropout(dropout),
|
| 62 |
+
nn.Linear(d_ff, d_model)
|
| 63 |
+
)
|
| 64 |
+
self.dropout = nn.Dropout(dropout)
|
| 65 |
+
|
| 66 |
+
def forward(self, x, cos, sin, mask=None):
|
| 67 |
+
b, s, d = x.shape
|
| 68 |
+
residual = x
|
| 69 |
+
|
| 70 |
+
# らせんA (Attention) 並列方式
|
| 71 |
+
x_norm1 = self.norm1(x)
|
| 72 |
+
qkv = self.qkv(x_norm1).reshape(b, s, 3, self.n_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 73 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 74 |
+
|
| 75 |
+
# K因果的動的回転
|
| 76 |
+
k_for_phase = torch.cat([torch.zeros_like(k[:, :, :1, :]), k[:, :, :-1, :]], dim=2)
|
| 77 |
+
rt_phase = torch.tanh(k_for_phase) * math.pi
|
| 78 |
+
|
| 79 |
+
# 静的RoPEを動的位相で変調
|
| 80 |
+
d_cos = (cos * torch.cos(rt_phase)) - (sin * torch.sin(rt_phase))
|
| 81 |
+
d_sin = (sin * torch.cos(rt_phase)) + (cos * torch.sin(rt_phase))
|
| 82 |
+
|
| 83 |
+
# 2重らせん回転
|
| 84 |
+
q, k = apply_drna_rope(q, k, d_cos, d_sin)
|
| 85 |
+
|
| 86 |
+
# Attention計算
|
| 87 |
+
attn = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))
|
| 88 |
+
if mask is not None:
|
| 89 |
+
attn = attn + mask
|
| 90 |
+
|
| 91 |
+
attn = F.softmax(attn, dim=-1)
|
| 92 |
+
a_out_raw = (attn @ v).transpose(1, 2).reshape(b, s, d)
|
| 93 |
+
a_out = self.out_proj(a_out_raw)
|
| 94 |
+
|
| 95 |
+
# らせんB (MLP)
|
| 96 |
+
x_norm2 = self.norm2(x)
|
| 97 |
+
m_out = self.mlp(x_norm2)
|
| 98 |
+
|
| 99 |
+
# 並列結合
|
| 100 |
+
x = residual + self.dropout(a_out) + self.dropout(m_out)
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
class DRNA_Recurrent_Model(nn.Module):
|
| 104 |
+
"""汎用 DRNA モデルコンテナ(レイヤ回帰・仮想レイヤ組み込み版)"""
|
| 105 |
+
def __init__(self, vocab_size, d_model=256, n_layers=4, p_layers=4, n_heads=8, d_ff=1024):
|
| 106 |
+
super().__init__()
|
| 107 |
+
self.embed = nn.Embedding(vocab_size, d_model)
|
| 108 |
+
self.head_dim = d_model // n_heads
|
| 109 |
+
self.rope = DRNA_RoPE(self.head_dim)
|
| 110 |
+
self.p_layers = p_layers # 仮想レイヤーの周回数(再帰回数)
|
| 111 |
+
|
| 112 |
+
self.layers = nn.ModuleList([
|
| 113 |
+
DRNA_Block(d_model, n_heads, self.head_dim, d_ff) for _ in range(n_layers)
|
| 114 |
+
])
|
| 115 |
+
|
| 116 |
+
self.final_norm = RMSNorm(d_model)
|
| 117 |
+
self.output_head = nn.Linear(d_model, vocab_size)
|
| 118 |
+
|
| 119 |
+
def forward(self, x, mask=None, pad_id=None):
|
| 120 |
+
b, s = x.shape
|
| 121 |
+
device = x.device
|
| 122 |
+
inputs = x
|
| 123 |
+
x = self.embed(x)
|
| 124 |
+
|
| 125 |
+
if mask is None or mask.sum() == 0:
|
| 126 |
+
# 退避させた inputs でパディング位置を判定
|
| 127 |
+
# パッドマスクとコーザルマスクを判定(pad_idの型チェック&テンソルバグ修正)
|
| 128 |
+
p_id = pad_id.item() if isinstance(pad_id, torch.Tensor) else pad_id
|
| 129 |
+
pad_mask = (inputs != p_id).unsqueeze(1).unsqueeze(2) if isinstance(p_id, (int, float)) else torch.ones((1, 1, 1, s), device=device, dtype=torch.bool)
|
| 130 |
+
causal = torch.triu(torch.ones(s, s, device=device), diagonal=1).bool().unsqueeze(0).unsqueeze(0)
|
| 131 |
+
|
| 132 |
+
# 環境(fp16/32)に応じた最小値を安全に自動計算
|
| 133 |
+
inf_value = torch.finfo(x.dtype).min if x.dtype != torch.float16 else -65500.0
|
| 134 |
+
|
| 135 |
+
# ゼロ初期化テンソルに無効領域をインプレースで直接埋める(カンニングの完全遮断)
|
| 136 |
+
mask = torch.zeros((b, 1, s, s), device=device, dtype=x.dtype).masked_fill_(causal | (~pad_mask), inf_value)
|
| 137 |
+
|
| 138 |
+
cos, sin = self.rope(x, x.size(1))
|
| 139 |
+
|
| 140 |
+
# 実レイヤーのループ
|
| 141 |
+
for layer in self.layers:
|
| 142 |
+
# 🌀 仮想レイヤーの再帰ループ
|
| 143 |
+
for _ in range(self.p_layers):
|
| 144 |
+
if self.training:
|
| 145 |
+
# 確実に勾配追跡を有効化するため、x のrequires_gradを確認・保障
|
| 146 |
+
if not x.requires_grad:
|
| 147 |
+
x.requires_grad_()
|
| 148 |
+
# グラディエント・チェックポインティングによるVRAM抑制
|
| 149 |
+
x = checkpoint.checkpoint(
|
| 150 |
+
layer,
|
| 151 |
+
x,
|
| 152 |
+
cos,
|
| 153 |
+
sin,
|
| 154 |
+
mask,
|
| 155 |
+
use_reentrant=False
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
x = layer(x, cos, sin, mask=mask)
|
| 159 |
+
|
| 160 |
+
x = self.final_norm(x)
|
| 161 |
+
return self.output_head(x)
|
| 162 |
+
|
| 163 |
+
'''
|
| 164 |
+
260520:p_layers による回帰で仮想レイヤをつくる(グラディエント・チェックポインティング活用/VRAM抑制)
|
| 165 |
+
「仮想レイヤ」は物理的なレイヤ数を最小限に、同一レイヤ内で自己周回(再帰)させ、安定的な効率化を実現する
|
| 166 |
+
グラディエント・チェックポインティング(GC)による再帰は中間計算(勾配)を必要時に再計算することで省VRAM化をします
|
| 167 |
+
'''
|
| 168 |
+
|
| 169 |
+
'''
|
| 170 |
+
汎用型 D-RNA (Pre-Norm) License: Apache License 2.0 https://github.com/muooon/DRNA
|
| 171 |
+
Attention is all you need_started, Resonance is all you need_endure,
|
| 172 |
+
Neocognitron ― Transformer ― D‑RNA Dream Resonance Never Adjourns — it goes on...
|
| 173 |
+
'''
|