|
|
from typing import Optional, Union, Tuple, List, Dict
|
|
|
import numpy as np
|
|
|
from dataclasses import dataclass
|
|
|
from enum import Enum
|
|
|
import warnings
|
|
|
from .core.db_manager import HeliumDBManager
|
|
|
from virtual_gpu_driver.src.ai.tensor_types import Tensor, Device, DType
|
|
|
import hashlib
|
|
|
import json
|
|
|
from functools import lru_cache
|
|
|
|
|
|
class NormType(Enum):
|
|
|
"""Supported normalization types"""
|
|
|
BATCH = "batch"
|
|
|
LAYER = "layer"
|
|
|
GROUP = "group"
|
|
|
INSTANCE = "instance"
|
|
|
RMS = "rms"
|
|
|
|
|
|
def normalize(input: Tensor,
|
|
|
mean: Optional[Tensor] = None,
|
|
|
variance: Optional[Tensor] = None,
|
|
|
weight: Optional[Tensor] = None,
|
|
|
bias: Optional[Tensor] = None,
|
|
|
eps: float = 1e-5) -> Tuple[Tensor, Tensor, Tensor]:
|
|
|
"""
|
|
|
Normalizes the input using mean and variance. If mean/variance not provided,
|
|
|
they are computed from the input.
|
|
|
|
|
|
Args:
|
|
|
input: Input tensor
|
|
|
mean: Optional pre-computed mean
|
|
|
variance: Optional pre-computed variance
|
|
|
weight: Optional scale parameter
|
|
|
bias: Optional bias parameter
|
|
|
eps: Small constant for numerical stability
|
|
|
|
|
|
Returns:
|
|
|
Tuple of (normalized tensor, mean, variance)
|
|
|
"""
|
|
|
|
|
|
if mean is None or variance is None:
|
|
|
|
|
|
axes = tuple(range(input.ndim - 1))
|
|
|
mean = input.mean(axis=axes, keepdims=True)
|
|
|
variance = input.var(axis=axes, keepdims=True)
|
|
|
|
|
|
|
|
|
denom = (variance + eps).sqrt()
|
|
|
normalized = (input - mean) / denom
|
|
|
|
|
|
|
|
|
if weight is not None:
|
|
|
normalized = normalized * weight
|
|
|
if bias is not None:
|
|
|
normalized = normalized + bias
|
|
|
|
|
|
return normalized, mean, variance
|
|
|
|
|
|
@dataclass
|
|
|
class NormConfig:
|
|
|
"""Configuration for normalization layers"""
|
|
|
norm_type: NormType
|
|
|
num_features: int
|
|
|
eps: float = 1e-5
|
|
|
momentum: float = 0.1
|
|
|
affine: bool = True
|
|
|
num_groups: int = 32
|
|
|
track_running_stats: bool = True
|
|
|
dtype: np.dtype = np.float32
|
|
|
use_cache: bool = True
|
|
|
|
|
|
class NormalizationCache:
|
|
|
"""Cache manager for normalization computations"""
|
|
|
def __init__(self):
|
|
|
self.db = HeliumDBManager.get_instance()
|
|
|
self.running_means: Dict[str, np.ndarray] = {}
|
|
|
self.running_vars: Dict[str, np.ndarray] = {}
|
|
|
|
|
|
def _compute_key(self, x: np.ndarray, norm_type: NormType) -> str:
|
|
|
"""Compute cache key for input"""
|
|
|
hasher = hashlib.sha256()
|
|
|
hasher.update(x.tobytes())
|
|
|
hasher.update(norm_type.value.encode())
|
|
|
return hasher.hexdigest()
|
|
|
|
|
|
def get(self, key: str) -> Optional[Dict[str, np.ndarray]]:
|
|
|
"""Get cached computation"""
|
|
|
return self.db.get_activation(key)
|
|
|
|
|
|
def set(self, key: str, value: Dict[str, np.ndarray], metadata: Dict):
|
|
|
"""Cache computation"""
|
|
|
self.db.set_activation(key, value, metadata)
|
|
|
|
|
|
def update_running_stats(
|
|
|
self,
|
|
|
key: str,
|
|
|
mean: np.ndarray,
|
|
|
var: np.ndarray,
|
|
|
momentum: float
|
|
|
):
|
|
|
"""Update running statistics"""
|
|
|
if key in self.running_means:
|
|
|
self.running_means[key] = (
|
|
|
(1 - momentum) * self.running_means[key] +
|
|
|
momentum * mean
|
|
|
)
|
|
|
self.running_vars[key] = (
|
|
|
(1 - momentum) * self.running_vars[key] +
|
|
|
momentum * var
|
|
|
)
|
|
|
else:
|
|
|
self.running_means[key] = mean
|
|
|
self.running_vars[key] = var
|
|
|
|
|
|
class Normalization:
|
|
|
"""
|
|
|
Unified normalization implementation with support for:
|
|
|
- Multiple normalization types
|
|
|
- Hardware acceleration
|
|
|
- Mixed precision
|
|
|
- Computation caching
|
|
|
- Running statistics tracking
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: NormConfig,
|
|
|
driver = None
|
|
|
):
|
|
|
"""Initialize normalization layer"""
|
|
|
self.config = config
|
|
|
self.driver = driver
|
|
|
self.cache = NormalizationCache()
|
|
|
|
|
|
|
|
|
if config.affine:
|
|
|
self.gamma = np.ones(config.num_features, dtype=config.dtype)
|
|
|
self.beta = np.zeros(config.num_features, dtype=config.dtype)
|
|
|
else:
|
|
|
self.gamma = None
|
|
|
self.beta = None
|
|
|
|
|
|
@staticmethod
|
|
|
@lru_cache(maxsize=128)
|
|
|
def _get_reshape_dims(
|
|
|
input_shape: Tuple[int, ...],
|
|
|
num_features: int,
|
|
|
norm_type: NormType
|
|
|
) -> Tuple[Tuple[int, ...], Tuple[int, ...]]:
|
|
|
"""Get reshaping dimensions for parameters"""
|
|
|
ndim = len(input_shape)
|
|
|
if norm_type == NormType.BATCH:
|
|
|
param_shape = (1, num_features) + (1,) * (ndim - 2)
|
|
|
reduction_axes = (0,) + tuple(range(2, ndim))
|
|
|
elif norm_type == NormType.LAYER:
|
|
|
param_shape = (1,) * (ndim - 1) + (num_features,)
|
|
|
reduction_axes = tuple(range(ndim - 1))
|
|
|
else:
|
|
|
param_shape = (1, num_features) + (1,) * (ndim - 2)
|
|
|
reduction_axes = (2,) + tuple(range(3, ndim))
|
|
|
return param_shape, reduction_axes
|
|
|
|
|
|
def _check_input(self, x: np.ndarray):
|
|
|
"""Validate input tensor"""
|
|
|
if x.ndim < 2:
|
|
|
raise ValueError(f"Expected at least 2D input, got shape {x.shape}")
|
|
|
|
|
|
if self.config.norm_type in [NormType.BATCH, NormType.GROUP]:
|
|
|
if x.shape[1] != self.config.num_features:
|
|
|
raise ValueError(
|
|
|
f"Expected {self.config.num_features} features, got {x.shape[1]}"
|
|
|
)
|
|
|
|
|
|
def _compute_stats(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
reduction_axes: Tuple[int, ...]
|
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
|
"""Compute mean and variance"""
|
|
|
if self.driver and hasattr(self.driver, 'reduce_mean'):
|
|
|
mean = self.driver.reduce_mean(x, axis=reduction_axes, keepdims=True)
|
|
|
var = self.driver.reduce_var(x, axis=reduction_axes, keepdims=True)
|
|
|
else:
|
|
|
mean = np.mean(x, axis=reduction_axes, keepdims=True)
|
|
|
var = np.var(x, axis=reduction_axes, keepdims=True)
|
|
|
return mean, var
|
|
|
|
|
|
|
|
|
def normalize(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
training: bool = True
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Apply normalization to input tensor
|
|
|
"""
|
|
|
self._check_input(x)
|
|
|
|
|
|
|
|
|
if self.config.use_cache and not training:
|
|
|
cache_key = self.cache._compute_key(x, self.config.norm_type)
|
|
|
cached = self.cache.get(cache_key)
|
|
|
if cached is not None:
|
|
|
return cached['output']
|
|
|
|
|
|
|
|
|
param_shape, reduction_axes = self._get_reshape_dims(
|
|
|
x.shape,
|
|
|
self.config.num_features,
|
|
|
self.config.norm_type
|
|
|
)
|
|
|
|
|
|
|
|
|
if self.config.norm_type == NormType.GROUP:
|
|
|
groups = self.config.num_groups
|
|
|
N, C = x.shape[:2]
|
|
|
x = x.reshape(N, groups, C // groups, *x.shape[2:])
|
|
|
reduction_axes = (2,) + tuple(range(3, x.ndim))
|
|
|
|
|
|
|
|
|
mean, var = self._compute_stats(x, reduction_axes)
|
|
|
|
|
|
|
|
|
if training and self.config.track_running_stats:
|
|
|
self.cache.update_running_stats(
|
|
|
str(id(self)),
|
|
|
mean,
|
|
|
var,
|
|
|
self.config.momentum
|
|
|
)
|
|
|
|
|
|
|
|
|
x_norm = (x - mean) / np.sqrt(var + self.config.eps)
|
|
|
|
|
|
|
|
|
if self.config.norm_type == NormType.GROUP:
|
|
|
x_norm = x_norm.reshape(N, C, *x.shape[3:])
|
|
|
|
|
|
|
|
|
if self.config.affine:
|
|
|
gamma = self.gamma.reshape(param_shape)
|
|
|
beta = self.beta.reshape(param_shape)
|
|
|
out = gamma * x_norm + beta
|
|
|
else:
|
|
|
out = x_norm
|
|
|
|
|
|
|
|
|
if self.config.use_cache and not training:
|
|
|
self.cache.set(
|
|
|
cache_key,
|
|
|
{
|
|
|
'output': out,
|
|
|
'mean': mean,
|
|
|
'var': var
|
|
|
},
|
|
|
{
|
|
|
'shape': x.shape,
|
|
|
'dtype': str(x.dtype),
|
|
|
'norm_type': self.config.norm_type.value
|
|
|
}
|
|
|
)
|
|
|
|
|
|
return out
|
|
|
|
|
|
@classmethod
|
|
|
def batch_norm(
|
|
|
cls,
|
|
|
x: np.ndarray,
|
|
|
num_features: Optional[int] = None,
|
|
|
**kwargs
|
|
|
) -> np.ndarray:
|
|
|
"""Convenience method for batch normalization"""
|
|
|
config = NormConfig(
|
|
|
norm_type=NormType.BATCH,
|
|
|
num_features=num_features or x.shape[1],
|
|
|
**kwargs
|
|
|
)
|
|
|
return cls(config).normalize(x)
|
|
|
|
|
|
@classmethod
|
|
|
def layer_norm(
|
|
|
cls,
|
|
|
x: np.ndarray,
|
|
|
num_features: Optional[int] = None,
|
|
|
**kwargs
|
|
|
) -> np.ndarray:
|
|
|
"""Convenience method for layer normalization"""
|
|
|
config = NormConfig(
|
|
|
norm_type=NormType.LAYER,
|
|
|
num_features=num_features or x.shape[-1],
|
|
|
**kwargs
|
|
|
)
|
|
|
return cls(config).normalize(x)
|
|
|
|
|
|
@classmethod
|
|
|
def group_norm(
|
|
|
cls,
|
|
|
x: np.ndarray,
|
|
|
num_features: Optional[int] = None,
|
|
|
num_groups: int = 32,
|
|
|
**kwargs
|
|
|
) -> np.ndarray:
|
|
|
"""Convenience method for group normalization"""
|
|
|
config = NormConfig(
|
|
|
norm_type=NormType.GROUP,
|
|
|
num_features=num_features or x.shape[1],
|
|
|
num_groups=num_groups,
|
|
|
**kwargs
|
|
|
)
|
|
|
return cls(config).normalize(x)
|
|
|
|
|
|
@classmethod
|
|
|
def instance_norm(
|
|
|
cls,
|
|
|
x: np.ndarray,
|
|
|
num_features: Optional[int] = None,
|
|
|
**kwargs
|
|
|
) -> np.ndarray:
|
|
|
"""Convenience method for instance normalization"""
|
|
|
config = NormConfig(
|
|
|
norm_type=NormType.INSTANCE,
|
|
|
num_features=num_features or x.shape[1],
|
|
|
**kwargs
|
|
|
)
|
|
|
return cls(config).normalize(x)
|
|
|
|
|
|
@classmethod
|
|
|
def rms_norm(
|
|
|
cls,
|
|
|
x: np.ndarray,
|
|
|
num_features: Optional[int] = None,
|
|
|
**kwargs
|
|
|
) -> np.ndarray:
|
|
|
"""Convenience method for RMS normalization"""
|
|
|
config = NormConfig(
|
|
|
norm_type=NormType.RMS,
|
|
|
num_features=num_features or x.shape[-1],
|
|
|
track_running_stats=False,
|
|
|
**kwargs
|
|
|
)
|
|
|
return cls(config).normalize(x)
|
|
|
|