amuzetnoM's picture
WYRM kernel source (v27 FINAL)
9463e5c verified
"""
GLADIUS β€” Gaussian Head Data
Procedural Gaussian scene generation for training the specialist head.
No external datasets needed β€” generates ground truth Gaussians from
geometric primitives (spheres, cubes, multi-object scenes).
Also includes a differentiable 2D renderer for computing training loss.
The renderer projects 3D Gaussians to 2D and alpha-composites them.
"""
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
import math
import random
def fibonacci_sphere(n: int, device: str = 'cpu') -> torch.Tensor:
"""
Generate n points uniformly distributed on a unit sphere.
Fibonacci/golden spiral method.
Returns: (n, 3) β€” unit vectors on sphere surface
"""
golden_ratio = (1 + math.sqrt(5)) / 2
indices = torch.arange(n, dtype=torch.float32, device=device)
theta = 2 * math.pi * indices / golden_ratio
phi = torch.acos(1 - 2 * (indices + 0.5) / n)
x = torch.sin(phi) * torch.cos(theta)
y = torch.sin(phi) * torch.sin(theta)
z = torch.cos(phi)
return torch.stack([x, y, z], dim=-1)
def random_quaternion(n: int, device: str = 'cpu') -> torch.Tensor:
"""Generate n random unit quaternions (uniform on SO(3))."""
u = torch.rand(n, 3, device=device)
q = torch.stack([
torch.sqrt(1 - u[:, 0]) * torch.sin(2 * math.pi * u[:, 1]),
torch.sqrt(1 - u[:, 0]) * torch.cos(2 * math.pi * u[:, 1]),
torch.sqrt(u[:, 0]) * torch.sin(2 * math.pi * u[:, 2]),
torch.sqrt(u[:, 0]) * torch.cos(2 * math.pi * u[:, 2]),
], dim=-1)
return q
def quaternion_from_direction(normals: torch.Tensor) -> torch.Tensor:
"""
Compute quaternion that aligns Z-axis with given normal vectors.
normals: (N, 3) β€” unit vectors
Returns: (N, 4) β€” unit quaternions
"""
N = normals.shape[0]
z = torch.tensor([0.0, 0.0, 1.0], device=normals.device).expand(N, 3)
# Cross product z Γ— normal
cross = torch.cross(z, normals, dim=-1)
dot = (z * normals).sum(dim=-1, keepdim=True)
# Quaternion: q = [cross, 1 + dot] then normalize
q = torch.cat([cross, 1.0 + dot], dim=-1)
# Handle anti-parallel case (dot β‰ˆ -1)
anti = dot.squeeze(-1) < -0.999
if anti.any():
# Use perpendicular vector
perp = torch.zeros_like(normals[anti])
perp[:, 0] = 1.0 # arbitrary perpendicular
q[anti] = torch.cat([perp, torch.zeros(anti.sum(), 1, device=normals.device)], dim=-1)
return F.normalize(q, dim=-1)
def generate_sphere_gaussians(center: torch.Tensor, radius: float,
color: torch.Tensor, n_gaussians: int,
device: str = 'cpu') -> torch.Tensor:
"""
Generate Gaussians for a sphere surface.
Returns: (n_gaussians, 14) β€” full Gaussian params
[pos(3), scale(3), rot(4), opacity(1), color(3)]
"""
# Points on sphere surface
normals = fibonacci_sphere(n_gaussians, device)
positions = center.unsqueeze(0) + radius * normals
# Scale: flat discs covering the surface
splat_radius = radius * math.sqrt(4 * math.pi / n_gaussians) * 0.5
log_scale = math.log(max(splat_radius, 1e-6))
scales = torch.full((n_gaussians, 3), log_scale, device=device)
scales[:, 2] = log_scale - 1.5 # Thin in normal direction
# Rotation: align with surface normal
rotations = quaternion_from_direction(normals)
# Opacity (pre-sigmoid): high opacity
opacity = torch.full((n_gaussians, 1), 2.0, device=device) # sigmoid(2) β‰ˆ 0.88
# Color: uniform with slight variation
colors = color.unsqueeze(0).expand(n_gaussians, 3).clone()
colors += torch.randn_like(colors) * 0.02 # Slight variation
colors = colors.clamp(0, 1)
return torch.cat([positions, scales, rotations, opacity, colors], dim=-1)
def generate_cube_gaussians(center: torch.Tensor, size: float,
color: torch.Tensor, n_gaussians: int,
device: str = 'cpu') -> torch.Tensor:
"""Generate Gaussians for a cube surface."""
per_face = n_gaussians // 6
remainder = n_gaussians - per_face * 6
half = size / 2
all_gaussians = []
# 6 faces: +X, -X, +Y, -Y, +Z, -Z
face_normals = torch.tensor([
[1, 0, 0], [-1, 0, 0], [0, 1, 0], [0, -1, 0], [0, 0, 1], [0, 0, -1]
], dtype=torch.float32, device=device)
for i, normal in enumerate(face_normals):
n = per_face + (1 if i < remainder else 0)
if n == 0:
continue
# Random points on face
uv = torch.rand(n, 2, device=device) * size - half
# Face center offset
face_center = center + normal * half
# Map UV to 3D based on face normal axis
abs_normal = normal.abs()
if abs_normal[0] > 0.5: # X face
positions = torch.stack([
torch.full((n,), face_center[0].item(), device=device),
uv[:, 0] + center[1],
uv[:, 1] + center[2],
], dim=-1)
elif abs_normal[1] > 0.5: # Y face
positions = torch.stack([
uv[:, 0] + center[0],
torch.full((n,), face_center[1].item(), device=device),
uv[:, 1] + center[2],
], dim=-1)
else: # Z face
positions = torch.stack([
uv[:, 0] + center[0],
uv[:, 1] + center[1],
torch.full((n,), face_center[2].item(), device=device),
], dim=-1)
splat_radius = size / math.sqrt(per_face) * 0.6
log_scale = math.log(max(splat_radius, 1e-6))
scales = torch.full((n, 3), log_scale, device=device)
scales[:, 2] = log_scale - 1.5
rotations = quaternion_from_direction(normal.unsqueeze(0).expand(n, 3))
opacity = torch.full((n, 1), 2.0, device=device)
colors = color.unsqueeze(0).expand(n, 3).clone()
colors += torch.randn_like(colors) * 0.02
colors = colors.clamp(0, 1)
all_gaussians.append(torch.cat([positions, scales, rotations, opacity, colors], dim=-1))
return torch.cat(all_gaussians, dim=0)
def generate_scene(n_objects: int = None, n_gaussians_per_object: int = 128,
device: str = 'cpu', scene_scale: float = 2.0) -> dict:
"""
Generate a random procedural scene with Gaussian splats.
Returns:
dict with:
gaussians: (N, 14) β€” all Gaussians in the scene
description: str β€” text description of the scene
objects: list of dicts β€” object metadata
"""
if n_objects is None:
n_objects = random.randint(1, 5)
all_gaussians = []
objects = []
descriptions = []
color_names = {
'red': torch.tensor([0.9, 0.15, 0.1]),
'green': torch.tensor([0.1, 0.85, 0.2]),
'blue': torch.tensor([0.1, 0.2, 0.9]),
'yellow': torch.tensor([0.95, 0.9, 0.1]),
'purple': torch.tensor([0.7, 0.1, 0.8]),
'orange': torch.tensor([0.95, 0.5, 0.05]),
'white': torch.tensor([0.9, 0.9, 0.9]),
'cyan': torch.tensor([0.1, 0.9, 0.9]),
}
color_list = list(color_names.items())
shape_types = ['sphere', 'cube']
for i in range(n_objects):
shape = random.choice(shape_types)
color_name, color_val = random.choice(color_list)
color_val = color_val.to(device)
# Random position within scene bounds
center = (torch.rand(3, device=device) - 0.5) * scene_scale * 1.2
size = random.uniform(0.3, 0.8)
if shape == 'sphere':
gs = generate_sphere_gaussians(center, size, color_val,
n_gaussians_per_object, device)
desc = f"a {color_name} sphere"
else:
gs = generate_cube_gaussians(center, size, color_val,
n_gaussians_per_object, device)
desc = f"a {color_name} cube"
all_gaussians.append(gs)
objects.append({'shape': shape, 'color': color_name, 'center': center.cpu(),
'size': size})
descriptions.append(desc)
gaussians = torch.cat(all_gaussians, dim=0)
description = " and ".join(descriptions)
return {
'gaussians': gaussians,
'description': description,
'objects': objects,
'n_gaussians': gaussians.shape[0],
}
class ProceduralGaussianDataset(Dataset):
"""
On-the-fly procedural Gaussian scene dataset.
Generates random scenes with spheres and cubes.
Each sample: (gaussians, description_text)
For VQ-VAE Phase 1: use .get_gaussian_params() to get pooled params.
For Head Phase 2: use __getitem__ for (text, gaussians) pairs.
"""
def __init__(self, size: int = 10000, gaussians_per_object: int = 128,
max_objects: int = 5, scene_scale: float = 2.0,
device: str = 'cpu', fixed_seed: bool = False):
self.size = size
self.gaussians_per_object = gaussians_per_object
self.max_objects = max_objects
self.scene_scale = scene_scale
self.device = device
self.fixed_seed = fixed_seed
def __len__(self) -> int:
return self.size
def __getitem__(self, idx: int) -> dict:
if self.fixed_seed:
# Deterministic for eval
torch.manual_seed(idx)
random.seed(idx)
scene = generate_scene(
n_objects=random.randint(1, self.max_objects),
n_gaussians_per_object=self.gaussians_per_object,
device=self.device,
scene_scale=self.scene_scale,
)
return scene
def get_gaussian_params(self, n_samples: int = 100000) -> torch.Tensor:
"""
Get a flat pool of Gaussian non-position params for VQ-VAE training.
Returns: (N, 11) β€” scale(3), rot(4), opacity(1), color(3)
"""
all_params = []
collected = 0
while collected < n_samples:
scene = generate_scene(
n_objects=random.randint(1, self.max_objects),
n_gaussians_per_object=self.gaussians_per_object,
device='cpu',
scene_scale=self.scene_scale,
)
# Extract non-position params: [3:14] = scale(3) + rot(4) + opacity(1) + color(3)
params = scene['gaussians'][:, 3:] # (N, 11)
all_params.append(params)
collected += params.shape[0]
return torch.cat(all_params, dim=0)[:n_samples]
# ═══ Differentiable 2D Renderer ═══
def render_gaussians_2d(gaussians: torch.Tensor, camera_pos: torch.Tensor,
camera_target: torch.Tensor, image_size: int = 64,
fov: float = 60.0, bg_color: torch.Tensor = None) -> torch.Tensor:
"""
Simple differentiable Gaussian splatting renderer.
Projects 3D Gaussians to 2D, evaluates 2D Gaussian at each pixel,
alpha-composites front-to-back (depth-sorted).
This is NOT a full 3DGS renderer β€” it's a training-grade approximation
that provides useful gradients. Good enough for learning, not for display.
Args:
gaussians: (N, 14) β€” [pos(3), scale(3), rot(4), opacity(1), color(3)]
camera_pos: (3,) β€” camera position in world space
camera_target: (3,) β€” camera look-at point
image_size: H = W of output image
fov: vertical field of view in degrees
bg_color: (3,) β€” background color, default black
Returns:
image: (3, H, W) β€” rendered RGB image [0, 1]
"""
device = gaussians.device
N = gaussians.shape[0]
H = W = image_size
if bg_color is None:
bg_color = torch.zeros(3, device=device)
# Parse Gaussian params
pos = gaussians[:, :3] # (N, 3)
scale = gaussians[:, 3:6].exp() # (N, 3) β€” from log-scale
opacity = torch.sigmoid(gaussians[:, 10:11]) # (N, 1)
color = gaussians[:, 11:14] # (N, 3) β€” already [0,1] from sigmoid in head
# ── Camera setup ──
forward = F.normalize(camera_target - camera_pos, dim=0)
world_up = torch.tensor([0.0, 1.0, 0.0], device=device)
right = F.normalize(torch.cross(forward, world_up), dim=0)
up = torch.cross(right, forward)
# View matrix (world β†’ camera)
R = torch.stack([right, up, -forward], dim=0) # (3, 3)
t = -R @ camera_pos # (3,)
# Project Gaussians to camera space
cam_pos = (R @ pos.t()).t() + t.unsqueeze(0) # (N, 3)
# Perspective projection
fov_rad = fov * math.pi / 180
focal = 0.5 * H / math.tan(fov_rad / 2)
# Filter: only Gaussians in front of camera
valid = cam_pos[:, 2] < -0.1 # Z points backward in our convention
if valid.sum() == 0:
return bg_color.view(3, 1, 1).expand(3, H, W)
cam_pos_valid = cam_pos[valid]
color_valid = color[valid]
opacity_valid = opacity[valid]
scale_valid = scale[valid]
# Project to pixel coords
px = -focal * cam_pos_valid[:, 0] / cam_pos_valid[:, 2] + W / 2 # (M,)
py = -focal * cam_pos_valid[:, 1] / cam_pos_valid[:, 2] + H / 2
# Sort by depth (front to back)
depths = -cam_pos_valid[:, 2] # Positive = further
sort_idx = depths.argsort()
px = px[sort_idx]
py = py[sort_idx]
color_sorted = color_valid[sort_idx]
opacity_sorted = opacity_valid[sort_idx]
scale_sorted = scale_valid[sort_idx]
depth_sorted = depths[sort_idx]
# Projected 2D scale (approximate)
# Use average of X,Y scale projected by focal/depth
s2d = (scale_sorted[:, :2].mean(dim=-1) * focal / depth_sorted).clamp(min=0.5) # (M,)
# ── Rasterize via 2D Gaussian evaluation ──
# Create pixel grid
yy, xx = torch.meshgrid(
torch.arange(H, dtype=torch.float32, device=device),
torch.arange(W, dtype=torch.float32, device=device),
indexing='ij'
) # (H, W) each
# Initialize image
image = bg_color.view(3, 1, 1).expand(3, H, W).clone()
transmittance = torch.ones(H, W, device=device)
# Splat each Gaussian (front to back alpha compositing)
M = px.shape[0]
# Batch for efficiency: process in chunks
chunk_size = min(M, 512)
for start in range(0, M, chunk_size):
end = min(start + chunk_size, M)
chunk_px = px[start:end] # (C,)
chunk_py = py[start:end]
chunk_s = s2d[start:end] # (C,)
chunk_color = color_sorted[start:end] # (C, 3)
chunk_alpha = opacity_sorted[start:end] # (C, 1)
# Distance from each pixel to each Gaussian center
dx = xx.unsqueeze(0) - chunk_px.view(-1, 1, 1) # (C, H, W)
dy = yy.unsqueeze(0) - chunk_py.view(-1, 1, 1)
# 2D Gaussian evaluation
inv_s2 = 1.0 / (chunk_s.view(-1, 1, 1) ** 2 + 1e-6)
power = -0.5 * (dx ** 2 + dy ** 2) * inv_s2 # (C, H, W)
# Gaussian contribution
gauss = torch.exp(power.clamp(min=-8)) # Clamp for numerical stability
alpha = chunk_alpha.view(-1, 1, 1) * gauss # (C, H, W)
# Alpha composite per Gaussian in chunk (sequential β€” maintains depth order)
for i in range(end - start):
a = alpha[i] * transmittance # (H, W)
c = chunk_color[i] # (3,)
image += c.view(3, 1, 1) * a.unsqueeze(0) # (3, H, W)
transmittance = transmittance * (1 - alpha[i])
return image.clamp(0, 1)
def multi_view_render(gaussians: torch.Tensor, n_views: int = 4,
image_size: int = 64, orbit_radius: float = 4.0,
fov: float = 60.0) -> torch.Tensor:
"""
Render Gaussians from multiple viewpoints on an orbit.
Args:
gaussians: (N, 14)
n_views: number of views
image_size: H = W
orbit_radius: camera distance from origin
fov: field of view
Returns:
images: (n_views, 3, H, W)
"""
device = gaussians.device
images = []
target = torch.zeros(3, device=device)
for i in range(n_views):
angle = 2 * math.pi * i / n_views
# Camera on orbit (XZ plane, Y slightly elevated)
cam_pos = torch.tensor([
orbit_radius * math.cos(angle),
orbit_radius * 0.4, # Slight elevation
orbit_radius * math.sin(angle),
], device=device)
img = render_gaussians_2d(gaussians, cam_pos, target,
image_size=image_size, fov=fov)
images.append(img)
return torch.stack(images, dim=0)