from typing import Optional, Union, Tuple, List, Dict import numpy as np from dataclasses import dataclass from enum import Enum import warnings from .core.db_manager import HeliumDBManager from virtual_gpu_driver.src.ai.tensor_types import Tensor, Device, DType import hashlib import json from functools import lru_cache class NormType(Enum): """Supported normalization types""" BATCH = "batch" LAYER = "layer" GROUP = "group" INSTANCE = "instance" RMS = "rms" def normalize(input: Tensor, mean: Optional[Tensor] = None, variance: Optional[Tensor] = None, weight: Optional[Tensor] = None, bias: Optional[Tensor] = None, eps: float = 1e-5) -> Tuple[Tensor, Tensor, Tensor]: """ Normalizes the input using mean and variance. If mean/variance not provided, they are computed from the input. Args: input: Input tensor mean: Optional pre-computed mean variance: Optional pre-computed variance weight: Optional scale parameter bias: Optional bias parameter eps: Small constant for numerical stability Returns: Tuple of (normalized tensor, mean, variance) """ # Calculate mean and variance if not provided if mean is None or variance is None: # Compute stats along last dimension axes = tuple(range(input.ndim - 1)) mean = input.mean(axis=axes, keepdims=True) variance = input.var(axis=axes, keepdims=True) # Normalize denom = (variance + eps).sqrt() normalized = (input - mean) / denom # Apply scale and bias if provided if weight is not None: normalized = normalized * weight if bias is not None: normalized = normalized + bias return normalized, mean, variance @dataclass class NormConfig: """Configuration for normalization layers""" norm_type: NormType num_features: int eps: float = 1e-5 momentum: float = 0.1 affine: bool = True num_groups: int = 32 # For group norm track_running_stats: bool = True dtype: np.dtype = np.float32 use_cache: bool = True class NormalizationCache: """Cache manager for normalization computations""" def __init__(self): self.db = HeliumDBManager.get_instance() self.running_means: Dict[str, np.ndarray] = {} self.running_vars: Dict[str, np.ndarray] = {} def _compute_key(self, x: np.ndarray, norm_type: NormType) -> str: """Compute cache key for input""" hasher = hashlib.sha256() hasher.update(x.tobytes()) hasher.update(norm_type.value.encode()) return hasher.hexdigest() def get(self, key: str) -> Optional[Dict[str, np.ndarray]]: """Get cached computation""" return self.db.get_activation(key) def set(self, key: str, value: Dict[str, np.ndarray], metadata: Dict): """Cache computation""" self.db.set_activation(key, value, metadata) def update_running_stats( self, key: str, mean: np.ndarray, var: np.ndarray, momentum: float ): """Update running statistics""" if key in self.running_means: self.running_means[key] = ( (1 - momentum) * self.running_means[key] + momentum * mean ) self.running_vars[key] = ( (1 - momentum) * self.running_vars[key] + momentum * var ) else: self.running_means[key] = mean self.running_vars[key] = var class Normalization: """ Unified normalization implementation with support for: - Multiple normalization types - Hardware acceleration - Mixed precision - Computation caching - Running statistics tracking """ def __init__( self, config: NormConfig, driver = None ): """Initialize normalization layer""" self.config = config self.driver = driver self.cache = NormalizationCache() # Initialize learnable parameters if needed if config.affine: self.gamma = np.ones(config.num_features, dtype=config.dtype) self.beta = np.zeros(config.num_features, dtype=config.dtype) else: self.gamma = None self.beta = None @staticmethod @lru_cache(maxsize=128) def _get_reshape_dims( input_shape: Tuple[int, ...], num_features: int, norm_type: NormType ) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: """Get reshaping dimensions for parameters""" ndim = len(input_shape) if norm_type == NormType.BATCH: param_shape = (1, num_features) + (1,) * (ndim - 2) reduction_axes = (0,) + tuple(range(2, ndim)) elif norm_type == NormType.LAYER: param_shape = (1,) * (ndim - 1) + (num_features,) reduction_axes = tuple(range(ndim - 1)) else: # GROUP, INSTANCE param_shape = (1, num_features) + (1,) * (ndim - 2) reduction_axes = (2,) + tuple(range(3, ndim)) return param_shape, reduction_axes def _check_input(self, x: np.ndarray): """Validate input tensor""" if x.ndim < 2: raise ValueError(f"Expected at least 2D input, got shape {x.shape}") if self.config.norm_type in [NormType.BATCH, NormType.GROUP]: if x.shape[1] != self.config.num_features: raise ValueError( f"Expected {self.config.num_features} features, got {x.shape[1]}" ) def _compute_stats( self, x: np.ndarray, reduction_axes: Tuple[int, ...] ) -> Tuple[np.ndarray, np.ndarray]: """Compute mean and variance""" if self.driver and hasattr(self.driver, 'reduce_mean'): mean = self.driver.reduce_mean(x, axis=reduction_axes, keepdims=True) var = self.driver.reduce_var(x, axis=reduction_axes, keepdims=True) else: mean = np.mean(x, axis=reduction_axes, keepdims=True) var = np.var(x, axis=reduction_axes, keepdims=True) return mean, var def normalize( self, x: np.ndarray, training: bool = True ) -> np.ndarray: """ Apply normalization to input tensor """ self._check_input(x) # Get cache key and check cache if self.config.use_cache and not training: cache_key = self.cache._compute_key(x, self.config.norm_type) cached = self.cache.get(cache_key) if cached is not None: return cached['output'] # Get reshaping dimensions param_shape, reduction_axes = self._get_reshape_dims( x.shape, self.config.num_features, self.config.norm_type ) # Special handling for group norm if self.config.norm_type == NormType.GROUP: groups = self.config.num_groups N, C = x.shape[:2] x = x.reshape(N, groups, C // groups, *x.shape[2:]) reduction_axes = (2,) + tuple(range(3, x.ndim)) # Compute statistics mean, var = self._compute_stats(x, reduction_axes) # Update running statistics during training if training and self.config.track_running_stats: self.cache.update_running_stats( str(id(self)), mean, var, self.config.momentum ) # Normalize x_norm = (x - mean) / np.sqrt(var + self.config.eps) # Reshape back if group norm if self.config.norm_type == NormType.GROUP: x_norm = x_norm.reshape(N, C, *x.shape[3:]) # Apply affine transform if needed if self.config.affine: gamma = self.gamma.reshape(param_shape) beta = self.beta.reshape(param_shape) out = gamma * x_norm + beta else: out = x_norm # Cache result if needed if self.config.use_cache and not training: self.cache.set( cache_key, { 'output': out, 'mean': mean, 'var': var }, { 'shape': x.shape, 'dtype': str(x.dtype), 'norm_type': self.config.norm_type.value } ) return out @classmethod def batch_norm( cls, x: np.ndarray, num_features: Optional[int] = None, **kwargs ) -> np.ndarray: """Convenience method for batch normalization""" config = NormConfig( norm_type=NormType.BATCH, num_features=num_features or x.shape[1], **kwargs ) return cls(config).normalize(x) @classmethod def layer_norm( cls, x: np.ndarray, num_features: Optional[int] = None, **kwargs ) -> np.ndarray: """Convenience method for layer normalization""" config = NormConfig( norm_type=NormType.LAYER, num_features=num_features or x.shape[-1], **kwargs ) return cls(config).normalize(x) @classmethod def group_norm( cls, x: np.ndarray, num_features: Optional[int] = None, num_groups: int = 32, **kwargs ) -> np.ndarray: """Convenience method for group normalization""" config = NormConfig( norm_type=NormType.GROUP, num_features=num_features or x.shape[1], num_groups=num_groups, **kwargs ) return cls(config).normalize(x) @classmethod def instance_norm( cls, x: np.ndarray, num_features: Optional[int] = None, **kwargs ) -> np.ndarray: """Convenience method for instance normalization""" config = NormConfig( norm_type=NormType.INSTANCE, num_features=num_features or x.shape[1], **kwargs ) return cls(config).normalize(x) @classmethod def rms_norm( cls, x: np.ndarray, num_features: Optional[int] = None, **kwargs ) -> np.ndarray: """Convenience method for RMS normalization""" config = NormConfig( norm_type=NormType.RMS, num_features=num_features or x.shape[-1], track_running_stats=False, # RMS norm doesn't use running stats **kwargs ) return cls(config).normalize(x)