""" GLADIUS v2.0 — Sensory Cortex Vision and Audio perception modules. Design principle: Every sensory input projects into the SAME hidden_dim manifold as text tokens. The transformer backbone doesn't know — and shouldn't know — whether it's processing text, image patches, or audio frames. All are tokens. All live in the same space. Biological analogy: The thalamus doesn't differentiate modalities — it routes everything into cortical columns. The cortex learns what to do with each modality through the patterns in the data, not through hardcoded pathways. The sensory cortex adds exactly two things: 1. Encoders that project raw sensory data into hidden_dim 2. Modality embeddings that let the transformer KNOW what it's looking at That's it. No separate attention. No separate heads. The existing SLA² attention, MoE router, memory, and cognition systems handle everything else. They were always designed to — they just never had sensory data to work with. Architecture: Image → PatchEncoder → [CLS] + patch_embeds + modality_embed Audio → MelEncoder → [CLS] + frame_embeds + modality_embed Text → TokenEmbed → token_embeds + modality_embed All three produce (B, S, hidden_dim) tensors that concatenate into a unified sequence for the transformer backbone. """ import torch import torch.nn as nn import torch.nn.functional as F import math from dataclasses import dataclass from typing import Optional, Tuple from .config import KernelConfig # ── Configuration ────────────────────────────────────────────── @dataclass class VisionConfig: """Configuration for the vision sensory cortex.""" image_size: int = 28 # Input image dimension (square) patch_size: int = 4 # Patch dimension (image_size must be divisible) in_channels: int = 1 # Grayscale=1, RGB=3 use_cls_token: bool = True # Prepend [CLS] vision token dropout: float = 0.1 @property def num_patches(self) -> int: return (self.image_size // self.patch_size) ** 2 @property def patch_dim(self) -> int: return self.patch_size * self.patch_size * self.in_channels @dataclass class AudioConfig: """Configuration for the audio sensory cortex.""" n_mels: int = 64 # Mel frequency bins n_frames: int = 128 # Time frames (1-2 sec at 16kHz) patch_size_freq: int = 8 # Frequency patch size patch_size_time: int = 8 # Time patch size use_cls_token: bool = True dropout: float = 0.1 @property def num_patches(self) -> int: return (self.n_mels // self.patch_size_freq) * (self.n_frames // self.patch_size_time) @property def patch_dim(self) -> int: return self.patch_size_freq * self.patch_size_time # ── Modality Embeddings ─────────────────────────────────────── class ModalityEmbedding(nn.Module): """ Learned modality identifier. Each modality (text, vision, audio) gets a unique learned embedding added to every token of that modality. This is how the transformer knows WHAT it's processing — not through architecture, but through a learned signal in the data. Analogy: synaesthesia. The modality embedding IS the "color" of a sound or the "texture" of a word. It's information, not structure. """ def __init__(self, num_modalities: int, hidden_dim: int): super().__init__() # 0=text, 1=vision, 2=audio (extensible) self.embed = nn.Embedding(num_modalities, hidden_dim) nn.init.normal_(self.embed.weight, mean=0.0, std=0.02) def forward(self, modality_id: int, seq_len: int, batch_size: int = 1) -> torch.Tensor: """ Returns: (B, S, D) modality embedding to ADD to hidden states. """ ids = torch.full( (batch_size, seq_len), modality_id, dtype=torch.long, device=self.embed.weight.device ) return self.embed(ids) # ── Vision Cortex ───────────────────────────────────────────── class VisionEncoder(nn.Module): """ Image → Patches → hidden_dim projections. No convolutions. No pretrained backbone. Pure patch embedding with positional encoding — the simplest possible thing that projects visual data into the token manifold. ViT showed this works at scale. At 28×28 (MNIST), we get 49 patches of 4×4 pixels = 16 dimensions each, projected to hidden_dim. For larger images, increase image_size and patch_size proportionally. At 224×224 with patch_size=16: 196 patches, 768-dim patches. """ def __init__(self, config: KernelConfig, vision_config: VisionConfig): super().__init__() self.config = config self.v_config = vision_config # Patch embedding: flatten patch → linear → hidden_dim self.patch_embed = nn.Linear(vision_config.patch_dim, config.hidden_dim, bias=False) # Learnable position embeddings for patches num_pos = vision_config.num_patches + (1 if vision_config.use_cls_token else 0) self.pos_embed = nn.Parameter(torch.zeros(1, num_pos, config.hidden_dim)) nn.init.trunc_normal_(self.pos_embed, std=0.02) # Optional [CLS] token — summarizes the full image if vision_config.use_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_dim)) nn.init.trunc_normal_(self.cls_token, std=0.02) # Layer norm before projection into backbone self.norm = nn.LayerNorm(config.hidden_dim) self.dropout = nn.Dropout(vision_config.dropout) self._init_weights() def _init_weights(self): # Xavier init for patch embedding — critical for gradient flow nn.init.xavier_uniform_(self.patch_embed.weight) def patchify(self, images: torch.Tensor) -> torch.Tensor: """ Convert images to patch sequences. Args: images: (B, C, H, W) — pixel values [0, 1] Returns: (B, num_patches, patch_dim) — flattened patches """ B, C, H, W = images.shape p = self.v_config.patch_size assert H == W == self.v_config.image_size, \ f"Expected {self.v_config.image_size}×{self.v_config.image_size}, got {H}×{W}" # Unfold: (B, C, H, W) → (B, C, H//p, p, W//p, p) → (B, num_patches, patch_dim) patches = images.reshape(B, C, H // p, p, W // p, p) patches = patches.permute(0, 2, 4, 1, 3, 5) # (B, H//p, W//p, C, p, p) patches = patches.reshape(B, -1, self.v_config.patch_dim) # (B, num_patches, patch_dim) return patches def forward(self, images: torch.Tensor) -> torch.Tensor: """ Encode images into the token manifold. Args: images: (B, C, H, W) pixel values normalized to [0, 1] Returns: (B, num_patches [+ 1], hidden_dim) — vision tokens """ B = images.shape[0] # 1. Patchify patches = self.patchify(images) # (B, N, patch_dim) # 2. Linear projection into hidden_dim x = self.patch_embed(patches) # (B, N, D) # 3. Prepend [CLS] if configured if self.v_config.use_cls_token: cls = self.cls_token.expand(B, -1, -1) # (B, 1, D) x = torch.cat([cls, x], dim=1) # (B, N+1, D) # 4. Add positional embeddings x = x + self.pos_embed[:, :x.shape[1], :] # 5. Normalize and dropout x = self.norm(x) x = self.dropout(x) return x # ── Audio Cortex ────────────────────────────────────────────── class AudioEncoder(nn.Module): """ Mel spectrogram → Patches → hidden_dim projections. Audio is inherently 2D when represented as a spectrogram: frequency × time. We treat it exactly like an image — patch it, project it, add position embeddings. The Time2Vec engine in GLADIUS already provides temporal awareness, so audio patches inherit temporal context from the backbone for free. Input: Pre-computed mel spectrogram (B, 1, n_mels, n_frames) The mel transform happens OUTSIDE the model (preprocessing). This keeps the kernel clean — raw audio processing is a sensor, not cognition. """ def __init__(self, config: KernelConfig, audio_config: AudioConfig): super().__init__() self.config = config self.a_config = audio_config # Patch embedding: flatten freq×time patch → hidden_dim self.patch_embed = nn.Linear(audio_config.patch_dim, config.hidden_dim, bias=False) # Separate positional embeddings for frequency and time axes n_freq_patches = audio_config.n_mels // audio_config.patch_size_freq n_time_patches = audio_config.n_frames // audio_config.patch_size_time num_pos = n_freq_patches * n_time_patches + (1 if audio_config.use_cls_token else 0) self.pos_embed = nn.Parameter(torch.zeros(1, num_pos, config.hidden_dim)) nn.init.trunc_normal_(self.pos_embed, std=0.02) # Frequency and time axis embeddings (2D positional decomposition) self.freq_embed = nn.Parameter(torch.zeros(1, n_freq_patches, config.hidden_dim)) self.time_embed = nn.Parameter(torch.zeros(1, n_time_patches, config.hidden_dim)) nn.init.trunc_normal_(self.freq_embed, std=0.02) nn.init.trunc_normal_(self.time_embed, std=0.02) # [CLS] for audio if audio_config.use_cls_token: self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_dim)) nn.init.trunc_normal_(self.cls_token, std=0.02) self.norm = nn.LayerNorm(config.hidden_dim) self.dropout = nn.Dropout(audio_config.dropout) self._init_weights() def _init_weights(self): nn.init.xavier_uniform_(self.patch_embed.weight) def patchify(self, mel: torch.Tensor) -> Tuple[torch.Tensor, int, int]: """ Convert mel spectrogram to patch sequence. Args: mel: (B, 1, n_mels, n_frames) — mel spectrogram Returns: patches: (B, num_patches, patch_dim) n_freq: number of frequency patches n_time: number of time patches """ B, C, F, T = mel.shape pf = self.a_config.patch_size_freq pt = self.a_config.patch_size_time n_freq = F // pf n_time = T // pt # (B, 1, F, T) → (B, n_freq, pf, n_time, pt) → (B, n_freq*n_time, pf*pt) patches = mel.reshape(B, C, n_freq, pf, n_time, pt) patches = patches.permute(0, 2, 4, 1, 3, 5) # (B, n_freq, n_time, C, pf, pt) patches = patches.reshape(B, n_freq * n_time, pf * pt) # (B, N, patch_dim) return patches, n_freq, n_time def forward(self, mel: torch.Tensor) -> torch.Tensor: """ Encode mel spectrogram into the token manifold. Args: mel: (B, 1, n_mels, n_frames) mel spectrogram, normalized Returns: (B, num_patches [+ 1], hidden_dim) — audio tokens """ B = mel.shape[0] # 1. Patchify patches, n_freq, n_time = self.patchify(mel) # (B, N, patch_dim) # 2. Project to hidden_dim x = self.patch_embed(patches) # (B, N, D) # 3. Add 2D positional decomposition # Broadcast: freq_embed (1, n_freq, D) × time_embed (1, n_time, D) # → (1, n_freq, 1, D) + (1, 1, n_time, D) → (1, n_freq, n_time, D) → (1, N, D) pos_2d = self.freq_embed.unsqueeze(2) + self.time_embed.unsqueeze(1) pos_2d = pos_2d.reshape(1, n_freq * n_time, -1) x = x + pos_2d # 4. Prepend [CLS] if self.a_config.use_cls_token: cls = self.cls_token.expand(B, -1, -1) x = torch.cat([cls, x], dim=1) # Add absolute positional for CLS + patches x = x + self.pos_embed[:, :x.shape[1], :] else: x = x + self.pos_embed[:, :x.shape[1], :] # 5. Normalize and dropout x = self.norm(x) x = self.dropout(x) return x # ── Unified Sensory Cortex ──────────────────────────────────── class SensoryCortex(nn.Module): """ The sensory integration layer. Manages all modality encoders and produces a unified token sequence for the transformer backbone. Handles: 1. Modality-specific encoding (vision, audio) 2. Modality embedding injection (so the transformer knows WHAT) 3. Sequence construction (interleaving or concatenation) 4. Cross-modal positional awareness The cortex is OPTIONAL — the kernel works exactly as before with text-only input. Sensory data is additive, never required. Usage: cortex = SensoryCortex(config, vision_config, audio_config) # Vision only tokens = cortex(text_embeds=text, images=images) # Audio only tokens = cortex(text_embeds=text, audio=mel) # Full multimodal tokens = cortex(text_embeds=text, images=images, audio=mel) # Text only (passthrough) tokens = cortex(text_embeds=text) """ def __init__( self, config: KernelConfig, vision_config: Optional[VisionConfig] = None, audio_config: Optional[AudioConfig] = None, ): super().__init__() self.config = config self.has_vision = vision_config is not None self.has_audio = audio_config is not None # Count modalities: 0=text (always), 1=vision, 2=audio num_modalities = 3 self.modality_embed = ModalityEmbedding(num_modalities, config.hidden_dim) # Sensory encoders if self.has_vision: self.vision = VisionEncoder(config, vision_config) self.vision_config = vision_config if self.has_audio: self.audio = AudioEncoder(config, audio_config) self.audio_config = audio_config # Cross-modal position encoding # When multiple modalities are present, we need the model to know # the GLOBAL position in the unified sequence, not just within-modality self.global_pos_scale = nn.Parameter(torch.ones(1)) def forward( self, text_embeds: Optional[torch.Tensor] = None, images: Optional[torch.Tensor] = None, audio: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Combine all available modalities into a unified sequence. Order: [vision_tokens] [audio_tokens] [text_tokens] Vision/audio BEFORE text — the creature sees and hears before it speaks. This is the natural order. Args: text_embeds: (B, S_text, D) — already embedded text tokens images: (B, C, H, W) — raw pixel values [0, 1] audio: (B, 1, n_mels, n_frames) — mel spectrogram Returns: unified: (B, S_total, D) — unified token sequence modality_mask: (B, S_total) — modality labels (0=text, 1=vision, 2=audio) """ sequences = [] modality_labels = [] B = None device = None # Determine batch size and device from whatever is provided if text_embeds is not None: B = text_embeds.shape[0] device = text_embeds.device elif images is not None: B = images.shape[0] device = images.device elif audio is not None: B = audio.shape[0] device = audio.device else: raise ValueError("At least one modality must be provided") # 1. Vision if images is not None and self.has_vision: v_tokens = self.vision(images) # (B, S_v, D) v_tokens = v_tokens + self.modality_embed(1, v_tokens.shape[1], B) sequences.append(v_tokens) modality_labels.append(torch.full((B, v_tokens.shape[1]), 1, device=device)) # 2. Audio if audio is not None and self.has_audio: a_tokens = self.audio(audio) # (B, S_a, D) a_tokens = a_tokens + self.modality_embed(2, a_tokens.shape[1], B) sequences.append(a_tokens) modality_labels.append(torch.full((B, a_tokens.shape[1]), 2, device=device)) # 3. Text (always last — see before speak) if text_embeds is not None: t_tokens = text_embeds + self.modality_embed(0, text_embeds.shape[1], B) sequences.append(t_tokens) modality_labels.append(torch.full((B, text_embeds.shape[1]), 0, device=device)) # Concatenate unified = torch.cat(sequences, dim=1) # (B, S_total, D) modality_mask = torch.cat(modality_labels, dim=1) # (B, S_total) return unified, modality_mask def param_count(self) -> dict: """Report parameter count by component.""" counts = {'modality_embed': sum(p.numel() for p in self.modality_embed.parameters())} if self.has_vision: counts['vision'] = sum(p.numel() for p in self.vision.parameters()) if self.has_audio: counts['audio'] = sum(p.numel() for p in self.audio.parameters()) counts['total'] = sum(counts.values()) return counts # ── Preset Configurations ───────────────────────────────────── def mnist_vision_config() -> VisionConfig: """MNIST: 28×28 grayscale → 49 patches of 4×4.""" return VisionConfig( image_size=28, patch_size=4, in_channels=1, use_cls_token=True, ) def cifar_vision_config() -> VisionConfig: """CIFAR-10: 32×32 RGB → 64 patches of 4×4.""" return VisionConfig( image_size=32, patch_size=4, in_channels=3, use_cls_token=True, ) def imagenet_vision_config() -> VisionConfig: """ImageNet: 224×224 RGB → 196 patches of 16×16.""" return VisionConfig( image_size=224, patch_size=16, in_channels=3, use_cls_token=True, ) def speech_audio_config() -> AudioConfig: """Speech: ~2 sec at 16kHz, 64 mel bands.""" return AudioConfig( n_mels=64, n_frames=128, patch_size_freq=8, patch_size_time=8, use_cls_token=True, ) def music_audio_config() -> AudioConfig: """Music: ~4 sec at 22kHz, 128 mel bands.""" return AudioConfig( n_mels=128, n_frames=256, patch_size_freq=16, patch_size_time=16, use_cls_token=True, )