WrinkleBrane / src /wrinklebrane /optimizations.py
WCNegentropy's picture
๐Ÿ“š Updated with scientifically rigorous documentation
dc2b9f3 verified
"""
WrinkleBrane Optimizations Module
Advanced optimizations for improved performance and fidelity.
"""
from __future__ import annotations
import math
from typing import Optional, Tuple
import torch
import numpy as np
from scipy.linalg import qr
from .codes import normalize_columns, hadamard_codes, dct_codes
from .write_ops import store_pairs
def compute_adaptive_alphas(
patterns: torch.Tensor,
C: torch.Tensor,
keys: torch.Tensor,
energy_weight: float = 1.0,
orthogonality_weight: float = 0.5,
min_alpha: float = 0.1,
max_alpha: float = 3.0
) -> torch.Tensor:
"""Compute optimal alpha values for each pattern based on energy and orthogonality.
Parameters
----------
patterns : torch.Tensor
Input patterns with shape [T, H, W]
C : torch.Tensor
Codebook tensor with shape [L, K]
keys : torch.Tensor
Key indices with shape [T]
energy_weight : float
Weight for energy normalization (default: 1.0)
orthogonality_weight : float
Weight for orthogonality compensation (default: 0.5)
min_alpha : float
Minimum alpha value (default: 0.1)
max_alpha : float
Maximum alpha value (default: 3.0)
Returns
-------
torch.Tensor
Optimized alpha values with shape [T]
"""
T = len(patterns)
device = patterns.device
dtype = patterns.dtype
alphas = torch.ones(T, device=device, dtype=dtype)
for i, key in enumerate(keys):
# 1. Energy normalization - normalize by pattern RMS
pattern_rms = torch.sqrt(torch.mean(patterns[i] ** 2)).clamp_min(1e-6)
reference_rms = 0.5 # Target RMS level
energy_factor = reference_rms / pattern_rms if energy_weight > 0 else 1.0
# 2. Orthogonality compensation - less orthogonal codes need smaller alphas
if orthogonality_weight > 0 and key < C.shape[1]:
code_vec = C[:, key]
# Compute maximum correlation with other codes
other_codes = torch.cat([C[:, :key], C[:, key+1:]], dim=1) if C.shape[1] > 1 else C[:, :0]
if other_codes.numel() > 0:
correlations = torch.abs(code_vec @ other_codes)
max_correlation = correlations.max() if correlations.numel() > 0 else torch.tensor(0.0)
orthogonality_factor = 1.0 / (1.0 + orthogonality_weight * max_correlation)
else:
orthogonality_factor = 1.0
else:
orthogonality_factor = 1.0
# Combine factors
alpha = energy_factor * orthogonality_factor
alphas[i] = torch.clamp(alpha, min_alpha, max_alpha)
return alphas
def generate_extended_codes(
L: int,
K: int,
method: str = "auto",
existing_patterns: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None
) -> torch.Tensor:
"""Generate orthogonal codes with support for K > L.
Parameters
----------
L : int
Code length (number of layers)
K : int
Number of codes needed
method : str
Code generation method: "auto", "hadamard", "dct", "gram_schmidt", "random_ortho"
existing_patterns : torch.Tensor, optional
Existing patterns to optimize codes for, shape [K, H, W]
device : torch.device, optional
Device to place tensors on
Returns
-------
torch.Tensor
Code matrix with shape [L, K]
"""
if device is None:
device = torch.device("cpu")
# For K <= L, use optimal orthogonal codes
if K <= L:
if method == "auto":
# Choose best method based on dimensions
if K <= 2**int(math.log2(L)): # Power of 2 for Hadamard
return hadamard_codes(L, K).to(device)
else:
return dct_codes(L, K).to(device)
elif method == "hadamard":
return hadamard_codes(L, K).to(device)
elif method == "dct":
return dct_codes(L, K).to(device)
# For K > L, we need to generate overcomplete codes
if method == "gram_schmidt" or (method == "auto" and K > L):
return generate_gram_schmidt_codes(L, K, existing_patterns, device)
elif method == "random_ortho":
return generate_random_orthogonal_codes(L, K, device)
# Fallback: use best available orthogonal codes up to L, then random
if K <= L:
torch.manual_seed(42) # Reproducible
C = torch.randn(L, K, device=device)
return normalize_columns(C)
else:
# For K > L, use hybrid approach
return generate_gram_schmidt_codes(L, K, existing_patterns, device)
def generate_gram_schmidt_codes(
L: int,
K: int,
existing_patterns: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None
) -> torch.Tensor:
"""Generate codes using Gram-Schmidt orthogonalization.
This method can generate more codes than layers (K > L) by creating
an orthonormal basis that maximally preserves pattern information.
"""
if device is None:
device = torch.device("cpu")
# Start with best orthogonal codes up to L
if K <= L:
base_codes = hadamard_codes(L, min(K, L)).to(device)
if K == base_codes.shape[1]:
return base_codes
else:
base_codes = hadamard_codes(L, L).to(device)
# If we need more codes, generate them using pattern-informed random vectors
if K > base_codes.shape[1]:
additional_needed = K - base_codes.shape[1]
# Generate additional vectors
if existing_patterns is not None:
# Use pattern information to generate meaningful directions
patterns_flat = existing_patterns.view(existing_patterns.shape[0], -1)
# SVD on patterns to find principal directions
U, _, _ = torch.svd(patterns_flat.T)
additional_vectors = U[:L, :additional_needed].to(device)
else:
# Random additional vectors
torch.manual_seed(42)
additional_vectors = torch.randn(L, additional_needed, device=device)
# Combine base codes with additional vectors
all_vectors = torch.cat([base_codes, additional_vectors], dim=1)
else:
all_vectors = base_codes
# Gram-Schmidt orthogonalization using QR decomposition for stability
Q, R = torch.linalg.qr(all_vectors)
# Ensure positive diagonal for consistency
signs = torch.sign(torch.diag(R))
signs[signs == 0] = 1
Q = Q * signs.unsqueeze(0)
return normalize_columns(Q[:, :K])
def generate_random_orthogonal_codes(L: int, K: int, device: Optional[torch.device] = None) -> torch.Tensor:
"""Generate random orthogonal codes using QR decomposition."""
if device is None:
device = torch.device("cpu")
torch.manual_seed(42) # Reproducible
if K <= L:
# Generate random matrix and orthogonalize
A = torch.randn(L, K, device=device)
Q, _ = torch.linalg.qr(A)
return normalize_columns(Q)
else:
# For K > L, generate the best L orthogonal vectors, then add random ones
A = torch.randn(L, L, device=device)
Q, _ = torch.linalg.qr(A)
# Add random additional vectors (not orthogonal to first L)
additional = torch.randn(L, K - L, device=device)
all_codes = torch.cat([Q, additional], dim=1)
return normalize_columns(all_codes)
class HierarchicalMembraneBank:
"""Multi-level membrane bank for better capacity scaling."""
def __init__(
self,
L: int,
H: int,
W: int,
levels: int = 3,
level_ratios: Optional[list[float]] = None,
device: Optional[torch.device] = None,
dtype: torch.dtype = torch.float32
):
"""Initialize hierarchical membrane bank.
Parameters
----------
L, H, W : int
Base dimensions
levels : int
Number of hierarchy levels
level_ratios : list[float], optional
Fraction of L allocated to each level. If None, uses geometric series.
device : torch.device, optional
Device for tensors
dtype : torch.dtype
Data type for tensors
"""
self.levels = levels
self.H = H
self.W = W
self.device = device
self.dtype = dtype
if level_ratios is None:
# Geometric series: 1/2, 1/4, 1/8, ...
level_ratios = [0.5 ** (i + 1) for i in range(levels)]
level_ratios = [r / sum(level_ratios) for r in level_ratios] # Normalize
self.level_ratios = level_ratios
# Create membrane banks for each level
from .membrane_bank import MembraneBank
self.banks = []
remaining_L = L
for i, ratio in enumerate(level_ratios):
level_L = max(1, int(L * ratio))
if i == levels - 1: # Last level gets remaining
level_L = remaining_L
bank = MembraneBank(level_L, H, W, device=device, dtype=dtype)
self.banks.append(bank)
remaining_L -= level_L
self.total_L = sum(bank.L for bank in self.banks)
def allocate(self, B: int) -> None:
"""Allocate all level banks for batch size B."""
for bank in self.banks:
bank.allocate(B)
def store_hierarchical(
self,
patterns: torch.Tensor,
keys: torch.Tensor,
level_assignment: Optional[torch.Tensor] = None
) -> None:
"""Store patterns in appropriate hierarchy levels.
Parameters
----------
patterns : torch.Tensor
Patterns to store, shape [T, H, W]
keys : torch.Tensor
Key indices, shape [T]
level_assignment : torch.Tensor, optional
Level assignment for each pattern. If None, auto-assign based on pattern complexity.
"""
if level_assignment is None:
level_assignment = self._auto_assign_levels(patterns)
# Store patterns in assigned levels
for level in range(self.levels):
level_mask = level_assignment == level
if level_mask.any():
level_patterns = patterns[level_mask]
level_keys = keys[level_mask]
# Generate codes for this level
bank = self.banks[level]
C = generate_extended_codes(bank.L, level_keys.max().item() + 1, device=self.device)
# Store patterns
alphas = compute_adaptive_alphas(level_patterns, C, level_keys)
M = store_pairs(bank.read(), C, level_keys, level_patterns, alphas)
bank.write(M - bank.read())
def _auto_assign_levels(self, patterns: torch.Tensor) -> torch.Tensor:
"""Automatically assign patterns to hierarchy levels based on complexity."""
from .metrics import spectral_entropy_2d
entropies = torch.tensor([spectral_entropy_2d(p) for p in patterns])
# Sort by entropy and assign to levels
sorted_indices = torch.argsort(entropies, descending=True)
level_assignment = torch.zeros(len(patterns), dtype=torch.long)
patterns_per_level = len(patterns) // self.levels
for level in range(self.levels):
start_idx = level * patterns_per_level
end_idx = start_idx + patterns_per_level if level < self.levels - 1 else len(patterns)
for idx in sorted_indices[start_idx:end_idx]:
level_assignment[idx] = level
return level_assignment
def optimized_store_pairs(
M: torch.Tensor,
C: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
alphas: Optional[torch.Tensor] = None,
adaptive_alphas: bool = True,
sparsity_threshold: float = 0.01
) -> torch.Tensor:
"""Enhanced store_pairs with optimizations.
Parameters
----------
M : torch.Tensor
Current membranes with shape [B, L, H, W]
C : torch.Tensor
Codebook tensor with shape [L, K]
keys : torch.Tensor
Key indices with shape [T]
values : torch.Tensor
Values to store with shape [T, H, W]
alphas : torch.Tensor, optional
Manual alpha values. If None and adaptive_alphas=True, computes automatically.
adaptive_alphas : bool
Whether to use adaptive alpha scaling
sparsity_threshold : float
Threshold for sparse pattern detection
Returns
-------
torch.Tensor
Updated membrane tensor
"""
if alphas is None and adaptive_alphas:
alphas = compute_adaptive_alphas(values, C, keys)
elif alphas is None:
alphas = torch.ones(len(keys), device=values.device, dtype=values.dtype)
# Check for sparse patterns
pattern_norms = torch.norm(values.view(len(values), -1), dim=1)
sparse_mask = pattern_norms < sparsity_threshold
if sparse_mask.any() and not sparse_mask.all():
# Mixed sparse/dense - handle separately
dense_mask = ~sparse_mask
result = M.clone()
if dense_mask.any():
dense_result = store_pairs(
result, C, keys[dense_mask], values[dense_mask], alphas[dense_mask]
)
result = dense_result
if sparse_mask.any():
sparse_result = _sparse_store_pairs(
result, C, keys[sparse_mask], values[sparse_mask], alphas[sparse_mask]
)
result = sparse_result
return result
else:
# All dense or all sparse - use single method
if sparse_mask.all():
return _sparse_store_pairs(M, C, keys, values, alphas)
else:
return store_pairs(M, C, keys, values, alphas)
def _sparse_store_pairs(
M: torch.Tensor,
C: torch.Tensor,
keys: torch.Tensor,
values: torch.Tensor,
alphas: torch.Tensor
) -> torch.Tensor:
"""Sparse implementation of store_pairs for sparse patterns."""
# For now, use regular implementation - could be optimized with sparse tensors
return store_pairs(M, C, keys, values, alphas)