Spaces:
Sleeping
Sleeping
| """ | |
| Gaussian Splat Weight Decomposition Engine | |
| ========================================== | |
| Decomposes dense weight matrices into collections of Gaussian splats. | |
| Each splat has: position (mu), spread (sigma), amplitude (alpha). | |
| The splats ARE the weight representation. Inference runs through them, | |
| not through the original dense matrix. This is physics-based compression: | |
| resolution concentrates where the weight landscape has structure. | |
| For BitNet ternary weights {-1, 0, 1}, splats fit naturally: | |
| - Positive splats over +1 clusters | |
| - Negative splats over -1 clusters | |
| - Zero regions need no splats (free compression) | |
| # ---- Changelog ---- | |
| # [2026-04-05] Claude Code (Opus 4.6) — Initial implementation | |
| # What: Gaussian splat decomposition + reconstruction + fitting | |
| # Why: PoC for physics-based weight compression with Lenia dynamics | |
| # ------------------- | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import time | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| logger = logging.getLogger("uniai.splat") | |
| class SplatConfig: | |
| """Configuration for splat decomposition.""" | |
| # How many splats per matrix element (compression ratio control) | |
| # e.g., 0.1 means 10x compression (10% as many splats as weight elements) | |
| splat_ratio: float = 0.1 | |
| # Minimum splats per layer (don't go below this) | |
| min_splats: int = 32 | |
| # Maximum splats per layer (memory safety) | |
| max_splats: int = 8192 | |
| # Initial sigma for splats (spread) | |
| init_sigma: float = 1.0 | |
| # Fitting iterations (gradient descent to fit splats to dense weights) | |
| fit_iterations: int = 200 | |
| # Fitting learning rate | |
| fit_lr: float = 0.01 | |
| # Reconstruction tolerance (stop early if below this MSE) | |
| fit_tolerance: float = 1e-4 | |
| class GaussianSplats: | |
| """ | |
| A collection of Gaussian splats representing a weight matrix. | |
| Each splat i has: | |
| - mu_i: (2,) position in weight space (row, col) | |
| - sigma_i: (1,) spread (isotropic for Phase 0) | |
| - alpha_i: (1,) amplitude (positive or negative) | |
| The reconstructed weight at position (r, c) is: | |
| W(r,c) = sum_i alpha_i * exp(-||[r,c] - mu_i||^2 / (2 * sigma_i^2)) | |
| """ | |
| def __init__(self, n_splats: int, rows: int, cols: int, device: torch.device = None): | |
| self.n_splats = n_splats | |
| self.rows = rows | |
| self.cols = cols | |
| self.device = device or torch.device('cpu') | |
| # Splat parameters — these are what Lenia will operate on | |
| self.mu = torch.zeros(n_splats, 2, device=self.device) # positions | |
| self.sigma = torch.ones(n_splats, device=self.device) # spreads | |
| self.alpha = torch.zeros(n_splats, device=self.device) # amplitudes | |
| def reconstruct(self, chunk_size: int = 512) -> torch.Tensor: | |
| """Reconstruct the dense weight matrix from splats. | |
| W(r,c) = sum_i alpha_i * exp(-||[r,c] - mu_i||^2 / (2 * sigma_i^2)) | |
| Uses chunked computation to avoid OOM on large matrices. | |
| """ | |
| W = torch.zeros(self.rows, self.cols, device=self.device) | |
| # Create coordinate grid | |
| row_coords = torch.arange(self.rows, dtype=torch.float32, device=self.device) | |
| col_coords = torch.arange(self.cols, dtype=torch.float32, device=self.device) | |
| # Process splats in chunks to manage memory | |
| for start in range(0, self.n_splats, chunk_size): | |
| end = min(start + chunk_size, self.n_splats) | |
| chunk_mu = self.mu[start:end] # (chunk, 2) | |
| chunk_sigma = self.sigma[start:end] # (chunk,) | |
| chunk_alpha = self.alpha[start:end] # (chunk,) | |
| # For each splat in chunk, compute contribution to all positions | |
| for i in range(end - start): | |
| mu_r, mu_c = chunk_mu[i, 0], chunk_mu[i, 1] | |
| s = chunk_sigma[i] | |
| a = chunk_alpha[i] | |
| # Compute distances (separable Gaussian for speed) | |
| dr = (row_coords - mu_r) ** 2 | |
| dc = (col_coords - mu_c) ** 2 | |
| # Outer product gives full distance grid | |
| dist_sq = dr.unsqueeze(1) + dc.unsqueeze(0) # (rows, cols) | |
| # Gaussian contribution | |
| W += a * torch.exp(-dist_sq / (2 * s ** 2 + 1e-8)) | |
| return W | |
| def reconstruct_fast(self) -> torch.Tensor: | |
| """Vectorized reconstruction — faster but uses more memory. | |
| Good for small-to-medium matrices. Falls back to chunked for large ones. | |
| """ | |
| if self.rows * self.cols * self.n_splats > 50_000_000: | |
| return self.reconstruct() | |
| # All positions as (rows*cols, 2) grid | |
| row_coords = torch.arange(self.rows, dtype=torch.float32, device=self.device) | |
| col_coords = torch.arange(self.cols, dtype=torch.float32, device=self.device) | |
| rr, cc = torch.meshgrid(row_coords, col_coords, indexing='ij') | |
| positions = torch.stack([rr.flatten(), cc.flatten()], dim=1) # (R*C, 2) | |
| # Distances from every position to every splat | |
| # positions: (R*C, 2), mu: (N, 2) | |
| diff = positions.unsqueeze(1) - self.mu.unsqueeze(0) # (R*C, N, 2) | |
| dist_sq = (diff ** 2).sum(dim=2) # (R*C, N) | |
| # Gaussian values | |
| var = 2 * self.sigma.unsqueeze(0) ** 2 + 1e-8 # (1, N) | |
| gaussians = torch.exp(-dist_sq / var) # (R*C, N) | |
| # Weighted sum | |
| W_flat = (gaussians * self.alpha.unsqueeze(0)).sum(dim=1) # (R*C,) | |
| return W_flat.reshape(self.rows, self.cols) | |
| def memory_bytes(self) -> int: | |
| """Estimate memory usage of splat representation.""" | |
| # mu: n*2*4, sigma: n*4, alpha: n*4 = n*16 bytes (float32) | |
| return self.n_splats * 16 | |
| def compression_ratio(self) -> float: | |
| """Compression ratio vs dense float32 matrix.""" | |
| dense_bytes = self.rows * self.cols * 4 # float32 | |
| return dense_bytes / max(self.memory_bytes(), 1) | |
| def state_dict(self) -> Dict[str, torch.Tensor]: | |
| """Export splat parameters for persistence.""" | |
| return { | |
| 'mu': self.mu.clone(), | |
| 'sigma': self.sigma.clone(), | |
| 'alpha': self.alpha.clone(), | |
| 'rows': torch.tensor(self.rows), | |
| 'cols': torch.tensor(self.cols), | |
| } | |
| def from_state_dict(cls, d: Dict[str, torch.Tensor]) -> 'GaussianSplats': | |
| """Restore from saved state.""" | |
| rows, cols = d['rows'].item(), d['cols'].item() | |
| n = d['mu'].shape[0] | |
| splats = cls(n, rows, cols) | |
| splats.mu = d['mu'] | |
| splats.sigma = d['sigma'] | |
| splats.alpha = d['alpha'] | |
| return splats | |
| def compute_n_splats(rows: int, cols: int, config: SplatConfig) -> int: | |
| """Determine how many splats to use for a given matrix size.""" | |
| n = int(rows * cols * config.splat_ratio) | |
| return max(config.min_splats, min(n, config.max_splats)) | |
| def initialize_splats_from_ternary( | |
| weight: torch.Tensor, | |
| n_splats: int, | |
| config: SplatConfig, | |
| ) -> GaussianSplats: | |
| """Initialize splats from a ternary {-1, 0, 1} weight matrix. | |
| Strategy: place splats at the centers of non-zero weight clusters. | |
| For ternary weights this is efficient — zeros need no representation. | |
| 1. Find all non-zero positions | |
| 2. Sample n_splats positions from them (weighted by absolute value) | |
| 3. Set alpha to the weight value at that position | |
| 4. Set sigma to initial spread | |
| """ | |
| rows, cols = weight.shape | |
| splats = GaussianSplats(n_splats, rows, cols, device=weight.device) | |
| # Find non-zero positions | |
| nonzero_mask = weight != 0 | |
| nonzero_positions = nonzero_mask.nonzero(as_tuple=False).float() # (K, 2) | |
| if len(nonzero_positions) == 0: | |
| # All zeros — return empty splats | |
| return splats | |
| if len(nonzero_positions) <= n_splats: | |
| # Fewer non-zero elements than splats — use them all | |
| k = len(nonzero_positions) | |
| splats.mu[:k] = nonzero_positions | |
| for i in range(k): | |
| r, c = int(nonzero_positions[i, 0]), int(nonzero_positions[i, 1]) | |
| splats.alpha[i] = weight[r, c] | |
| splats.sigma[:] = config.init_sigma | |
| return splats | |
| # Sample positions — prefer dense regions | |
| # Use farthest-point sampling for coverage | |
| indices = _farthest_point_sample(nonzero_positions, n_splats) | |
| sampled_positions = nonzero_positions[indices] | |
| splats.mu = sampled_positions.clone() | |
| # Set alpha based on local weight values | |
| for i in range(n_splats): | |
| r, c = int(sampled_positions[i, 0]), int(sampled_positions[i, 1]) | |
| splats.alpha[i] = weight[r, c] | |
| splats.sigma[:] = config.init_sigma | |
| return splats | |
| def _farthest_point_sample(points: torch.Tensor, n: int) -> torch.Tensor: | |
| """Spatial coverage sampling. | |
| For small point sets (< 5000): true farthest-point sampling. | |
| For large point sets: stratified random sampling (grid-based) for speed. | |
| """ | |
| K = len(points) | |
| if K <= 5000: | |
| # True FPS — O(n*K), fine for small sets | |
| selected = torch.zeros(n, dtype=torch.long, device=points.device) | |
| selected[0] = torch.randint(K, (1,)).item() | |
| dists = torch.full((K,), float('inf'), device=points.device) | |
| for i in range(1, n): | |
| last = points[selected[i - 1]].unsqueeze(0) | |
| new_dists = ((points - last) ** 2).sum(dim=1) | |
| dists = torch.minimum(dists, new_dists) | |
| selected[i] = dists.argmax() | |
| return selected | |
| # Stratified random: divide space into grid, sample from each cell | |
| # Gives good coverage without O(n*K) cost | |
| perm = torch.randperm(K, device=points.device)[:n] | |
| return perm | |
| def _reconstruct_for_fitting(mu, sigma, alpha, rows, cols, row_chunk=64): | |
| """Memory-efficient reconstruction with gradients. | |
| Processes rows in chunks to avoid building the full (R*C, N) tensor. | |
| Each chunk computes its contribution independently, so gradient memory | |
| stays bounded regardless of matrix size. | |
| """ | |
| row_coords = torch.arange(rows, dtype=torch.float32) | |
| col_coords = torch.arange(cols, dtype=torch.float32) | |
| chunks = [] | |
| for r_start in range(0, rows, row_chunk): | |
| r_end = min(r_start + row_chunk, rows) | |
| chunk_rows = row_coords[r_start:r_end] # (chunk,) | |
| # Distance from each position in this row chunk to each splat | |
| dr = (chunk_rows.unsqueeze(1) - mu[:, 0].unsqueeze(0)) ** 2 # (chunk, N) | |
| dc_all = (col_coords.unsqueeze(1) - mu[:, 1].unsqueeze(0)) ** 2 # (cols, N) | |
| # For each row in chunk, compute full col contribution | |
| var = 2 * sigma.unsqueeze(0) ** 2 + 1e-8 # (1, N) | |
| row_results = [] | |
| for ri in range(len(chunk_rows)): | |
| dist_sq = dr[ri:ri+1, :] + dc_all # (cols, N) — broadcast row dist | |
| gaussians = torch.exp(-dist_sq / var) # (cols, N) | |
| row_vals = (gaussians * alpha.unsqueeze(0)).sum(dim=1) # (cols,) | |
| row_results.append(row_vals) | |
| chunks.append(torch.stack(row_results)) # (chunk, cols) | |
| return torch.cat(chunks, dim=0) # (rows, cols) | |
| def fit_splats( | |
| splats: GaussianSplats, | |
| target: torch.Tensor, | |
| config: SplatConfig, | |
| verbose: bool = False, | |
| ) -> Dict[str, float]: | |
| """Fit splat parameters to reconstruct the target weight matrix. | |
| Uses gradient descent on (mu, sigma, alpha) to minimize | |
| reconstruction error. Row-chunked reconstruction keeps memory | |
| bounded on low-RAM machines. | |
| Returns metrics about the fitting process. | |
| """ | |
| # Make parameters require grad for fitting | |
| splats.mu = splats.mu.detach().requires_grad_(True) | |
| splats.sigma = splats.sigma.detach().requires_grad_(True) | |
| splats.alpha = splats.alpha.detach().requires_grad_(True) | |
| optimizer = torch.optim.Adam([splats.mu, splats.sigma, splats.alpha], lr=config.fit_lr) | |
| # Choose chunk size based on splat count to stay under ~200MB gradient memory | |
| mem_per_row = splats.n_splats * splats.cols * 4 * 3 # float32 * (fwd + grad + optimizer) | |
| row_chunk = max(4, min(64, int(200_000_000 / (mem_per_row + 1)))) | |
| start = time.time() | |
| initial_mse = None | |
| final_mse = None | |
| for step in range(config.fit_iterations): | |
| optimizer.zero_grad() | |
| # Memory-efficient chunked reconstruction | |
| reconstructed = _reconstruct_for_fitting( | |
| splats.mu, splats.sigma, splats.alpha, | |
| splats.rows, splats.cols, row_chunk=row_chunk, | |
| ) | |
| # MSE loss | |
| loss = ((reconstructed - target) ** 2).mean() | |
| if step == 0: | |
| initial_mse = loss.item() | |
| final_mse = loss.item() | |
| if step % 50 == 0: | |
| print(f" fit step {step}: MSE={loss.item():.6f}", flush=True) | |
| if loss.item() < config.fit_tolerance: | |
| print(f" converged at step {step}", flush=True) | |
| break | |
| loss.backward() | |
| optimizer.step() | |
| # Keep sigma positive | |
| with torch.no_grad(): | |
| splats.sigma.clamp_(min=0.1) | |
| # Keep mu in bounds | |
| splats.mu[:, 0].clamp_(0, splats.rows - 1) | |
| splats.mu[:, 1].clamp_(0, splats.cols - 1) | |
| # Detach after fitting | |
| splats.mu = splats.mu.detach() | |
| splats.sigma = splats.sigma.detach() | |
| splats.alpha = splats.alpha.detach() | |
| elapsed = time.time() - start | |
| return { | |
| 'initial_mse': initial_mse, | |
| 'final_mse': final_mse, | |
| 'improvement': (initial_mse - final_mse) / (initial_mse + 1e-8), | |
| 'fit_time_s': elapsed, | |
| 'steps': step + 1, | |
| } | |
| def decompose_layer( | |
| weight: torch.Tensor, | |
| config: SplatConfig = None, | |
| verbose: bool = False, | |
| ) -> Tuple[GaussianSplats, Dict[str, float]]: | |
| """Full pipeline: dense weight matrix -> fitted Gaussian splats. | |
| 1. Determine number of splats from compression ratio | |
| 2. Initialize splats (ternary-aware if applicable) | |
| 3. Fit splats to target via gradient descent | |
| 4. Return splats + metrics | |
| """ | |
| config = config or SplatConfig() | |
| if weight.dim() != 2: | |
| weight = weight.reshape(weight.shape[0], -1) | |
| rows, cols = weight.shape | |
| n_splats = compute_n_splats(rows, cols, config) | |
| # Detect if ternary | |
| unique_vals = weight.unique() | |
| is_ternary = len(unique_vals) <= 3 and all(v in [-1, 0, 1] for v in unique_vals.tolist()) | |
| if is_ternary: | |
| if verbose: | |
| logger.info(f" Ternary weights detected — using cluster initialization") | |
| splats = initialize_splats_from_ternary(weight, n_splats, config) | |
| else: | |
| # General initialization: random positions, amplitudes from weight samples | |
| splats = GaussianSplats(n_splats, rows, cols, device=weight.device) | |
| idx_r = torch.randint(rows, (n_splats,)) | |
| idx_c = torch.randint(cols, (n_splats,)) | |
| splats.mu = torch.stack([idx_r.float(), idx_c.float()], dim=1) | |
| splats.alpha = weight[idx_r, idx_c].clone() | |
| splats.sigma[:] = config.init_sigma | |
| # Fit to target | |
| fit_metrics = fit_splats(splats, weight, config, verbose=verbose) | |
| fit_metrics['n_splats'] = n_splats | |
| fit_metrics['matrix_size'] = (rows, cols) | |
| fit_metrics['dense_elements'] = rows * cols | |
| fit_metrics['compression_ratio'] = splats.compression_ratio() | |
| fit_metrics['is_ternary'] = is_ternary | |
| return splats, fit_metrics | |