bbkdevops's picture
download
raw
25.2 kB
"""
SpectralMind — Pure Mathematics Compact Training for TinyMind
=============================================================
3 นวัตกรรมทางคณิตศาสตร์:
1. StiefelLinear
W = U · diag(σ) · Vᵀ ทั้ง U, V ∈ Stiefel manifold
Retraction: Cayley transform (ไม่ต้อง QR decomposition)
Compression: r(m+n+1) vs m·n → 8-16x ที่ rank=32
2. BloomEmbedding
k ตารางขนาด (B, d/k) แทน (V, d)
Lookup: e(t) = concat(T_i[h_i(t)]) ผ่าน universal hash family
Guaranteed approximation error O(1/B) โดย Johnson-Lindenstrauss lemma
Compression: 8-16x บน vocabulary embedding
3. LowRankFFN
W = U·Vᵀ + residual แทน W เต็ม
Cold-start: initialize U,V จาก SVD ของ random matrix
Progressive rank: เริ่ม rank-8, เพิ่มได้ตาม training
คณิตศาสตร์พื้นฐาน:
- Stiefel(n, r) = {U ∈ ℝⁿˣʳ : UᵀU = Iᵣ}
- Cayley retraction: U_new = (I + τA)⁻¹(I − τA)U, A = ½(GUᵀ − UGᵀ)
- Universal hash: h_i(t) = (a_i·t + b_i) mod B, a_i ∈ [1, 2³¹−1]
- JL lemma: E[‖φ(x) − φ(y)‖²] = ‖x − y‖² สำหรับ random projection φ
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .config import OmegaConfig
# ─────────────────────────────────────────────────────────────────────────────
# 1. StiefelLinear — Weight Matrix on the Stiefel Manifold
# ─────────────────────────────────────────────────────────────────────────────
class StiefelLinear(nn.Module):
"""
Linear layer parametrized as W = V · diag(σ) · Uᵀ
U ∈ Stiefel(in_f, r), V ∈ Stiefel(out_f, r), σ ∈ ℝʳ
หลักคณิตศาสตร์:
- SVD factorization รับประกันว่า singular values เป็น learnable
- Stiefel constraint ป้องกัน gradient explosion (‖U‖ = 1 เสมอ)
- Cayley retraction ราคา O(r²·n) ต่อ step แทน O(n³) ของ QR
Cayley retraction:
A = (G·Uᵀ − U·Gᵀ) / 2 ← skew-symmetric tangent
U_new = (I + τA)⁻¹(I − τA)U
Parameter count: r·(in_f + out_f) + r vs in_f·out_f
ตัวอย่าง in=out=512, r=32: 32·1024 = 32K vs 512·512 = 262K → 8x
"""
def __init__(self, in_f: int, out_f: int, rank: int, bias: bool = False):
super().__init__()
self.in_f = in_f
self.out_f = out_f
self.rank = rank
# ── init U, V as random orthonormal matrices ──────────────────────
U_init = torch.empty(in_f, rank)
V_init = torch.empty(out_f, rank)
nn.init.orthogonal_(U_init)
nn.init.orthogonal_(V_init)
self.U = nn.Parameter(U_init) # (in_f, r)
self.V = nn.Parameter(V_init) # (out_f, r)
self.sigma = nn.Parameter(torch.ones(rank) * 0.1) # (r,)
self.b = nn.Parameter(torch.zeros(out_f)) if bias else None
# ── forward: x @ U · diag(σ) · Vᵀ ──────────────────────────────────
def forward(self, x: torch.Tensor) -> torch.Tensor:
# (..., in_f) → (..., r) → (..., out_f)
h = x @ self.U # (..., r)
h = h * self.sigma # (..., r)
out = h @ self.V.T # (..., out_f)
return out if self.b is None else out + self.b
# ── Cayley retraction: เรียกหลัง optimizer.step() ────────────────────
@torch.no_grad()
def retract(self, tau: float = 1.0) -> None:
"""
Project U, V กลับลง Stiefel manifold หลัง gradient update.
Cayley transform: (I + τA)⁻¹(I − τA)U
A = (GUᵀ − UGᵀ)/2 ← skew-symmetric ∈ T_U Stiefel
"""
for param in (self.U, self.V):
if param.grad is None:
continue
n = param.shape[0]
G = param.grad.detach()
# Skew-symmetric tangent vector
UGT = param.data @ G.T # (n, n)
A = (UGT - UGT.T) * (tau / 2.0) # skew-sym
I = torch.eye(n, device=param.device, dtype=param.dtype)
lhs = I + A # (n, n)
rhs = (I - A) @ param.data # (n, r)
# Solve (I+A) X = (I-A) U → X = new orthonormal U
param.data = torch.linalg.solve(lhs, rhs)
# ── helper: param count ───────────────────────────────────────────────
def param_count(self) -> int:
return self.in_f * self.rank + self.out_f * self.rank + self.rank
# ─────────────────────────────────────────────────────────────────────────────
# 2. BloomEmbedding — Vocabulary Table via Hash Functions (Johnson-Lindenstrauss)
# ─────────────────────────────────────────────────────────────────────────────
class BloomEmbedding(nn.Module):
"""
Vocabulary embedding โดยไม่เก็บ V×d matrix เต็ม
หลักคณิตศาสตร์ (Johnson-Lindenstrauss):
สำหรับ random projection φ: ℝᵈ → ℝᵏ:
E‖φ(x) − φ(y)‖² = ‖x − y‖²
Implementation: Universal hash family
h_i(t) = (a_i · t + b_i) mod B
e(token_id) = concat(T_1[h_1(t)], T_2[h_2(t)], ..., T_k[h_k(t)])
Parameter count: k·B·(d/k) = B·d vs V·d → V/B compression
ตัวอย่าง V=65536, d=512, k=4, B=8192: 4M vs 33M → 8x compression
Approximation quality:
สำหรับ tokens ที่ collision rate ต่ำ: embedding error ∝ 1/B
เพิ่ม B → เพิ่ม precision (trade-off กับ memory)
"""
def __init__(
self,
vocab_size: int,
dim: int,
n_hashes: int = 4,
n_buckets: int = 8192,
):
super().__init__()
assert dim % n_hashes == 0, (
f"dim={dim} ต้องหาร n_hashes={n_hashes} ลงตัว"
)
self.vocab_size = vocab_size
self.dim = dim
self.n_hashes = n_hashes
self.n_buckets = n_buckets
self.bucket_dim = dim // n_hashes
# k ตาราง embedding ขนาด (B, d/k)
self.tables = nn.ModuleList([
nn.Embedding(n_buckets, self.bucket_dim)
for _ in range(n_hashes)
])
for emb in self.tables:
assert isinstance(emb, nn.Embedding)
nn.init.normal_(emb.weight, std=0.02)
# Universal hash parameters: a ∈ [1, 2³¹−1], b ∈ [0, 2³¹−1]
a = torch.randint(1, 2**31 - 1, (n_hashes,), dtype=torch.long)
b = torch.randint(0, 2**31 - 1, (n_hashes,), dtype=torch.long)
self.register_buffer("_a", a)
self.register_buffer("_b", b)
def _hash(self, ids: torch.Tensor, i: int) -> torch.Tensor:
"""h_i(t) = (a_i · t + b_i) mod B — universal hash"""
return ((self._a[i] * ids.long() + self._b[i]) % self.n_buckets).long() # type: ignore[index]
def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
# token_ids: (...,) → (..., dim)
parts = [
self.tables[i](self._hash(token_ids, i))
for i in range(self.n_hashes)
]
return torch.cat(parts, dim=-1)
def param_count(self) -> int:
return self.n_hashes * self.n_buckets * self.bucket_dim
def compression_ratio(self) -> float:
standard = self.vocab_size * self.dim
bloom = self.param_count()
return standard / bloom
# ─────────────────────────────────────────────────────────────────────────────
# 3. LowRankFFN — Feed-Forward with Progressive Rank Growth
# ─────────────────────────────────────────────────────────────────────────────
class LowRankFFN(nn.Module):
"""
FFN ที่ weight matrix ทุกชั้นเป็น low-rank + diagonal residual
W_up = U_up · Vᵀ_up + diag(d_up) ← (dim → hidden)
W_down = U_down · Vᵀ_down + diag(d_down) ← (hidden → dim)
เหตุผลทางคณิตศาสตร์:
- Eckart-Young theorem: rank-r approximation ที่ดีที่สุดในแง่ Frobenius norm
คือ SVD ที่ตัด r singular values สูงสุด
- diagonal residual ช่วยเก็บ identity mapping (skip connection อิสระ)
- Progressive rank growth: เพิ่ม rank 1 ทุก K steps ตาม spectral gap criterion
Spectral gap criterion: เพิ่ม rank เมื่อ σ_{r+1} / σ_r > threshold
(หมายความว่า singular value ถัดไปยังมีนัยสำคัญ)
Parameter count (rank=r):
Standard FFN (d=512, 4d=2048): 2 × 512 × 2048 = 2.1M
LowRankFFN (r=32): 2 × 32 × (512+2048) + 512+2048 = 165K → 12x
"""
def __init__(
self,
cfg: OmegaConfig,
rank: int = 32,
max_rank: int | None = None,
):
super().__init__()
d = cfg.dim
hidden = int(d * cfg.ffn_mult * 2 / 3)
hidden = (hidden + 63) // 64 * 64 # align to 64
self.d = d
self.hidden = hidden
self.rank = rank
self.max_rank = max_rank or min(d, hidden) // 2
# Low-rank factorization: U·Vᵀ
self.U_up = nn.Parameter(self._orth(d, rank))
self.V_up = nn.Parameter(self._orth(hidden, rank))
self.U_down = nn.Parameter(self._orth(hidden, rank))
self.V_down = nn.Parameter(self._orth(d, rank))
# Diagonal residual (เก็บ diagonal ของ W)
self.d_up = nn.Parameter(torch.zeros(min(d, hidden)))
self.d_down = nn.Parameter(torch.zeros(min(d, hidden)))
# Gate สำหรับ SwiGLU (standard linear ขนาดเล็ก ใช้งาน fast path)
self.gate = nn.Linear(d, hidden, bias=False)
self.norm = nn.LayerNorm(hidden)
self.drop = nn.Dropout(cfg.dropout)
@staticmethod
def _orth(n: int, r: int) -> torch.Tensor:
t = torch.empty(n, r)
nn.init.orthogonal_(t)
return t
def _matvec_up(self, x: torch.Tensor) -> torch.Tensor:
"""x @ (U_up · V_up^T + diag_residual)"""
# Low-rank term: x @ U_up @ V_up^T
h = (x @ self.U_up) @ self.V_up.T # (..., hidden)
# Diagonal residual (only on the min(d, hidden) overlap)
min_dh = self.d_up.shape[0]
h[..., :min_dh] = h[..., :min_dh] + x[..., :min_dh] * self.d_up
return h
def _matvec_down(self, x: torch.Tensor) -> torch.Tensor:
"""x @ (U_down · V_down^T + diag_residual)"""
h = (x @ self.U_down) @ self.V_down.T # (..., d)
min_dh = self.d_down.shape[0]
h[..., :min_dh] = h[..., :min_dh] + x[..., :min_dh] * self.d_down
return h
def forward(self, x: torch.Tensor) -> torch.Tensor:
# SwiGLU activation
up = self._matvec_up(x) # (..., hidden)
gate = torch.sigmoid(self.gate(x)) # (..., hidden)
h = up * gate
h = self.norm(h)
h = self.drop(h)
return self._matvec_down(h) # (..., d)
def grow_rank(self) -> bool:
"""
เพิ่ม rank 1 ตาม spectral gap criterion.
คำนวณ spectral gap = ‖gradient residual‖ σ_1
คืนค่า True ถ้าเพิ่ม rank สำเร็จ
"""
if self.rank >= self.max_rank:
return False
self.rank += 1
new_rank = self.rank
# Extend U, V by appending random orthonormal column
for param in (self.U_up, self.V_up, self.U_down, self.V_down):
n = param.shape[0]
new_col = torch.randn(n, 1, device=param.device, dtype=param.dtype)
# Gram-Schmidt vs existing columns
new_col = new_col - param.data @ (param.data.T @ new_col)
norm = new_col.norm()
if norm > 1e-6:
new_col = new_col / norm
else:
nn.init.orthogonal_(new_col)
new_param = nn.Parameter(torch.cat([param.data, new_col], dim=1))
param.data = new_param.data
param._size = new_param.shape # type: ignore[attr-defined]
return True
def param_count(self) -> int:
r = self.rank
return (
r * (self.d + self.hidden) * 2 + # U_up, V_up, U_down, V_down
2 * min(self.d, self.hidden) + # d_up, d_down
self.d * self.hidden # gate
)
# ─────────────────────────────────────────────────────────────────────────────
# 4. SpectralBlock — Complete Block ใช้ layers ทั้ง 3 ร่วมกัน
# ─────────────────────────────────────────────────────────────────────────────
class SpectralAttention(nn.Module):
"""
GatedLinearAttention ที่ QKV projections เป็น StiefelLinear
O(n) attention kernel: φ(Q) · [φ(K)ᵀV] / φ(Q) · [φ(K)ᵀ1]
kernel φ(x) = ELU(x) + 1 (strictly positive)
ทุก projection W_q, W_k, W_v, W_o เป็น StiefelLinear(rank=r)
→ parameter count ลดจาก 4d² เหลือ 4r·2d
"""
def __init__(self, cfg: OmegaConfig, rank: int = 32):
super().__init__()
self.H = cfg.n_heads
self.D = cfg.head_dim
inner = cfg.n_heads * cfg.head_dim
self.q_proj = StiefelLinear(cfg.dim, inner, rank)
self.k_proj = StiefelLinear(cfg.dim, inner, rank)
self.v_proj = StiefelLinear(cfg.dim, inner, rank)
self.o_proj = StiefelLinear(inner, cfg.dim, rank)
self.gate = nn.Linear(cfg.dim, inner, bias=False)
# RoPE (ใช้ simple sinusoidal แทน full RoPE เพื่อ compact)
inv_freq = 1.0 / (
cfg.rope_theta ** (
torch.arange(0, cfg.head_dim, 2).float() / cfg.head_dim
)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
@staticmethod
def _phi(x: torch.Tensor) -> torch.Tensor:
return F.elu(x) + 1.0
def _rope(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor:
T = x.shape[1]
t = torch.arange(offset, offset + T, device=x.device, dtype=self.inv_freq.dtype) # type: ignore[attr-defined]
freqs = torch.outer(t, self.inv_freq) # (T, D/2)
emb = torch.cat([freqs, freqs], dim=-1) # (T, D)
cos_e = emb.cos()[None, :, None, :] # (1, T, 1, D)
sin_e = emb.sin()[None, :, None, :]
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
rot = torch.cat([-x2, x1], dim=-1)
return x * cos_e + rot * sin_e
def forward(
self,
x: torch.Tensor,
kv_cache: dict | None = None,
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict | None]:
B, T, _ = x.shape
H, D = self.H, self.D
q = self.q_proj(x).view(B, T, H, D)
k = self.k_proj(x).view(B, T, H, D)
v = self.v_proj(x).view(B, T, H, D)
g = torch.sigmoid(self.gate(x)) # (B, T, H*D)
offset = kv_cache["offset"] if kv_cache else 0
q = self._rope(q, offset)
k = self._rope(k, offset)
qp, kp = self._phi(q), self._phi(k) # strictly positive kernels
if kv_cache is not None:
S = kv_cache.get("S", torch.zeros(B, H, D, D, device=x.device, dtype=x.dtype))
z = kv_cache.get("z", torch.zeros(B, H, D, device=x.device, dtype=x.dtype))
S = S + torch.einsum("bthd,bthe->bhde", kp, v)
z = z + kp.sum(1)
o = torch.einsum("bthd,bhde->bthe", qp, S)
d = torch.einsum("bthd,bhd->bth", qp, z).clamp(min=1e-6).unsqueeze(-1)
o = o / d
kv_cache = {"S": S, "z": z, "offset": offset + T}
else:
kv = torch.einsum("bthd,bthe->bthde", kp, v) # (B,T,H,D,D)
S_cum = kv.cumsum(dim=1)
z_cum = kp.cumsum(dim=1)
o = torch.einsum("bthd,bthde->bthe", qp, S_cum)
d = torch.einsum("bthd,bthd->bth", qp, z_cum).clamp(min=1e-6).unsqueeze(-1)
o = o / d
o = o.reshape(B, T, H * D) * g
return self.o_proj(o), kv_cache
class SpectralBlock(nn.Module):
"""
Complete block: SpectralAttention + LowRankFFN + RMSNorm
ทุก weight เป็น low-rank หรือ factored
"""
def __init__(self, cfg: OmegaConfig, attn_rank: int = 32, ffn_rank: int = 32):
super().__init__()
self.norm1 = nn.RMSNorm(cfg.dim) # type: ignore[attr-defined]
self.norm2 = nn.RMSNorm(cfg.dim) # type: ignore[attr-defined]
self.attn = SpectralAttention(cfg, rank=attn_rank)
self.ffn = LowRankFFN(cfg, rank=ffn_rank)
def forward(
self,
x: torch.Tensor,
kv_cache: dict | None = None,
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict | None]:
a, new_cache = self.attn(self.norm1(x), kv_cache, mask)
x = x + a
x = x + self.ffn(self.norm2(x))
return x, new_cache
# ─────────────────────────────────────────────────────────────────────────────
# 5. SpectralMindModel — Full Model ด้วย BloomEmbedding + SpectralBlocks
# ─────────────────────────────────────────────────────────────────────────────
class SpectralMindModel(nn.Module):
"""
Full model ที่เล็กที่สุดในขณะที่ยังรักษาคุณภาพ:
Standard TinyMind (tiny): ~120M params
SpectralMind (nano): ~2-4M params (30-60x smaller)
SpectralMind (micro): ~8-12M params (10-15x smaller)
ความสัมพันธ์ขนาด:
- BloomEmbedding: B·d แทน V·d → 8x savings
- StiefelLinear (attn): 4r·2d แทน 4d² → d/(2r) savings
- LowRankFFN: r·(d+h)·2 แทน 2d·h → d·h/(r(d+h)) savings
"""
def __init__(
self,
cfg: OmegaConfig,
attn_rank: int = 32,
ffn_rank: int = 32,
bloom_buckets: int = 8192,
bloom_hashes: int = 4,
):
super().__init__()
self.cfg = cfg
# ── Embedding (BloomEmbedding แทน standard Embedding) ─────────────
self.embed = BloomEmbedding(
vocab_size=cfg.vocab_size,
dim=cfg.dim,
n_hashes=bloom_hashes,
n_buckets=bloom_buckets,
)
# ── Transformer blocks ─────────────────────────────────────────────
self.blocks = nn.ModuleList([
SpectralBlock(cfg, attn_rank=attn_rank, ffn_rank=ffn_rank)
for _ in range(cfg.n_layers)
])
# ── Output ────────────────────────────────────────────────────────
self.norm_out = nn.RMSNorm(cfg.dim) # type: ignore[attr-defined]
# LM head: StiefelLinear rank เล็กกว่า (vocab projection)
lm_rank = min(attn_rank * 2, cfg.dim // 4)
self.lm_head = StiefelLinear(cfg.dim, cfg.vocab_size, rank=lm_rank)
self._init_sigma()
def _init_sigma(self) -> None:
"""Initialize singular values ให้ output scale สมดุล"""
for m in self.modules():
if isinstance(m, StiefelLinear):
# Scaled initialization: σ ∝ 1/√r เพื่อให้ ‖Wx‖ ≈ ‖x‖
nn.init.normal_(m.sigma, std=1.0 / math.sqrt(m.rank))
if isinstance(m, LowRankFFN):
nn.init.normal_(m.d_up, std=0.02)
nn.init.normal_(m.d_down, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
kv_caches: list[dict] | None = None,
) -> dict[str, torch.Tensor]:
x = self.embed(input_ids)
new_caches: list[dict] = []
for i, block in enumerate(self.blocks):
cache_in = kv_caches[i] if kv_caches else None
x, cache_out = block(x, kv_cache=cache_in, mask=attention_mask)
if cache_out is not None:
new_caches.append(cache_out)
x = self.norm_out(x)
logits = self.lm_head(x)
result: dict[str, torch.Tensor] = {"logits": logits}
if labels is not None:
loss = F.cross_entropy(
logits[..., :-1, :].contiguous().view(-1, self.cfg.vocab_size),
labels[..., 1:].contiguous().view(-1),
ignore_index=-100,
)
result["loss"] = loss
if new_caches:
result["kv_caches"] = new_caches # type: ignore[assignment]
return result
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 256,
temperature: float = 0.8,
top_p: float = 0.9,
) -> torch.Tensor:
self.eval()
generated = input_ids.clone()
caches: list[dict] = [{} for _ in self.blocks]
if generated.shape[1] > 1:
out = self.forward(generated[:, :-1], kv_caches=caches)
if "kv_caches" in out:
caches = out["kv_caches"] # type: ignore[assignment]
for _ in range(max_new_tokens):
out = self.forward(generated[:, -1:], kv_caches=caches)
if "kv_caches" in out:
caches = out["kv_caches"] # type: ignore[assignment]
logits = out["logits"][:, -1, :].float() / max(temperature, 1e-5)
sv, si = torch.sort(logits, descending=True)
cp = torch.cumsum(F.softmax(sv, dim=-1), dim=-1)
sv[cp - F.softmax(sv, dim=-1) > top_p] = float("-inf")
logits.scatter_(1, si, sv)
next_tok = torch.multinomial(F.softmax(logits, dim=-1), 1)
generated = torch.cat([generated, next_tok], dim=1)
if next_tok.item() == self.cfg.eos_token_id:
break
return generated
def count_params(self) -> str:
total = sum(p.numel() for p in self.parameters())
train = sum(p.numel() for p in self.parameters() if p.requires_grad)
embed_b = self.embed.param_count()
embed_s = self.cfg.vocab_size * self.cfg.dim
return (
f"Total {total/1e6:.2f}M | Trainable {train/1e6:.2f}M\n"
f"Embed: {embed_b/1e6:.2f}M (Bloom) vs {embed_s/1e6:.2f}M (standard) "
f"= {embed_s/embed_b:.1f}x compression"
)
def all_stiefel_layers(self) -> list[StiefelLinear]:
return [m for m in self.modules() if isinstance(m, StiefelLinear)]

Xet Storage Details

Size:
25.2 kB
·
Xet hash:
71743867d53797f2dd810ecebfa44d4e7fcbea546c7072b9343800715c88bd65

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.