Buckets:
bbkdevops/unicosys-hypergraph-bucket / tinymind-native-colab-handoff /bundle /model /spectral_compact.py
| """ | |
| SpectralMind — Pure Mathematics Compact Training for TinyMind | |
| ============================================================= | |
| 3 นวัตกรรมทางคณิตศาสตร์: | |
| 1. StiefelLinear | |
| W = U · diag(σ) · Vᵀ ทั้ง U, V ∈ Stiefel manifold | |
| Retraction: Cayley transform (ไม่ต้อง QR decomposition) | |
| Compression: r(m+n+1) vs m·n → 8-16x ที่ rank=32 | |
| 2. BloomEmbedding | |
| k ตารางขนาด (B, d/k) แทน (V, d) | |
| Lookup: e(t) = concat(T_i[h_i(t)]) ผ่าน universal hash family | |
| Guaranteed approximation error O(1/B) โดย Johnson-Lindenstrauss lemma | |
| Compression: 8-16x บน vocabulary embedding | |
| 3. LowRankFFN | |
| W = U·Vᵀ + residual แทน W เต็ม | |
| Cold-start: initialize U,V จาก SVD ของ random matrix | |
| Progressive rank: เริ่ม rank-8, เพิ่มได้ตาม training | |
| คณิตศาสตร์พื้นฐาน: | |
| - Stiefel(n, r) = {U ∈ ℝⁿˣʳ : UᵀU = Iᵣ} | |
| - Cayley retraction: U_new = (I + τA)⁻¹(I − τA)U, A = ½(GUᵀ − UGᵀ) | |
| - Universal hash: h_i(t) = (a_i·t + b_i) mod B, a_i ∈ [1, 2³¹−1] | |
| - JL lemma: E[‖φ(x) − φ(y)‖²] = ‖x − y‖² สำหรับ random projection φ | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from .config import OmegaConfig | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # 1. StiefelLinear — Weight Matrix on the Stiefel Manifold | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class StiefelLinear(nn.Module): | |
| """ | |
| Linear layer parametrized as W = V · diag(σ) · Uᵀ | |
| U ∈ Stiefel(in_f, r), V ∈ Stiefel(out_f, r), σ ∈ ℝʳ | |
| หลักคณิตศาสตร์: | |
| - SVD factorization รับประกันว่า singular values เป็น learnable | |
| - Stiefel constraint ป้องกัน gradient explosion (‖U‖ = 1 เสมอ) | |
| - Cayley retraction ราคา O(r²·n) ต่อ step แทน O(n³) ของ QR | |
| Cayley retraction: | |
| A = (G·Uᵀ − U·Gᵀ) / 2 ← skew-symmetric tangent | |
| U_new = (I + τA)⁻¹(I − τA)U | |
| Parameter count: r·(in_f + out_f) + r vs in_f·out_f | |
| ตัวอย่าง in=out=512, r=32: 32·1024 = 32K vs 512·512 = 262K → 8x | |
| """ | |
| def __init__(self, in_f: int, out_f: int, rank: int, bias: bool = False): | |
| super().__init__() | |
| self.in_f = in_f | |
| self.out_f = out_f | |
| self.rank = rank | |
| # ── init U, V as random orthonormal matrices ────────────────────── | |
| U_init = torch.empty(in_f, rank) | |
| V_init = torch.empty(out_f, rank) | |
| nn.init.orthogonal_(U_init) | |
| nn.init.orthogonal_(V_init) | |
| self.U = nn.Parameter(U_init) # (in_f, r) | |
| self.V = nn.Parameter(V_init) # (out_f, r) | |
| self.sigma = nn.Parameter(torch.ones(rank) * 0.1) # (r,) | |
| self.b = nn.Parameter(torch.zeros(out_f)) if bias else None | |
| # ── forward: x @ U · diag(σ) · Vᵀ ────────────────────────────────── | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # (..., in_f) → (..., r) → (..., out_f) | |
| h = x @ self.U # (..., r) | |
| h = h * self.sigma # (..., r) | |
| out = h @ self.V.T # (..., out_f) | |
| return out if self.b is None else out + self.b | |
| # ── Cayley retraction: เรียกหลัง optimizer.step() ──────────────────── | |
| def retract(self, tau: float = 1.0) -> None: | |
| """ | |
| Project U, V กลับลง Stiefel manifold หลัง gradient update. | |
| Cayley transform: (I + τA)⁻¹(I − τA)U | |
| A = (GUᵀ − UGᵀ)/2 ← skew-symmetric ∈ T_U Stiefel | |
| """ | |
| for param in (self.U, self.V): | |
| if param.grad is None: | |
| continue | |
| n = param.shape[0] | |
| G = param.grad.detach() | |
| # Skew-symmetric tangent vector | |
| UGT = param.data @ G.T # (n, n) | |
| A = (UGT - UGT.T) * (tau / 2.0) # skew-sym | |
| I = torch.eye(n, device=param.device, dtype=param.dtype) | |
| lhs = I + A # (n, n) | |
| rhs = (I - A) @ param.data # (n, r) | |
| # Solve (I+A) X = (I-A) U → X = new orthonormal U | |
| param.data = torch.linalg.solve(lhs, rhs) | |
| # ── helper: param count ─────────────────────────────────────────────── | |
| def param_count(self) -> int: | |
| return self.in_f * self.rank + self.out_f * self.rank + self.rank | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # 2. BloomEmbedding — Vocabulary Table via Hash Functions (Johnson-Lindenstrauss) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class BloomEmbedding(nn.Module): | |
| """ | |
| Vocabulary embedding โดยไม่เก็บ V×d matrix เต็ม | |
| หลักคณิตศาสตร์ (Johnson-Lindenstrauss): | |
| สำหรับ random projection φ: ℝᵈ → ℝᵏ: | |
| E‖φ(x) − φ(y)‖² = ‖x − y‖² | |
| Implementation: Universal hash family | |
| h_i(t) = (a_i · t + b_i) mod B | |
| e(token_id) = concat(T_1[h_1(t)], T_2[h_2(t)], ..., T_k[h_k(t)]) | |
| Parameter count: k·B·(d/k) = B·d vs V·d → V/B compression | |
| ตัวอย่าง V=65536, d=512, k=4, B=8192: 4M vs 33M → 8x compression | |
| Approximation quality: | |
| สำหรับ tokens ที่ collision rate ต่ำ: embedding error ∝ 1/B | |
| เพิ่ม B → เพิ่ม precision (trade-off กับ memory) | |
| """ | |
| def __init__( | |
| self, | |
| vocab_size: int, | |
| dim: int, | |
| n_hashes: int = 4, | |
| n_buckets: int = 8192, | |
| ): | |
| super().__init__() | |
| assert dim % n_hashes == 0, ( | |
| f"dim={dim} ต้องหาร n_hashes={n_hashes} ลงตัว" | |
| ) | |
| self.vocab_size = vocab_size | |
| self.dim = dim | |
| self.n_hashes = n_hashes | |
| self.n_buckets = n_buckets | |
| self.bucket_dim = dim // n_hashes | |
| # k ตาราง embedding ขนาด (B, d/k) | |
| self.tables = nn.ModuleList([ | |
| nn.Embedding(n_buckets, self.bucket_dim) | |
| for _ in range(n_hashes) | |
| ]) | |
| for emb in self.tables: | |
| assert isinstance(emb, nn.Embedding) | |
| nn.init.normal_(emb.weight, std=0.02) | |
| # Universal hash parameters: a ∈ [1, 2³¹−1], b ∈ [0, 2³¹−1] | |
| a = torch.randint(1, 2**31 - 1, (n_hashes,), dtype=torch.long) | |
| b = torch.randint(0, 2**31 - 1, (n_hashes,), dtype=torch.long) | |
| self.register_buffer("_a", a) | |
| self.register_buffer("_b", b) | |
| def _hash(self, ids: torch.Tensor, i: int) -> torch.Tensor: | |
| """h_i(t) = (a_i · t + b_i) mod B — universal hash""" | |
| return ((self._a[i] * ids.long() + self._b[i]) % self.n_buckets).long() # type: ignore[index] | |
| def forward(self, token_ids: torch.Tensor) -> torch.Tensor: | |
| # token_ids: (...,) → (..., dim) | |
| parts = [ | |
| self.tables[i](self._hash(token_ids, i)) | |
| for i in range(self.n_hashes) | |
| ] | |
| return torch.cat(parts, dim=-1) | |
| def param_count(self) -> int: | |
| return self.n_hashes * self.n_buckets * self.bucket_dim | |
| def compression_ratio(self) -> float: | |
| standard = self.vocab_size * self.dim | |
| bloom = self.param_count() | |
| return standard / bloom | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # 3. LowRankFFN — Feed-Forward with Progressive Rank Growth | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class LowRankFFN(nn.Module): | |
| """ | |
| FFN ที่ weight matrix ทุกชั้นเป็น low-rank + diagonal residual | |
| W_up = U_up · Vᵀ_up + diag(d_up) ← (dim → hidden) | |
| W_down = U_down · Vᵀ_down + diag(d_down) ← (hidden → dim) | |
| เหตุผลทางคณิตศาสตร์: | |
| - Eckart-Young theorem: rank-r approximation ที่ดีที่สุดในแง่ Frobenius norm | |
| คือ SVD ที่ตัด r singular values สูงสุด | |
| - diagonal residual ช่วยเก็บ identity mapping (skip connection อิสระ) | |
| - Progressive rank growth: เพิ่ม rank 1 ทุก K steps ตาม spectral gap criterion | |
| Spectral gap criterion: เพิ่ม rank เมื่อ σ_{r+1} / σ_r > threshold | |
| (หมายความว่า singular value ถัดไปยังมีนัยสำคัญ) | |
| Parameter count (rank=r): | |
| Standard FFN (d=512, 4d=2048): 2 × 512 × 2048 = 2.1M | |
| LowRankFFN (r=32): 2 × 32 × (512+2048) + 512+2048 = 165K → 12x | |
| """ | |
| def __init__( | |
| self, | |
| cfg: OmegaConfig, | |
| rank: int = 32, | |
| max_rank: int | None = None, | |
| ): | |
| super().__init__() | |
| d = cfg.dim | |
| hidden = int(d * cfg.ffn_mult * 2 / 3) | |
| hidden = (hidden + 63) // 64 * 64 # align to 64 | |
| self.d = d | |
| self.hidden = hidden | |
| self.rank = rank | |
| self.max_rank = max_rank or min(d, hidden) // 2 | |
| # Low-rank factorization: U·Vᵀ | |
| self.U_up = nn.Parameter(self._orth(d, rank)) | |
| self.V_up = nn.Parameter(self._orth(hidden, rank)) | |
| self.U_down = nn.Parameter(self._orth(hidden, rank)) | |
| self.V_down = nn.Parameter(self._orth(d, rank)) | |
| # Diagonal residual (เก็บ diagonal ของ W) | |
| self.d_up = nn.Parameter(torch.zeros(min(d, hidden))) | |
| self.d_down = nn.Parameter(torch.zeros(min(d, hidden))) | |
| # Gate สำหรับ SwiGLU (standard linear ขนาดเล็ก ใช้งาน fast path) | |
| self.gate = nn.Linear(d, hidden, bias=False) | |
| self.norm = nn.LayerNorm(hidden) | |
| self.drop = nn.Dropout(cfg.dropout) | |
| def _orth(n: int, r: int) -> torch.Tensor: | |
| t = torch.empty(n, r) | |
| nn.init.orthogonal_(t) | |
| return t | |
| def _matvec_up(self, x: torch.Tensor) -> torch.Tensor: | |
| """x @ (U_up · V_up^T + diag_residual)""" | |
| # Low-rank term: x @ U_up @ V_up^T | |
| h = (x @ self.U_up) @ self.V_up.T # (..., hidden) | |
| # Diagonal residual (only on the min(d, hidden) overlap) | |
| min_dh = self.d_up.shape[0] | |
| h[..., :min_dh] = h[..., :min_dh] + x[..., :min_dh] * self.d_up | |
| return h | |
| def _matvec_down(self, x: torch.Tensor) -> torch.Tensor: | |
| """x @ (U_down · V_down^T + diag_residual)""" | |
| h = (x @ self.U_down) @ self.V_down.T # (..., d) | |
| min_dh = self.d_down.shape[0] | |
| h[..., :min_dh] = h[..., :min_dh] + x[..., :min_dh] * self.d_down | |
| return h | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # SwiGLU activation | |
| up = self._matvec_up(x) # (..., hidden) | |
| gate = torch.sigmoid(self.gate(x)) # (..., hidden) | |
| h = up * gate | |
| h = self.norm(h) | |
| h = self.drop(h) | |
| return self._matvec_down(h) # (..., d) | |
| def grow_rank(self) -> bool: | |
| """ | |
| เพิ่ม rank 1 ตาม spectral gap criterion. | |
| คำนวณ spectral gap = ‖gradient residual‖ σ_1 | |
| คืนค่า True ถ้าเพิ่ม rank สำเร็จ | |
| """ | |
| if self.rank >= self.max_rank: | |
| return False | |
| self.rank += 1 | |
| new_rank = self.rank | |
| # Extend U, V by appending random orthonormal column | |
| for param in (self.U_up, self.V_up, self.U_down, self.V_down): | |
| n = param.shape[0] | |
| new_col = torch.randn(n, 1, device=param.device, dtype=param.dtype) | |
| # Gram-Schmidt vs existing columns | |
| new_col = new_col - param.data @ (param.data.T @ new_col) | |
| norm = new_col.norm() | |
| if norm > 1e-6: | |
| new_col = new_col / norm | |
| else: | |
| nn.init.orthogonal_(new_col) | |
| new_param = nn.Parameter(torch.cat([param.data, new_col], dim=1)) | |
| param.data = new_param.data | |
| param._size = new_param.shape # type: ignore[attr-defined] | |
| return True | |
| def param_count(self) -> int: | |
| r = self.rank | |
| return ( | |
| r * (self.d + self.hidden) * 2 + # U_up, V_up, U_down, V_down | |
| 2 * min(self.d, self.hidden) + # d_up, d_down | |
| self.d * self.hidden # gate | |
| ) | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # 4. SpectralBlock — Complete Block ใช้ layers ทั้ง 3 ร่วมกัน | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class SpectralAttention(nn.Module): | |
| """ | |
| GatedLinearAttention ที่ QKV projections เป็น StiefelLinear | |
| O(n) attention kernel: φ(Q) · [φ(K)ᵀV] / φ(Q) · [φ(K)ᵀ1] | |
| kernel φ(x) = ELU(x) + 1 (strictly positive) | |
| ทุก projection W_q, W_k, W_v, W_o เป็น StiefelLinear(rank=r) | |
| → parameter count ลดจาก 4d² เหลือ 4r·2d | |
| """ | |
| def __init__(self, cfg: OmegaConfig, rank: int = 32): | |
| super().__init__() | |
| self.H = cfg.n_heads | |
| self.D = cfg.head_dim | |
| inner = cfg.n_heads * cfg.head_dim | |
| self.q_proj = StiefelLinear(cfg.dim, inner, rank) | |
| self.k_proj = StiefelLinear(cfg.dim, inner, rank) | |
| self.v_proj = StiefelLinear(cfg.dim, inner, rank) | |
| self.o_proj = StiefelLinear(inner, cfg.dim, rank) | |
| self.gate = nn.Linear(cfg.dim, inner, bias=False) | |
| # RoPE (ใช้ simple sinusoidal แทน full RoPE เพื่อ compact) | |
| inv_freq = 1.0 / ( | |
| cfg.rope_theta ** ( | |
| torch.arange(0, cfg.head_dim, 2).float() / cfg.head_dim | |
| ) | |
| ) | |
| self.register_buffer("inv_freq", inv_freq, persistent=False) | |
| def _phi(x: torch.Tensor) -> torch.Tensor: | |
| return F.elu(x) + 1.0 | |
| def _rope(self, x: torch.Tensor, offset: int = 0) -> torch.Tensor: | |
| T = x.shape[1] | |
| t = torch.arange(offset, offset + T, device=x.device, dtype=self.inv_freq.dtype) # type: ignore[attr-defined] | |
| freqs = torch.outer(t, self.inv_freq) # (T, D/2) | |
| emb = torch.cat([freqs, freqs], dim=-1) # (T, D) | |
| cos_e = emb.cos()[None, :, None, :] # (1, T, 1, D) | |
| sin_e = emb.sin()[None, :, None, :] | |
| x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] | |
| rot = torch.cat([-x2, x1], dim=-1) | |
| return x * cos_e + rot * sin_e | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| kv_cache: dict | None = None, | |
| mask: torch.Tensor | None = None, | |
| ) -> tuple[torch.Tensor, dict | None]: | |
| B, T, _ = x.shape | |
| H, D = self.H, self.D | |
| q = self.q_proj(x).view(B, T, H, D) | |
| k = self.k_proj(x).view(B, T, H, D) | |
| v = self.v_proj(x).view(B, T, H, D) | |
| g = torch.sigmoid(self.gate(x)) # (B, T, H*D) | |
| offset = kv_cache["offset"] if kv_cache else 0 | |
| q = self._rope(q, offset) | |
| k = self._rope(k, offset) | |
| qp, kp = self._phi(q), self._phi(k) # strictly positive kernels | |
| if kv_cache is not None: | |
| S = kv_cache.get("S", torch.zeros(B, H, D, D, device=x.device, dtype=x.dtype)) | |
| z = kv_cache.get("z", torch.zeros(B, H, D, device=x.device, dtype=x.dtype)) | |
| S = S + torch.einsum("bthd,bthe->bhde", kp, v) | |
| z = z + kp.sum(1) | |
| o = torch.einsum("bthd,bhde->bthe", qp, S) | |
| d = torch.einsum("bthd,bhd->bth", qp, z).clamp(min=1e-6).unsqueeze(-1) | |
| o = o / d | |
| kv_cache = {"S": S, "z": z, "offset": offset + T} | |
| else: | |
| kv = torch.einsum("bthd,bthe->bthde", kp, v) # (B,T,H,D,D) | |
| S_cum = kv.cumsum(dim=1) | |
| z_cum = kp.cumsum(dim=1) | |
| o = torch.einsum("bthd,bthde->bthe", qp, S_cum) | |
| d = torch.einsum("bthd,bthd->bth", qp, z_cum).clamp(min=1e-6).unsqueeze(-1) | |
| o = o / d | |
| o = o.reshape(B, T, H * D) * g | |
| return self.o_proj(o), kv_cache | |
| class SpectralBlock(nn.Module): | |
| """ | |
| Complete block: SpectralAttention + LowRankFFN + RMSNorm | |
| ทุก weight เป็น low-rank หรือ factored | |
| """ | |
| def __init__(self, cfg: OmegaConfig, attn_rank: int = 32, ffn_rank: int = 32): | |
| super().__init__() | |
| self.norm1 = nn.RMSNorm(cfg.dim) # type: ignore[attr-defined] | |
| self.norm2 = nn.RMSNorm(cfg.dim) # type: ignore[attr-defined] | |
| self.attn = SpectralAttention(cfg, rank=attn_rank) | |
| self.ffn = LowRankFFN(cfg, rank=ffn_rank) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| kv_cache: dict | None = None, | |
| mask: torch.Tensor | None = None, | |
| ) -> tuple[torch.Tensor, dict | None]: | |
| a, new_cache = self.attn(self.norm1(x), kv_cache, mask) | |
| x = x + a | |
| x = x + self.ffn(self.norm2(x)) | |
| return x, new_cache | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| # 5. SpectralMindModel — Full Model ด้วย BloomEmbedding + SpectralBlocks | |
| # ───────────────────────────────────────────────────────────────────────────── | |
| class SpectralMindModel(nn.Module): | |
| """ | |
| Full model ที่เล็กที่สุดในขณะที่ยังรักษาคุณภาพ: | |
| Standard TinyMind (tiny): ~120M params | |
| SpectralMind (nano): ~2-4M params (30-60x smaller) | |
| SpectralMind (micro): ~8-12M params (10-15x smaller) | |
| ความสัมพันธ์ขนาด: | |
| - BloomEmbedding: B·d แทน V·d → 8x savings | |
| - StiefelLinear (attn): 4r·2d แทน 4d² → d/(2r) savings | |
| - LowRankFFN: r·(d+h)·2 แทน 2d·h → d·h/(r(d+h)) savings | |
| """ | |
| def __init__( | |
| self, | |
| cfg: OmegaConfig, | |
| attn_rank: int = 32, | |
| ffn_rank: int = 32, | |
| bloom_buckets: int = 8192, | |
| bloom_hashes: int = 4, | |
| ): | |
| super().__init__() | |
| self.cfg = cfg | |
| # ── Embedding (BloomEmbedding แทน standard Embedding) ───────────── | |
| self.embed = BloomEmbedding( | |
| vocab_size=cfg.vocab_size, | |
| dim=cfg.dim, | |
| n_hashes=bloom_hashes, | |
| n_buckets=bloom_buckets, | |
| ) | |
| # ── Transformer blocks ───────────────────────────────────────────── | |
| self.blocks = nn.ModuleList([ | |
| SpectralBlock(cfg, attn_rank=attn_rank, ffn_rank=ffn_rank) | |
| for _ in range(cfg.n_layers) | |
| ]) | |
| # ── Output ──────────────────────────────────────────────────────── | |
| self.norm_out = nn.RMSNorm(cfg.dim) # type: ignore[attr-defined] | |
| # LM head: StiefelLinear rank เล็กกว่า (vocab projection) | |
| lm_rank = min(attn_rank * 2, cfg.dim // 4) | |
| self.lm_head = StiefelLinear(cfg.dim, cfg.vocab_size, rank=lm_rank) | |
| self._init_sigma() | |
| def _init_sigma(self) -> None: | |
| """Initialize singular values ให้ output scale สมดุล""" | |
| for m in self.modules(): | |
| if isinstance(m, StiefelLinear): | |
| # Scaled initialization: σ ∝ 1/√r เพื่อให้ ‖Wx‖ ≈ ‖x‖ | |
| nn.init.normal_(m.sigma, std=1.0 / math.sqrt(m.rank)) | |
| if isinstance(m, LowRankFFN): | |
| nn.init.normal_(m.d_up, std=0.02) | |
| nn.init.normal_(m.d_down, std=0.02) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| labels: torch.Tensor | None = None, | |
| kv_caches: list[dict] | None = None, | |
| ) -> dict[str, torch.Tensor]: | |
| x = self.embed(input_ids) | |
| new_caches: list[dict] = [] | |
| for i, block in enumerate(self.blocks): | |
| cache_in = kv_caches[i] if kv_caches else None | |
| x, cache_out = block(x, kv_cache=cache_in, mask=attention_mask) | |
| if cache_out is not None: | |
| new_caches.append(cache_out) | |
| x = self.norm_out(x) | |
| logits = self.lm_head(x) | |
| result: dict[str, torch.Tensor] = {"logits": logits} | |
| if labels is not None: | |
| loss = F.cross_entropy( | |
| logits[..., :-1, :].contiguous().view(-1, self.cfg.vocab_size), | |
| labels[..., 1:].contiguous().view(-1), | |
| ignore_index=-100, | |
| ) | |
| result["loss"] = loss | |
| if new_caches: | |
| result["kv_caches"] = new_caches # type: ignore[assignment] | |
| return result | |
| def generate( | |
| self, | |
| input_ids: torch.Tensor, | |
| max_new_tokens: int = 256, | |
| temperature: float = 0.8, | |
| top_p: float = 0.9, | |
| ) -> torch.Tensor: | |
| self.eval() | |
| generated = input_ids.clone() | |
| caches: list[dict] = [{} for _ in self.blocks] | |
| if generated.shape[1] > 1: | |
| out = self.forward(generated[:, :-1], kv_caches=caches) | |
| if "kv_caches" in out: | |
| caches = out["kv_caches"] # type: ignore[assignment] | |
| for _ in range(max_new_tokens): | |
| out = self.forward(generated[:, -1:], kv_caches=caches) | |
| if "kv_caches" in out: | |
| caches = out["kv_caches"] # type: ignore[assignment] | |
| logits = out["logits"][:, -1, :].float() / max(temperature, 1e-5) | |
| sv, si = torch.sort(logits, descending=True) | |
| cp = torch.cumsum(F.softmax(sv, dim=-1), dim=-1) | |
| sv[cp - F.softmax(sv, dim=-1) > top_p] = float("-inf") | |
| logits.scatter_(1, si, sv) | |
| next_tok = torch.multinomial(F.softmax(logits, dim=-1), 1) | |
| generated = torch.cat([generated, next_tok], dim=1) | |
| if next_tok.item() == self.cfg.eos_token_id: | |
| break | |
| return generated | |
| def count_params(self) -> str: | |
| total = sum(p.numel() for p in self.parameters()) | |
| train = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| embed_b = self.embed.param_count() | |
| embed_s = self.cfg.vocab_size * self.cfg.dim | |
| return ( | |
| f"Total {total/1e6:.2f}M | Trainable {train/1e6:.2f}M\n" | |
| f"Embed: {embed_b/1e6:.2f}M (Bloom) vs {embed_s/1e6:.2f}M (standard) " | |
| f"= {embed_s/embed_b:.1f}x compression" | |
| ) | |
| def all_stiefel_layers(self) -> list[StiefelLinear]: | |
| return [m for m in self.modules() if isinstance(m, StiefelLinear)] | |
Xet Storage Details
- Size:
- 25.2 kB
- Xet hash:
- 71743867d53797f2dd810ecebfa44d4e7fcbea546c7072b9343800715c88bd65
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.