Buckets:
bbkdevops/unicosys-hypergraph-bucket / tinymind-native-colab-handoff /bundle /model /architecture.py
| """ | |
| TinyMind Omega — Full Model (KV-Cache + Checkpoint-Efficient) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.utils.checkpoint as ckpt_util | |
| from .config import OmegaConfig | |
| from .layers import GatedLinearAttention, SelectiveSSM, KANFeedForward, RMSNorm | |
| from .purefield import PureFieldBlock, PureFieldShared | |
| from .pure_lattice_cnn import PureLatticeCNNConfig, PureLatticeCNNCore | |
| from .self_assessment_core import SelfAssessmentCore, SelfAssessmentCoreConfig | |
| class OmegaBlock(nn.Module): | |
| def __init__( | |
| self, | |
| cfg: OmegaConfig, | |
| layer_type: str, | |
| layer_index: int = 0, | |
| purefield_shared: PureFieldShared | None = None, | |
| ): | |
| super().__init__() | |
| self.layer_type = layer_type | |
| self.is_purefield = layer_type == "P" | |
| self.mixer: PureFieldBlock | SelectiveSSM | GatedLinearAttention | |
| if self.is_purefield: | |
| self.mixer = PureFieldBlock(cfg, layer_index=layer_index, shared=purefield_shared) | |
| self.norm1 = None | |
| self.norm2 = None | |
| self.ffn = None | |
| else: | |
| self.norm1 = RMSNorm(cfg.dim) | |
| self.norm2 = RMSNorm(cfg.dim) | |
| if layer_type == "S": | |
| self.mixer: SelectiveSSM | GatedLinearAttention = SelectiveSSM(cfg) | |
| else: | |
| self.mixer = GatedLinearAttention(cfg) | |
| self.ffn = KANFeedForward(cfg) | |
| self.use_grad_ckpt: bool = False | |
| def _forward_body( | |
| self, | |
| x: torch.Tensor, | |
| kv_cache: dict | None, | |
| mask: torch.Tensor | None, | |
| ) -> tuple[torch.Tensor, dict | None]: | |
| if self.is_purefield: | |
| return self.mixer(x, kv_cache=kv_cache, mask=mask) # type: ignore[misc] | |
| assert self.norm1 is not None and self.norm2 is not None and self.ffn is not None | |
| mx, new_cache = self.mixer(self.norm1(x), kv_cache, mask) | |
| x = x + mx | |
| x = x + self.ffn(self.norm2(x)) | |
| return x, new_cache | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| kv_cache: dict | None = None, | |
| mask: torch.Tensor | None = None, | |
| ) -> tuple[torch.Tensor, dict | None]: | |
| if self.use_grad_ckpt and self.training: | |
| # gradient checkpointing: rematerialise activations, ประหยัด ~40% VRAM | |
| def ckpt_fn(x_: torch.Tensor) -> torch.Tensor: | |
| out, _ = self._forward_body(x_, None, mask) | |
| return out | |
| out_x: torch.Tensor = ckpt_util.checkpoint(ckpt_fn, x, use_reentrant=False) # type: ignore[assignment] | |
| return out_x, None | |
| return self._forward_body(x, kv_cache, mask) | |
| class OmegaModel(nn.Module): | |
| def __init__(self, cfg: OmegaConfig): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.embed = nn.Embedding(cfg.vocab_size, cfg.dim, padding_idx=cfg.pad_token_id) | |
| self.cnn_stem = ( | |
| PureLatticeCNNCore( | |
| PureLatticeCNNConfig( | |
| dim=cfg.dim, | |
| hidden_mult=cfg.cnn_hidden_mult, | |
| kernel_sizes=cfg.cnn_kernel_sizes, | |
| dilations=cfg.cnn_dilations, | |
| dropout=cfg.dropout, | |
| residual_scale=cfg.cnn_residual_scale, | |
| ) | |
| ) | |
| if cfg.cnn_core_enabled | |
| else None | |
| ) | |
| self.self_assessment = ( | |
| SelfAssessmentCore( | |
| SelfAssessmentCoreConfig( | |
| dim=cfg.dim, | |
| inner_dim=cfg.dim * cfg.self_assessment_inner_mult, | |
| recursion_steps=cfg.self_assessment_steps, | |
| residual_scale=cfg.self_assessment_residual_scale, | |
| dropout=cfg.dropout, | |
| ) | |
| ) | |
| if cfg.self_assessment_enabled | |
| else None | |
| ) | |
| self.layer_self_assessment = ( | |
| SelfAssessmentCore( | |
| SelfAssessmentCoreConfig( | |
| dim=cfg.dim, | |
| inner_dim=cfg.dim * cfg.self_assessment_inner_mult, | |
| recursion_steps=cfg.self_assessment_steps, | |
| residual_scale=cfg.self_assessment_residual_scale, | |
| dropout=cfg.dropout, | |
| ) | |
| ) | |
| if cfg.self_assessment_enabled and cfg.self_assessment_frequency > 0 | |
| else None | |
| ) | |
| pattern = (cfg.layer_pattern * (cfg.n_layers // len(cfg.layer_pattern) + 1))[:cfg.n_layers] | |
| if cfg.architecture_mode == "purefield": | |
| pattern = "P" * cfg.n_layers | |
| self.purefield_shared = PureFieldShared(cfg) if "P" in pattern else None | |
| self.blocks: nn.ModuleList[OmegaBlock] = nn.ModuleList( | |
| [OmegaBlock(cfg, t, layer_index=i, purefield_shared=self.purefield_shared) for i, t in enumerate(pattern)] | |
| ) | |
| self.norm_out = RMSNorm(cfg.dim) | |
| self.lm_head = nn.Linear(cfg.dim, cfg.vocab_size, bias=False) | |
| if cfg.tie_word_embeddings: | |
| self.lm_head.weight = self.embed.weight | |
| self._init_weights() | |
| def _init_weights(self): | |
| nn.init.normal_(self.embed.weight, std=0.02) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Linear): | |
| nn.init.normal_(m.weight, std=0.02) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias) | |
| def enable_grad_checkpointing(self): | |
| """เปิด gradient checkpointing — ประหยัด VRAM ~40% ขณะ train""" | |
| for block in self.blocks: | |
| assert isinstance(block, OmegaBlock) | |
| block.use_grad_ckpt = True | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| attention_mask: torch.Tensor | None = None, | |
| labels: torch.Tensor | None = None, | |
| kv_caches: list[dict] | None = None, # per-layer caches for inference | |
| ) -> dict[str, torch.Tensor]: | |
| x = self.embed(input_ids) | |
| if self.cnn_stem is not None: | |
| x, _cnn_state = self.cnn_stem(x) | |
| new_caches: list[dict] = [] | |
| layer_assessments: list[dict[str, torch.Tensor]] = [] | |
| for i, block in enumerate(self.blocks): | |
| cache_in = kv_caches[i] if kv_caches else None | |
| x, cache_out = block(x, kv_cache=cache_in, mask=attention_mask) | |
| if cache_out is not None: | |
| new_caches.append(cache_out) | |
| if self.layer_self_assessment is not None and (i + 1) % max(1, self.cfg.self_assessment_frequency) == 0: | |
| x, layer_report = self.layer_self_assessment(x) | |
| layer_assessments.append(layer_report) | |
| assessment_report = None | |
| if self.self_assessment is not None: | |
| x, assessment_report = self.self_assessment(x) | |
| x = self.norm_out(x) | |
| logits = self.lm_head(x) | |
| result: dict[str, torch.Tensor] = {"logits": logits} | |
| if assessment_report is not None: | |
| result["self_assessment"] = assessment_report # type: ignore[assignment] | |
| if layer_assessments: | |
| result["layer_self_assessments"] = layer_assessments # type: ignore[assignment] | |
| if labels is not None: | |
| loss = nn.functional.cross_entropy( | |
| logits[..., :-1, :].contiguous().view(-1, self.cfg.vocab_size), | |
| labels[..., 1:].contiguous().view(-1), | |
| ignore_index=-100, | |
| ) | |
| result["loss"] = loss | |
| if new_caches: | |
| result["kv_caches"] = new_caches # type: ignore[assignment] | |
| return result | |
| def generate( | |
| self, | |
| input_ids: torch.Tensor, | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.8, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.1, | |
| ) -> torch.Tensor: | |
| self.eval() | |
| generated = input_ids.clone() | |
| caches: list[dict] = [{} for _ in self.blocks] | |
| # Prefill all but the last token so decode consumes each token once. | |
| if generated.shape[1] > 1: | |
| out = self.forward(generated[:, :-1], kv_caches=caches) | |
| if "kv_caches" in out: | |
| caches = out["kv_caches"] # type: ignore[assignment] | |
| for _ in range(max_new_tokens): | |
| # Decode: one token at a time with KV cache | |
| last_tok = generated[:, -1:] | |
| out = self.forward(last_tok, kv_caches=caches) | |
| if "kv_caches" in out: | |
| caches = out["kv_caches"] # type: ignore[assignment] | |
| logits = out["logits"][:, -1, :].float() / max(temperature, 1e-5) | |
| # Repetition penalty | |
| for tid in generated[0].tolist(): | |
| logits[0, tid] /= repetition_penalty | |
| # Top-p nucleus sampling | |
| sv, si = torch.sort(logits, descending=True) | |
| cp = torch.cumsum(torch.softmax(sv, dim=-1), dim=-1) | |
| sv[cp - torch.softmax(sv, dim=-1) > top_p] = float("-inf") | |
| logits.scatter_(1, si, sv) | |
| next_tok = torch.multinomial(torch.softmax(logits, dim=-1), 1) | |
| generated = torch.cat([generated, next_tok], dim=1) | |
| if next_tok.item() == self.cfg.eos_token_id: | |
| break | |
| return generated | |
| def count_params(self) -> str: | |
| n = sum(p.numel() for p in self.parameters()) | |
| t = sum(p.numel() for p in self.parameters() if p.requires_grad) | |
| return f"Total {n/1e6:.1f}M | Trainable {t/1e6:.1f}M" | |
Xet Storage Details
- Size:
- 9.47 kB
- Xet hash:
- 10e139815b3e1bf8d1093047a0bec643c60a6976692430782e2c9c94462a10e9
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.