""" WrinkleBrane Optimizations Module Advanced optimizations for improved performance and fidelity. """ from __future__ import annotations import math from typing import Optional, Tuple import torch import numpy as np from scipy.linalg import qr from .codes import normalize_columns, hadamard_codes, dct_codes from .write_ops import store_pairs def compute_adaptive_alphas( patterns: torch.Tensor, C: torch.Tensor, keys: torch.Tensor, energy_weight: float = 1.0, orthogonality_weight: float = 0.5, min_alpha: float = 0.1, max_alpha: float = 3.0 ) -> torch.Tensor: """Compute optimal alpha values for each pattern based on energy and orthogonality. Parameters ---------- patterns : torch.Tensor Input patterns with shape [T, H, W] C : torch.Tensor Codebook tensor with shape [L, K] keys : torch.Tensor Key indices with shape [T] energy_weight : float Weight for energy normalization (default: 1.0) orthogonality_weight : float Weight for orthogonality compensation (default: 0.5) min_alpha : float Minimum alpha value (default: 0.1) max_alpha : float Maximum alpha value (default: 3.0) Returns ------- torch.Tensor Optimized alpha values with shape [T] """ T = len(patterns) device = patterns.device dtype = patterns.dtype alphas = torch.ones(T, device=device, dtype=dtype) for i, key in enumerate(keys): # 1. Energy normalization - normalize by pattern RMS pattern_rms = torch.sqrt(torch.mean(patterns[i] ** 2)).clamp_min(1e-6) reference_rms = 0.5 # Target RMS level energy_factor = reference_rms / pattern_rms if energy_weight > 0 else 1.0 # 2. Orthogonality compensation - less orthogonal codes need smaller alphas if orthogonality_weight > 0 and key < C.shape[1]: code_vec = C[:, key] # Compute maximum correlation with other codes other_codes = torch.cat([C[:, :key], C[:, key+1:]], dim=1) if C.shape[1] > 1 else C[:, :0] if other_codes.numel() > 0: correlations = torch.abs(code_vec @ other_codes) max_correlation = correlations.max() if correlations.numel() > 0 else torch.tensor(0.0) orthogonality_factor = 1.0 / (1.0 + orthogonality_weight * max_correlation) else: orthogonality_factor = 1.0 else: orthogonality_factor = 1.0 # Combine factors alpha = energy_factor * orthogonality_factor alphas[i] = torch.clamp(alpha, min_alpha, max_alpha) return alphas def generate_extended_codes( L: int, K: int, method: str = "auto", existing_patterns: Optional[torch.Tensor] = None, device: Optional[torch.device] = None ) -> torch.Tensor: """Generate orthogonal codes with support for K > L. Parameters ---------- L : int Code length (number of layers) K : int Number of codes needed method : str Code generation method: "auto", "hadamard", "dct", "gram_schmidt", "random_ortho" existing_patterns : torch.Tensor, optional Existing patterns to optimize codes for, shape [K, H, W] device : torch.device, optional Device to place tensors on Returns ------- torch.Tensor Code matrix with shape [L, K] """ if device is None: device = torch.device("cpu") # For K <= L, use optimal orthogonal codes if K <= L: if method == "auto": # Choose best method based on dimensions if K <= 2**int(math.log2(L)): # Power of 2 for Hadamard return hadamard_codes(L, K).to(device) else: return dct_codes(L, K).to(device) elif method == "hadamard": return hadamard_codes(L, K).to(device) elif method == "dct": return dct_codes(L, K).to(device) # For K > L, we need to generate overcomplete codes if method == "gram_schmidt" or (method == "auto" and K > L): return generate_gram_schmidt_codes(L, K, existing_patterns, device) elif method == "random_ortho": return generate_random_orthogonal_codes(L, K, device) # Fallback: use best available orthogonal codes up to L, then random if K <= L: torch.manual_seed(42) # Reproducible C = torch.randn(L, K, device=device) return normalize_columns(C) else: # For K > L, use hybrid approach return generate_gram_schmidt_codes(L, K, existing_patterns, device) def generate_gram_schmidt_codes( L: int, K: int, existing_patterns: Optional[torch.Tensor] = None, device: Optional[torch.device] = None ) -> torch.Tensor: """Generate codes using Gram-Schmidt orthogonalization. This method can generate more codes than layers (K > L) by creating an orthonormal basis that maximally preserves pattern information. """ if device is None: device = torch.device("cpu") # Start with best orthogonal codes up to L if K <= L: base_codes = hadamard_codes(L, min(K, L)).to(device) if K == base_codes.shape[1]: return base_codes else: base_codes = hadamard_codes(L, L).to(device) # If we need more codes, generate them using pattern-informed random vectors if K > base_codes.shape[1]: additional_needed = K - base_codes.shape[1] # Generate additional vectors if existing_patterns is not None: # Use pattern information to generate meaningful directions patterns_flat = existing_patterns.view(existing_patterns.shape[0], -1) # SVD on patterns to find principal directions U, _, _ = torch.svd(patterns_flat.T) additional_vectors = U[:L, :additional_needed].to(device) else: # Random additional vectors torch.manual_seed(42) additional_vectors = torch.randn(L, additional_needed, device=device) # Combine base codes with additional vectors all_vectors = torch.cat([base_codes, additional_vectors], dim=1) else: all_vectors = base_codes # Gram-Schmidt orthogonalization using QR decomposition for stability Q, R = torch.linalg.qr(all_vectors) # Ensure positive diagonal for consistency signs = torch.sign(torch.diag(R)) signs[signs == 0] = 1 Q = Q * signs.unsqueeze(0) return normalize_columns(Q[:, :K]) def generate_random_orthogonal_codes(L: int, K: int, device: Optional[torch.device] = None) -> torch.Tensor: """Generate random orthogonal codes using QR decomposition.""" if device is None: device = torch.device("cpu") torch.manual_seed(42) # Reproducible if K <= L: # Generate random matrix and orthogonalize A = torch.randn(L, K, device=device) Q, _ = torch.linalg.qr(A) return normalize_columns(Q) else: # For K > L, generate the best L orthogonal vectors, then add random ones A = torch.randn(L, L, device=device) Q, _ = torch.linalg.qr(A) # Add random additional vectors (not orthogonal to first L) additional = torch.randn(L, K - L, device=device) all_codes = torch.cat([Q, additional], dim=1) return normalize_columns(all_codes) class HierarchicalMembraneBank: """Multi-level membrane bank for better capacity scaling.""" def __init__( self, L: int, H: int, W: int, levels: int = 3, level_ratios: Optional[list[float]] = None, device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32 ): """Initialize hierarchical membrane bank. Parameters ---------- L, H, W : int Base dimensions levels : int Number of hierarchy levels level_ratios : list[float], optional Fraction of L allocated to each level. If None, uses geometric series. device : torch.device, optional Device for tensors dtype : torch.dtype Data type for tensors """ self.levels = levels self.H = H self.W = W self.device = device self.dtype = dtype if level_ratios is None: # Geometric series: 1/2, 1/4, 1/8, ... level_ratios = [0.5 ** (i + 1) for i in range(levels)] level_ratios = [r / sum(level_ratios) for r in level_ratios] # Normalize self.level_ratios = level_ratios # Create membrane banks for each level from .membrane_bank import MembraneBank self.banks = [] remaining_L = L for i, ratio in enumerate(level_ratios): level_L = max(1, int(L * ratio)) if i == levels - 1: # Last level gets remaining level_L = remaining_L bank = MembraneBank(level_L, H, W, device=device, dtype=dtype) self.banks.append(bank) remaining_L -= level_L self.total_L = sum(bank.L for bank in self.banks) def allocate(self, B: int) -> None: """Allocate all level banks for batch size B.""" for bank in self.banks: bank.allocate(B) def store_hierarchical( self, patterns: torch.Tensor, keys: torch.Tensor, level_assignment: Optional[torch.Tensor] = None ) -> None: """Store patterns in appropriate hierarchy levels. Parameters ---------- patterns : torch.Tensor Patterns to store, shape [T, H, W] keys : torch.Tensor Key indices, shape [T] level_assignment : torch.Tensor, optional Level assignment for each pattern. If None, auto-assign based on pattern complexity. """ if level_assignment is None: level_assignment = self._auto_assign_levels(patterns) # Store patterns in assigned levels for level in range(self.levels): level_mask = level_assignment == level if level_mask.any(): level_patterns = patterns[level_mask] level_keys = keys[level_mask] # Generate codes for this level bank = self.banks[level] C = generate_extended_codes(bank.L, level_keys.max().item() + 1, device=self.device) # Store patterns alphas = compute_adaptive_alphas(level_patterns, C, level_keys) M = store_pairs(bank.read(), C, level_keys, level_patterns, alphas) bank.write(M - bank.read()) def _auto_assign_levels(self, patterns: torch.Tensor) -> torch.Tensor: """Automatically assign patterns to hierarchy levels based on complexity.""" from .metrics import spectral_entropy_2d entropies = torch.tensor([spectral_entropy_2d(p) for p in patterns]) # Sort by entropy and assign to levels sorted_indices = torch.argsort(entropies, descending=True) level_assignment = torch.zeros(len(patterns), dtype=torch.long) patterns_per_level = len(patterns) // self.levels for level in range(self.levels): start_idx = level * patterns_per_level end_idx = start_idx + patterns_per_level if level < self.levels - 1 else len(patterns) for idx in sorted_indices[start_idx:end_idx]: level_assignment[idx] = level return level_assignment def optimized_store_pairs( M: torch.Tensor, C: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, alphas: Optional[torch.Tensor] = None, adaptive_alphas: bool = True, sparsity_threshold: float = 0.01 ) -> torch.Tensor: """Enhanced store_pairs with optimizations. Parameters ---------- M : torch.Tensor Current membranes with shape [B, L, H, W] C : torch.Tensor Codebook tensor with shape [L, K] keys : torch.Tensor Key indices with shape [T] values : torch.Tensor Values to store with shape [T, H, W] alphas : torch.Tensor, optional Manual alpha values. If None and adaptive_alphas=True, computes automatically. adaptive_alphas : bool Whether to use adaptive alpha scaling sparsity_threshold : float Threshold for sparse pattern detection Returns ------- torch.Tensor Updated membrane tensor """ if alphas is None and adaptive_alphas: alphas = compute_adaptive_alphas(values, C, keys) elif alphas is None: alphas = torch.ones(len(keys), device=values.device, dtype=values.dtype) # Check for sparse patterns pattern_norms = torch.norm(values.view(len(values), -1), dim=1) sparse_mask = pattern_norms < sparsity_threshold if sparse_mask.any() and not sparse_mask.all(): # Mixed sparse/dense - handle separately dense_mask = ~sparse_mask result = M.clone() if dense_mask.any(): dense_result = store_pairs( result, C, keys[dense_mask], values[dense_mask], alphas[dense_mask] ) result = dense_result if sparse_mask.any(): sparse_result = _sparse_store_pairs( result, C, keys[sparse_mask], values[sparse_mask], alphas[sparse_mask] ) result = sparse_result return result else: # All dense or all sparse - use single method if sparse_mask.all(): return _sparse_store_pairs(M, C, keys, values, alphas) else: return store_pairs(M, C, keys, values, alphas) def _sparse_store_pairs( M: torch.Tensor, C: torch.Tensor, keys: torch.Tensor, values: torch.Tensor, alphas: torch.Tensor ) -> torch.Tensor: """Sparse implementation of store_pairs for sparse patterns.""" # For now, use regular implementation - could be optimized with sparse tensors return store_pairs(M, C, keys, values, alphas)