| """ |
| GLADIUS β Gaussian Head Configuration |
| |
| All hyperparameters for the Gaussian specialist in one place. |
| """ |
|
|
| from dataclasses import dataclass |
|
|
|
|
| @dataclass |
| class GaussianConfig: |
| """Configuration for the Gaussian specialist head.""" |
|
|
| |
| num_anchors: int = 64 |
| details_per_anchor: int = 32 |
| max_gaussians: int = 2048 |
|
|
| |
| |
| |
| param_dim: int = 11 |
| pos_dim: int = 3 |
| full_dim: int = 14 |
|
|
| |
| codebook_size: int = 4096 |
| codebook_dim: int = 64 |
| commitment_weight: float = 0.25 |
| ema_decay: float = 0.99 |
| codebook_reset_threshold: int = 2 |
|
|
| |
| vqvae_hidden: int = 256 |
| vqvae_layers: int = 3 |
|
|
| |
| cross_attn_heads: int = 8 |
| anchor_hidden: int = 256 |
| detail_hidden: int = 256 |
|
|
| |
| render_size: int = 64 |
| render_views: int = 4 |
| ssim_weight: float = 0.2 |
| l1_weight: float = 0.8 |
|
|
| |
| vqvae_lr: float = 3e-4 |
| head_lr: float = 1e-4 |
| backbone_lr: float = 1e-5 |
| vqvae_steps: int = 50_000 |
| head_steps: int = 100_000 |
| joint_steps: int = 50_000 |
|
|
| |
| scene_scale: float = 2.0 |
| min_gaussian_scale: float = -6.0 |
| max_gaussian_scale: float = 0.0 |
|
|
| @property |
| def total_gaussians(self) -> int: |
| return self.num_anchors * self.details_per_anchor |
|
|
| def estimate_new_params(self, backbone_dim: int) -> dict: |
| """Estimate parameter count of the Gaussian specialist.""" |
| |
| anchor_mlp = (backbone_dim * self.anchor_hidden + |
| self.anchor_hidden * self.num_anchors * (self.pos_dim + self.param_dim)) |
|
|
| |
| cross_attn = 4 * backbone_dim * backbone_dim |
|
|
| |
| vq_logits = backbone_dim * self.codebook_size |
|
|
| |
| pos_offset = backbone_dim * self.pos_dim |
|
|
| |
| layer_gates = 48 |
|
|
| |
| vqvae_encoder = (self.param_dim * self.vqvae_hidden + |
| self.vqvae_hidden * self.codebook_dim) |
| vqvae_decoder = (self.codebook_dim * self.vqvae_hidden + |
| self.vqvae_hidden * self.param_dim) |
| vqvae_codebook = self.codebook_size * self.codebook_dim |
|
|
| total_trainable = anchor_mlp + cross_attn + vq_logits + pos_offset + layer_gates |
| total_vqvae = vqvae_encoder + vqvae_decoder + vqvae_codebook |
|
|
| return { |
| 'anchor_mlp': anchor_mlp, |
| 'cross_attention': cross_attn, |
| 'vq_logits': vq_logits, |
| 'pos_offset': pos_offset, |
| 'layer_gates': layer_gates, |
| 'total_trainable': total_trainable, |
| 'vqvae_total': total_vqvae, |
| 'total_all': total_trainable + total_vqvae, |
| } |
|
|