| """ |
| GLADIUS β Gaussian Specialist Head |
| |
| The specialist module that plugs into WYRM's NexusRouter. |
| Generates 3D Gaussian Splat parameters from backbone hidden states. |
| |
| Two-stage hierarchical generation: |
| Stage 1 (Anchors): Direct regression of coarse Gaussians from pooled hidden state |
| Stage 2 (Details): VQ-coded fine Gaussians via cross-attention to backbone features |
| |
| Depth profile integration: learned per-layer gates determine which backbone |
| layers contribute to structure (anchors) vs detail. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
|
|
| from .config import GaussianConfig |
| from .vqvae import GaussianVQVAE, GaussianVQDecoder |
|
|
|
|
| class AnchorHead(nn.Module): |
| """ |
| Generates K coarse anchor Gaussians from the backbone's pooled representation. |
| |
| Each anchor: position(3) + scale(3) + rotation(4) + opacity(1) + sh_dc(3) = 14 floats. |
| Uses direct regression β anchors need precise continuous placement. |
| """ |
|
|
| def __init__(self, backbone_dim: int, config: GaussianConfig): |
| super().__init__() |
| self.config = config |
| self.n_anchors = config.num_anchors |
| self.out_dim = config.full_dim |
|
|
| self.net = nn.Sequential( |
| nn.Linear(backbone_dim, config.anchor_hidden), |
| nn.LayerNorm(config.anchor_hidden), |
| nn.GELU(), |
| nn.Linear(config.anchor_hidden, config.anchor_hidden), |
| nn.LayerNorm(config.anchor_hidden), |
| nn.GELU(), |
| nn.Linear(config.anchor_hidden, self.n_anchors * self.out_dim), |
| ) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| |
| nn.init.normal_(self.net[-1].weight, std=0.01) |
| nn.init.zeros_(self.net[-1].bias) |
|
|
| def forward(self, pooled: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| pooled: (B, backbone_dim) β pooled backbone output |
| |
| Returns: |
| anchors: (B, K, 14) β K anchor Gaussians with full params |
| """ |
| B = pooled.shape[0] |
| raw = self.net(pooled) |
| raw = raw.view(B, self.n_anchors, self.out_dim) |
|
|
| |
| anchors = self._activate(raw) |
| return anchors |
|
|
| def _activate(self, raw: torch.Tensor) -> torch.Tensor: |
| """Apply per-component activations to raw output.""" |
| pos = torch.tanh(raw[..., :3]) * self.config.scene_scale |
| scale = raw[..., 3:6].clamp(self.config.min_gaussian_scale, |
| self.config.max_gaussian_scale) |
| rot = F.normalize(raw[..., 6:10], dim=-1) |
| opacity = torch.sigmoid(raw[..., 10:11]) |
| color = torch.sigmoid(raw[..., 11:14]) |
|
|
| return torch.cat([pos, scale, rot, opacity, color], dim=-1) |
|
|
|
|
| class DetailHead(nn.Module): |
| """ |
| Generates M fine detail Gaussians per anchor using VQ tokens. |
| |
| Cross-attends from anchor queries to backbone hidden states, |
| then predicts VQ codebook indices + continuous position offsets. |
| """ |
|
|
| def __init__(self, backbone_dim: int, config: GaussianConfig): |
| super().__init__() |
| self.config = config |
| self.n_details = config.details_per_anchor |
| self.backbone_dim = backbone_dim |
|
|
| |
| self.query_expand = nn.Sequential( |
| nn.Linear(config.full_dim, config.detail_hidden), |
| nn.GELU(), |
| nn.Linear(config.detail_hidden, self.n_details * backbone_dim), |
| ) |
|
|
| |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim=backbone_dim, |
| num_heads=config.cross_attn_heads, |
| batch_first=True, |
| dropout=0.0, |
| ) |
| self.cross_norm = nn.LayerNorm(backbone_dim) |
|
|
| |
| self.vq_proj = nn.Sequential( |
| nn.Linear(backbone_dim, backbone_dim), |
| nn.GELU(), |
| nn.Linear(backbone_dim, config.codebook_size), |
| ) |
|
|
| |
| self.offset_proj = nn.Sequential( |
| nn.Linear(backbone_dim, config.detail_hidden), |
| nn.GELU(), |
| nn.Linear(config.detail_hidden, 3), |
| ) |
|
|
| def forward(self, anchors: torch.Tensor, backbone_features: torch.Tensor) -> dict: |
| """ |
| Args: |
| anchors: (B, K, 14) β anchor Gaussians |
| backbone_features: (B, S, backbone_dim) β backbone hidden states |
| |
| Returns: |
| dict with vq_logits, vq_indices, pos_offsets |
| """ |
| B, K, _ = anchors.shape |
|
|
| |
| queries = self.query_expand(anchors) |
| queries = queries.view(B, K * self.n_details, self.backbone_dim) |
|
|
| |
| detail_features, _ = self.cross_attn( |
| queries, backbone_features, backbone_features |
| ) |
| detail_features = self.cross_norm(detail_features + queries) |
|
|
| |
| vq_logits = self.vq_proj(detail_features) |
| vq_indices = vq_logits.argmax(dim=-1) |
|
|
| |
| pos_offsets = self.offset_proj(detail_features) |
| |
| pos_offsets = torch.tanh(pos_offsets) * 0.5 |
|
|
| return { |
| 'vq_logits': vq_logits, |
| 'vq_indices': vq_indices, |
| 'pos_offsets': pos_offsets, |
| 'detail_features': detail_features, |
| } |
|
|
|
|
| class GaussianSpecialist(nn.Module): |
| """ |
| The complete Gaussian specialist head for WYRM. |
| |
| Plugs into NexusRouter as specialist index N. |
| Generates 3D Gaussian splat scenes from backbone hidden states. |
| |
| Two-stage generation: |
| 1. Anchor Head β K coarse Gaussians (direct regression) |
| 2. Detail Head β K*M fine Gaussians (VQ-coded) |
| |
| Depth profile integration: learned per-layer gates select which |
| backbone layers contribute to structure vs detail. |
| """ |
|
|
| def __init__(self, backbone_dim: int, num_backbone_layers: int, |
| config: GaussianConfig, vqvae: GaussianVQVAE = None): |
| super().__init__() |
| self.config = config |
| self.backbone_dim = backbone_dim |
| self.num_layers = num_backbone_layers |
|
|
| |
| self.anchor_head = AnchorHead(backbone_dim, config) |
| self.detail_head = DetailHead(backbone_dim, config) |
|
|
| |
| if vqvae is not None: |
| self.vq_decoder = vqvae.decoder |
| |
| for p in self.vq_decoder.parameters(): |
| p.requires_grad = False |
| |
| self.register_buffer('vq_codebook', vqvae.quantizer.embed.clone()) |
| else: |
| self.vq_decoder = None |
| self.vq_codebook = None |
|
|
| |
| |
| self.anchor_layer_gate = nn.Parameter(torch.zeros(num_backbone_layers)) |
| self.detail_layer_gate = nn.Parameter(torch.zeros(num_backbone_layers)) |
|
|
| |
| self.pool_proj = nn.Linear(backbone_dim, backbone_dim) |
|
|
| def forward(self, layer_outputs: list[torch.Tensor]) -> dict: |
| """ |
| Args: |
| layer_outputs: List of (B, S, D) tensors from each backbone layer. |
| |
| Returns: |
| dict with: |
| anchors: (B, K, 14) β anchor Gaussians |
| details: (B, K*M, 14) β detail Gaussians (if VQ decoder available) |
| all_gaussians: (B, K + K*M, 14) β concatenated scene |
| vq_logits: (B, K*M, codebook_size) β for training loss |
| pos_offsets: (B, K*M, 3) β detail position offsets |
| """ |
| B = layer_outputs[0].shape[0] |
|
|
| |
| anchor_weights = torch.softmax(self.anchor_layer_gate, dim=0) |
| anchor_hidden = sum( |
| w * h for w, h in zip(anchor_weights, layer_outputs) |
| ) |
|
|
| |
| pooled = self.pool_proj(anchor_hidden.mean(dim=1)) |
|
|
| |
| anchors = self.anchor_head(pooled) |
|
|
| |
| detail_weights = torch.softmax(self.detail_layer_gate, dim=0) |
| detail_hidden = sum( |
| w * h for w, h in zip(detail_weights, layer_outputs) |
| ) |
|
|
| |
| detail_out = self.detail_head(anchors, detail_hidden) |
|
|
| result = { |
| 'anchors': anchors, |
| 'vq_logits': detail_out['vq_logits'], |
| 'vq_indices': detail_out['vq_indices'], |
| 'pos_offsets': detail_out['pos_offsets'], |
| } |
|
|
| |
| if self.vq_decoder is not None and self.vq_codebook is not None: |
| details = self._decode_details( |
| anchors, detail_out['vq_indices'], detail_out['pos_offsets'] |
| ) |
| result['details'] = details |
| result['all_gaussians'] = torch.cat([anchors, details], dim=1) |
|
|
| return result |
|
|
| def _decode_details(self, anchors: torch.Tensor, vq_indices: torch.Tensor, |
| pos_offsets: torch.Tensor) -> torch.Tensor: |
| """ |
| Decode VQ indices + anchor positions β full detail Gaussians. |
| |
| Args: |
| anchors: (B, K, 14) |
| vq_indices: (B, K*M) |
| pos_offsets: (B, K*M, 3) |
| |
| Returns: |
| details: (B, K*M, 14) β fully parameterized detail Gaussians |
| """ |
| B, KM = vq_indices.shape |
| K = self.config.num_anchors |
| M = self.config.details_per_anchor |
|
|
| |
| z_q = self.vq_codebook[vq_indices.view(-1)] |
| params = self.vq_decoder(z_q) |
| params = params.view(B, KM, self.config.param_dim) |
|
|
| |
| scale = params[..., :3].clamp(self.config.min_gaussian_scale, |
| self.config.max_gaussian_scale) |
| rot = F.normalize(params[..., 3:7], dim=-1) |
| opacity = torch.sigmoid(params[..., 7:8]) |
| color = torch.sigmoid(params[..., 8:11]) |
|
|
| |
| anchor_pos = anchors[:, :, :3] |
| |
| anchor_pos_expanded = anchor_pos.unsqueeze(2).expand(B, K, M, 3).reshape(B, KM, 3) |
| world_pos = anchor_pos_expanded + pos_offsets |
|
|
| return torch.cat([world_pos, scale, rot, opacity, color], dim=-1) |
|
|
| def get_depth_profile(self) -> dict: |
| """Return the learned depth profile weights for analysis.""" |
| with torch.no_grad(): |
| anchor_w = torch.softmax(self.anchor_layer_gate, dim=0) |
| detail_w = torch.softmax(self.detail_layer_gate, dim=0) |
| return { |
| 'anchor_weights': anchor_w.cpu().tolist(), |
| 'detail_weights': detail_w.cpu().tolist(), |
| 'anchor_peak_layer': anchor_w.argmax().item(), |
| 'detail_peak_layer': detail_w.argmax().item(), |
| } |
|
|
| def count_params(self) -> dict: |
| """Count parameters by component.""" |
| def count(module): |
| return sum(p.numel() for p in module.parameters() if p.requires_grad) |
|
|
| return { |
| 'anchor_head': count(self.anchor_head), |
| 'detail_head': count(self.detail_head), |
| 'pool_proj': count(self.pool_proj), |
| 'layer_gates': self.anchor_layer_gate.numel() + self.detail_layer_gate.numel(), |
| 'total': sum(p.numel() for p in self.parameters() if p.requires_grad), |
| } |
|
|