WrinkleBrane / Wrinkle /WrinkleBrane-Research /06_learnable_codes /src /wrinklebrane /optimizations.py
| """ | |
| WrinkleBrane Optimizations Module | |
| Advanced optimizations for improved performance and fidelity. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from typing import Optional, Tuple | |
| import torch | |
| import numpy as np | |
| from scipy.linalg import qr | |
| from .codes import normalize_columns, hadamard_codes, dct_codes | |
| from .write_ops import store_pairs | |
| def compute_adaptive_alphas( | |
| patterns: torch.Tensor, | |
| C: torch.Tensor, | |
| keys: torch.Tensor, | |
| energy_weight: float = 1.0, | |
| orthogonality_weight: float = 0.5, | |
| min_alpha: float = 0.1, | |
| max_alpha: float = 3.0 | |
| ) -> torch.Tensor: | |
| """Compute optimal alpha values for each pattern based on energy and orthogonality. | |
| Parameters | |
| ---------- | |
| patterns : torch.Tensor | |
| Input patterns with shape [T, H, W] | |
| C : torch.Tensor | |
| Codebook tensor with shape [L, K] | |
| keys : torch.Tensor | |
| Key indices with shape [T] | |
| energy_weight : float | |
| Weight for energy normalization (default: 1.0) | |
| orthogonality_weight : float | |
| Weight for orthogonality compensation (default: 0.5) | |
| min_alpha : float | |
| Minimum alpha value (default: 0.1) | |
| max_alpha : float | |
| Maximum alpha value (default: 3.0) | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Optimized alpha values with shape [T] | |
| """ | |
| T = len(patterns) | |
| device = patterns.device | |
| dtype = patterns.dtype | |
| alphas = torch.ones(T, device=device, dtype=dtype) | |
| for i, key in enumerate(keys): | |
| # 1. Energy normalization - normalize by pattern RMS | |
| pattern_rms = torch.sqrt(torch.mean(patterns[i] ** 2)).clamp_min(1e-6) | |
| reference_rms = 0.5 # Target RMS level | |
| energy_factor = reference_rms / pattern_rms if energy_weight > 0 else 1.0 | |
| # 2. Orthogonality compensation - less orthogonal codes need smaller alphas | |
| if orthogonality_weight > 0 and key < C.shape[1]: | |
| code_vec = C[:, key] | |
| # Compute maximum correlation with other codes | |
| other_codes = torch.cat([C[:, :key], C[:, key+1:]], dim=1) if C.shape[1] > 1 else C[:, :0] | |
| if other_codes.numel() > 0: | |
| correlations = torch.abs(code_vec @ other_codes) | |
| max_correlation = correlations.max() if correlations.numel() > 0 else torch.tensor(0.0) | |
| orthogonality_factor = 1.0 / (1.0 + orthogonality_weight * max_correlation) | |
| else: | |
| orthogonality_factor = 1.0 | |
| else: | |
| orthogonality_factor = 1.0 | |
| # Combine factors | |
| alpha = energy_factor * orthogonality_factor | |
| alphas[i] = torch.clamp(alpha, min_alpha, max_alpha) | |
| return alphas | |
| def generate_extended_codes( | |
| L: int, | |
| K: int, | |
| method: str = "auto", | |
| existing_patterns: Optional[torch.Tensor] = None, | |
| device: Optional[torch.device] = None | |
| ) -> torch.Tensor: | |
| """Generate orthogonal codes with support for K > L. | |
| Parameters | |
| ---------- | |
| L : int | |
| Code length (number of layers) | |
| K : int | |
| Number of codes needed | |
| method : str | |
| Code generation method: "auto", "hadamard", "dct", "gram_schmidt", "random_ortho" | |
| existing_patterns : torch.Tensor, optional | |
| Existing patterns to optimize codes for, shape [K, H, W] | |
| device : torch.device, optional | |
| Device to place tensors on | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Code matrix with shape [L, K] | |
| """ | |
| if device is None: | |
| device = torch.device("cpu") | |
| # For K <= L, use optimal orthogonal codes | |
| if K <= L: | |
| if method == "auto": | |
| # Choose best method based on dimensions | |
| if K <= 2**int(math.log2(L)): # Power of 2 for Hadamard | |
| return hadamard_codes(L, K).to(device) | |
| else: | |
| return dct_codes(L, K).to(device) | |
| elif method == "hadamard": | |
| return hadamard_codes(L, K).to(device) | |
| elif method == "dct": | |
| return dct_codes(L, K).to(device) | |
| # For K > L, we need to generate overcomplete codes | |
| if method == "gram_schmidt" or (method == "auto" and K > L): | |
| return generate_gram_schmidt_codes(L, K, existing_patterns, device) | |
| elif method == "random_ortho": | |
| return generate_random_orthogonal_codes(L, K, device) | |
| # Fallback: use best available orthogonal codes up to L, then random | |
| if K <= L: | |
| torch.manual_seed(42) # Reproducible | |
| C = torch.randn(L, K, device=device) | |
| return normalize_columns(C) | |
| else: | |
| # For K > L, use hybrid approach | |
| return generate_gram_schmidt_codes(L, K, existing_patterns, device) | |
| def generate_gram_schmidt_codes( | |
| L: int, | |
| K: int, | |
| existing_patterns: Optional[torch.Tensor] = None, | |
| device: Optional[torch.device] = None | |
| ) -> torch.Tensor: | |
| """Generate codes using Gram-Schmidt orthogonalization. | |
| This method can generate more codes than layers (K > L) by creating | |
| an orthonormal basis that maximally preserves pattern information. | |
| """ | |
| if device is None: | |
| device = torch.device("cpu") | |
| # Start with best orthogonal codes up to L | |
| if K <= L: | |
| base_codes = hadamard_codes(L, min(K, L)).to(device) | |
| if K == base_codes.shape[1]: | |
| return base_codes | |
| else: | |
| base_codes = hadamard_codes(L, L).to(device) | |
| # If we need more codes, generate them using pattern-informed random vectors | |
| if K > base_codes.shape[1]: | |
| additional_needed = K - base_codes.shape[1] | |
| # Generate additional vectors | |
| if existing_patterns is not None: | |
| # Use pattern information to generate meaningful directions | |
| patterns_flat = existing_patterns.view(existing_patterns.shape[0], -1) | |
| # SVD on patterns to find principal directions | |
| U, _, _ = torch.svd(patterns_flat.T) | |
| additional_vectors = U[:L, :additional_needed].to(device) | |
| else: | |
| # Random additional vectors | |
| torch.manual_seed(42) | |
| additional_vectors = torch.randn(L, additional_needed, device=device) | |
| # Combine base codes with additional vectors | |
| all_vectors = torch.cat([base_codes, additional_vectors], dim=1) | |
| else: | |
| all_vectors = base_codes | |
| # Gram-Schmidt orthogonalization using QR decomposition for stability | |
| Q, R = torch.linalg.qr(all_vectors) | |
| # Ensure positive diagonal for consistency | |
| signs = torch.sign(torch.diag(R)) | |
| signs[signs == 0] = 1 | |
| Q = Q * signs.unsqueeze(0) | |
| return normalize_columns(Q[:, :K]) | |
| def generate_random_orthogonal_codes(L: int, K: int, device: Optional[torch.device] = None) -> torch.Tensor: | |
| """Generate random orthogonal codes using QR decomposition.""" | |
| if device is None: | |
| device = torch.device("cpu") | |
| torch.manual_seed(42) # Reproducible | |
| if K <= L: | |
| # Generate random matrix and orthogonalize | |
| A = torch.randn(L, K, device=device) | |
| Q, _ = torch.linalg.qr(A) | |
| return normalize_columns(Q) | |
| else: | |
| # For K > L, generate the best L orthogonal vectors, then add random ones | |
| A = torch.randn(L, L, device=device) | |
| Q, _ = torch.linalg.qr(A) | |
| # Add random additional vectors (not orthogonal to first L) | |
| additional = torch.randn(L, K - L, device=device) | |
| all_codes = torch.cat([Q, additional], dim=1) | |
| return normalize_columns(all_codes) | |
| class HierarchicalMembraneBank: | |
| """Multi-level membrane bank for better capacity scaling.""" | |
| def __init__( | |
| self, | |
| L: int, | |
| H: int, | |
| W: int, | |
| levels: int = 3, | |
| level_ratios: Optional[list[float]] = None, | |
| device: Optional[torch.device] = None, | |
| dtype: torch.dtype = torch.float32 | |
| ): | |
| """Initialize hierarchical membrane bank. | |
| Parameters | |
| ---------- | |
| L, H, W : int | |
| Base dimensions | |
| levels : int | |
| Number of hierarchy levels | |
| level_ratios : list[float], optional | |
| Fraction of L allocated to each level. If None, uses geometric series. | |
| device : torch.device, optional | |
| Device for tensors | |
| dtype : torch.dtype | |
| Data type for tensors | |
| """ | |
| self.levels = levels | |
| self.H = H | |
| self.W = W | |
| self.device = device | |
| self.dtype = dtype | |
| if level_ratios is None: | |
| # Geometric series: 1/2, 1/4, 1/8, ... | |
| level_ratios = [0.5 ** (i + 1) for i in range(levels)] | |
| level_ratios = [r / sum(level_ratios) for r in level_ratios] # Normalize | |
| self.level_ratios = level_ratios | |
| # Create membrane banks for each level | |
| from .membrane_bank import MembraneBank | |
| self.banks = [] | |
| remaining_L = L | |
| for i, ratio in enumerate(level_ratios): | |
| level_L = max(1, int(L * ratio)) | |
| if i == levels - 1: # Last level gets remaining | |
| level_L = remaining_L | |
| bank = MembraneBank(level_L, H, W, device=device, dtype=dtype) | |
| self.banks.append(bank) | |
| remaining_L -= level_L | |
| self.total_L = sum(bank.L for bank in self.banks) | |
| def allocate(self, B: int) -> None: | |
| """Allocate all level banks for batch size B.""" | |
| for bank in self.banks: | |
| bank.allocate(B) | |
| def store_hierarchical( | |
| self, | |
| patterns: torch.Tensor, | |
| keys: torch.Tensor, | |
| level_assignment: Optional[torch.Tensor] = None | |
| ) -> None: | |
| """Store patterns in appropriate hierarchy levels. | |
| Parameters | |
| ---------- | |
| patterns : torch.Tensor | |
| Patterns to store, shape [T, H, W] | |
| keys : torch.Tensor | |
| Key indices, shape [T] | |
| level_assignment : torch.Tensor, optional | |
| Level assignment for each pattern. If None, auto-assign based on pattern complexity. | |
| """ | |
| if level_assignment is None: | |
| level_assignment = self._auto_assign_levels(patterns) | |
| # Store patterns in assigned levels | |
| for level in range(self.levels): | |
| level_mask = level_assignment == level | |
| if level_mask.any(): | |
| level_patterns = patterns[level_mask] | |
| level_keys = keys[level_mask] | |
| # Generate codes for this level | |
| bank = self.banks[level] | |
| C = generate_extended_codes(bank.L, level_keys.max().item() + 1, device=self.device) | |
| # Store patterns | |
| alphas = compute_adaptive_alphas(level_patterns, C, level_keys) | |
| M = store_pairs(bank.read(), C, level_keys, level_patterns, alphas) | |
| bank.write(M - bank.read()) | |
| def _auto_assign_levels(self, patterns: torch.Tensor) -> torch.Tensor: | |
| """Automatically assign patterns to hierarchy levels based on complexity.""" | |
| from .metrics import spectral_entropy_2d | |
| entropies = torch.tensor([spectral_entropy_2d(p) for p in patterns]) | |
| # Sort by entropy and assign to levels | |
| sorted_indices = torch.argsort(entropies, descending=True) | |
| level_assignment = torch.zeros(len(patterns), dtype=torch.long) | |
| patterns_per_level = len(patterns) // self.levels | |
| for level in range(self.levels): | |
| start_idx = level * patterns_per_level | |
| end_idx = start_idx + patterns_per_level if level < self.levels - 1 else len(patterns) | |
| for idx in sorted_indices[start_idx:end_idx]: | |
| level_assignment[idx] = level | |
| return level_assignment | |
| def optimized_store_pairs( | |
| M: torch.Tensor, | |
| C: torch.Tensor, | |
| keys: torch.Tensor, | |
| values: torch.Tensor, | |
| alphas: Optional[torch.Tensor] = None, | |
| adaptive_alphas: bool = True, | |
| sparsity_threshold: float = 0.01 | |
| ) -> torch.Tensor: | |
| """Enhanced store_pairs with optimizations. | |
| Parameters | |
| ---------- | |
| M : torch.Tensor | |
| Current membranes with shape [B, L, H, W] | |
| C : torch.Tensor | |
| Codebook tensor with shape [L, K] | |
| keys : torch.Tensor | |
| Key indices with shape [T] | |
| values : torch.Tensor | |
| Values to store with shape [T, H, W] | |
| alphas : torch.Tensor, optional | |
| Manual alpha values. If None and adaptive_alphas=True, computes automatically. | |
| adaptive_alphas : bool | |
| Whether to use adaptive alpha scaling | |
| sparsity_threshold : float | |
| Threshold for sparse pattern detection | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Updated membrane tensor | |
| """ | |
| if alphas is None and adaptive_alphas: | |
| alphas = compute_adaptive_alphas(values, C, keys) | |
| elif alphas is None: | |
| alphas = torch.ones(len(keys), device=values.device, dtype=values.dtype) | |
| # Check for sparse patterns | |
| pattern_norms = torch.norm(values.view(len(values), -1), dim=1) | |
| sparse_mask = pattern_norms < sparsity_threshold | |
| if sparse_mask.any() and not sparse_mask.all(): | |
| # Mixed sparse/dense - handle separately | |
| dense_mask = ~sparse_mask | |
| result = M.clone() | |
| if dense_mask.any(): | |
| dense_result = store_pairs( | |
| result, C, keys[dense_mask], values[dense_mask], alphas[dense_mask] | |
| ) | |
| result = dense_result | |
| if sparse_mask.any(): | |
| sparse_result = _sparse_store_pairs( | |
| result, C, keys[sparse_mask], values[sparse_mask], alphas[sparse_mask] | |
| ) | |
| result = sparse_result | |
| return result | |
| else: | |
| # All dense or all sparse - use single method | |
| if sparse_mask.all(): | |
| return _sparse_store_pairs(M, C, keys, values, alphas) | |
| else: | |
| return store_pairs(M, C, keys, values, alphas) | |
| def _sparse_store_pairs( | |
| M: torch.Tensor, | |
| C: torch.Tensor, | |
| keys: torch.Tensor, | |
| values: torch.Tensor, | |
| alphas: torch.Tensor | |
| ) -> torch.Tensor: | |
| """Sparse implementation of store_pairs for sparse patterns.""" | |
| # For now, use regular implementation - could be optimized with sparse tensors | |
| return store_pairs(M, C, keys, values, alphas) |