amuzetnoM's picture
WYRM kernel source (v27 FINAL)
9463e5c verified
"""
GLADIUS β€” Gaussian Specialist Head
The specialist module that plugs into WYRM's NexusRouter.
Generates 3D Gaussian Splat parameters from backbone hidden states.
Two-stage hierarchical generation:
Stage 1 (Anchors): Direct regression of coarse Gaussians from pooled hidden state
Stage 2 (Details): VQ-coded fine Gaussians via cross-attention to backbone features
Depth profile integration: learned per-layer gates determine which backbone
layers contribute to structure (anchors) vs detail.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from .config import GaussianConfig
from .vqvae import GaussianVQVAE, GaussianVQDecoder
class AnchorHead(nn.Module):
"""
Generates K coarse anchor Gaussians from the backbone's pooled representation.
Each anchor: position(3) + scale(3) + rotation(4) + opacity(1) + sh_dc(3) = 14 floats.
Uses direct regression β€” anchors need precise continuous placement.
"""
def __init__(self, backbone_dim: int, config: GaussianConfig):
super().__init__()
self.config = config
self.n_anchors = config.num_anchors
self.out_dim = config.full_dim # 14
self.net = nn.Sequential(
nn.Linear(backbone_dim, config.anchor_hidden),
nn.LayerNorm(config.anchor_hidden),
nn.GELU(),
nn.Linear(config.anchor_hidden, config.anchor_hidden),
nn.LayerNorm(config.anchor_hidden),
nn.GELU(),
nn.Linear(config.anchor_hidden, self.n_anchors * self.out_dim),
)
self._init_weights()
def _init_weights(self):
# Initialize last layer small β€” anchors start near origin
nn.init.normal_(self.net[-1].weight, std=0.01)
nn.init.zeros_(self.net[-1].bias)
def forward(self, pooled: torch.Tensor) -> torch.Tensor:
"""
Args:
pooled: (B, backbone_dim) β€” pooled backbone output
Returns:
anchors: (B, K, 14) β€” K anchor Gaussians with full params
"""
B = pooled.shape[0]
raw = self.net(pooled) # (B, K*14)
raw = raw.view(B, self.n_anchors, self.out_dim)
# Activate each component appropriately
anchors = self._activate(raw)
return anchors
def _activate(self, raw: torch.Tensor) -> torch.Tensor:
"""Apply per-component activations to raw output."""
pos = torch.tanh(raw[..., :3]) * self.config.scene_scale # [-scale, +scale]
scale = raw[..., 3:6].clamp(self.config.min_gaussian_scale,
self.config.max_gaussian_scale) # Log-scale clamped
rot = F.normalize(raw[..., 6:10], dim=-1) # Unit quaternion
opacity = torch.sigmoid(raw[..., 10:11]) # [0, 1]
color = torch.sigmoid(raw[..., 11:14]) # [0, 1] RGB
return torch.cat([pos, scale, rot, opacity, color], dim=-1)
class DetailHead(nn.Module):
"""
Generates M fine detail Gaussians per anchor using VQ tokens.
Cross-attends from anchor queries to backbone hidden states,
then predicts VQ codebook indices + continuous position offsets.
"""
def __init__(self, backbone_dim: int, config: GaussianConfig):
super().__init__()
self.config = config
self.n_details = config.details_per_anchor
self.backbone_dim = backbone_dim
# Anchor β†’ query expansion: each anchor generates M query vectors
self.query_expand = nn.Sequential(
nn.Linear(config.full_dim, config.detail_hidden),
nn.GELU(),
nn.Linear(config.detail_hidden, self.n_details * backbone_dim),
)
# Cross-attention: detail queries attend to backbone features
self.cross_attn = nn.MultiheadAttention(
embed_dim=backbone_dim,
num_heads=config.cross_attn_heads,
batch_first=True,
dropout=0.0,
)
self.cross_norm = nn.LayerNorm(backbone_dim)
# VQ index prediction
self.vq_proj = nn.Sequential(
nn.Linear(backbone_dim, backbone_dim),
nn.GELU(),
nn.Linear(backbone_dim, config.codebook_size),
)
# Continuous position offset (relative to anchor)
self.offset_proj = nn.Sequential(
nn.Linear(backbone_dim, config.detail_hidden),
nn.GELU(),
nn.Linear(config.detail_hidden, 3),
)
def forward(self, anchors: torch.Tensor, backbone_features: torch.Tensor) -> dict:
"""
Args:
anchors: (B, K, 14) β€” anchor Gaussians
backbone_features: (B, S, backbone_dim) β€” backbone hidden states
Returns:
dict with vq_logits, vq_indices, pos_offsets
"""
B, K, _ = anchors.shape
# Expand each anchor into M query vectors
queries = self.query_expand(anchors) # (B, K, M * D)
queries = queries.view(B, K * self.n_details, self.backbone_dim) # (B, K*M, D)
# Cross-attend to backbone
detail_features, _ = self.cross_attn(
queries, backbone_features, backbone_features
) # (B, K*M, D)
detail_features = self.cross_norm(detail_features + queries) # Residual
# Predict VQ indices
vq_logits = self.vq_proj(detail_features) # (B, K*M, codebook_size)
vq_indices = vq_logits.argmax(dim=-1) # (B, K*M)
# Predict position offsets
pos_offsets = self.offset_proj(detail_features) # (B, K*M, 3)
# Scale offsets β€” details should be close to their anchor
pos_offsets = torch.tanh(pos_offsets) * 0.5 # [-0.5, +0.5] around anchor
return {
'vq_logits': vq_logits,
'vq_indices': vq_indices,
'pos_offsets': pos_offsets,
'detail_features': detail_features,
}
class GaussianSpecialist(nn.Module):
"""
The complete Gaussian specialist head for WYRM.
Plugs into NexusRouter as specialist index N.
Generates 3D Gaussian splat scenes from backbone hidden states.
Two-stage generation:
1. Anchor Head β†’ K coarse Gaussians (direct regression)
2. Detail Head β†’ K*M fine Gaussians (VQ-coded)
Depth profile integration: learned per-layer gates select which
backbone layers contribute to structure vs detail.
"""
def __init__(self, backbone_dim: int, num_backbone_layers: int,
config: GaussianConfig, vqvae: GaussianVQVAE = None):
super().__init__()
self.config = config
self.backbone_dim = backbone_dim
self.num_layers = num_backbone_layers
# ── Sub-modules ──
self.anchor_head = AnchorHead(backbone_dim, config)
self.detail_head = DetailHead(backbone_dim, config)
# ── VQ-VAE decoder (frozen β€” pre-trained in Phase 1) ──
if vqvae is not None:
self.vq_decoder = vqvae.decoder
# Freeze VQ decoder
for p in self.vq_decoder.parameters():
p.requires_grad = False
# Also keep the quantizer's codebook for decoding
self.register_buffer('vq_codebook', vqvae.quantizer.embed.clone())
else:
self.vq_decoder = None
self.vq_codebook = None
# ── Depth Profile Gates ──
# Learned: which backbone layers matter for anchors vs details
self.anchor_layer_gate = nn.Parameter(torch.zeros(num_backbone_layers))
self.detail_layer_gate = nn.Parameter(torch.zeros(num_backbone_layers))
# ── Projection for pooling ──
self.pool_proj = nn.Linear(backbone_dim, backbone_dim)
def forward(self, layer_outputs: list[torch.Tensor]) -> dict:
"""
Args:
layer_outputs: List of (B, S, D) tensors from each backbone layer.
Returns:
dict with:
anchors: (B, K, 14) β€” anchor Gaussians
details: (B, K*M, 14) β€” detail Gaussians (if VQ decoder available)
all_gaussians: (B, K + K*M, 14) β€” concatenated scene
vq_logits: (B, K*M, codebook_size) β€” for training loss
pos_offsets: (B, K*M, 3) β€” detail position offsets
"""
B = layer_outputs[0].shape[0]
# ── Depth-profiled aggregation for anchors ──
anchor_weights = torch.softmax(self.anchor_layer_gate, dim=0)
anchor_hidden = sum(
w * h for w, h in zip(anchor_weights, layer_outputs)
) # (B, S, D)
# Pool β†’ single vector per batch
pooled = self.pool_proj(anchor_hidden.mean(dim=1)) # (B, D)
# ── Stage 1: Generate anchors ──
anchors = self.anchor_head(pooled) # (B, K, 14)
# ── Depth-profiled aggregation for details ──
detail_weights = torch.softmax(self.detail_layer_gate, dim=0)
detail_hidden = sum(
w * h for w, h in zip(detail_weights, layer_outputs)
) # (B, S, D)
# ── Stage 2: Generate details ──
detail_out = self.detail_head(anchors, detail_hidden)
result = {
'anchors': anchors,
'vq_logits': detail_out['vq_logits'],
'vq_indices': detail_out['vq_indices'],
'pos_offsets': detail_out['pos_offsets'],
}
# ── Decode VQ indices to full Gaussian params (if decoder available) ──
if self.vq_decoder is not None and self.vq_codebook is not None:
details = self._decode_details(
anchors, detail_out['vq_indices'], detail_out['pos_offsets']
)
result['details'] = details
result['all_gaussians'] = torch.cat([anchors, details], dim=1)
return result
def _decode_details(self, anchors: torch.Tensor, vq_indices: torch.Tensor,
pos_offsets: torch.Tensor) -> torch.Tensor:
"""
Decode VQ indices + anchor positions β†’ full detail Gaussians.
Args:
anchors: (B, K, 14)
vq_indices: (B, K*M)
pos_offsets: (B, K*M, 3)
Returns:
details: (B, K*M, 14) β€” fully parameterized detail Gaussians
"""
B, KM = vq_indices.shape
K = self.config.num_anchors
M = self.config.details_per_anchor
# Decode VQ β†’ (scale, rot, opacity, color)
z_q = self.vq_codebook[vq_indices.view(-1)] # (B*K*M, codebook_dim)
params = self.vq_decoder(z_q) # (B*K*M, param_dim=11)
params = params.view(B, KM, self.config.param_dim)
# Extract components
scale = params[..., :3].clamp(self.config.min_gaussian_scale,
self.config.max_gaussian_scale)
rot = F.normalize(params[..., 3:7], dim=-1)
opacity = torch.sigmoid(params[..., 7:8])
color = torch.sigmoid(params[..., 8:11])
# Compute world-space positions: anchor_pos + offset
anchor_pos = anchors[:, :, :3] # (B, K, 3)
# Repeat each anchor M times
anchor_pos_expanded = anchor_pos.unsqueeze(2).expand(B, K, M, 3).reshape(B, KM, 3)
world_pos = anchor_pos_expanded + pos_offsets
return torch.cat([world_pos, scale, rot, opacity, color], dim=-1)
def get_depth_profile(self) -> dict:
"""Return the learned depth profile weights for analysis."""
with torch.no_grad():
anchor_w = torch.softmax(self.anchor_layer_gate, dim=0)
detail_w = torch.softmax(self.detail_layer_gate, dim=0)
return {
'anchor_weights': anchor_w.cpu().tolist(),
'detail_weights': detail_w.cpu().tolist(),
'anchor_peak_layer': anchor_w.argmax().item(),
'detail_peak_layer': detail_w.argmax().item(),
}
def count_params(self) -> dict:
"""Count parameters by component."""
def count(module):
return sum(p.numel() for p in module.parameters() if p.requires_grad)
return {
'anchor_head': count(self.anchor_head),
'detail_head': count(self.detail_head),
'pool_proj': count(self.pool_proj),
'layer_gates': self.anchor_layer_gate.numel() + self.detail_layer_gate.numel(),
'total': sum(p.numel() for p in self.parameters() if p.requires_grad),
}