| """ |
| GLADIUS v2.0 — The Kernel |
| |
| This is the core. Everything flows through here. |
| |
| Input → Embed → Memory Read → Time Stamp → Transformer Layers → |
| Router → Specialists → Tool Check → Modulate → Decode → |
| Memory Write → Cognition Check → Output |
| |
| Every decision is argmax S(x | context). |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import time as time_module |
|
|
| from .config import KernelConfig |
| from .embeddings import SharedEmbeddings |
| from .attention import TransformerLayer, RMSNorm |
| from .memory import ThreeTemperatureMemory |
| from .temporal import TimeEngine |
| from .cognition import CognitionLoop |
| from .modulator import Modulator |
| from .tools import ToolCortex |
| from .router import NexusRouter |
| from .senses import SensoryCortex, VisionConfig, AudioConfig |
|
|
|
|
| class GladiusKernel(nn.Module): |
| """ |
| The GLADIUS Kernel. |
| |
| Not a model. Not a wrapper. A kernel. |
| Memory manages persistence. Cognition schedules thinking. |
| Time provides awareness. Modulator controls voice. |
| Tool Cortex provides hands. Specialists run ON this kernel. |
| """ |
|
|
| def __init__(self, config: KernelConfig, |
| vision_config: VisionConfig | None = None, |
| audio_config: AudioConfig | None = None): |
| super().__init__() |
| self.config = config |
|
|
| |
| self.embeddings = SharedEmbeddings(config) |
| self.memory = ThreeTemperatureMemory(config) |
| self.time_engine = TimeEngine(config) |
| self.cognition = CognitionLoop(config) |
| self.modulator = Modulator(config) |
| self.tool_cortex = ToolCortex(config) |
| self.router = NexusRouter(config) |
|
|
| |
| self.has_senses = vision_config is not None or audio_config is not None |
| if self.has_senses: |
| self.senses = SensoryCortex(config, vision_config, audio_config) |
| else: |
| self.senses = None |
|
|
| |
| self.layers = nn.ModuleList([ |
| TransformerLayer(config, layer_idx=i) |
| for i in range(config.num_layers) |
| ]) |
|
|
| |
| self.final_norm = RMSNorm(config.hidden_dim) |
|
|
| |
| self.register_buffer( |
| 'causal_mask', |
| torch.tril(torch.ones(config.max_seq_len, config.max_seq_len)) |
| .unsqueeze(0).unsqueeze(0) |
| ) |
|
|
| |
| self._report_params() |
|
|
| def _report_params(self): |
| total = sum(p.numel() for p in self.parameters()) |
| trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| print(f"GLADIUS Kernel initialized: {total:,} params ({trainable:,} trainable)") |
| print(f" Memory: {total * 4 / 1024 / 1024:.1f} MB (float32)") |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor | None = None, |
| timestamp: float | torch.Tensor | None = None, |
| images: torch.Tensor | None = None, |
| audio: torch.Tensor | None = None, |
| ) -> dict: |
| """ |
| Full forward pass through the kernel. |
| |
| Args: |
| input_ids: (batch, seq_len) token IDs — can be None for pure sensory input |
| timestamp: Unix timestamp (or None for current time) |
| images: (batch, C, H, W) pixel values [0, 1] — vision input |
| audio: (batch, 1, n_mels, n_frames) mel spectrogram — audio input |
| |
| Returns: |
| dict with: |
| logits: (batch, seq_len, vocab_size) — modulated output logits |
| silence: (batch, 1) — silence gate value |
| mode: CognitiveMode — current cognitive mode |
| importance: (batch, seq_len, 1) — memory importance scores |
| modality_mask: (batch, seq_len) — 0=text, 1=vision, 2=audio (if multimodal) |
| """ |
| |
| text_embeds = None |
| if input_ids is not None: |
| B, S = input_ids.shape |
| text_embeds = self.embeddings.embed(input_ids) |
| |
| |
| modality_mask = None |
| if self.has_senses and (images is not None or audio is not None): |
| x, modality_mask = self.senses( |
| text_embeds=text_embeds, |
| images=images, |
| audio=audio, |
| ) |
| B = x.shape[0] |
| S = x.shape[1] |
| elif text_embeds is not None: |
| x = text_embeds |
| B, S = x.shape[0], x.shape[1] |
| else: |
| raise ValueError("Must provide input_ids, images, or audio") |
|
|
| |
| x = self.memory.read(x) |
|
|
| |
| time_embed = None |
| if timestamp is not None: |
| if isinstance(timestamp, (int, float)): |
| timestamp = torch.tensor([timestamp] * B, dtype=torch.float32) |
| time_embed = self.time_engine(timestamp) |
| x = x + time_embed.unsqueeze(1) |
|
|
| |
| |
| if S <= self.config.max_seq_len: |
| mask = self.causal_mask[:, :, :S, :S] |
| else: |
| mask = torch.tril(torch.ones(1, 1, S, S, device=x.device)) |
| for layer in self.layers: |
| x = layer(x, mask=mask) |
|
|
| |
| x = self.final_norm(x) |
|
|
| |
| tool_result = self.tool_cortex.check_activation(x) |
| if tool_result is not None: |
| x = x + tool_result |
|
|
| |
| logits, silence, pixel_output = self.modulator(x, self.embeddings.output_head, temporal_embedding=time_embed) |
|
|
| |
| importance = self.memory.write(x) |
|
|
| |
| mode, cognitive_state, mode_probs = self.cognition.heartbeat(x) |
|
|
| |
| if self.cognition.should_consolidate(): |
| self.memory.consolidate() |
|
|
| |
| self.time_engine.record_event() |
|
|
| return { |
| 'logits': logits, |
| 'silence': silence, |
| 'pixel_output': pixel_output, |
| 'mode': mode, |
| 'importance': importance, |
| 'modality_mask': modality_mask, |
| 'cognitive_state': cognitive_state, |
| 'mode_probs': mode_probs, |
| } |
|
|
| @torch.no_grad() |
| def generate( |
| self, |
| input_ids: torch.Tensor, |
| max_tokens: int = 100, |
| temperature: float = 1.0, |
| top_k: int = 50, |
| timestamp: float | None = None, |
| ) -> torch.Tensor: |
| """ |
| Autoregressive generation. |
| |
| Args: |
| input_ids: (1, seq_len) — prompt tokens |
| max_tokens: maximum tokens to generate |
| temperature: sampling temperature (1.0 = neutral) |
| top_k: top-k sampling (0 = greedy) |
| timestamp: Unix timestamp |
| |
| Returns: |
| (1, seq_len + generated) — full sequence |
| """ |
| self.eval() |
|
|
| if timestamp is None: |
| timestamp = time_module.time() |
|
|
| for _ in range(max_tokens): |
| |
| context = input_ids[:, -self.config.max_seq_len:] |
|
|
| |
| result = self.forward(context, timestamp=timestamp) |
| logits = result['logits'][:, -1, :] |
| silence = result['silence'] |
|
|
| |
| if silence.item() > self.config.silence_threshold: |
| break |
|
|
| |
| if temperature != 1.0: |
| logits = logits / temperature |
|
|
| |
| if top_k > 0: |
| topk_logits, topk_indices = logits.topk(top_k, dim=-1) |
| probs = torch.softmax(topk_logits, dim=-1) |
| sampled_idx = torch.multinomial(probs, 1) |
| next_token = topk_indices.gather(-1, sampled_idx) |
| else: |
| next_token = logits.argmax(dim=-1, keepdim=True) |
|
|
| |
| input_ids = torch.cat([input_ids, next_token], dim=1) |
|
|
| |
| if next_token.item() == self.config.eos_token_id: |
| break |
|
|
| return input_ids |
|
|
| def save_checkpoint(self, path: str): |
| """Save full kernel state.""" |
| torch.save({ |
| 'model_state_dict': self.state_dict(), |
| 'config': self.config, |
| }, path) |
| |
| self.memory.checkpoint(path + '.warm') |
|
|
| @classmethod |
| def load_checkpoint(cls, path: str, map_location: str | None = None) -> 'GladiusKernel': |
| """Load kernel from checkpoint.""" |
| data = torch.load(path, map_location=map_location, weights_only=False) |
| |
| import dataclasses |
| cfg_raw = data['config'] |
| if isinstance(cfg_raw, dict): |
| config_dict = cfg_raw |
| elif dataclasses.is_dataclass(cfg_raw) and not isinstance(cfg_raw, type): |
| config_dict = dataclasses.asdict(cfg_raw) |
| else: |
| config_dict = dict(cfg_raw) |
| |
| |
| if 'cold_embedding_dim' not in config_dict or config_dict['cold_embedding_dim'] != config_dict['hidden_dim']: |
| config_dict['cold_embedding_dim'] = config_dict['hidden_dim'] |
| |
| |
| valid_fields = {f.name for f in dataclasses.fields(KernelConfig)} |
| extra_keys = {k for k in config_dict if k not in valid_fields} |
| filtered_config = {k: v for k, v in config_dict.items() if k in valid_fields} |
| |
| |
| if 'dtype' in filtered_config: |
| dtype_val = filtered_config['dtype'] |
| if isinstance(dtype_val, str): |
| filtered_config['dtype'] = getattr(torch, dtype_val.replace('torch.', ''), torch.float32) |
| elif not isinstance(dtype_val, torch.dtype): |
| filtered_config['dtype'] = torch.float32 |
| |
| |
| clock_mode = config_dict.get('clock_mode', 'continuous') |
| |
| config_from_checkpoint = KernelConfig(**filtered_config) |
| |
| |
| if hasattr(config_from_checkpoint, 'clock_mode'): |
| config_from_checkpoint.clock_mode = clock_mode |
| elif clock_mode != 'continuous': |
| config_from_checkpoint.clock_mode = clock_mode |
| |
| |
| sd = data['model_state_dict'] |
| if 'tool_cortex.tool_embeddings' in sd: |
| actual_max_tools = sd['tool_cortex.tool_embeddings'].shape[0] |
| if config_from_checkpoint.max_tools != actual_max_tools: |
| config_from_checkpoint.max_tools = actual_max_tools |
| |
| kernel = cls(config_from_checkpoint) |
| |
| kernel.load_state_dict(data['model_state_dict'], strict=False) |
| |
| try: |
| kernel.memory.restore(path + '.warm') |
| except FileNotFoundError: |
| pass |
| return kernel |
|
|