bbkdevops's picture
download
raw
9.47 kB
"""
TinyMind Omega — Full Model (KV-Cache + Checkpoint-Efficient)
"""
import torch
import torch.nn as nn
import torch.utils.checkpoint as ckpt_util
from .config import OmegaConfig
from .layers import GatedLinearAttention, SelectiveSSM, KANFeedForward, RMSNorm
from .purefield import PureFieldBlock, PureFieldShared
from .pure_lattice_cnn import PureLatticeCNNConfig, PureLatticeCNNCore
from .self_assessment_core import SelfAssessmentCore, SelfAssessmentCoreConfig
class OmegaBlock(nn.Module):
def __init__(
self,
cfg: OmegaConfig,
layer_type: str,
layer_index: int = 0,
purefield_shared: PureFieldShared | None = None,
):
super().__init__()
self.layer_type = layer_type
self.is_purefield = layer_type == "P"
self.mixer: PureFieldBlock | SelectiveSSM | GatedLinearAttention
if self.is_purefield:
self.mixer = PureFieldBlock(cfg, layer_index=layer_index, shared=purefield_shared)
self.norm1 = None
self.norm2 = None
self.ffn = None
else:
self.norm1 = RMSNorm(cfg.dim)
self.norm2 = RMSNorm(cfg.dim)
if layer_type == "S":
self.mixer: SelectiveSSM | GatedLinearAttention = SelectiveSSM(cfg)
else:
self.mixer = GatedLinearAttention(cfg)
self.ffn = KANFeedForward(cfg)
self.use_grad_ckpt: bool = False
def _forward_body(
self,
x: torch.Tensor,
kv_cache: dict | None,
mask: torch.Tensor | None,
) -> tuple[torch.Tensor, dict | None]:
if self.is_purefield:
return self.mixer(x, kv_cache=kv_cache, mask=mask) # type: ignore[misc]
assert self.norm1 is not None and self.norm2 is not None and self.ffn is not None
mx, new_cache = self.mixer(self.norm1(x), kv_cache, mask)
x = x + mx
x = x + self.ffn(self.norm2(x))
return x, new_cache
def forward(
self,
x: torch.Tensor,
kv_cache: dict | None = None,
mask: torch.Tensor | None = None,
) -> tuple[torch.Tensor, dict | None]:
if self.use_grad_ckpt and self.training:
# gradient checkpointing: rematerialise activations, ประหยัด ~40% VRAM
def ckpt_fn(x_: torch.Tensor) -> torch.Tensor:
out, _ = self._forward_body(x_, None, mask)
return out
out_x: torch.Tensor = ckpt_util.checkpoint(ckpt_fn, x, use_reentrant=False) # type: ignore[assignment]
return out_x, None
return self._forward_body(x, kv_cache, mask)
class OmegaModel(nn.Module):
def __init__(self, cfg: OmegaConfig):
super().__init__()
self.cfg = cfg
self.embed = nn.Embedding(cfg.vocab_size, cfg.dim, padding_idx=cfg.pad_token_id)
self.cnn_stem = (
PureLatticeCNNCore(
PureLatticeCNNConfig(
dim=cfg.dim,
hidden_mult=cfg.cnn_hidden_mult,
kernel_sizes=cfg.cnn_kernel_sizes,
dilations=cfg.cnn_dilations,
dropout=cfg.dropout,
residual_scale=cfg.cnn_residual_scale,
)
)
if cfg.cnn_core_enabled
else None
)
self.self_assessment = (
SelfAssessmentCore(
SelfAssessmentCoreConfig(
dim=cfg.dim,
inner_dim=cfg.dim * cfg.self_assessment_inner_mult,
recursion_steps=cfg.self_assessment_steps,
residual_scale=cfg.self_assessment_residual_scale,
dropout=cfg.dropout,
)
)
if cfg.self_assessment_enabled
else None
)
self.layer_self_assessment = (
SelfAssessmentCore(
SelfAssessmentCoreConfig(
dim=cfg.dim,
inner_dim=cfg.dim * cfg.self_assessment_inner_mult,
recursion_steps=cfg.self_assessment_steps,
residual_scale=cfg.self_assessment_residual_scale,
dropout=cfg.dropout,
)
)
if cfg.self_assessment_enabled and cfg.self_assessment_frequency > 0
else None
)
pattern = (cfg.layer_pattern * (cfg.n_layers // len(cfg.layer_pattern) + 1))[:cfg.n_layers]
if cfg.architecture_mode == "purefield":
pattern = "P" * cfg.n_layers
self.purefield_shared = PureFieldShared(cfg) if "P" in pattern else None
self.blocks: nn.ModuleList[OmegaBlock] = nn.ModuleList(
[OmegaBlock(cfg, t, layer_index=i, purefield_shared=self.purefield_shared) for i, t in enumerate(pattern)]
)
self.norm_out = RMSNorm(cfg.dim)
self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False)
if cfg.tie_word_embeddings:
self.lm_head.weight = self.embed.weight
self._init_weights()
def _init_weights(self):
nn.init.normal_(self.embed.weight, std=0.02)
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def enable_grad_checkpointing(self):
"""เปิด gradient checkpointing — ประหยัด VRAM ~40% ขณะ train"""
for block in self.blocks:
assert isinstance(block, OmegaBlock)
block.use_grad_ckpt = True
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
labels: torch.Tensor | None = None,
kv_caches: list[dict] | None = None, # per-layer caches for inference
) -> dict[str, torch.Tensor]:
x = self.embed(input_ids)
if self.cnn_stem is not None:
x, _cnn_state = self.cnn_stem(x)
new_caches: list[dict] = []
layer_assessments: list[dict[str, torch.Tensor]] = []
for i, block in enumerate(self.blocks):
cache_in = kv_caches[i] if kv_caches else None
x, cache_out = block(x, kv_cache=cache_in, mask=attention_mask)
if cache_out is not None:
new_caches.append(cache_out)
if self.layer_self_assessment is not None and (i + 1) % max(1, self.cfg.self_assessment_frequency) == 0:
x, layer_report = self.layer_self_assessment(x)
layer_assessments.append(layer_report)
assessment_report = None
if self.self_assessment is not None:
x, assessment_report = self.self_assessment(x)
x = self.norm_out(x)
logits = self.lm_head(x)
result: dict[str, torch.Tensor] = {"logits": logits}
if assessment_report is not None:
result["self_assessment"] = assessment_report # type: ignore[assignment]
if layer_assessments:
result["layer_self_assessments"] = layer_assessments # type: ignore[assignment]
if labels is not None:
loss = nn.functional.cross_entropy(
logits[..., :-1, :].contiguous().view(-1, self.cfg.vocab_size),
labels[..., 1:].contiguous().view(-1),
ignore_index=-100,
)
result["loss"] = loss
if new_caches:
result["kv_caches"] = new_caches # type: ignore[assignment]
return result
@torch.no_grad()
def generate(
self,
input_ids: torch.Tensor,
max_new_tokens: int = 512,
temperature: float = 0.8,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
) -> torch.Tensor:
self.eval()
generated = input_ids.clone()
caches: list[dict] = [{} for _ in self.blocks]
# Prefill all but the last token so decode consumes each token once.
if generated.shape[1] > 1:
out = self.forward(generated[:, :-1], kv_caches=caches)
if "kv_caches" in out:
caches = out["kv_caches"] # type: ignore[assignment]
for _ in range(max_new_tokens):
# Decode: one token at a time with KV cache
last_tok = generated[:, -1:]
out = self.forward(last_tok, kv_caches=caches)
if "kv_caches" in out:
caches = out["kv_caches"] # type: ignore[assignment]
logits = out["logits"][:, -1, :].float() / max(temperature, 1e-5)
# Repetition penalty
for tid in generated[0].tolist():
logits[0, tid] /= repetition_penalty
# Top-p nucleus sampling
sv, si = torch.sort(logits, descending=True)
cp = torch.cumsum(torch.softmax(sv, dim=-1), dim=-1)
sv[cp - torch.softmax(sv, dim=-1) > top_p] = float("-inf")
logits.scatter_(1, si, sv)
next_tok = torch.multinomial(torch.softmax(logits, dim=-1), 1)
generated = torch.cat([generated, next_tok], dim=1)
if next_tok.item() == self.cfg.eos_token_id:
break
return generated
def count_params(self) -> str:
n = sum(p.numel() for p in self.parameters())
t = sum(p.numel() for p in self.parameters() if p.requires_grad)
return f"Total {n/1e6:.1f}M | Trainable {t/1e6:.1f}M"

Xet Storage Details

Size:
9.47 kB
·
Xet hash:
10e139815b3e1bf8d1093047a0bec643c60a6976692430782e2c9c94462a10e9

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.