""" GLADIUS — Gaussian Head Data Procedural Gaussian scene generation for training the specialist head. No external datasets needed — generates ground truth Gaussians from geometric primitives (spheres, cubes, multi-object scenes). Also includes a differentiable 2D renderer for computing training loss. The renderer projects 3D Gaussians to 2D and alpha-composites them. """ import torch import torch.nn.functional as F from torch.utils.data import Dataset import math import random def fibonacci_sphere(n: int, device: str = 'cpu') -> torch.Tensor: """ Generate n points uniformly distributed on a unit sphere. Fibonacci/golden spiral method. Returns: (n, 3) — unit vectors on sphere surface """ golden_ratio = (1 + math.sqrt(5)) / 2 indices = torch.arange(n, dtype=torch.float32, device=device) theta = 2 * math.pi * indices / golden_ratio phi = torch.acos(1 - 2 * (indices + 0.5) / n) x = torch.sin(phi) * torch.cos(theta) y = torch.sin(phi) * torch.sin(theta) z = torch.cos(phi) return torch.stack([x, y, z], dim=-1) def random_quaternion(n: int, device: str = 'cpu') -> torch.Tensor: """Generate n random unit quaternions (uniform on SO(3)).""" u = torch.rand(n, 3, device=device) q = torch.stack([ torch.sqrt(1 - u[:, 0]) * torch.sin(2 * math.pi * u[:, 1]), torch.sqrt(1 - u[:, 0]) * torch.cos(2 * math.pi * u[:, 1]), torch.sqrt(u[:, 0]) * torch.sin(2 * math.pi * u[:, 2]), torch.sqrt(u[:, 0]) * torch.cos(2 * math.pi * u[:, 2]), ], dim=-1) return q def quaternion_from_direction(normals: torch.Tensor) -> torch.Tensor: """ Compute quaternion that aligns Z-axis with given normal vectors. normals: (N, 3) — unit vectors Returns: (N, 4) — unit quaternions """ N = normals.shape[0] z = torch.tensor([0.0, 0.0, 1.0], device=normals.device).expand(N, 3) # Cross product z × normal cross = torch.cross(z, normals, dim=-1) dot = (z * normals).sum(dim=-1, keepdim=True) # Quaternion: q = [cross, 1 + dot] then normalize q = torch.cat([cross, 1.0 + dot], dim=-1) # Handle anti-parallel case (dot ≈ -1) anti = dot.squeeze(-1) < -0.999 if anti.any(): # Use perpendicular vector perp = torch.zeros_like(normals[anti]) perp[:, 0] = 1.0 # arbitrary perpendicular q[anti] = torch.cat([perp, torch.zeros(anti.sum(), 1, device=normals.device)], dim=-1) return F.normalize(q, dim=-1) def generate_sphere_gaussians(center: torch.Tensor, radius: float, color: torch.Tensor, n_gaussians: int, device: str = 'cpu') -> torch.Tensor: """ Generate Gaussians for a sphere surface. Returns: (n_gaussians, 14) — full Gaussian params [pos(3), scale(3), rot(4), opacity(1), color(3)] """ # Points on sphere surface normals = fibonacci_sphere(n_gaussians, device) positions = center.unsqueeze(0) + radius * normals # Scale: flat discs covering the surface splat_radius = radius * math.sqrt(4 * math.pi / n_gaussians) * 0.5 log_scale = math.log(max(splat_radius, 1e-6)) scales = torch.full((n_gaussians, 3), log_scale, device=device) scales[:, 2] = log_scale - 1.5 # Thin in normal direction # Rotation: align with surface normal rotations = quaternion_from_direction(normals) # Opacity (pre-sigmoid): high opacity opacity = torch.full((n_gaussians, 1), 2.0, device=device) # sigmoid(2) ≈ 0.88 # Color: uniform with slight variation colors = color.unsqueeze(0).expand(n_gaussians, 3).clone() colors += torch.randn_like(colors) * 0.02 # Slight variation colors = colors.clamp(0, 1) return torch.cat([positions, scales, rotations, opacity, colors], dim=-1) def generate_cube_gaussians(center: torch.Tensor, size: float, color: torch.Tensor, n_gaussians: int, device: str = 'cpu') -> torch.Tensor: """Generate Gaussians for a cube surface.""" per_face = n_gaussians // 6 remainder = n_gaussians - per_face * 6 half = size / 2 all_gaussians = [] # 6 faces: +X, -X, +Y, -Y, +Z, -Z face_normals = torch.tensor([ [1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1] ], dtype=torch.float32, device=device) for i, normal in enumerate(face_normals): n = per_face + (1 if i < remainder else 0) if n == 0: continue # Random points on face uv = torch.rand(n, 2, device=device) * size - half # Face center offset face_center = center + normal * half # Map UV to 3D based on face normal axis abs_normal = normal.abs() if abs_normal[0] > 0.5: # X face positions = torch.stack([ torch.full((n,), face_center[0].item(), device=device), uv[:, 0] + center[1], uv[:, 1] + center[2], ], dim=-1) elif abs_normal[1] > 0.5: # Y face positions = torch.stack([ uv[:, 0] + center[0], torch.full((n,), face_center[1].item(), device=device), uv[:, 1] + center[2], ], dim=-1) else: # Z face positions = torch.stack([ uv[:, 0] + center[0], uv[:, 1] + center[1], torch.full((n,), face_center[2].item(), device=device), ], dim=-1) splat_radius = size / math.sqrt(per_face) * 0.6 log_scale = math.log(max(splat_radius, 1e-6)) scales = torch.full((n, 3), log_scale, device=device) scales[:, 2] = log_scale - 1.5 rotations = quaternion_from_direction(normal.unsqueeze(0).expand(n, 3)) opacity = torch.full((n, 1), 2.0, device=device) colors = color.unsqueeze(0).expand(n, 3).clone() colors += torch.randn_like(colors) * 0.02 colors = colors.clamp(0, 1) all_gaussians.append(torch.cat([positions, scales, rotations, opacity, colors], dim=-1)) return torch.cat(all_gaussians, dim=0) def generate_scene(n_objects: int = None, n_gaussians_per_object: int = 128, device: str = 'cpu', scene_scale: float = 2.0) -> dict: """ Generate a random procedural scene with Gaussian splats. Returns: dict with: gaussians: (N, 14) — all Gaussians in the scene description: str — text description of the scene objects: list of dicts — object metadata """ if n_objects is None: n_objects = random.randint(1, 5) all_gaussians = [] objects = [] descriptions = [] color_names = { 'red': torch.tensor([0.9, 0.15, 0.1]), 'green': torch.tensor([0.1, 0.85, 0.2]), 'blue': torch.tensor([0.1, 0.2, 0.9]), 'yellow': torch.tensor([0.95, 0.9, 0.1]), 'purple': torch.tensor([0.7, 0.1, 0.8]), 'orange': torch.tensor([0.95, 0.5, 0.05]), 'white': torch.tensor([0.9, 0.9, 0.9]), 'cyan': torch.tensor([0.1, 0.9, 0.9]), } color_list = list(color_names.items()) shape_types = ['sphere', 'cube'] for i in range(n_objects): shape = random.choice(shape_types) color_name, color_val = random.choice(color_list) color_val = color_val.to(device) # Random position within scene bounds center = (torch.rand(3, device=device) - 0.5) * scene_scale * 1.2 size = random.uniform(0.3, 0.8) if shape == 'sphere': gs = generate_sphere_gaussians(center, size, color_val, n_gaussians_per_object, device) desc = f"a {color_name} sphere" else: gs = generate_cube_gaussians(center, size, color_val, n_gaussians_per_object, device) desc = f"a {color_name} cube" all_gaussians.append(gs) objects.append({'shape': shape, 'color': color_name, 'center': center.cpu(), 'size': size}) descriptions.append(desc) gaussians = torch.cat(all_gaussians, dim=0) description = " and ".join(descriptions) return { 'gaussians': gaussians, 'description': description, 'objects': objects, 'n_gaussians': gaussians.shape[0], } class ProceduralGaussianDataset(Dataset): """ On-the-fly procedural Gaussian scene dataset. Generates random scenes with spheres and cubes. Each sample: (gaussians, description_text) For VQ-VAE Phase 1: use .get_gaussian_params() to get pooled params. For Head Phase 2: use __getitem__ for (text, gaussians) pairs. """ def __init__(self, size: int = 10000, gaussians_per_object: int = 128, max_objects: int = 5, scene_scale: float = 2.0, device: str = 'cpu', fixed_seed: bool = False): self.size = size self.gaussians_per_object = gaussians_per_object self.max_objects = max_objects self.scene_scale = scene_scale self.device = device self.fixed_seed = fixed_seed def __len__(self) -> int: return self.size def __getitem__(self, idx: int) -> dict: if self.fixed_seed: # Deterministic for eval torch.manual_seed(idx) random.seed(idx) scene = generate_scene( n_objects=random.randint(1, self.max_objects), n_gaussians_per_object=self.gaussians_per_object, device=self.device, scene_scale=self.scene_scale, ) return scene def get_gaussian_params(self, n_samples: int = 100000) -> torch.Tensor: """ Get a flat pool of Gaussian non-position params for VQ-VAE training. Returns: (N, 11) — scale(3), rot(4), opacity(1), color(3) """ all_params = [] collected = 0 while collected < n_samples: scene = generate_scene( n_objects=random.randint(1, self.max_objects), n_gaussians_per_object=self.gaussians_per_object, device='cpu', scene_scale=self.scene_scale, ) # Extract non-position params: [3:14] = scale(3) + rot(4) + opacity(1) + color(3) params = scene['gaussians'][:, 3:] # (N, 11) all_params.append(params) collected += params.shape[0] return torch.cat(all_params, dim=0)[:n_samples] # ═══ Differentiable 2D Renderer ═══ def render_gaussians_2d(gaussians: torch.Tensor, camera_pos: torch.Tensor, camera_target: torch.Tensor, image_size: int = 64, fov: float = 60.0, bg_color: torch.Tensor = None) -> torch.Tensor: """ Simple differentiable Gaussian splatting renderer. Projects 3D Gaussians to 2D, evaluates 2D Gaussian at each pixel, alpha-composites front-to-back (depth-sorted). This is NOT a full 3DGS renderer — it's a training-grade approximation that provides useful gradients. Good enough for learning, not for display. Args: gaussians: (N, 14) — [pos(3), scale(3), rot(4), opacity(1), color(3)] camera_pos: (3,) — camera position in world space camera_target: (3,) — camera look-at point image_size: H = W of output image fov: vertical field of view in degrees bg_color: (3,) — background color, default black Returns: image: (3, H, W) — rendered RGB image [0, 1] """ device = gaussians.device N = gaussians.shape[0] H = W = image_size if bg_color is None: bg_color = torch.zeros(3, device=device) # Parse Gaussian params pos = gaussians[:, :3] # (N, 3) scale = gaussians[:, 3:6].exp() # (N, 3) — from log-scale opacity = torch.sigmoid(gaussians[:, 10:11]) # (N, 1) color = gaussians[:, 11:14] # (N, 3) — already [0,1] from sigmoid in head # ── Camera setup ── forward = F.normalize(camera_target - camera_pos, dim=0) world_up = torch.tensor([0.0, 1.0, 0.0], device=device) right = F.normalize(torch.cross(forward, world_up), dim=0) up = torch.cross(right, forward) # View matrix (world → camera) R = torch.stack([right, up, -forward], dim=0) # (3, 3) t = -R @ camera_pos # (3,) # Project Gaussians to camera space cam_pos = (R @ pos.t()).t() + t.unsqueeze(0) # (N, 3) # Perspective projection fov_rad = fov * math.pi / 180 focal = 0.5 * H / math.tan(fov_rad / 2) # Filter: only Gaussians in front of camera valid = cam_pos[:, 2] < -0.1 # Z points backward in our convention if valid.sum() == 0: return bg_color.view(3, 1, 1).expand(3, H, W) cam_pos_valid = cam_pos[valid] color_valid = color[valid] opacity_valid = opacity[valid] scale_valid = scale[valid] # Project to pixel coords px = -focal * cam_pos_valid[:, 0] / cam_pos_valid[:, 2] + W / 2 # (M,) py = -focal * cam_pos_valid[:, 1] / cam_pos_valid[:, 2] + H / 2 # Sort by depth (front to back) depths = -cam_pos_valid[:, 2] # Positive = further sort_idx = depths.argsort() px = px[sort_idx] py = py[sort_idx] color_sorted = color_valid[sort_idx] opacity_sorted = opacity_valid[sort_idx] scale_sorted = scale_valid[sort_idx] depth_sorted = depths[sort_idx] # Projected 2D scale (approximate) # Use average of X,Y scale projected by focal/depth s2d = (scale_sorted[:, :2].mean(dim=-1) * focal / depth_sorted).clamp(min=0.5) # (M,) # ── Rasterize via 2D Gaussian evaluation ── # Create pixel grid yy, xx = torch.meshgrid( torch.arange(H, dtype=torch.float32, device=device), torch.arange(W, dtype=torch.float32, device=device), indexing='ij' ) # (H, W) each # Initialize image image = bg_color.view(3, 1, 1).expand(3, H, W).clone() transmittance = torch.ones(H, W, device=device) # Splat each Gaussian (front to back alpha compositing) M = px.shape[0] # Batch for efficiency: process in chunks chunk_size = min(M, 512) for start in range(0, M, chunk_size): end = min(start + chunk_size, M) chunk_px = px[start:end] # (C,) chunk_py = py[start:end] chunk_s = s2d[start:end] # (C,) chunk_color = color_sorted[start:end] # (C, 3) chunk_alpha = opacity_sorted[start:end] # (C, 1) # Distance from each pixel to each Gaussian center dx = xx.unsqueeze(0) - chunk_px.view(-1, 1, 1) # (C, H, W) dy = yy.unsqueeze(0) - chunk_py.view(-1, 1, 1) # 2D Gaussian evaluation inv_s2 = 1.0 / (chunk_s.view(-1, 1, 1) ** 2 + 1e-6) power = -0.5 * (dx ** 2 + dy ** 2) * inv_s2 # (C, H, W) # Gaussian contribution gauss = torch.exp(power.clamp(min=-8)) # Clamp for numerical stability alpha = chunk_alpha.view(-1, 1, 1) * gauss # (C, H, W) # Alpha composite per Gaussian in chunk (sequential — maintains depth order) for i in range(end - start): a = alpha[i] * transmittance # (H, W) c = chunk_color[i] # (3,) image += c.view(3, 1, 1) * a.unsqueeze(0) # (3, H, W) transmittance = transmittance * (1 - alpha[i]) return image.clamp(0, 1) def multi_view_render(gaussians: torch.Tensor, n_views: int = 4, image_size: int = 64, orbit_radius: float = 4.0, fov: float = 60.0) -> torch.Tensor: """ Render Gaussians from multiple viewpoints on an orbit. Args: gaussians: (N, 14) n_views: number of views image_size: H = W orbit_radius: camera distance from origin fov: field of view Returns: images: (n_views, 3, H, W) """ device = gaussians.device images = [] target = torch.zeros(3, device=device) for i in range(n_views): angle = 2 * math.pi * i / n_views # Camera on orbit (XZ plane, Y slightly elevated) cam_pos = torch.tensor([ orbit_radius * math.cos(angle), orbit_radius * 0.4, # Slight elevation orbit_radius * math.sin(angle), ], device=device) img = render_gaussians_2d(gaussians, cam_pos, target, image_size=image_size, fov=fov) images.append(img) return torch.stack(images, dim=0)