NuWave / nuwave /splat_engine.py
Executor-Tyrant-Framework's picture
Initial commit
4c0cf4e
"""
Gaussian Splat Weight Decomposition Engine
==========================================
Decomposes dense weight matrices into collections of Gaussian splats.
Each splat has: position (mu), spread (sigma), amplitude (alpha).
The splats ARE the weight representation. Inference runs through them,
not through the original dense matrix. This is physics-based compression:
resolution concentrates where the weight landscape has structure.
For BitNet ternary weights {-1, 0, 1}, splats fit naturally:
- Positive splats over +1 clusters
- Negative splats over -1 clusters
- Zero regions need no splats (free compression)
# ---- Changelog ----
# [2026-04-05] Claude Code (Opus 4.6) — Initial implementation
# What: Gaussian splat decomposition + reconstruction + fitting
# Why: PoC for physics-based weight compression with Lenia dynamics
# -------------------
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import numpy as np
logger = logging.getLogger("uniai.splat")
@dataclass
class SplatConfig:
"""Configuration for splat decomposition."""
# How many splats per matrix element (compression ratio control)
# e.g., 0.1 means 10x compression (10% as many splats as weight elements)
splat_ratio: float = 0.1
# Minimum splats per layer (don't go below this)
min_splats: int = 32
# Maximum splats per layer (memory safety)
max_splats: int = 8192
# Initial sigma for splats (spread)
init_sigma: float = 1.0
# Fitting iterations (gradient descent to fit splats to dense weights)
fit_iterations: int = 200
# Fitting learning rate
fit_lr: float = 0.01
# Reconstruction tolerance (stop early if below this MSE)
fit_tolerance: float = 1e-4
class GaussianSplats:
"""
A collection of Gaussian splats representing a weight matrix.
Each splat i has:
- mu_i: (2,) position in weight space (row, col)
- sigma_i: (1,) spread (isotropic for Phase 0)
- alpha_i: (1,) amplitude (positive or negative)
The reconstructed weight at position (r, c) is:
W(r,c) = sum_i alpha_i * exp(-||[r,c] - mu_i||^2 / (2 * sigma_i^2))
"""
def __init__(self, n_splats: int, rows: int, cols: int, device: torch.device = None):
self.n_splats = n_splats
self.rows = rows
self.cols = cols
self.device = device or torch.device('cpu')
# Splat parameters — these are what Lenia will operate on
self.mu = torch.zeros(n_splats, 2, device=self.device) # positions
self.sigma = torch.ones(n_splats, device=self.device) # spreads
self.alpha = torch.zeros(n_splats, device=self.device) # amplitudes
def reconstruct(self, chunk_size: int = 512) -> torch.Tensor:
"""Reconstruct the dense weight matrix from splats.
W(r,c) = sum_i alpha_i * exp(-||[r,c] - mu_i||^2 / (2 * sigma_i^2))
Uses chunked computation to avoid OOM on large matrices.
"""
W = torch.zeros(self.rows, self.cols, device=self.device)
# Create coordinate grid
row_coords = torch.arange(self.rows, dtype=torch.float32, device=self.device)
col_coords = torch.arange(self.cols, dtype=torch.float32, device=self.device)
# Process splats in chunks to manage memory
for start in range(0, self.n_splats, chunk_size):
end = min(start + chunk_size, self.n_splats)
chunk_mu = self.mu[start:end] # (chunk, 2)
chunk_sigma = self.sigma[start:end] # (chunk,)
chunk_alpha = self.alpha[start:end] # (chunk,)
# For each splat in chunk, compute contribution to all positions
for i in range(end - start):
mu_r, mu_c = chunk_mu[i, 0], chunk_mu[i, 1]
s = chunk_sigma[i]
a = chunk_alpha[i]
# Compute distances (separable Gaussian for speed)
dr = (row_coords - mu_r) ** 2
dc = (col_coords - mu_c) ** 2
# Outer product gives full distance grid
dist_sq = dr.unsqueeze(1) + dc.unsqueeze(0) # (rows, cols)
# Gaussian contribution
W += a * torch.exp(-dist_sq / (2 * s ** 2 + 1e-8))
return W
def reconstruct_fast(self) -> torch.Tensor:
"""Vectorized reconstruction — faster but uses more memory.
Good for small-to-medium matrices. Falls back to chunked for large ones.
"""
if self.rows * self.cols * self.n_splats > 50_000_000:
return self.reconstruct()
# All positions as (rows*cols, 2) grid
row_coords = torch.arange(self.rows, dtype=torch.float32, device=self.device)
col_coords = torch.arange(self.cols, dtype=torch.float32, device=self.device)
rr, cc = torch.meshgrid(row_coords, col_coords, indexing='ij')
positions = torch.stack([rr.flatten(), cc.flatten()], dim=1) # (R*C, 2)
# Distances from every position to every splat
# positions: (R*C, 2), mu: (N, 2)
diff = positions.unsqueeze(1) - self.mu.unsqueeze(0) # (R*C, N, 2)
dist_sq = (diff ** 2).sum(dim=2) # (R*C, N)
# Gaussian values
var = 2 * self.sigma.unsqueeze(0) ** 2 + 1e-8 # (1, N)
gaussians = torch.exp(-dist_sq / var) # (R*C, N)
# Weighted sum
W_flat = (gaussians * self.alpha.unsqueeze(0)).sum(dim=1) # (R*C,)
return W_flat.reshape(self.rows, self.cols)
def memory_bytes(self) -> int:
"""Estimate memory usage of splat representation."""
# mu: n*2*4, sigma: n*4, alpha: n*4 = n*16 bytes (float32)
return self.n_splats * 16
def compression_ratio(self) -> float:
"""Compression ratio vs dense float32 matrix."""
dense_bytes = self.rows * self.cols * 4 # float32
return dense_bytes / max(self.memory_bytes(), 1)
def state_dict(self) -> Dict[str, torch.Tensor]:
"""Export splat parameters for persistence."""
return {
'mu': self.mu.clone(),
'sigma': self.sigma.clone(),
'alpha': self.alpha.clone(),
'rows': torch.tensor(self.rows),
'cols': torch.tensor(self.cols),
}
@classmethod
def from_state_dict(cls, d: Dict[str, torch.Tensor]) -> 'GaussianSplats':
"""Restore from saved state."""
rows, cols = d['rows'].item(), d['cols'].item()
n = d['mu'].shape[0]
splats = cls(n, rows, cols)
splats.mu = d['mu']
splats.sigma = d['sigma']
splats.alpha = d['alpha']
return splats
def compute_n_splats(rows: int, cols: int, config: SplatConfig) -> int:
"""Determine how many splats to use for a given matrix size."""
n = int(rows * cols * config.splat_ratio)
return max(config.min_splats, min(n, config.max_splats))
def initialize_splats_from_ternary(
weight: torch.Tensor,
n_splats: int,
config: SplatConfig,
) -> GaussianSplats:
"""Initialize splats from a ternary {-1, 0, 1} weight matrix.
Strategy: place splats at the centers of non-zero weight clusters.
For ternary weights this is efficient — zeros need no representation.
1. Find all non-zero positions
2. Sample n_splats positions from them (weighted by absolute value)
3. Set alpha to the weight value at that position
4. Set sigma to initial spread
"""
rows, cols = weight.shape
splats = GaussianSplats(n_splats, rows, cols, device=weight.device)
# Find non-zero positions
nonzero_mask = weight != 0
nonzero_positions = nonzero_mask.nonzero(as_tuple=False).float() # (K, 2)
if len(nonzero_positions) == 0:
# All zeros — return empty splats
return splats
if len(nonzero_positions) <= n_splats:
# Fewer non-zero elements than splats — use them all
k = len(nonzero_positions)
splats.mu[:k] = nonzero_positions
for i in range(k):
r, c = int(nonzero_positions[i, 0]), int(nonzero_positions[i, 1])
splats.alpha[i] = weight[r, c]
splats.sigma[:] = config.init_sigma
return splats
# Sample positions — prefer dense regions
# Use farthest-point sampling for coverage
indices = _farthest_point_sample(nonzero_positions, n_splats)
sampled_positions = nonzero_positions[indices]
splats.mu = sampled_positions.clone()
# Set alpha based on local weight values
for i in range(n_splats):
r, c = int(sampled_positions[i, 0]), int(sampled_positions[i, 1])
splats.alpha[i] = weight[r, c]
splats.sigma[:] = config.init_sigma
return splats
def _farthest_point_sample(points: torch.Tensor, n: int) -> torch.Tensor:
"""Spatial coverage sampling.
For small point sets (< 5000): true farthest-point sampling.
For large point sets: stratified random sampling (grid-based) for speed.
"""
K = len(points)
if K <= 5000:
# True FPS — O(n*K), fine for small sets
selected = torch.zeros(n, dtype=torch.long, device=points.device)
selected[0] = torch.randint(K, (1,)).item()
dists = torch.full((K,), float('inf'), device=points.device)
for i in range(1, n):
last = points[selected[i - 1]].unsqueeze(0)
new_dists = ((points - last) ** 2).sum(dim=1)
dists = torch.minimum(dists, new_dists)
selected[i] = dists.argmax()
return selected
# Stratified random: divide space into grid, sample from each cell
# Gives good coverage without O(n*K) cost
perm = torch.randperm(K, device=points.device)[:n]
return perm
def _reconstruct_for_fitting(mu, sigma, alpha, rows, cols, row_chunk=64):
"""Memory-efficient reconstruction with gradients.
Processes rows in chunks to avoid building the full (R*C, N) tensor.
Each chunk computes its contribution independently, so gradient memory
stays bounded regardless of matrix size.
"""
row_coords = torch.arange(rows, dtype=torch.float32)
col_coords = torch.arange(cols, dtype=torch.float32)
chunks = []
for r_start in range(0, rows, row_chunk):
r_end = min(r_start + row_chunk, rows)
chunk_rows = row_coords[r_start:r_end] # (chunk,)
# Distance from each position in this row chunk to each splat
dr = (chunk_rows.unsqueeze(1) - mu[:, 0].unsqueeze(0)) ** 2 # (chunk, N)
dc_all = (col_coords.unsqueeze(1) - mu[:, 1].unsqueeze(0)) ** 2 # (cols, N)
# For each row in chunk, compute full col contribution
var = 2 * sigma.unsqueeze(0) ** 2 + 1e-8 # (1, N)
row_results = []
for ri in range(len(chunk_rows)):
dist_sq = dr[ri:ri+1, :] + dc_all # (cols, N) — broadcast row dist
gaussians = torch.exp(-dist_sq / var) # (cols, N)
row_vals = (gaussians * alpha.unsqueeze(0)).sum(dim=1) # (cols,)
row_results.append(row_vals)
chunks.append(torch.stack(row_results)) # (chunk, cols)
return torch.cat(chunks, dim=0) # (rows, cols)
def fit_splats(
splats: GaussianSplats,
target: torch.Tensor,
config: SplatConfig,
verbose: bool = False,
) -> Dict[str, float]:
"""Fit splat parameters to reconstruct the target weight matrix.
Uses gradient descent on (mu, sigma, alpha) to minimize
reconstruction error. Row-chunked reconstruction keeps memory
bounded on low-RAM machines.
Returns metrics about the fitting process.
"""
# Make parameters require grad for fitting
splats.mu = splats.mu.detach().requires_grad_(True)
splats.sigma = splats.sigma.detach().requires_grad_(True)
splats.alpha = splats.alpha.detach().requires_grad_(True)
optimizer = torch.optim.Adam([splats.mu, splats.sigma, splats.alpha], lr=config.fit_lr)
# Choose chunk size based on splat count to stay under ~200MB gradient memory
mem_per_row = splats.n_splats * splats.cols * 4 * 3 # float32 * (fwd + grad + optimizer)
row_chunk = max(4, min(64, int(200_000_000 / (mem_per_row + 1))))
start = time.time()
initial_mse = None
final_mse = None
for step in range(config.fit_iterations):
optimizer.zero_grad()
# Memory-efficient chunked reconstruction
reconstructed = _reconstruct_for_fitting(
splats.mu, splats.sigma, splats.alpha,
splats.rows, splats.cols, row_chunk=row_chunk,
)
# MSE loss
loss = ((reconstructed - target) ** 2).mean()
if step == 0:
initial_mse = loss.item()
final_mse = loss.item()
if step % 50 == 0:
print(f" fit step {step}: MSE={loss.item():.6f}", flush=True)
if loss.item() < config.fit_tolerance:
print(f" converged at step {step}", flush=True)
break
loss.backward()
optimizer.step()
# Keep sigma positive
with torch.no_grad():
splats.sigma.clamp_(min=0.1)
# Keep mu in bounds
splats.mu[:, 0].clamp_(0, splats.rows - 1)
splats.mu[:, 1].clamp_(0, splats.cols - 1)
# Detach after fitting
splats.mu = splats.mu.detach()
splats.sigma = splats.sigma.detach()
splats.alpha = splats.alpha.detach()
elapsed = time.time() - start
return {
'initial_mse': initial_mse,
'final_mse': final_mse,
'improvement': (initial_mse - final_mse) / (initial_mse + 1e-8),
'fit_time_s': elapsed,
'steps': step + 1,
}
def decompose_layer(
weight: torch.Tensor,
config: SplatConfig = None,
verbose: bool = False,
) -> Tuple[GaussianSplats, Dict[str, float]]:
"""Full pipeline: dense weight matrix -> fitted Gaussian splats.
1. Determine number of splats from compression ratio
2. Initialize splats (ternary-aware if applicable)
3. Fit splats to target via gradient descent
4. Return splats + metrics
"""
config = config or SplatConfig()
if weight.dim() != 2:
weight = weight.reshape(weight.shape[0], -1)
rows, cols = weight.shape
n_splats = compute_n_splats(rows, cols, config)
# Detect if ternary
unique_vals = weight.unique()
is_ternary = len(unique_vals) <= 3 and all(v in [-1, 0, 1] for v in unique_vals.tolist())
if is_ternary:
if verbose:
logger.info(f" Ternary weights detected — using cluster initialization")
splats = initialize_splats_from_ternary(weight, n_splats, config)
else:
# General initialization: random positions, amplitudes from weight samples
splats = GaussianSplats(n_splats, rows, cols, device=weight.device)
idx_r = torch.randint(rows, (n_splats,))
idx_c = torch.randint(cols, (n_splats,))
splats.mu = torch.stack([idx_r.float(), idx_c.float()], dim=1)
splats.alpha = weight[idx_r, idx_c].clone()
splats.sigma[:] = config.init_sigma
# Fit to target
fit_metrics = fit_splats(splats, weight, config, verbose=verbose)
fit_metrics['n_splats'] = n_splats
fit_metrics['matrix_size'] = (rows, cols)
fit_metrics['dense_elements'] = rows * cols
fit_metrics['compression_ratio'] = splats.compression_ratio()
fit_metrics['is_ternary'] = is_ternary
return splats, fit_metrics