| """ |
| GLADIUS v2.0 β Sensory Cortex |
| |
| Vision and Audio perception modules. |
| |
| Design principle: Every sensory input projects into the SAME hidden_dim |
| manifold as text tokens. The transformer backbone doesn't know β and |
| shouldn't know β whether it's processing text, image patches, or audio |
| frames. All are tokens. All live in the same space. |
| |
| Biological analogy: The thalamus doesn't differentiate modalities β |
| it routes everything into cortical columns. The cortex learns what |
| to do with each modality through the patterns in the data, not |
| through hardcoded pathways. |
| |
| The sensory cortex adds exactly two things: |
| 1. Encoders that project raw sensory data into hidden_dim |
| 2. Modality embeddings that let the transformer KNOW what it's looking at |
| |
| That's it. No separate attention. No separate heads. The existing |
| SLAΒ² attention, MoE router, memory, and cognition systems handle |
| everything else. They were always designed to β they just never had |
| sensory data to work with. |
| |
| Architecture: |
| Image β PatchEncoder β [CLS] + patch_embeds + modality_embed |
| Audio β MelEncoder β [CLS] + frame_embeds + modality_embed |
| Text β TokenEmbed β token_embeds + modality_embed |
| |
| All three produce (B, S, hidden_dim) tensors that concatenate into |
| a unified sequence for the transformer backbone. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
|
|
| from .config import KernelConfig |
|
|
|
|
| |
|
|
| @dataclass |
| class VisionConfig: |
| """Configuration for the vision sensory cortex.""" |
| image_size: int = 28 |
| patch_size: int = 4 |
| in_channels: int = 1 |
| use_cls_token: bool = True |
| dropout: float = 0.1 |
|
|
| @property |
| def num_patches(self) -> int: |
| return (self.image_size // self.patch_size) ** 2 |
|
|
| @property |
| def patch_dim(self) -> int: |
| return self.patch_size * self.patch_size * self.in_channels |
|
|
|
|
| @dataclass |
| class AudioConfig: |
| """Configuration for the audio sensory cortex.""" |
| n_mels: int = 64 |
| n_frames: int = 128 |
| patch_size_freq: int = 8 |
| patch_size_time: int = 8 |
| use_cls_token: bool = True |
| dropout: float = 0.1 |
|
|
| @property |
| def num_patches(self) -> int: |
| return (self.n_mels // self.patch_size_freq) * (self.n_frames // self.patch_size_time) |
|
|
| @property |
| def patch_dim(self) -> int: |
| return self.patch_size_freq * self.patch_size_time |
|
|
|
|
| |
|
|
| class ModalityEmbedding(nn.Module): |
| """ |
| Learned modality identifier. |
| |
| Each modality (text, vision, audio) gets a unique learned embedding |
| added to every token of that modality. This is how the transformer |
| knows WHAT it's processing β not through architecture, but through |
| a learned signal in the data. |
| |
| Analogy: synaesthesia. The modality embedding IS the "color" of |
| a sound or the "texture" of a word. It's information, not structure. |
| """ |
| |
| def __init__(self, num_modalities: int, hidden_dim: int): |
| super().__init__() |
| |
| self.embed = nn.Embedding(num_modalities, hidden_dim) |
| nn.init.normal_(self.embed.weight, mean=0.0, std=0.02) |
| |
| def forward(self, modality_id: int, seq_len: int, batch_size: int = 1) -> torch.Tensor: |
| """ |
| Returns: (B, S, D) modality embedding to ADD to hidden states. |
| """ |
| ids = torch.full( |
| (batch_size, seq_len), modality_id, |
| dtype=torch.long, device=self.embed.weight.device |
| ) |
| return self.embed(ids) |
|
|
|
|
| |
|
|
| class VisionEncoder(nn.Module): |
| """ |
| Image β Patches β hidden_dim projections. |
| |
| No convolutions. No pretrained backbone. Pure patch embedding |
| with positional encoding β the simplest possible thing that |
| projects visual data into the token manifold. |
| |
| ViT showed this works at scale. At 28Γ28 (MNIST), we get 49 |
| patches of 4Γ4 pixels = 16 dimensions each, projected to hidden_dim. |
| |
| For larger images, increase image_size and patch_size proportionally. |
| At 224Γ224 with patch_size=16: 196 patches, 768-dim patches. |
| """ |
| |
| def __init__(self, config: KernelConfig, vision_config: VisionConfig): |
| super().__init__() |
| self.config = config |
| self.v_config = vision_config |
| |
| |
| self.patch_embed = nn.Linear(vision_config.patch_dim, config.hidden_dim, bias=False) |
| |
| |
| num_pos = vision_config.num_patches + (1 if vision_config.use_cls_token else 0) |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_pos, config.hidden_dim)) |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| |
| |
| if vision_config.use_cls_token: |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_dim)) |
| nn.init.trunc_normal_(self.cls_token, std=0.02) |
| |
| |
| self.norm = nn.LayerNorm(config.hidden_dim) |
| self.dropout = nn.Dropout(vision_config.dropout) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| |
| nn.init.xavier_uniform_(self.patch_embed.weight) |
| |
| def patchify(self, images: torch.Tensor) -> torch.Tensor: |
| """ |
| Convert images to patch sequences. |
| |
| Args: |
| images: (B, C, H, W) β pixel values [0, 1] |
| |
| Returns: |
| (B, num_patches, patch_dim) β flattened patches |
| """ |
| B, C, H, W = images.shape |
| p = self.v_config.patch_size |
| |
| assert H == W == self.v_config.image_size, \ |
| f"Expected {self.v_config.image_size}Γ{self.v_config.image_size}, got {H}Γ{W}" |
| |
| |
| patches = images.reshape(B, C, H // p, p, W // p, p) |
| patches = patches.permute(0, 2, 4, 1, 3, 5) |
| patches = patches.reshape(B, -1, self.v_config.patch_dim) |
| |
| return patches |
| |
| def forward(self, images: torch.Tensor) -> torch.Tensor: |
| """ |
| Encode images into the token manifold. |
| |
| Args: |
| images: (B, C, H, W) pixel values normalized to [0, 1] |
| |
| Returns: |
| (B, num_patches [+ 1], hidden_dim) β vision tokens |
| """ |
| B = images.shape[0] |
| |
| |
| patches = self.patchify(images) |
| |
| |
| x = self.patch_embed(patches) |
| |
| |
| if self.v_config.use_cls_token: |
| cls = self.cls_token.expand(B, -1, -1) |
| x = torch.cat([cls, x], dim=1) |
| |
| |
| x = x + self.pos_embed[:, :x.shape[1], :] |
| |
| |
| x = self.norm(x) |
| x = self.dropout(x) |
| |
| return x |
|
|
|
|
| |
|
|
| class AudioEncoder(nn.Module): |
| """ |
| Mel spectrogram β Patches β hidden_dim projections. |
| |
| Audio is inherently 2D when represented as a spectrogram: |
| frequency Γ time. We treat it exactly like an image β patch it, |
| project it, add position embeddings. |
| |
| The Time2Vec engine in GLADIUS already provides temporal awareness, |
| so audio patches inherit temporal context from the backbone for free. |
| |
| Input: Pre-computed mel spectrogram (B, 1, n_mels, n_frames) |
| The mel transform happens OUTSIDE the model (preprocessing). |
| This keeps the kernel clean β raw audio processing is a sensor, |
| not cognition. |
| """ |
| |
| def __init__(self, config: KernelConfig, audio_config: AudioConfig): |
| super().__init__() |
| self.config = config |
| self.a_config = audio_config |
| |
| |
| self.patch_embed = nn.Linear(audio_config.patch_dim, config.hidden_dim, bias=False) |
| |
| |
| n_freq_patches = audio_config.n_mels // audio_config.patch_size_freq |
| n_time_patches = audio_config.n_frames // audio_config.patch_size_time |
| num_pos = n_freq_patches * n_time_patches + (1 if audio_config.use_cls_token else 0) |
| |
| self.pos_embed = nn.Parameter(torch.zeros(1, num_pos, config.hidden_dim)) |
| nn.init.trunc_normal_(self.pos_embed, std=0.02) |
| |
| |
| self.freq_embed = nn.Parameter(torch.zeros(1, n_freq_patches, config.hidden_dim)) |
| self.time_embed = nn.Parameter(torch.zeros(1, n_time_patches, config.hidden_dim)) |
| nn.init.trunc_normal_(self.freq_embed, std=0.02) |
| nn.init.trunc_normal_(self.time_embed, std=0.02) |
| |
| |
| if audio_config.use_cls_token: |
| self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_dim)) |
| nn.init.trunc_normal_(self.cls_token, std=0.02) |
| |
| self.norm = nn.LayerNorm(config.hidden_dim) |
| self.dropout = nn.Dropout(audio_config.dropout) |
| |
| self._init_weights() |
| |
| def _init_weights(self): |
| nn.init.xavier_uniform_(self.patch_embed.weight) |
| |
| def patchify(self, mel: torch.Tensor) -> Tuple[torch.Tensor, int, int]: |
| """ |
| Convert mel spectrogram to patch sequence. |
| |
| Args: |
| mel: (B, 1, n_mels, n_frames) β mel spectrogram |
| |
| Returns: |
| patches: (B, num_patches, patch_dim) |
| n_freq: number of frequency patches |
| n_time: number of time patches |
| """ |
| B, C, F, T = mel.shape |
| pf = self.a_config.patch_size_freq |
| pt = self.a_config.patch_size_time |
| |
| n_freq = F // pf |
| n_time = T // pt |
| |
| |
| patches = mel.reshape(B, C, n_freq, pf, n_time, pt) |
| patches = patches.permute(0, 2, 4, 1, 3, 5) |
| patches = patches.reshape(B, n_freq * n_time, pf * pt) |
| |
| return patches, n_freq, n_time |
| |
| def forward(self, mel: torch.Tensor) -> torch.Tensor: |
| """ |
| Encode mel spectrogram into the token manifold. |
| |
| Args: |
| mel: (B, 1, n_mels, n_frames) mel spectrogram, normalized |
| |
| Returns: |
| (B, num_patches [+ 1], hidden_dim) β audio tokens |
| """ |
| B = mel.shape[0] |
| |
| |
| patches, n_freq, n_time = self.patchify(mel) |
| |
| |
| x = self.patch_embed(patches) |
| |
| |
| |
| |
| pos_2d = self.freq_embed.unsqueeze(2) + self.time_embed.unsqueeze(1) |
| pos_2d = pos_2d.reshape(1, n_freq * n_time, -1) |
| x = x + pos_2d |
| |
| |
| if self.a_config.use_cls_token: |
| cls = self.cls_token.expand(B, -1, -1) |
| x = torch.cat([cls, x], dim=1) |
| |
| x = x + self.pos_embed[:, :x.shape[1], :] |
| else: |
| x = x + self.pos_embed[:, :x.shape[1], :] |
| |
| |
| x = self.norm(x) |
| x = self.dropout(x) |
| |
| return x |
|
|
|
|
| |
|
|
| class SensoryCortex(nn.Module): |
| """ |
| The sensory integration layer. |
| |
| Manages all modality encoders and produces a unified token |
| sequence for the transformer backbone. Handles: |
| |
| 1. Modality-specific encoding (vision, audio) |
| 2. Modality embedding injection (so the transformer knows WHAT) |
| 3. Sequence construction (interleaving or concatenation) |
| 4. Cross-modal positional awareness |
| |
| The cortex is OPTIONAL β the kernel works exactly as before |
| with text-only input. Sensory data is additive, never required. |
| |
| Usage: |
| cortex = SensoryCortex(config, vision_config, audio_config) |
| |
| # Vision only |
| tokens = cortex(text_embeds=text, images=images) |
| |
| # Audio only |
| tokens = cortex(text_embeds=text, audio=mel) |
| |
| # Full multimodal |
| tokens = cortex(text_embeds=text, images=images, audio=mel) |
| |
| # Text only (passthrough) |
| tokens = cortex(text_embeds=text) |
| """ |
| |
| def __init__( |
| self, |
| config: KernelConfig, |
| vision_config: Optional[VisionConfig] = None, |
| audio_config: Optional[AudioConfig] = None, |
| ): |
| super().__init__() |
| self.config = config |
| self.has_vision = vision_config is not None |
| self.has_audio = audio_config is not None |
| |
| |
| num_modalities = 3 |
| self.modality_embed = ModalityEmbedding(num_modalities, config.hidden_dim) |
| |
| |
| if self.has_vision: |
| self.vision = VisionEncoder(config, vision_config) |
| self.vision_config = vision_config |
| |
| if self.has_audio: |
| self.audio = AudioEncoder(config, audio_config) |
| self.audio_config = audio_config |
| |
| |
| |
| |
| self.global_pos_scale = nn.Parameter(torch.ones(1)) |
| |
| def forward( |
| self, |
| text_embeds: Optional[torch.Tensor] = None, |
| images: Optional[torch.Tensor] = None, |
| audio: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Combine all available modalities into a unified sequence. |
| |
| Order: [vision_tokens] [audio_tokens] [text_tokens] |
| |
| Vision/audio BEFORE text β the creature sees and hears |
| before it speaks. This is the natural order. |
| |
| Args: |
| text_embeds: (B, S_text, D) β already embedded text tokens |
| images: (B, C, H, W) β raw pixel values [0, 1] |
| audio: (B, 1, n_mels, n_frames) β mel spectrogram |
| |
| Returns: |
| unified: (B, S_total, D) β unified token sequence |
| modality_mask: (B, S_total) β modality labels (0=text, 1=vision, 2=audio) |
| """ |
| sequences = [] |
| modality_labels = [] |
| B = None |
| device = None |
| |
| |
| if text_embeds is not None: |
| B = text_embeds.shape[0] |
| device = text_embeds.device |
| elif images is not None: |
| B = images.shape[0] |
| device = images.device |
| elif audio is not None: |
| B = audio.shape[0] |
| device = audio.device |
| else: |
| raise ValueError("At least one modality must be provided") |
| |
| |
| if images is not None and self.has_vision: |
| v_tokens = self.vision(images) |
| v_tokens = v_tokens + self.modality_embed(1, v_tokens.shape[1], B) |
| sequences.append(v_tokens) |
| modality_labels.append(torch.full((B, v_tokens.shape[1]), 1, device=device)) |
| |
| |
| if audio is not None and self.has_audio: |
| a_tokens = self.audio(audio) |
| a_tokens = a_tokens + self.modality_embed(2, a_tokens.shape[1], B) |
| sequences.append(a_tokens) |
| modality_labels.append(torch.full((B, a_tokens.shape[1]), 2, device=device)) |
| |
| |
| if text_embeds is not None: |
| t_tokens = text_embeds + self.modality_embed(0, text_embeds.shape[1], B) |
| sequences.append(t_tokens) |
| modality_labels.append(torch.full((B, text_embeds.shape[1]), 0, device=device)) |
| |
| |
| unified = torch.cat(sequences, dim=1) |
| modality_mask = torch.cat(modality_labels, dim=1) |
| |
| return unified, modality_mask |
| |
| def param_count(self) -> dict: |
| """Report parameter count by component.""" |
| counts = {'modality_embed': sum(p.numel() for p in self.modality_embed.parameters())} |
| if self.has_vision: |
| counts['vision'] = sum(p.numel() for p in self.vision.parameters()) |
| if self.has_audio: |
| counts['audio'] = sum(p.numel() for p in self.audio.parameters()) |
| counts['total'] = sum(counts.values()) |
| return counts |
|
|
|
|
| |
|
|
| def mnist_vision_config() -> VisionConfig: |
| """MNIST: 28Γ28 grayscale β 49 patches of 4Γ4.""" |
| return VisionConfig( |
| image_size=28, |
| patch_size=4, |
| in_channels=1, |
| use_cls_token=True, |
| ) |
|
|
| def cifar_vision_config() -> VisionConfig: |
| """CIFAR-10: 32Γ32 RGB β 64 patches of 4Γ4.""" |
| return VisionConfig( |
| image_size=32, |
| patch_size=4, |
| in_channels=3, |
| use_cls_token=True, |
| ) |
|
|
| def imagenet_vision_config() -> VisionConfig: |
| """ImageNet: 224Γ224 RGB β 196 patches of 16Γ16.""" |
| return VisionConfig( |
| image_size=224, |
| patch_size=16, |
| in_channels=3, |
| use_cls_token=True, |
| ) |
|
|
| def speech_audio_config() -> AudioConfig: |
| """Speech: ~2 sec at 16kHz, 64 mel bands.""" |
| return AudioConfig( |
| n_mels=64, |
| n_frames=128, |
| patch_size_freq=8, |
| patch_size_time=8, |
| use_cls_token=True, |
| ) |
|
|
| def music_audio_config() -> AudioConfig: |
| """Music: ~4 sec at 22kHz, 128 mel bands.""" |
| return AudioConfig( |
| n_mels=128, |
| n_frames=256, |
| patch_size_freq=16, |
| patch_size_time=16, |
| use_cls_token=True, |
| ) |
|
|