""" GLADIUS — Gaussian Specialist Head The specialist module that plugs into WYRM's NexusRouter. Generates 3D Gaussian Splat parameters from backbone hidden states. Two-stage hierarchical generation: Stage 1 (Anchors): Direct regression of coarse Gaussians from pooled hidden state Stage 2 (Details): VQ-coded fine Gaussians via cross-attention to backbone features Depth profile integration: learned per-layer gates determine which backbone layers contribute to structure (anchors) vs detail. """ import torch import torch.nn as nn import torch.nn.functional as F import math from .config import GaussianConfig from .vqvae import GaussianVQVAE, GaussianVQDecoder class AnchorHead(nn.Module): """ Generates K coarse anchor Gaussians from the backbone's pooled representation. Each anchor: position(3) + scale(3) + rotation(4) + opacity(1) + sh_dc(3) = 14 floats. Uses direct regression — anchors need precise continuous placement. """ def __init__(self, backbone_dim: int, config: GaussianConfig): super().__init__() self.config = config self.n_anchors = config.num_anchors self.out_dim = config.full_dim # 14 self.net = nn.Sequential( nn.Linear(backbone_dim, config.anchor_hidden), nn.LayerNorm(config.anchor_hidden), nn.GELU(), nn.Linear(config.anchor_hidden, config.anchor_hidden), nn.LayerNorm(config.anchor_hidden), nn.GELU(), nn.Linear(config.anchor_hidden, self.n_anchors * self.out_dim), ) self._init_weights() def _init_weights(self): # Initialize last layer small — anchors start near origin nn.init.normal_(self.net[-1].weight, std=0.01) nn.init.zeros_(self.net[-1].bias) def forward(self, pooled: torch.Tensor) -> torch.Tensor: """ Args: pooled: (B, backbone_dim) — pooled backbone output Returns: anchors: (B, K, 14) — K anchor Gaussians with full params """ B = pooled.shape[0] raw = self.net(pooled) # (B, K*14) raw = raw.view(B, self.n_anchors, self.out_dim) # Activate each component appropriately anchors = self._activate(raw) return anchors def _activate(self, raw: torch.Tensor) -> torch.Tensor: """Apply per-component activations to raw output.""" pos = torch.tanh(raw[..., :3]) * self.config.scene_scale # [-scale, +scale] scale = raw[..., 3:6].clamp(self.config.min_gaussian_scale, self.config.max_gaussian_scale) # Log-scale clamped rot = F.normalize(raw[..., 6:10], dim=-1) # Unit quaternion opacity = torch.sigmoid(raw[..., 10:11]) # [0, 1] color = torch.sigmoid(raw[..., 11:14]) # [0, 1] RGB return torch.cat([pos, scale, rot, opacity, color], dim=-1) class DetailHead(nn.Module): """ Generates M fine detail Gaussians per anchor using VQ tokens. Cross-attends from anchor queries to backbone hidden states, then predicts VQ codebook indices + continuous position offsets. """ def __init__(self, backbone_dim: int, config: GaussianConfig): super().__init__() self.config = config self.n_details = config.details_per_anchor self.backbone_dim = backbone_dim # Anchor → query expansion: each anchor generates M query vectors self.query_expand = nn.Sequential( nn.Linear(config.full_dim, config.detail_hidden), nn.GELU(), nn.Linear(config.detail_hidden, self.n_details * backbone_dim), ) # Cross-attention: detail queries attend to backbone features self.cross_attn = nn.MultiheadAttention( embed_dim=backbone_dim, num_heads=config.cross_attn_heads, batch_first=True, dropout=0.0, ) self.cross_norm = nn.LayerNorm(backbone_dim) # VQ index prediction self.vq_proj = nn.Sequential( nn.Linear(backbone_dim, backbone_dim), nn.GELU(), nn.Linear(backbone_dim, config.codebook_size), ) # Continuous position offset (relative to anchor) self.offset_proj = nn.Sequential( nn.Linear(backbone_dim, config.detail_hidden), nn.GELU(), nn.Linear(config.detail_hidden, 3), ) def forward(self, anchors: torch.Tensor, backbone_features: torch.Tensor) -> dict: """ Args: anchors: (B, K, 14) — anchor Gaussians backbone_features: (B, S, backbone_dim) — backbone hidden states Returns: dict with vq_logits, vq_indices, pos_offsets """ B, K, _ = anchors.shape # Expand each anchor into M query vectors queries = self.query_expand(anchors) # (B, K, M * D) queries = queries.view(B, K * self.n_details, self.backbone_dim) # (B, K*M, D) # Cross-attend to backbone detail_features, _ = self.cross_attn( queries, backbone_features, backbone_features ) # (B, K*M, D) detail_features = self.cross_norm(detail_features + queries) # Residual # Predict VQ indices vq_logits = self.vq_proj(detail_features) # (B, K*M, codebook_size) vq_indices = vq_logits.argmax(dim=-1) # (B, K*M) # Predict position offsets pos_offsets = self.offset_proj(detail_features) # (B, K*M, 3) # Scale offsets — details should be close to their anchor pos_offsets = torch.tanh(pos_offsets) * 0.5 # [-0.5, +0.5] around anchor return { 'vq_logits': vq_logits, 'vq_indices': vq_indices, 'pos_offsets': pos_offsets, 'detail_features': detail_features, } class GaussianSpecialist(nn.Module): """ The complete Gaussian specialist head for WYRM. Plugs into NexusRouter as specialist index N. Generates 3D Gaussian splat scenes from backbone hidden states. Two-stage generation: 1. Anchor Head → K coarse Gaussians (direct regression) 2. Detail Head → K*M fine Gaussians (VQ-coded) Depth profile integration: learned per-layer gates select which backbone layers contribute to structure vs detail. """ def __init__(self, backbone_dim: int, num_backbone_layers: int, config: GaussianConfig, vqvae: GaussianVQVAE = None): super().__init__() self.config = config self.backbone_dim = backbone_dim self.num_layers = num_backbone_layers # ── Sub-modules ── self.anchor_head = AnchorHead(backbone_dim, config) self.detail_head = DetailHead(backbone_dim, config) # ── VQ-VAE decoder (frozen — pre-trained in Phase 1) ── if vqvae is not None: self.vq_decoder = vqvae.decoder # Freeze VQ decoder for p in self.vq_decoder.parameters(): p.requires_grad = False # Also keep the quantizer's codebook for decoding self.register_buffer('vq_codebook', vqvae.quantizer.embed.clone()) else: self.vq_decoder = None self.vq_codebook = None # ── Depth Profile Gates ── # Learned: which backbone layers matter for anchors vs details self.anchor_layer_gate = nn.Parameter(torch.zeros(num_backbone_layers)) self.detail_layer_gate = nn.Parameter(torch.zeros(num_backbone_layers)) # ── Projection for pooling ── self.pool_proj = nn.Linear(backbone_dim, backbone_dim) def forward(self, layer_outputs: list[torch.Tensor]) -> dict: """ Args: layer_outputs: List of (B, S, D) tensors from each backbone layer. Returns: dict with: anchors: (B, K, 14) — anchor Gaussians details: (B, K*M, 14) — detail Gaussians (if VQ decoder available) all_gaussians: (B, K + K*M, 14) — concatenated scene vq_logits: (B, K*M, codebook_size) — for training loss pos_offsets: (B, K*M, 3) — detail position offsets """ B = layer_outputs[0].shape[0] # ── Depth-profiled aggregation for anchors ── anchor_weights = torch.softmax(self.anchor_layer_gate, dim=0) anchor_hidden = sum( w * h for w, h in zip(anchor_weights, layer_outputs) ) # (B, S, D) # Pool → single vector per batch pooled = self.pool_proj(anchor_hidden.mean(dim=1)) # (B, D) # ── Stage 1: Generate anchors ── anchors = self.anchor_head(pooled) # (B, K, 14) # ── Depth-profiled aggregation for details ── detail_weights = torch.softmax(self.detail_layer_gate, dim=0) detail_hidden = sum( w * h for w, h in zip(detail_weights, layer_outputs) ) # (B, S, D) # ── Stage 2: Generate details ── detail_out = self.detail_head(anchors, detail_hidden) result = { 'anchors': anchors, 'vq_logits': detail_out['vq_logits'], 'vq_indices': detail_out['vq_indices'], 'pos_offsets': detail_out['pos_offsets'], } # ── Decode VQ indices to full Gaussian params (if decoder available) ── if self.vq_decoder is not None and self.vq_codebook is not None: details = self._decode_details( anchors, detail_out['vq_indices'], detail_out['pos_offsets'] ) result['details'] = details result['all_gaussians'] = torch.cat([anchors, details], dim=1) return result def _decode_details(self, anchors: torch.Tensor, vq_indices: torch.Tensor, pos_offsets: torch.Tensor) -> torch.Tensor: """ Decode VQ indices + anchor positions → full detail Gaussians. Args: anchors: (B, K, 14) vq_indices: (B, K*M) pos_offsets: (B, K*M, 3) Returns: details: (B, K*M, 14) — fully parameterized detail Gaussians """ B, KM = vq_indices.shape K = self.config.num_anchors M = self.config.details_per_anchor # Decode VQ → (scale, rot, opacity, color) z_q = self.vq_codebook[vq_indices.view(-1)] # (B*K*M, codebook_dim) params = self.vq_decoder(z_q) # (B*K*M, param_dim=11) params = params.view(B, KM, self.config.param_dim) # Extract components scale = params[..., :3].clamp(self.config.min_gaussian_scale, self.config.max_gaussian_scale) rot = F.normalize(params[..., 3:7], dim=-1) opacity = torch.sigmoid(params[..., 7:8]) color = torch.sigmoid(params[..., 8:11]) # Compute world-space positions: anchor_pos + offset anchor_pos = anchors[:, :, :3] # (B, K, 3) # Repeat each anchor M times anchor_pos_expanded = anchor_pos.unsqueeze(2).expand(B, K, M, 3).reshape(B, KM, 3) world_pos = anchor_pos_expanded + pos_offsets return torch.cat([world_pos, scale, rot, opacity, color], dim=-1) def get_depth_profile(self) -> dict: """Return the learned depth profile weights for analysis.""" with torch.no_grad(): anchor_w = torch.softmax(self.anchor_layer_gate, dim=0) detail_w = torch.softmax(self.detail_layer_gate, dim=0) return { 'anchor_weights': anchor_w.cpu().tolist(), 'detail_weights': detail_w.cpu().tolist(), 'anchor_peak_layer': anchor_w.argmax().item(), 'detail_peak_layer': detail_w.argmax().item(), } def count_params(self) -> dict: """Count parameters by component.""" def count(module): return sum(p.numel() for p in module.parameters() if p.requires_grad) return { 'anchor_head': count(self.anchor_head), 'detail_head': count(self.detail_head), 'pool_proj': count(self.pool_proj), 'layer_gates': self.anchor_layer_gate.numel() + self.detail_layer_gate.numel(), 'total': sum(p.numel() for p in self.parameters() if p.requires_grad), }