amuzetnoM's picture
WYRM kernel source (v27 FINAL)
9463e5c verified
"""
GLADIUS v2.0 β€” Sensory Cortex
Vision and Audio perception modules.
Design principle: Every sensory input projects into the SAME hidden_dim
manifold as text tokens. The transformer backbone doesn't know β€” and
shouldn't know β€” whether it's processing text, image patches, or audio
frames. All are tokens. All live in the same space.
Biological analogy: The thalamus doesn't differentiate modalities β€”
it routes everything into cortical columns. The cortex learns what
to do with each modality through the patterns in the data, not
through hardcoded pathways.
The sensory cortex adds exactly two things:
1. Encoders that project raw sensory data into hidden_dim
2. Modality embeddings that let the transformer KNOW what it's looking at
That's it. No separate attention. No separate heads. The existing
SLAΒ² attention, MoE router, memory, and cognition systems handle
everything else. They were always designed to β€” they just never had
sensory data to work with.
Architecture:
Image β†’ PatchEncoder β†’ [CLS] + patch_embeds + modality_embed
Audio β†’ MelEncoder β†’ [CLS] + frame_embeds + modality_embed
Text β†’ TokenEmbed β†’ token_embeds + modality_embed
All three produce (B, S, hidden_dim) tensors that concatenate into
a unified sequence for the transformer backbone.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass
from typing import Optional, Tuple
from .config import KernelConfig
# ── Configuration ──────────────────────────────────────────────
@dataclass
class VisionConfig:
"""Configuration for the vision sensory cortex."""
image_size: int = 28 # Input image dimension (square)
patch_size: int = 4 # Patch dimension (image_size must be divisible)
in_channels: int = 1 # Grayscale=1, RGB=3
use_cls_token: bool = True # Prepend [CLS] vision token
dropout: float = 0.1
@property
def num_patches(self) -> int:
return (self.image_size // self.patch_size) ** 2
@property
def patch_dim(self) -> int:
return self.patch_size * self.patch_size * self.in_channels
@dataclass
class AudioConfig:
"""Configuration for the audio sensory cortex."""
n_mels: int = 64 # Mel frequency bins
n_frames: int = 128 # Time frames (1-2 sec at 16kHz)
patch_size_freq: int = 8 # Frequency patch size
patch_size_time: int = 8 # Time patch size
use_cls_token: bool = True
dropout: float = 0.1
@property
def num_patches(self) -> int:
return (self.n_mels // self.patch_size_freq) * (self.n_frames // self.patch_size_time)
@property
def patch_dim(self) -> int:
return self.patch_size_freq * self.patch_size_time
# ── Modality Embeddings ───────────────────────────────────────
class ModalityEmbedding(nn.Module):
"""
Learned modality identifier.
Each modality (text, vision, audio) gets a unique learned embedding
added to every token of that modality. This is how the transformer
knows WHAT it's processing β€” not through architecture, but through
a learned signal in the data.
Analogy: synaesthesia. The modality embedding IS the "color" of
a sound or the "texture" of a word. It's information, not structure.
"""
def __init__(self, num_modalities: int, hidden_dim: int):
super().__init__()
# 0=text, 1=vision, 2=audio (extensible)
self.embed = nn.Embedding(num_modalities, hidden_dim)
nn.init.normal_(self.embed.weight, mean=0.0, std=0.02)
def forward(self, modality_id: int, seq_len: int, batch_size: int = 1) -> torch.Tensor:
"""
Returns: (B, S, D) modality embedding to ADD to hidden states.
"""
ids = torch.full(
(batch_size, seq_len), modality_id,
dtype=torch.long, device=self.embed.weight.device
)
return self.embed(ids)
# ── Vision Cortex ─────────────────────────────────────────────
class VisionEncoder(nn.Module):
"""
Image β†’ Patches β†’ hidden_dim projections.
No convolutions. No pretrained backbone. Pure patch embedding
with positional encoding β€” the simplest possible thing that
projects visual data into the token manifold.
ViT showed this works at scale. At 28Γ—28 (MNIST), we get 49
patches of 4Γ—4 pixels = 16 dimensions each, projected to hidden_dim.
For larger images, increase image_size and patch_size proportionally.
At 224Γ—224 with patch_size=16: 196 patches, 768-dim patches.
"""
def __init__(self, config: KernelConfig, vision_config: VisionConfig):
super().__init__()
self.config = config
self.v_config = vision_config
# Patch embedding: flatten patch β†’ linear β†’ hidden_dim
self.patch_embed = nn.Linear(vision_config.patch_dim, config.hidden_dim, bias=False)
# Learnable position embeddings for patches
num_pos = vision_config.num_patches + (1 if vision_config.use_cls_token else 0)
self.pos_embed = nn.Parameter(torch.zeros(1, num_pos, config.hidden_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
# Optional [CLS] token β€” summarizes the full image
if vision_config.use_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_dim))
nn.init.trunc_normal_(self.cls_token, std=0.02)
# Layer norm before projection into backbone
self.norm = nn.LayerNorm(config.hidden_dim)
self.dropout = nn.Dropout(vision_config.dropout)
self._init_weights()
def _init_weights(self):
# Xavier init for patch embedding β€” critical for gradient flow
nn.init.xavier_uniform_(self.patch_embed.weight)
def patchify(self, images: torch.Tensor) -> torch.Tensor:
"""
Convert images to patch sequences.
Args:
images: (B, C, H, W) β€” pixel values [0, 1]
Returns:
(B, num_patches, patch_dim) β€” flattened patches
"""
B, C, H, W = images.shape
p = self.v_config.patch_size
assert H == W == self.v_config.image_size, \
f"Expected {self.v_config.image_size}Γ—{self.v_config.image_size}, got {H}Γ—{W}"
# Unfold: (B, C, H, W) β†’ (B, C, H//p, p, W//p, p) β†’ (B, num_patches, patch_dim)
patches = images.reshape(B, C, H // p, p, W // p, p)
patches = patches.permute(0, 2, 4, 1, 3, 5) # (B, H//p, W//p, C, p, p)
patches = patches.reshape(B, -1, self.v_config.patch_dim) # (B, num_patches, patch_dim)
return patches
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""
Encode images into the token manifold.
Args:
images: (B, C, H, W) pixel values normalized to [0, 1]
Returns:
(B, num_patches [+ 1], hidden_dim) β€” vision tokens
"""
B = images.shape[0]
# 1. Patchify
patches = self.patchify(images) # (B, N, patch_dim)
# 2. Linear projection into hidden_dim
x = self.patch_embed(patches) # (B, N, D)
# 3. Prepend [CLS] if configured
if self.v_config.use_cls_token:
cls = self.cls_token.expand(B, -1, -1) # (B, 1, D)
x = torch.cat([cls, x], dim=1) # (B, N+1, D)
# 4. Add positional embeddings
x = x + self.pos_embed[:, :x.shape[1], :]
# 5. Normalize and dropout
x = self.norm(x)
x = self.dropout(x)
return x
# ── Audio Cortex ──────────────────────────────────────────────
class AudioEncoder(nn.Module):
"""
Mel spectrogram β†’ Patches β†’ hidden_dim projections.
Audio is inherently 2D when represented as a spectrogram:
frequency Γ— time. We treat it exactly like an image β€” patch it,
project it, add position embeddings.
The Time2Vec engine in GLADIUS already provides temporal awareness,
so audio patches inherit temporal context from the backbone for free.
Input: Pre-computed mel spectrogram (B, 1, n_mels, n_frames)
The mel transform happens OUTSIDE the model (preprocessing).
This keeps the kernel clean β€” raw audio processing is a sensor,
not cognition.
"""
def __init__(self, config: KernelConfig, audio_config: AudioConfig):
super().__init__()
self.config = config
self.a_config = audio_config
# Patch embedding: flatten freqΓ—time patch β†’ hidden_dim
self.patch_embed = nn.Linear(audio_config.patch_dim, config.hidden_dim, bias=False)
# Separate positional embeddings for frequency and time axes
n_freq_patches = audio_config.n_mels // audio_config.patch_size_freq
n_time_patches = audio_config.n_frames // audio_config.patch_size_time
num_pos = n_freq_patches * n_time_patches + (1 if audio_config.use_cls_token else 0)
self.pos_embed = nn.Parameter(torch.zeros(1, num_pos, config.hidden_dim))
nn.init.trunc_normal_(self.pos_embed, std=0.02)
# Frequency and time axis embeddings (2D positional decomposition)
self.freq_embed = nn.Parameter(torch.zeros(1, n_freq_patches, config.hidden_dim))
self.time_embed = nn.Parameter(torch.zeros(1, n_time_patches, config.hidden_dim))
nn.init.trunc_normal_(self.freq_embed, std=0.02)
nn.init.trunc_normal_(self.time_embed, std=0.02)
# [CLS] for audio
if audio_config.use_cls_token:
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_dim))
nn.init.trunc_normal_(self.cls_token, std=0.02)
self.norm = nn.LayerNorm(config.hidden_dim)
self.dropout = nn.Dropout(audio_config.dropout)
self._init_weights()
def _init_weights(self):
nn.init.xavier_uniform_(self.patch_embed.weight)
def patchify(self, mel: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
"""
Convert mel spectrogram to patch sequence.
Args:
mel: (B, 1, n_mels, n_frames) β€” mel spectrogram
Returns:
patches: (B, num_patches, patch_dim)
n_freq: number of frequency patches
n_time: number of time patches
"""
B, C, F, T = mel.shape
pf = self.a_config.patch_size_freq
pt = self.a_config.patch_size_time
n_freq = F // pf
n_time = T // pt
# (B, 1, F, T) β†’ (B, n_freq, pf, n_time, pt) β†’ (B, n_freq*n_time, pf*pt)
patches = mel.reshape(B, C, n_freq, pf, n_time, pt)
patches = patches.permute(0, 2, 4, 1, 3, 5) # (B, n_freq, n_time, C, pf, pt)
patches = patches.reshape(B, n_freq * n_time, pf * pt) # (B, N, patch_dim)
return patches, n_freq, n_time
def forward(self, mel: torch.Tensor) -> torch.Tensor:
"""
Encode mel spectrogram into the token manifold.
Args:
mel: (B, 1, n_mels, n_frames) mel spectrogram, normalized
Returns:
(B, num_patches [+ 1], hidden_dim) β€” audio tokens
"""
B = mel.shape[0]
# 1. Patchify
patches, n_freq, n_time = self.patchify(mel) # (B, N, patch_dim)
# 2. Project to hidden_dim
x = self.patch_embed(patches) # (B, N, D)
# 3. Add 2D positional decomposition
# Broadcast: freq_embed (1, n_freq, D) Γ— time_embed (1, n_time, D)
# β†’ (1, n_freq, 1, D) + (1, 1, n_time, D) β†’ (1, n_freq, n_time, D) β†’ (1, N, D)
pos_2d = self.freq_embed.unsqueeze(2) + self.time_embed.unsqueeze(1)
pos_2d = pos_2d.reshape(1, n_freq * n_time, -1)
x = x + pos_2d
# 4. Prepend [CLS]
if self.a_config.use_cls_token:
cls = self.cls_token.expand(B, -1, -1)
x = torch.cat([cls, x], dim=1)
# Add absolute positional for CLS + patches
x = x + self.pos_embed[:, :x.shape[1], :]
else:
x = x + self.pos_embed[:, :x.shape[1], :]
# 5. Normalize and dropout
x = self.norm(x)
x = self.dropout(x)
return x
# ── Unified Sensory Cortex ────────────────────────────────────
class SensoryCortex(nn.Module):
"""
The sensory integration layer.
Manages all modality encoders and produces a unified token
sequence for the transformer backbone. Handles:
1. Modality-specific encoding (vision, audio)
2. Modality embedding injection (so the transformer knows WHAT)
3. Sequence construction (interleaving or concatenation)
4. Cross-modal positional awareness
The cortex is OPTIONAL β€” the kernel works exactly as before
with text-only input. Sensory data is additive, never required.
Usage:
cortex = SensoryCortex(config, vision_config, audio_config)
# Vision only
tokens = cortex(text_embeds=text, images=images)
# Audio only
tokens = cortex(text_embeds=text, audio=mel)
# Full multimodal
tokens = cortex(text_embeds=text, images=images, audio=mel)
# Text only (passthrough)
tokens = cortex(text_embeds=text)
"""
def __init__(
self,
config: KernelConfig,
vision_config: Optional[VisionConfig] = None,
audio_config: Optional[AudioConfig] = None,
):
super().__init__()
self.config = config
self.has_vision = vision_config is not None
self.has_audio = audio_config is not None
# Count modalities: 0=text (always), 1=vision, 2=audio
num_modalities = 3
self.modality_embed = ModalityEmbedding(num_modalities, config.hidden_dim)
# Sensory encoders
if self.has_vision:
self.vision = VisionEncoder(config, vision_config)
self.vision_config = vision_config
if self.has_audio:
self.audio = AudioEncoder(config, audio_config)
self.audio_config = audio_config
# Cross-modal position encoding
# When multiple modalities are present, we need the model to know
# the GLOBAL position in the unified sequence, not just within-modality
self.global_pos_scale = nn.Parameter(torch.ones(1))
def forward(
self,
text_embeds: Optional[torch.Tensor] = None,
images: Optional[torch.Tensor] = None,
audio: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Combine all available modalities into a unified sequence.
Order: [vision_tokens] [audio_tokens] [text_tokens]
Vision/audio BEFORE text β€” the creature sees and hears
before it speaks. This is the natural order.
Args:
text_embeds: (B, S_text, D) β€” already embedded text tokens
images: (B, C, H, W) β€” raw pixel values [0, 1]
audio: (B, 1, n_mels, n_frames) β€” mel spectrogram
Returns:
unified: (B, S_total, D) β€” unified token sequence
modality_mask: (B, S_total) β€” modality labels (0=text, 1=vision, 2=audio)
"""
sequences = []
modality_labels = []
B = None
device = None
# Determine batch size and device from whatever is provided
if text_embeds is not None:
B = text_embeds.shape[0]
device = text_embeds.device
elif images is not None:
B = images.shape[0]
device = images.device
elif audio is not None:
B = audio.shape[0]
device = audio.device
else:
raise ValueError("At least one modality must be provided")
# 1. Vision
if images is not None and self.has_vision:
v_tokens = self.vision(images) # (B, S_v, D)
v_tokens = v_tokens + self.modality_embed(1, v_tokens.shape[1], B)
sequences.append(v_tokens)
modality_labels.append(torch.full((B, v_tokens.shape[1]), 1, device=device))
# 2. Audio
if audio is not None and self.has_audio:
a_tokens = self.audio(audio) # (B, S_a, D)
a_tokens = a_tokens + self.modality_embed(2, a_tokens.shape[1], B)
sequences.append(a_tokens)
modality_labels.append(torch.full((B, a_tokens.shape[1]), 2, device=device))
# 3. Text (always last β€” see before speak)
if text_embeds is not None:
t_tokens = text_embeds + self.modality_embed(0, text_embeds.shape[1], B)
sequences.append(t_tokens)
modality_labels.append(torch.full((B, text_embeds.shape[1]), 0, device=device))
# Concatenate
unified = torch.cat(sequences, dim=1) # (B, S_total, D)
modality_mask = torch.cat(modality_labels, dim=1) # (B, S_total)
return unified, modality_mask
def param_count(self) -> dict:
"""Report parameter count by component."""
counts = {'modality_embed': sum(p.numel() for p in self.modality_embed.parameters())}
if self.has_vision:
counts['vision'] = sum(p.numel() for p in self.vision.parameters())
if self.has_audio:
counts['audio'] = sum(p.numel() for p in self.audio.parameters())
counts['total'] = sum(counts.values())
return counts
# ── Preset Configurations ─────────────────────────────────────
def mnist_vision_config() -> VisionConfig:
"""MNIST: 28Γ—28 grayscale β†’ 49 patches of 4Γ—4."""
return VisionConfig(
image_size=28,
patch_size=4,
in_channels=1,
use_cls_token=True,
)
def cifar_vision_config() -> VisionConfig:
"""CIFAR-10: 32Γ—32 RGB β†’ 64 patches of 4Γ—4."""
return VisionConfig(
image_size=32,
patch_size=4,
in_channels=3,
use_cls_token=True,
)
def imagenet_vision_config() -> VisionConfig:
"""ImageNet: 224Γ—224 RGB β†’ 196 patches of 16Γ—16."""
return VisionConfig(
image_size=224,
patch_size=16,
in_channels=3,
use_cls_token=True,
)
def speech_audio_config() -> AudioConfig:
"""Speech: ~2 sec at 16kHz, 64 mel bands."""
return AudioConfig(
n_mels=64,
n_frames=128,
patch_size_freq=8,
patch_size_time=8,
use_cls_token=True,
)
def music_audio_config() -> AudioConfig:
"""Music: ~4 sec at 22kHz, 128 mel bands."""
return AudioConfig(
n_mels=128,
n_frames=256,
patch_size_freq=16,
patch_size_time=16,
use_cls_token=True,
)