INV / helium /normalization.py
Fred808's picture
Upload 256 files
7a0c684 verified
from typing import Optional, Union, Tuple, List, Dict
import numpy as np
from dataclasses import dataclass
from enum import Enum
import warnings
from .core.db_manager import HeliumDBManager
from virtual_gpu_driver.src.ai.tensor_types import Tensor, Device, DType
import hashlib
import json
from functools import lru_cache
class NormType(Enum):
"""Supported normalization types"""
BATCH = "batch"
LAYER = "layer"
GROUP = "group"
INSTANCE = "instance"
RMS = "rms"
def normalize(input: Tensor,
mean: Optional[Tensor] = None,
variance: Optional[Tensor] = None,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5) -> Tuple[Tensor, Tensor, Tensor]:
"""
Normalizes the input using mean and variance. If mean/variance not provided,
they are computed from the input.
Args:
input: Input tensor
mean: Optional pre-computed mean
variance: Optional pre-computed variance
weight: Optional scale parameter
bias: Optional bias parameter
eps: Small constant for numerical stability
Returns:
Tuple of (normalized tensor, mean, variance)
"""
# Calculate mean and variance if not provided
if mean is None or variance is None:
# Compute stats along last dimension
axes = tuple(range(input.ndim - 1))
mean = input.mean(axis=axes, keepdims=True)
variance = input.var(axis=axes, keepdims=True)
# Normalize
denom = (variance + eps).sqrt()
normalized = (input - mean) / denom
# Apply scale and bias if provided
if weight is not None:
normalized = normalized * weight
if bias is not None:
normalized = normalized + bias
return normalized, mean, variance
@dataclass
class NormConfig:
"""Configuration for normalization layers"""
norm_type: NormType
num_features: int
eps: float = 1e-5
momentum: float = 0.1
affine: bool = True
num_groups: int = 32 # For group norm
track_running_stats: bool = True
dtype: np.dtype = np.float32
use_cache: bool = True
class NormalizationCache:
"""Cache manager for normalization computations"""
def __init__(self):
self.db = HeliumDBManager.get_instance()
self.running_means: Dict[str, np.ndarray] = {}
self.running_vars: Dict[str, np.ndarray] = {}
def _compute_key(self, x: np.ndarray, norm_type: NormType) -> str:
"""Compute cache key for input"""
hasher = hashlib.sha256()
hasher.update(x.tobytes())
hasher.update(norm_type.value.encode())
return hasher.hexdigest()
def get(self, key: str) -> Optional[Dict[str, np.ndarray]]:
"""Get cached computation"""
return self.db.get_activation(key)
def set(self, key: str, value: Dict[str, np.ndarray], metadata: Dict):
"""Cache computation"""
self.db.set_activation(key, value, metadata)
def update_running_stats(
self,
key: str,
mean: np.ndarray,
var: np.ndarray,
momentum: float
):
"""Update running statistics"""
if key in self.running_means:
self.running_means[key] = (
(1 - momentum) * self.running_means[key] +
momentum * mean
)
self.running_vars[key] = (
(1 - momentum) * self.running_vars[key] +
momentum * var
)
else:
self.running_means[key] = mean
self.running_vars[key] = var
class Normalization:
"""
Unified normalization implementation with support for:
- Multiple normalization types
- Hardware acceleration
- Mixed precision
- Computation caching
- Running statistics tracking
"""
def __init__(
self,
config: NormConfig,
driver = None
):
"""Initialize normalization layer"""
self.config = config
self.driver = driver
self.cache = NormalizationCache()
# Initialize learnable parameters if needed
if config.affine:
self.gamma = np.ones(config.num_features, dtype=config.dtype)
self.beta = np.zeros(config.num_features, dtype=config.dtype)
else:
self.gamma = None
self.beta = None
@staticmethod
@lru_cache(maxsize=128)
def _get_reshape_dims(
input_shape: Tuple[int, ...],
num_features: int,
norm_type: NormType
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
"""Get reshaping dimensions for parameters"""
ndim = len(input_shape)
if norm_type == NormType.BATCH:
param_shape = (1, num_features) + (1,) * (ndim - 2)
reduction_axes = (0,) + tuple(range(2, ndim))
elif norm_type == NormType.LAYER:
param_shape = (1,) * (ndim - 1) + (num_features,)
reduction_axes = tuple(range(ndim - 1))
else: # GROUP, INSTANCE
param_shape = (1, num_features) + (1,) * (ndim - 2)
reduction_axes = (2,) + tuple(range(3, ndim))
return param_shape, reduction_axes
def _check_input(self, x: np.ndarray):
"""Validate input tensor"""
if x.ndim < 2:
raise ValueError(f"Expected at least 2D input, got shape {x.shape}")
if self.config.norm_type in [NormType.BATCH, NormType.GROUP]:
if x.shape[1] != self.config.num_features:
raise ValueError(
f"Expected {self.config.num_features} features, got {x.shape[1]}"
)
def _compute_stats(
self,
x: np.ndarray,
reduction_axes: Tuple[int, ...]
) -> Tuple[np.ndarray, np.ndarray]:
"""Compute mean and variance"""
if self.driver and hasattr(self.driver, 'reduce_mean'):
mean = self.driver.reduce_mean(x, axis=reduction_axes, keepdims=True)
var = self.driver.reduce_var(x, axis=reduction_axes, keepdims=True)
else:
mean = np.mean(x, axis=reduction_axes, keepdims=True)
var = np.var(x, axis=reduction_axes, keepdims=True)
return mean, var
def normalize(
self,
x: np.ndarray,
training: bool = True
) -> np.ndarray:
"""
Apply normalization to input tensor
"""
self._check_input(x)
# Get cache key and check cache
if self.config.use_cache and not training:
cache_key = self.cache._compute_key(x, self.config.norm_type)
cached = self.cache.get(cache_key)
if cached is not None:
return cached['output']
# Get reshaping dimensions
param_shape, reduction_axes = self._get_reshape_dims(
x.shape,
self.config.num_features,
self.config.norm_type
)
# Special handling for group norm
if self.config.norm_type == NormType.GROUP:
groups = self.config.num_groups
N, C = x.shape[:2]
x = x.reshape(N, groups, C // groups, *x.shape[2:])
reduction_axes = (2,) + tuple(range(3, x.ndim))
# Compute statistics
mean, var = self._compute_stats(x, reduction_axes)
# Update running statistics during training
if training and self.config.track_running_stats:
self.cache.update_running_stats(
str(id(self)),
mean,
var,
self.config.momentum
)
# Normalize
x_norm = (x - mean) / np.sqrt(var + self.config.eps)
# Reshape back if group norm
if self.config.norm_type == NormType.GROUP:
x_norm = x_norm.reshape(N, C, *x.shape[3:])
# Apply affine transform if needed
if self.config.affine:
gamma = self.gamma.reshape(param_shape)
beta = self.beta.reshape(param_shape)
out = gamma * x_norm + beta
else:
out = x_norm
# Cache result if needed
if self.config.use_cache and not training:
self.cache.set(
cache_key,
{
'output': out,
'mean': mean,
'var': var
},
{
'shape': x.shape,
'dtype': str(x.dtype),
'norm_type': self.config.norm_type.value
}
)
return out
@classmethod
def batch_norm(
cls,
x: np.ndarray,
num_features: Optional[int] = None,
**kwargs
) -> np.ndarray:
"""Convenience method for batch normalization"""
config = NormConfig(
norm_type=NormType.BATCH,
num_features=num_features or x.shape[1],
**kwargs
)
return cls(config).normalize(x)
@classmethod
def layer_norm(
cls,
x: np.ndarray,
num_features: Optional[int] = None,
**kwargs
) -> np.ndarray:
"""Convenience method for layer normalization"""
config = NormConfig(
norm_type=NormType.LAYER,
num_features=num_features or x.shape[-1],
**kwargs
)
return cls(config).normalize(x)
@classmethod
def group_norm(
cls,
x: np.ndarray,
num_features: Optional[int] = None,
num_groups: int = 32,
**kwargs
) -> np.ndarray:
"""Convenience method for group normalization"""
config = NormConfig(
norm_type=NormType.GROUP,
num_features=num_features or x.shape[1],
num_groups=num_groups,
**kwargs
)
return cls(config).normalize(x)
@classmethod
def instance_norm(
cls,
x: np.ndarray,
num_features: Optional[int] = None,
**kwargs
) -> np.ndarray:
"""Convenience method for instance normalization"""
config = NormConfig(
norm_type=NormType.INSTANCE,
num_features=num_features or x.shape[1],
**kwargs
)
return cls(config).normalize(x)
@classmethod
def rms_norm(
cls,
x: np.ndarray,
num_features: Optional[int] = None,
**kwargs
) -> np.ndarray:
"""Convenience method for RMS normalization"""
config = NormConfig(
norm_type=NormType.RMS,
num_features=num_features or x.shape[-1],
track_running_stats=False, # RMS norm doesn't use running stats
**kwargs
)
return cls(config).normalize(x)