mitotic-transformer / modeling_mitotic_transformer.py
alis-sila's picture
Update modeling_mitotic_transformer.py
03c6a0d verified
# modeling_mitotic_transformer.py
"""
Mitotic Transformer with Causal Language Modeling Head
A biologically and cosmologically inspired Transformer architecture based on
"The Cosmology of the Living Cell (Mother Theory)" by Alis Hasić.
Key concepts:
- Mitosis as core computational operation
- Cytoskeletal Attention (Dark Matter scaffold analogue)
- Osmotic Turgor Decoder (Dark Energy / 70/30 expansion)
- F1-String hierarchical scaling
- Consciousness Module with White-Hole Rendering
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional
from transformers import PreTrainedModel
from .configuration_mitotic_transformer import MitoticTransformerConfig
# ──────────────────────────────────────────────────────────────
# Model Components
# ──────────────────────────────────────────────────────────────
class CytoskeletalAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model)
self.filament_weight = nn.Parameter(torch.ones(n_heads))
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
B, N, C = hidden_states.shape
qkv = self.qkv(hidden_states).chunk(3, dim=-1)
q, k, v = [t.view(B, N, self.n_heads, self.head_dim).transpose(1, 2) for t in qkv]
attn = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn = attn * self.filament_weight.view(1, self.n_heads, 1, 1)
attn = F.softmax(attn, dim=-1)
out = attn @ v
return out.transpose(1, 2).contiguous().view(B, N, C)
class MitoticBlock(nn.Module):
def __init__(self, d_model: int, n_heads: int, turgor_pressure: float = 0.70):
super().__init__()
self.cyto_attention = CytoskeletalAttention(d_model, n_heads)
self.norm1 = nn.LayerNorm(d_model)
self.ffn = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
)
self.norm2 = nn.LayerNorm(d_model)
self.turgor_pressure = turgor_pressure
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.cyto_attention(hidden_states)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.norm2(hidden_states)
hidden_states = self.ffn(hidden_states)
hidden_states = hidden_states + residual
hidden_states = hidden_states * (1.0 + 0.01 * self.turgor_pressure)
return hidden_states
class OsmoticTurgorDecoder(nn.Module):
def __init__(self, d_model: int, turgor_pressure: float = 0.70):
super().__init__()
self.turgor_pressure = turgor_pressure
self.proj = nn.Linear(d_model, d_model)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.proj(hidden_states) * (1.0 + self.turgor_pressure)
class ConsciousnessModule(nn.Module):
def __init__(self, d_model: int):
super().__init__()
self.white_hole_proj = nn.Linear(d_model, d_model)
self.biological_gpu = nn.Sequential(
nn.Linear(d_model, 2048),
nn.GELU(),
nn.Linear(2048, d_model),
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
x = self.white_hole_proj(hidden_states)
x = self.biological_gpu(x)
return x
# ──────────────────────────────────────────────────────────────
# Main Model with Causal LM Head
# ──────────────────────────────────────────────────────────────
class MitoticTransformerForCausalLM(PreTrainedModel):
config_class = MitoticTransformerConfig
base_model_prefix = "mitotic_transformer"
supports_gradient_checkpointing = True
def __init__(self, config: MitoticTransformerConfig):
super().__init__(config)
self.embeddings = nn.Embedding(config.vocab_size, config.d_model)
self.f1_string_layer = nn.Linear(config.d_model, config.d_model)
self.layers = nn.ModuleList([
MitoticBlock(config.d_model, config.n_heads, config.turgor_pressure)
for _ in range(config.n_layers)
])
self.turgor_decoder = OsmoticTurgorDecoder(config.d_model, config.turgor_pressure)
self.consciousness_module = ConsciousnessModule(config.d_model)
self.final_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.lm_head.weight = self.embeddings.weight
self.is_decoder = True
self.tie_word_embeddings = True
self.post_init()
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
x = self.embeddings(input_ids) * math.sqrt(self.config.kappa_f1)
x = self.f1_string_layer(x)
for layer in self.layers:
if self.training and self.supports_gradient_checkpointing:
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
else:
x = layer(x)
x = self.turgor_decoder(x)
x = self.consciousness_module(x)
x = self.final_norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return {"loss": loss, "logits": logits, "last_hidden_state": x}