""" Gaussian Splat Weight Decomposition Engine ========================================== Decomposes dense weight matrices into collections of Gaussian splats. Each splat has: position (mu), spread (sigma), amplitude (alpha). The splats ARE the weight representation. Inference runs through them, not through the original dense matrix. This is physics-based compression: resolution concentrates where the weight landscape has structure. For BitNet ternary weights {-1, 0, 1}, splats fit naturally: - Positive splats over +1 clusters - Negative splats over -1 clusters - Zero regions need no splats (free compression) # ---- Changelog ---- # [2026-04-05] Claude Code (Opus 4.6) — Initial implementation # What: Gaussian splat decomposition + reconstruction + fitting # Why: PoC for physics-based weight compression with Lenia dynamics # ------------------- """ from __future__ import annotations import logging import time from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn import numpy as np logger = logging.getLogger("uniai.splat") @dataclass class SplatConfig: """Configuration for splat decomposition.""" # How many splats per matrix element (compression ratio control) # e.g., 0.1 means 10x compression (10% as many splats as weight elements) splat_ratio: float = 0.1 # Minimum splats per layer (don't go below this) min_splats: int = 32 # Maximum splats per layer (memory safety) max_splats: int = 8192 # Initial sigma for splats (spread) init_sigma: float = 1.0 # Fitting iterations (gradient descent to fit splats to dense weights) fit_iterations: int = 200 # Fitting learning rate fit_lr: float = 0.01 # Reconstruction tolerance (stop early if below this MSE) fit_tolerance: float = 1e-4 class GaussianSplats: """ A collection of Gaussian splats representing a weight matrix. Each splat i has: - mu_i: (2,) position in weight space (row, col) - sigma_i: (1,) spread (isotropic for Phase 0) - alpha_i: (1,) amplitude (positive or negative) The reconstructed weight at position (r, c) is: W(r,c) = sum_i alpha_i * exp(-||[r,c] - mu_i||^2 / (2 * sigma_i^2)) """ def __init__(self, n_splats: int, rows: int, cols: int, device: torch.device = None): self.n_splats = n_splats self.rows = rows self.cols = cols self.device = device or torch.device('cpu') # Splat parameters — these are what Lenia will operate on self.mu = torch.zeros(n_splats, 2, device=self.device) # positions self.sigma = torch.ones(n_splats, device=self.device) # spreads self.alpha = torch.zeros(n_splats, device=self.device) # amplitudes def reconstruct(self, chunk_size: int = 512) -> torch.Tensor: """Reconstruct the dense weight matrix from splats. W(r,c) = sum_i alpha_i * exp(-||[r,c] - mu_i||^2 / (2 * sigma_i^2)) Uses chunked computation to avoid OOM on large matrices. """ W = torch.zeros(self.rows, self.cols, device=self.device) # Create coordinate grid row_coords = torch.arange(self.rows, dtype=torch.float32, device=self.device) col_coords = torch.arange(self.cols, dtype=torch.float32, device=self.device) # Process splats in chunks to manage memory for start in range(0, self.n_splats, chunk_size): end = min(start + chunk_size, self.n_splats) chunk_mu = self.mu[start:end] # (chunk, 2) chunk_sigma = self.sigma[start:end] # (chunk,) chunk_alpha = self.alpha[start:end] # (chunk,) # For each splat in chunk, compute contribution to all positions for i in range(end - start): mu_r, mu_c = chunk_mu[i, 0], chunk_mu[i, 1] s = chunk_sigma[i] a = chunk_alpha[i] # Compute distances (separable Gaussian for speed) dr = (row_coords - mu_r) ** 2 dc = (col_coords - mu_c) ** 2 # Outer product gives full distance grid dist_sq = dr.unsqueeze(1) + dc.unsqueeze(0) # (rows, cols) # Gaussian contribution W += a * torch.exp(-dist_sq / (2 * s ** 2 + 1e-8)) return W def reconstruct_fast(self) -> torch.Tensor: """Vectorized reconstruction — faster but uses more memory. Good for small-to-medium matrices. Falls back to chunked for large ones. """ if self.rows * self.cols * self.n_splats > 50_000_000: return self.reconstruct() # All positions as (rows*cols, 2) grid row_coords = torch.arange(self.rows, dtype=torch.float32, device=self.device) col_coords = torch.arange(self.cols, dtype=torch.float32, device=self.device) rr, cc = torch.meshgrid(row_coords, col_coords, indexing='ij') positions = torch.stack([rr.flatten(), cc.flatten()], dim=1) # (R*C, 2) # Distances from every position to every splat # positions: (R*C, 2), mu: (N, 2) diff = positions.unsqueeze(1) - self.mu.unsqueeze(0) # (R*C, N, 2) dist_sq = (diff ** 2).sum(dim=2) # (R*C, N) # Gaussian values var = 2 * self.sigma.unsqueeze(0) ** 2 + 1e-8 # (1, N) gaussians = torch.exp(-dist_sq / var) # (R*C, N) # Weighted sum W_flat = (gaussians * self.alpha.unsqueeze(0)).sum(dim=1) # (R*C,) return W_flat.reshape(self.rows, self.cols) def memory_bytes(self) -> int: """Estimate memory usage of splat representation.""" # mu: n*2*4, sigma: n*4, alpha: n*4 = n*16 bytes (float32) return self.n_splats * 16 def compression_ratio(self) -> float: """Compression ratio vs dense float32 matrix.""" dense_bytes = self.rows * self.cols * 4 # float32 return dense_bytes / max(self.memory_bytes(), 1) def state_dict(self) -> Dict[str, torch.Tensor]: """Export splat parameters for persistence.""" return { 'mu': self.mu.clone(), 'sigma': self.sigma.clone(), 'alpha': self.alpha.clone(), 'rows': torch.tensor(self.rows), 'cols': torch.tensor(self.cols), } @classmethod def from_state_dict(cls, d: Dict[str, torch.Tensor]) -> 'GaussianSplats': """Restore from saved state.""" rows, cols = d['rows'].item(), d['cols'].item() n = d['mu'].shape[0] splats = cls(n, rows, cols) splats.mu = d['mu'] splats.sigma = d['sigma'] splats.alpha = d['alpha'] return splats def compute_n_splats(rows: int, cols: int, config: SplatConfig) -> int: """Determine how many splats to use for a given matrix size.""" n = int(rows * cols * config.splat_ratio) return max(config.min_splats, min(n, config.max_splats)) def initialize_splats_from_ternary( weight: torch.Tensor, n_splats: int, config: SplatConfig, ) -> GaussianSplats: """Initialize splats from a ternary {-1, 0, 1} weight matrix. Strategy: place splats at the centers of non-zero weight clusters. For ternary weights this is efficient — zeros need no representation. 1. Find all non-zero positions 2. Sample n_splats positions from them (weighted by absolute value) 3. Set alpha to the weight value at that position 4. Set sigma to initial spread """ rows, cols = weight.shape splats = GaussianSplats(n_splats, rows, cols, device=weight.device) # Find non-zero positions nonzero_mask = weight != 0 nonzero_positions = nonzero_mask.nonzero(as_tuple=False).float() # (K, 2) if len(nonzero_positions) == 0: # All zeros — return empty splats return splats if len(nonzero_positions) <= n_splats: # Fewer non-zero elements than splats — use them all k = len(nonzero_positions) splats.mu[:k] = nonzero_positions for i in range(k): r, c = int(nonzero_positions[i, 0]), int(nonzero_positions[i, 1]) splats.alpha[i] = weight[r, c] splats.sigma[:] = config.init_sigma return splats # Sample positions — prefer dense regions # Use farthest-point sampling for coverage indices = _farthest_point_sample(nonzero_positions, n_splats) sampled_positions = nonzero_positions[indices] splats.mu = sampled_positions.clone() # Set alpha based on local weight values for i in range(n_splats): r, c = int(sampled_positions[i, 0]), int(sampled_positions[i, 1]) splats.alpha[i] = weight[r, c] splats.sigma[:] = config.init_sigma return splats def _farthest_point_sample(points: torch.Tensor, n: int) -> torch.Tensor: """Spatial coverage sampling. For small point sets (< 5000): true farthest-point sampling. For large point sets: stratified random sampling (grid-based) for speed. """ K = len(points) if K <= 5000: # True FPS — O(n*K), fine for small sets selected = torch.zeros(n, dtype=torch.long, device=points.device) selected[0] = torch.randint(K, (1,)).item() dists = torch.full((K,), float('inf'), device=points.device) for i in range(1, n): last = points[selected[i - 1]].unsqueeze(0) new_dists = ((points - last) ** 2).sum(dim=1) dists = torch.minimum(dists, new_dists) selected[i] = dists.argmax() return selected # Stratified random: divide space into grid, sample from each cell # Gives good coverage without O(n*K) cost perm = torch.randperm(K, device=points.device)[:n] return perm def _reconstruct_for_fitting(mu, sigma, alpha, rows, cols, row_chunk=64): """Memory-efficient reconstruction with gradients. Processes rows in chunks to avoid building the full (R*C, N) tensor. Each chunk computes its contribution independently, so gradient memory stays bounded regardless of matrix size. """ row_coords = torch.arange(rows, dtype=torch.float32) col_coords = torch.arange(cols, dtype=torch.float32) chunks = [] for r_start in range(0, rows, row_chunk): r_end = min(r_start + row_chunk, rows) chunk_rows = row_coords[r_start:r_end] # (chunk,) # Distance from each position in this row chunk to each splat dr = (chunk_rows.unsqueeze(1) - mu[:, 0].unsqueeze(0)) ** 2 # (chunk, N) dc_all = (col_coords.unsqueeze(1) - mu[:, 1].unsqueeze(0)) ** 2 # (cols, N) # For each row in chunk, compute full col contribution var = 2 * sigma.unsqueeze(0) ** 2 + 1e-8 # (1, N) row_results = [] for ri in range(len(chunk_rows)): dist_sq = dr[ri:ri+1, :] + dc_all # (cols, N) — broadcast row dist gaussians = torch.exp(-dist_sq / var) # (cols, N) row_vals = (gaussians * alpha.unsqueeze(0)).sum(dim=1) # (cols,) row_results.append(row_vals) chunks.append(torch.stack(row_results)) # (chunk, cols) return torch.cat(chunks, dim=0) # (rows, cols) def fit_splats( splats: GaussianSplats, target: torch.Tensor, config: SplatConfig, verbose: bool = False, ) -> Dict[str, float]: """Fit splat parameters to reconstruct the target weight matrix. Uses gradient descent on (mu, sigma, alpha) to minimize reconstruction error. Row-chunked reconstruction keeps memory bounded on low-RAM machines. Returns metrics about the fitting process. """ # Make parameters require grad for fitting splats.mu = splats.mu.detach().requires_grad_(True) splats.sigma = splats.sigma.detach().requires_grad_(True) splats.alpha = splats.alpha.detach().requires_grad_(True) optimizer = torch.optim.Adam([splats.mu, splats.sigma, splats.alpha], lr=config.fit_lr) # Choose chunk size based on splat count to stay under ~200MB gradient memory mem_per_row = splats.n_splats * splats.cols * 4 * 3 # float32 * (fwd + grad + optimizer) row_chunk = max(4, min(64, int(200_000_000 / (mem_per_row + 1)))) start = time.time() initial_mse = None final_mse = None for step in range(config.fit_iterations): optimizer.zero_grad() # Memory-efficient chunked reconstruction reconstructed = _reconstruct_for_fitting( splats.mu, splats.sigma, splats.alpha, splats.rows, splats.cols, row_chunk=row_chunk, ) # MSE loss loss = ((reconstructed - target) ** 2).mean() if step == 0: initial_mse = loss.item() final_mse = loss.item() if step % 50 == 0: print(f" fit step {step}: MSE={loss.item():.6f}", flush=True) if loss.item() < config.fit_tolerance: print(f" converged at step {step}", flush=True) break loss.backward() optimizer.step() # Keep sigma positive with torch.no_grad(): splats.sigma.clamp_(min=0.1) # Keep mu in bounds splats.mu[:, 0].clamp_(0, splats.rows - 1) splats.mu[:, 1].clamp_(0, splats.cols - 1) # Detach after fitting splats.mu = splats.mu.detach() splats.sigma = splats.sigma.detach() splats.alpha = splats.alpha.detach() elapsed = time.time() - start return { 'initial_mse': initial_mse, 'final_mse': final_mse, 'improvement': (initial_mse - final_mse) / (initial_mse + 1e-8), 'fit_time_s': elapsed, 'steps': step + 1, } def decompose_layer( weight: torch.Tensor, config: SplatConfig = None, verbose: bool = False, ) -> Tuple[GaussianSplats, Dict[str, float]]: """Full pipeline: dense weight matrix -> fitted Gaussian splats. 1. Determine number of splats from compression ratio 2. Initialize splats (ternary-aware if applicable) 3. Fit splats to target via gradient descent 4. Return splats + metrics """ config = config or SplatConfig() if weight.dim() != 2: weight = weight.reshape(weight.shape[0], -1) rows, cols = weight.shape n_splats = compute_n_splats(rows, cols, config) # Detect if ternary unique_vals = weight.unique() is_ternary = len(unique_vals) <= 3 and all(v in [-1, 0, 1] for v in unique_vals.tolist()) if is_ternary: if verbose: logger.info(f" Ternary weights detected — using cluster initialization") splats = initialize_splats_from_ternary(weight, n_splats, config) else: # General initialization: random positions, amplitudes from weight samples splats = GaussianSplats(n_splats, rows, cols, device=weight.device) idx_r = torch.randint(rows, (n_splats,)) idx_c = torch.randint(cols, (n_splats,)) splats.mu = torch.stack([idx_r.float(), idx_c.float()], dim=1) splats.alpha = weight[idx_r, idx_c].clone() splats.sigma[:] = config.init_sigma # Fit to target fit_metrics = fit_splats(splats, weight, config, verbose=verbose) fit_metrics['n_splats'] = n_splats fit_metrics['matrix_size'] = (rows, cols) fit_metrics['dense_elements'] = rows * cols fit_metrics['compression_ratio'] = splats.compression_ratio() fit_metrics['is_ternary'] = is_ternary return splats, fit_metrics