|
|
""" |
|
|
WrinkleBrane Optimizations Module |
|
|
Advanced optimizations for improved performance and fidelity. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from scipy.linalg import qr |
|
|
|
|
|
from .codes import normalize_columns, hadamard_codes, dct_codes |
|
|
from .write_ops import store_pairs |
|
|
|
|
|
|
|
|
def compute_adaptive_alphas( |
|
|
patterns: torch.Tensor, |
|
|
C: torch.Tensor, |
|
|
keys: torch.Tensor, |
|
|
energy_weight: float = 1.0, |
|
|
orthogonality_weight: float = 0.5, |
|
|
min_alpha: float = 0.1, |
|
|
max_alpha: float = 3.0 |
|
|
) -> torch.Tensor: |
|
|
"""Compute optimal alpha values for each pattern based on energy and orthogonality. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
patterns : torch.Tensor |
|
|
Input patterns with shape [T, H, W] |
|
|
C : torch.Tensor |
|
|
Codebook tensor with shape [L, K] |
|
|
keys : torch.Tensor |
|
|
Key indices with shape [T] |
|
|
energy_weight : float |
|
|
Weight for energy normalization (default: 1.0) |
|
|
orthogonality_weight : float |
|
|
Weight for orthogonality compensation (default: 0.5) |
|
|
min_alpha : float |
|
|
Minimum alpha value (default: 0.1) |
|
|
max_alpha : float |
|
|
Maximum alpha value (default: 3.0) |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
Optimized alpha values with shape [T] |
|
|
""" |
|
|
T = len(patterns) |
|
|
device = patterns.device |
|
|
dtype = patterns.dtype |
|
|
|
|
|
alphas = torch.ones(T, device=device, dtype=dtype) |
|
|
|
|
|
for i, key in enumerate(keys): |
|
|
|
|
|
pattern_rms = torch.sqrt(torch.mean(patterns[i] ** 2)).clamp_min(1e-6) |
|
|
reference_rms = 0.5 |
|
|
energy_factor = reference_rms / pattern_rms if energy_weight > 0 else 1.0 |
|
|
|
|
|
|
|
|
if orthogonality_weight > 0 and key < C.shape[1]: |
|
|
code_vec = C[:, key] |
|
|
|
|
|
other_codes = torch.cat([C[:, :key], C[:, key+1:]], dim=1) if C.shape[1] > 1 else C[:, :0] |
|
|
if other_codes.numel() > 0: |
|
|
correlations = torch.abs(code_vec @ other_codes) |
|
|
max_correlation = correlations.max() if correlations.numel() > 0 else torch.tensor(0.0) |
|
|
orthogonality_factor = 1.0 / (1.0 + orthogonality_weight * max_correlation) |
|
|
else: |
|
|
orthogonality_factor = 1.0 |
|
|
else: |
|
|
orthogonality_factor = 1.0 |
|
|
|
|
|
|
|
|
alpha = energy_factor * orthogonality_factor |
|
|
alphas[i] = torch.clamp(alpha, min_alpha, max_alpha) |
|
|
|
|
|
return alphas |
|
|
|
|
|
|
|
|
def generate_extended_codes( |
|
|
L: int, |
|
|
K: int, |
|
|
method: str = "auto", |
|
|
existing_patterns: Optional[torch.Tensor] = None, |
|
|
device: Optional[torch.device] = None |
|
|
) -> torch.Tensor: |
|
|
"""Generate orthogonal codes with support for K > L. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
L : int |
|
|
Code length (number of layers) |
|
|
K : int |
|
|
Number of codes needed |
|
|
method : str |
|
|
Code generation method: "auto", "hadamard", "dct", "gram_schmidt", "random_ortho" |
|
|
existing_patterns : torch.Tensor, optional |
|
|
Existing patterns to optimize codes for, shape [K, H, W] |
|
|
device : torch.device, optional |
|
|
Device to place tensors on |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
Code matrix with shape [L, K] |
|
|
""" |
|
|
if device is None: |
|
|
device = torch.device("cpu") |
|
|
|
|
|
|
|
|
if K <= L: |
|
|
if method == "auto": |
|
|
|
|
|
if K <= 2**int(math.log2(L)): |
|
|
return hadamard_codes(L, K).to(device) |
|
|
else: |
|
|
return dct_codes(L, K).to(device) |
|
|
elif method == "hadamard": |
|
|
return hadamard_codes(L, K).to(device) |
|
|
elif method == "dct": |
|
|
return dct_codes(L, K).to(device) |
|
|
|
|
|
|
|
|
if method == "gram_schmidt" or (method == "auto" and K > L): |
|
|
return generate_gram_schmidt_codes(L, K, existing_patterns, device) |
|
|
elif method == "random_ortho": |
|
|
return generate_random_orthogonal_codes(L, K, device) |
|
|
|
|
|
|
|
|
if K <= L: |
|
|
torch.manual_seed(42) |
|
|
C = torch.randn(L, K, device=device) |
|
|
return normalize_columns(C) |
|
|
else: |
|
|
|
|
|
return generate_gram_schmidt_codes(L, K, existing_patterns, device) |
|
|
|
|
|
|
|
|
def generate_gram_schmidt_codes( |
|
|
L: int, |
|
|
K: int, |
|
|
existing_patterns: Optional[torch.Tensor] = None, |
|
|
device: Optional[torch.device] = None |
|
|
) -> torch.Tensor: |
|
|
"""Generate codes using Gram-Schmidt orthogonalization. |
|
|
|
|
|
This method can generate more codes than layers (K > L) by creating |
|
|
an orthonormal basis that maximally preserves pattern information. |
|
|
""" |
|
|
if device is None: |
|
|
device = torch.device("cpu") |
|
|
|
|
|
|
|
|
if K <= L: |
|
|
base_codes = hadamard_codes(L, min(K, L)).to(device) |
|
|
if K == base_codes.shape[1]: |
|
|
return base_codes |
|
|
else: |
|
|
base_codes = hadamard_codes(L, L).to(device) |
|
|
|
|
|
|
|
|
if K > base_codes.shape[1]: |
|
|
additional_needed = K - base_codes.shape[1] |
|
|
|
|
|
|
|
|
if existing_patterns is not None: |
|
|
|
|
|
patterns_flat = existing_patterns.view(existing_patterns.shape[0], -1) |
|
|
|
|
|
U, _, _ = torch.svd(patterns_flat.T) |
|
|
additional_vectors = U[:L, :additional_needed].to(device) |
|
|
else: |
|
|
|
|
|
torch.manual_seed(42) |
|
|
additional_vectors = torch.randn(L, additional_needed, device=device) |
|
|
|
|
|
|
|
|
all_vectors = torch.cat([base_codes, additional_vectors], dim=1) |
|
|
else: |
|
|
all_vectors = base_codes |
|
|
|
|
|
|
|
|
Q, R = torch.linalg.qr(all_vectors) |
|
|
|
|
|
|
|
|
signs = torch.sign(torch.diag(R)) |
|
|
signs[signs == 0] = 1 |
|
|
Q = Q * signs.unsqueeze(0) |
|
|
|
|
|
return normalize_columns(Q[:, :K]) |
|
|
|
|
|
|
|
|
def generate_random_orthogonal_codes(L: int, K: int, device: Optional[torch.device] = None) -> torch.Tensor: |
|
|
"""Generate random orthogonal codes using QR decomposition.""" |
|
|
if device is None: |
|
|
device = torch.device("cpu") |
|
|
|
|
|
torch.manual_seed(42) |
|
|
|
|
|
if K <= L: |
|
|
|
|
|
A = torch.randn(L, K, device=device) |
|
|
Q, _ = torch.linalg.qr(A) |
|
|
return normalize_columns(Q) |
|
|
else: |
|
|
|
|
|
A = torch.randn(L, L, device=device) |
|
|
Q, _ = torch.linalg.qr(A) |
|
|
|
|
|
|
|
|
additional = torch.randn(L, K - L, device=device) |
|
|
all_codes = torch.cat([Q, additional], dim=1) |
|
|
return normalize_columns(all_codes) |
|
|
|
|
|
|
|
|
class HierarchicalMembraneBank: |
|
|
"""Multi-level membrane bank for better capacity scaling.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
L: int, |
|
|
H: int, |
|
|
W: int, |
|
|
levels: int = 3, |
|
|
level_ratios: Optional[list[float]] = None, |
|
|
device: Optional[torch.device] = None, |
|
|
dtype: torch.dtype = torch.float32 |
|
|
): |
|
|
"""Initialize hierarchical membrane bank. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
L, H, W : int |
|
|
Base dimensions |
|
|
levels : int |
|
|
Number of hierarchy levels |
|
|
level_ratios : list[float], optional |
|
|
Fraction of L allocated to each level. If None, uses geometric series. |
|
|
device : torch.device, optional |
|
|
Device for tensors |
|
|
dtype : torch.dtype |
|
|
Data type for tensors |
|
|
""" |
|
|
self.levels = levels |
|
|
self.H = H |
|
|
self.W = W |
|
|
self.device = device |
|
|
self.dtype = dtype |
|
|
|
|
|
if level_ratios is None: |
|
|
|
|
|
level_ratios = [0.5 ** (i + 1) for i in range(levels)] |
|
|
level_ratios = [r / sum(level_ratios) for r in level_ratios] |
|
|
|
|
|
self.level_ratios = level_ratios |
|
|
|
|
|
|
|
|
from .membrane_bank import MembraneBank |
|
|
self.banks = [] |
|
|
remaining_L = L |
|
|
|
|
|
for i, ratio in enumerate(level_ratios): |
|
|
level_L = max(1, int(L * ratio)) |
|
|
if i == levels - 1: |
|
|
level_L = remaining_L |
|
|
|
|
|
bank = MembraneBank(level_L, H, W, device=device, dtype=dtype) |
|
|
self.banks.append(bank) |
|
|
remaining_L -= level_L |
|
|
|
|
|
self.total_L = sum(bank.L for bank in self.banks) |
|
|
|
|
|
def allocate(self, B: int) -> None: |
|
|
"""Allocate all level banks for batch size B.""" |
|
|
for bank in self.banks: |
|
|
bank.allocate(B) |
|
|
|
|
|
def store_hierarchical( |
|
|
self, |
|
|
patterns: torch.Tensor, |
|
|
keys: torch.Tensor, |
|
|
level_assignment: Optional[torch.Tensor] = None |
|
|
) -> None: |
|
|
"""Store patterns in appropriate hierarchy levels. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
patterns : torch.Tensor |
|
|
Patterns to store, shape [T, H, W] |
|
|
keys : torch.Tensor |
|
|
Key indices, shape [T] |
|
|
level_assignment : torch.Tensor, optional |
|
|
Level assignment for each pattern. If None, auto-assign based on pattern complexity. |
|
|
""" |
|
|
if level_assignment is None: |
|
|
level_assignment = self._auto_assign_levels(patterns) |
|
|
|
|
|
|
|
|
for level in range(self.levels): |
|
|
level_mask = level_assignment == level |
|
|
if level_mask.any(): |
|
|
level_patterns = patterns[level_mask] |
|
|
level_keys = keys[level_mask] |
|
|
|
|
|
|
|
|
bank = self.banks[level] |
|
|
C = generate_extended_codes(bank.L, level_keys.max().item() + 1, device=self.device) |
|
|
|
|
|
|
|
|
alphas = compute_adaptive_alphas(level_patterns, C, level_keys) |
|
|
M = store_pairs(bank.read(), C, level_keys, level_patterns, alphas) |
|
|
bank.write(M - bank.read()) |
|
|
|
|
|
def _auto_assign_levels(self, patterns: torch.Tensor) -> torch.Tensor: |
|
|
"""Automatically assign patterns to hierarchy levels based on complexity.""" |
|
|
from .metrics import spectral_entropy_2d |
|
|
|
|
|
entropies = torch.tensor([spectral_entropy_2d(p) for p in patterns]) |
|
|
|
|
|
|
|
|
sorted_indices = torch.argsort(entropies, descending=True) |
|
|
level_assignment = torch.zeros(len(patterns), dtype=torch.long) |
|
|
|
|
|
patterns_per_level = len(patterns) // self.levels |
|
|
for level in range(self.levels): |
|
|
start_idx = level * patterns_per_level |
|
|
end_idx = start_idx + patterns_per_level if level < self.levels - 1 else len(patterns) |
|
|
|
|
|
for idx in sorted_indices[start_idx:end_idx]: |
|
|
level_assignment[idx] = level |
|
|
|
|
|
return level_assignment |
|
|
|
|
|
|
|
|
def optimized_store_pairs( |
|
|
M: torch.Tensor, |
|
|
C: torch.Tensor, |
|
|
keys: torch.Tensor, |
|
|
values: torch.Tensor, |
|
|
alphas: Optional[torch.Tensor] = None, |
|
|
adaptive_alphas: bool = True, |
|
|
sparsity_threshold: float = 0.01 |
|
|
) -> torch.Tensor: |
|
|
"""Enhanced store_pairs with optimizations. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
M : torch.Tensor |
|
|
Current membranes with shape [B, L, H, W] |
|
|
C : torch.Tensor |
|
|
Codebook tensor with shape [L, K] |
|
|
keys : torch.Tensor |
|
|
Key indices with shape [T] |
|
|
values : torch.Tensor |
|
|
Values to store with shape [T, H, W] |
|
|
alphas : torch.Tensor, optional |
|
|
Manual alpha values. If None and adaptive_alphas=True, computes automatically. |
|
|
adaptive_alphas : bool |
|
|
Whether to use adaptive alpha scaling |
|
|
sparsity_threshold : float |
|
|
Threshold for sparse pattern detection |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
Updated membrane tensor |
|
|
""" |
|
|
if alphas is None and adaptive_alphas: |
|
|
alphas = compute_adaptive_alphas(values, C, keys) |
|
|
elif alphas is None: |
|
|
alphas = torch.ones(len(keys), device=values.device, dtype=values.dtype) |
|
|
|
|
|
|
|
|
pattern_norms = torch.norm(values.view(len(values), -1), dim=1) |
|
|
sparse_mask = pattern_norms < sparsity_threshold |
|
|
|
|
|
if sparse_mask.any() and not sparse_mask.all(): |
|
|
|
|
|
dense_mask = ~sparse_mask |
|
|
result = M.clone() |
|
|
|
|
|
if dense_mask.any(): |
|
|
dense_result = store_pairs( |
|
|
result, C, keys[dense_mask], values[dense_mask], alphas[dense_mask] |
|
|
) |
|
|
result = dense_result |
|
|
|
|
|
if sparse_mask.any(): |
|
|
sparse_result = _sparse_store_pairs( |
|
|
result, C, keys[sparse_mask], values[sparse_mask], alphas[sparse_mask] |
|
|
) |
|
|
result = sparse_result |
|
|
|
|
|
return result |
|
|
else: |
|
|
|
|
|
if sparse_mask.all(): |
|
|
return _sparse_store_pairs(M, C, keys, values, alphas) |
|
|
else: |
|
|
return store_pairs(M, C, keys, values, alphas) |
|
|
|
|
|
|
|
|
def _sparse_store_pairs( |
|
|
M: torch.Tensor, |
|
|
C: torch.Tensor, |
|
|
keys: torch.Tensor, |
|
|
values: torch.Tensor, |
|
|
alphas: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
"""Sparse implementation of store_pairs for sparse patterns.""" |
|
|
|
|
|
return store_pairs(M, C, keys, values, alphas) |