| """ |
| GLADIUS β Gaussian Head Data |
| |
| Procedural Gaussian scene generation for training the specialist head. |
| No external datasets needed β generates ground truth Gaussians from |
| geometric primitives (spheres, cubes, multi-object scenes). |
| |
| Also includes a differentiable 2D renderer for computing training loss. |
| The renderer projects 3D Gaussians to 2D and alpha-composites them. |
| """ |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset |
| import math |
| import random |
|
|
|
|
| def fibonacci_sphere(n: int, device: str = 'cpu') -> torch.Tensor: |
| """ |
| Generate n points uniformly distributed on a unit sphere. |
| Fibonacci/golden spiral method. |
| |
| Returns: (n, 3) β unit vectors on sphere surface |
| """ |
| golden_ratio = (1 + math.sqrt(5)) / 2 |
| indices = torch.arange(n, dtype=torch.float32, device=device) |
|
|
| theta = 2 * math.pi * indices / golden_ratio |
| phi = torch.acos(1 - 2 * (indices + 0.5) / n) |
|
|
| x = torch.sin(phi) * torch.cos(theta) |
| y = torch.sin(phi) * torch.sin(theta) |
| z = torch.cos(phi) |
|
|
| return torch.stack([x, y, z], dim=-1) |
|
|
|
|
| def random_quaternion(n: int, device: str = 'cpu') -> torch.Tensor: |
| """Generate n random unit quaternions (uniform on SO(3)).""" |
| u = torch.rand(n, 3, device=device) |
| q = torch.stack([ |
| torch.sqrt(1 - u[:, 0]) * torch.sin(2 * math.pi * u[:, 1]), |
| torch.sqrt(1 - u[:, 0]) * torch.cos(2 * math.pi * u[:, 1]), |
| torch.sqrt(u[:, 0]) * torch.sin(2 * math.pi * u[:, 2]), |
| torch.sqrt(u[:, 0]) * torch.cos(2 * math.pi * u[:, 2]), |
| ], dim=-1) |
| return q |
|
|
|
|
| def quaternion_from_direction(normals: torch.Tensor) -> torch.Tensor: |
| """ |
| Compute quaternion that aligns Z-axis with given normal vectors. |
| normals: (N, 3) β unit vectors |
| |
| Returns: (N, 4) β unit quaternions |
| """ |
| N = normals.shape[0] |
| z = torch.tensor([0.0, 0.0, 1.0], device=normals.device).expand(N, 3) |
|
|
| |
| cross = torch.cross(z, normals, dim=-1) |
| dot = (z * normals).sum(dim=-1, keepdim=True) |
|
|
| |
| q = torch.cat([cross, 1.0 + dot], dim=-1) |
|
|
| |
| anti = dot.squeeze(-1) < -0.999 |
| if anti.any(): |
| |
| perp = torch.zeros_like(normals[anti]) |
| perp[:, 0] = 1.0 |
| q[anti] = torch.cat([perp, torch.zeros(anti.sum(), 1, device=normals.device)], dim=-1) |
|
|
| return F.normalize(q, dim=-1) |
|
|
|
|
| def generate_sphere_gaussians(center: torch.Tensor, radius: float, |
| color: torch.Tensor, n_gaussians: int, |
| device: str = 'cpu') -> torch.Tensor: |
| """ |
| Generate Gaussians for a sphere surface. |
| |
| Returns: (n_gaussians, 14) β full Gaussian params |
| [pos(3), scale(3), rot(4), opacity(1), color(3)] |
| """ |
| |
| normals = fibonacci_sphere(n_gaussians, device) |
| positions = center.unsqueeze(0) + radius * normals |
|
|
| |
| splat_radius = radius * math.sqrt(4 * math.pi / n_gaussians) * 0.5 |
| log_scale = math.log(max(splat_radius, 1e-6)) |
| scales = torch.full((n_gaussians, 3), log_scale, device=device) |
| scales[:, 2] = log_scale - 1.5 |
|
|
| |
| rotations = quaternion_from_direction(normals) |
|
|
| |
| opacity = torch.full((n_gaussians, 1), 2.0, device=device) |
|
|
| |
| colors = color.unsqueeze(0).expand(n_gaussians, 3).clone() |
| colors += torch.randn_like(colors) * 0.02 |
| colors = colors.clamp(0, 1) |
|
|
| return torch.cat([positions, scales, rotations, opacity, colors], dim=-1) |
|
|
|
|
| def generate_cube_gaussians(center: torch.Tensor, size: float, |
| color: torch.Tensor, n_gaussians: int, |
| device: str = 'cpu') -> torch.Tensor: |
| """Generate Gaussians for a cube surface.""" |
| per_face = n_gaussians // 6 |
| remainder = n_gaussians - per_face * 6 |
| half = size / 2 |
|
|
| all_gaussians = [] |
|
|
| |
| face_normals = torch.tensor([ |
| [1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1] |
| ], dtype=torch.float32, device=device) |
|
|
| for i, normal in enumerate(face_normals): |
| n = per_face + (1 if i < remainder else 0) |
| if n == 0: |
| continue |
|
|
| |
| uv = torch.rand(n, 2, device=device) * size - half |
|
|
| |
| face_center = center + normal * half |
|
|
| |
| abs_normal = normal.abs() |
| if abs_normal[0] > 0.5: |
| positions = torch.stack([ |
| torch.full((n,), face_center[0].item(), device=device), |
| uv[:, 0] + center[1], |
| uv[:, 1] + center[2], |
| ], dim=-1) |
| elif abs_normal[1] > 0.5: |
| positions = torch.stack([ |
| uv[:, 0] + center[0], |
| torch.full((n,), face_center[1].item(), device=device), |
| uv[:, 1] + center[2], |
| ], dim=-1) |
| else: |
| positions = torch.stack([ |
| uv[:, 0] + center[0], |
| uv[:, 1] + center[1], |
| torch.full((n,), face_center[2].item(), device=device), |
| ], dim=-1) |
|
|
| splat_radius = size / math.sqrt(per_face) * 0.6 |
| log_scale = math.log(max(splat_radius, 1e-6)) |
| scales = torch.full((n, 3), log_scale, device=device) |
| scales[:, 2] = log_scale - 1.5 |
|
|
| rotations = quaternion_from_direction(normal.unsqueeze(0).expand(n, 3)) |
| opacity = torch.full((n, 1), 2.0, device=device) |
| colors = color.unsqueeze(0).expand(n, 3).clone() |
| colors += torch.randn_like(colors) * 0.02 |
| colors = colors.clamp(0, 1) |
|
|
| all_gaussians.append(torch.cat([positions, scales, rotations, opacity, colors], dim=-1)) |
|
|
| return torch.cat(all_gaussians, dim=0) |
|
|
|
|
| def generate_scene(n_objects: int = None, n_gaussians_per_object: int = 128, |
| device: str = 'cpu', scene_scale: float = 2.0) -> dict: |
| """ |
| Generate a random procedural scene with Gaussian splats. |
| |
| Returns: |
| dict with: |
| gaussians: (N, 14) β all Gaussians in the scene |
| description: str β text description of the scene |
| objects: list of dicts β object metadata |
| """ |
| if n_objects is None: |
| n_objects = random.randint(1, 5) |
|
|
| all_gaussians = [] |
| objects = [] |
| descriptions = [] |
|
|
| color_names = { |
| 'red': torch.tensor([0.9, 0.15, 0.1]), |
| 'green': torch.tensor([0.1, 0.85, 0.2]), |
| 'blue': torch.tensor([0.1, 0.2, 0.9]), |
| 'yellow': torch.tensor([0.95, 0.9, 0.1]), |
| 'purple': torch.tensor([0.7, 0.1, 0.8]), |
| 'orange': torch.tensor([0.95, 0.5, 0.05]), |
| 'white': torch.tensor([0.9, 0.9, 0.9]), |
| 'cyan': torch.tensor([0.1, 0.9, 0.9]), |
| } |
| color_list = list(color_names.items()) |
|
|
| shape_types = ['sphere', 'cube'] |
|
|
| for i in range(n_objects): |
| shape = random.choice(shape_types) |
| color_name, color_val = random.choice(color_list) |
| color_val = color_val.to(device) |
|
|
| |
| center = (torch.rand(3, device=device) - 0.5) * scene_scale * 1.2 |
| size = random.uniform(0.3, 0.8) |
|
|
| if shape == 'sphere': |
| gs = generate_sphere_gaussians(center, size, color_val, |
| n_gaussians_per_object, device) |
| desc = f"a {color_name} sphere" |
| else: |
| gs = generate_cube_gaussians(center, size, color_val, |
| n_gaussians_per_object, device) |
| desc = f"a {color_name} cube" |
|
|
| all_gaussians.append(gs) |
| objects.append({'shape': shape, 'color': color_name, 'center': center.cpu(), |
| 'size': size}) |
| descriptions.append(desc) |
|
|
| gaussians = torch.cat(all_gaussians, dim=0) |
| description = " and ".join(descriptions) |
|
|
| return { |
| 'gaussians': gaussians, |
| 'description': description, |
| 'objects': objects, |
| 'n_gaussians': gaussians.shape[0], |
| } |
|
|
|
|
| class ProceduralGaussianDataset(Dataset): |
| """ |
| On-the-fly procedural Gaussian scene dataset. |
| |
| Generates random scenes with spheres and cubes. |
| Each sample: (gaussians, description_text) |
| |
| For VQ-VAE Phase 1: use .get_gaussian_params() to get pooled params. |
| For Head Phase 2: use __getitem__ for (text, gaussians) pairs. |
| """ |
|
|
| def __init__(self, size: int = 10000, gaussians_per_object: int = 128, |
| max_objects: int = 5, scene_scale: float = 2.0, |
| device: str = 'cpu', fixed_seed: bool = False): |
| self.size = size |
| self.gaussians_per_object = gaussians_per_object |
| self.max_objects = max_objects |
| self.scene_scale = scene_scale |
| self.device = device |
| self.fixed_seed = fixed_seed |
|
|
| def __len__(self) -> int: |
| return self.size |
|
|
| def __getitem__(self, idx: int) -> dict: |
| if self.fixed_seed: |
| |
| torch.manual_seed(idx) |
| random.seed(idx) |
|
|
| scene = generate_scene( |
| n_objects=random.randint(1, self.max_objects), |
| n_gaussians_per_object=self.gaussians_per_object, |
| device=self.device, |
| scene_scale=self.scene_scale, |
| ) |
| return scene |
|
|
| def get_gaussian_params(self, n_samples: int = 100000) -> torch.Tensor: |
| """ |
| Get a flat pool of Gaussian non-position params for VQ-VAE training. |
| |
| Returns: (N, 11) β scale(3), rot(4), opacity(1), color(3) |
| """ |
| all_params = [] |
| collected = 0 |
|
|
| while collected < n_samples: |
| scene = generate_scene( |
| n_objects=random.randint(1, self.max_objects), |
| n_gaussians_per_object=self.gaussians_per_object, |
| device='cpu', |
| scene_scale=self.scene_scale, |
| ) |
| |
| params = scene['gaussians'][:, 3:] |
| all_params.append(params) |
| collected += params.shape[0] |
|
|
| return torch.cat(all_params, dim=0)[:n_samples] |
|
|
|
|
| |
|
|
| def render_gaussians_2d(gaussians: torch.Tensor, camera_pos: torch.Tensor, |
| camera_target: torch.Tensor, image_size: int = 64, |
| fov: float = 60.0, bg_color: torch.Tensor = None) -> torch.Tensor: |
| """ |
| Simple differentiable Gaussian splatting renderer. |
| |
| Projects 3D Gaussians to 2D, evaluates 2D Gaussian at each pixel, |
| alpha-composites front-to-back (depth-sorted). |
| |
| This is NOT a full 3DGS renderer β it's a training-grade approximation |
| that provides useful gradients. Good enough for learning, not for display. |
| |
| Args: |
| gaussians: (N, 14) β [pos(3), scale(3), rot(4), opacity(1), color(3)] |
| camera_pos: (3,) β camera position in world space |
| camera_target: (3,) β camera look-at point |
| image_size: H = W of output image |
| fov: vertical field of view in degrees |
| bg_color: (3,) β background color, default black |
| |
| Returns: |
| image: (3, H, W) β rendered RGB image [0, 1] |
| """ |
| device = gaussians.device |
| N = gaussians.shape[0] |
| H = W = image_size |
|
|
| if bg_color is None: |
| bg_color = torch.zeros(3, device=device) |
|
|
| |
| pos = gaussians[:, :3] |
| scale = gaussians[:, 3:6].exp() |
| opacity = torch.sigmoid(gaussians[:, 10:11]) |
| color = gaussians[:, 11:14] |
|
|
| |
| forward = F.normalize(camera_target - camera_pos, dim=0) |
| world_up = torch.tensor([0.0, 1.0, 0.0], device=device) |
| right = F.normalize(torch.cross(forward, world_up), dim=0) |
| up = torch.cross(right, forward) |
|
|
| |
| R = torch.stack([right, up, -forward], dim=0) |
| t = -R @ camera_pos |
|
|
| |
| cam_pos = (R @ pos.t()).t() + t.unsqueeze(0) |
|
|
| |
| fov_rad = fov * math.pi / 180 |
| focal = 0.5 * H / math.tan(fov_rad / 2) |
|
|
| |
| valid = cam_pos[:, 2] < -0.1 |
| if valid.sum() == 0: |
| return bg_color.view(3, 1, 1).expand(3, H, W) |
|
|
| cam_pos_valid = cam_pos[valid] |
| color_valid = color[valid] |
| opacity_valid = opacity[valid] |
| scale_valid = scale[valid] |
|
|
| |
| px = -focal * cam_pos_valid[:, 0] / cam_pos_valid[:, 2] + W / 2 |
| py = -focal * cam_pos_valid[:, 1] / cam_pos_valid[:, 2] + H / 2 |
|
|
| |
| depths = -cam_pos_valid[:, 2] |
| sort_idx = depths.argsort() |
|
|
| px = px[sort_idx] |
| py = py[sort_idx] |
| color_sorted = color_valid[sort_idx] |
| opacity_sorted = opacity_valid[sort_idx] |
| scale_sorted = scale_valid[sort_idx] |
| depth_sorted = depths[sort_idx] |
|
|
| |
| |
| s2d = (scale_sorted[:, :2].mean(dim=-1) * focal / depth_sorted).clamp(min=0.5) |
|
|
| |
| |
| yy, xx = torch.meshgrid( |
| torch.arange(H, dtype=torch.float32, device=device), |
| torch.arange(W, dtype=torch.float32, device=device), |
| indexing='ij' |
| ) |
|
|
| |
| image = bg_color.view(3, 1, 1).expand(3, H, W).clone() |
| transmittance = torch.ones(H, W, device=device) |
|
|
| |
| M = px.shape[0] |
| |
| chunk_size = min(M, 512) |
|
|
| for start in range(0, M, chunk_size): |
| end = min(start + chunk_size, M) |
| chunk_px = px[start:end] |
| chunk_py = py[start:end] |
| chunk_s = s2d[start:end] |
| chunk_color = color_sorted[start:end] |
| chunk_alpha = opacity_sorted[start:end] |
|
|
| |
| dx = xx.unsqueeze(0) - chunk_px.view(-1, 1, 1) |
| dy = yy.unsqueeze(0) - chunk_py.view(-1, 1, 1) |
|
|
| |
| inv_s2 = 1.0 / (chunk_s.view(-1, 1, 1) ** 2 + 1e-6) |
| power = -0.5 * (dx ** 2 + dy ** 2) * inv_s2 |
|
|
| |
| gauss = torch.exp(power.clamp(min=-8)) |
| alpha = chunk_alpha.view(-1, 1, 1) * gauss |
|
|
| |
| for i in range(end - start): |
| a = alpha[i] * transmittance |
| c = chunk_color[i] |
| image += c.view(3, 1, 1) * a.unsqueeze(0) |
| transmittance = transmittance * (1 - alpha[i]) |
|
|
| return image.clamp(0, 1) |
|
|
|
|
| def multi_view_render(gaussians: torch.Tensor, n_views: int = 4, |
| image_size: int = 64, orbit_radius: float = 4.0, |
| fov: float = 60.0) -> torch.Tensor: |
| """ |
| Render Gaussians from multiple viewpoints on an orbit. |
| |
| Args: |
| gaussians: (N, 14) |
| n_views: number of views |
| image_size: H = W |
| orbit_radius: camera distance from origin |
| fov: field of view |
| |
| Returns: |
| images: (n_views, 3, H, W) |
| """ |
| device = gaussians.device |
| images = [] |
| target = torch.zeros(3, device=device) |
|
|
| for i in range(n_views): |
| angle = 2 * math.pi * i / n_views |
| |
| cam_pos = torch.tensor([ |
| orbit_radius * math.cos(angle), |
| orbit_radius * 0.4, |
| orbit_radius * math.sin(angle), |
| ], device=device) |
|
|
| img = render_gaussians_2d(gaussians, cam_pos, target, |
| image_size=image_size, fov=fov) |
| images.append(img) |
|
|
| return torch.stack(images, dim=0) |
|
|