| from typing import Optional, Union, Callable, Dict, Type, Any
|
| import numpy as np
|
| from enum import Enum
|
| from dataclasses import dataclass
|
| import warnings
|
| from functools import lru_cache
|
| import math
|
| import os
|
| import duckdb
|
| import json
|
| from pathlib import Path
|
| from dotenv import load_dotenv
|
| import hashlib
|
| import pickle
|
| from datetime import datetime
|
|
|
|
|
| load_dotenv()
|
|
|
|
|
| DB_URL = os.getenv('HELIUM_DB_URL', 'hf://datasets/Fred808/helium/storage.json')
|
| DB_FILE = Path(DB_URL.replace('hf://datasets/', ''))
|
| import numpy as np
|
| from enum import Enum
|
| from dataclasses import dataclass
|
| import warnings
|
| from functools import lru_cache
|
| import math
|
| import sqlite3
|
| import pickle
|
| import hashlib
|
| import threading
|
| import time
|
| import os
|
| from pathlib import Path
|
|
|
|
|
| HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
|
| class ActivationType(Enum):
|
| """Supported activation function types"""
|
| RELU = "relu"
|
| GELU = "gelu"
|
| TANH = "tanh"
|
| SIGMOID = "sigmoid"
|
| SWISH = "swish"
|
| MISH = "mish"
|
| RELU6 = "relu6"
|
| ELU = "elu"
|
| SELU = "selu"
|
| LEAKY_RELU = "leaky_relu"
|
|
|
| @dataclass
|
| class ActivationConfig:
|
| """Configuration for activation functions"""
|
| type: ActivationType
|
| dtype: np.dtype = np.float32
|
| inplace: bool = False
|
| alpha: float = 0.01
|
| approximate: bool = True
|
| cache_size: int = 1024
|
|
|
| class DuckDBCache:
|
| """Database-backed cache for activation function computations"""
|
| def __init__(self, size: int = 1024):
|
| self.size = size
|
| self._connect_db()
|
| self._init_tables()
|
|
|
| def _connect_db(self):
|
| """Connect to DuckDB database"""
|
| db_path = str(DB_FILE)
|
| self.conn = duckdb.connect(db_path)
|
|
|
| def _init_tables(self):
|
| """Initialize database tables"""
|
| self.conn.execute("""
|
| CREATE TABLE IF NOT EXISTS activation_cache (
|
| key VARCHAR PRIMARY KEY,
|
| value BLOB,
|
| metadata JSON,
|
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| last_accessed TIMESTAMP
|
| )
|
| """)
|
|
|
|
|
| self.conn.execute("""
|
| CREATE INDEX IF NOT EXISTS idx_activation_cache_key
|
| ON activation_cache(key)
|
| """)
|
|
|
| def _compute_key(self, data: np.ndarray, activation_type: str) -> str:
|
| """Compute cache key based on input data and activation type"""
|
|
|
| hasher = hashlib.sha256()
|
| hasher.update(data.tobytes())
|
| hasher.update(activation_type.encode())
|
| return hasher.hexdigest()
|
|
|
| def clear(self):
|
| """Clear all cached computations"""
|
| self.conn.execute("DELETE FROM activation_cache")
|
|
|
| def get(self, data: np.ndarray, activation_type: str) -> Optional[np.ndarray]:
|
| """Get cached computation result"""
|
| key = self._compute_key(data, activation_type)
|
|
|
| result = self.conn.execute("""
|
| SELECT value, metadata FROM activation_cache
|
| WHERE key = ?
|
| """, [key]).fetchone()
|
|
|
| if result:
|
| value_blob, metadata = result
|
|
|
|
|
| self.conn.execute("""
|
| UPDATE activation_cache
|
| SET last_accessed = CURRENT_TIMESTAMP
|
| WHERE key = ?
|
| """, [key])
|
|
|
|
|
| try:
|
| value = pickle.loads(value_blob)
|
| return value
|
| except Exception as e:
|
| warnings.warn(f"Failed to deserialize cached value: {e}")
|
| return None
|
|
|
| return None
|
|
|
| def set(self, data: np.ndarray, activation_type: str, value: np.ndarray):
|
| """Cache computation result"""
|
| key = self._compute_key(data, activation_type)
|
|
|
|
|
| metadata = {
|
| 'shape': data.shape,
|
| 'dtype': str(data.dtype),
|
| 'activation_type': activation_type,
|
| 'timestamp': datetime.now().isoformat()
|
| }
|
|
|
|
|
| try:
|
| value_blob = pickle.dumps(value)
|
| except Exception as e:
|
| warnings.warn(f"Failed to serialize value for caching: {e}")
|
| return
|
|
|
|
|
| self.conn.execute("""
|
| DELETE FROM activation_cache
|
| WHERE key IN (
|
| SELECT key FROM activation_cache
|
| ORDER BY last_accessed ASC
|
| LIMIT MAX(0, (SELECT COUNT(*) - ?) FROM activation_cache)
|
| )
|
| """, [self.size - 1])
|
|
|
|
|
| self.conn.execute("""
|
| INSERT OR REPLACE INTO activation_cache (key, value, metadata)
|
| VALUES (?, ?, ?)
|
| """, [key, value_blob, json.dumps(metadata)])
|
|
|
| def cleanup_old_entries(self, max_age_days: int = 30):
|
| """Remove entries older than specified days"""
|
| self.conn.execute("""
|
| DELETE FROM activation_cache
|
| WHERE last_accessed < DATEADD(day, ?, CURRENT_TIMESTAMP)
|
| """, [-max_age_days])
|
|
|
| def get_stats(self) -> Dict[str, Any]:
|
| """Get cache statistics"""
|
| stats = self.conn.execute("""
|
| SELECT
|
| COUNT(*) as total_entries,
|
| SUM(LENGTH(value)) as total_size_bytes,
|
| MIN(created_at) as oldest_entry,
|
| MAX(last_accessed) as last_accessed
|
| FROM activation_cache
|
| """).fetchone()
|
|
|
| return {
|
| 'total_entries': stats[0],
|
| 'total_size_mb': stats[1] / (1024 * 1024) if stats[1] else 0,
|
| 'oldest_entry': stats[2],
|
| 'last_accessed': stats[3]
|
| }
|
|
|
| def __del__(self):
|
| """Close database connection on cleanup"""
|
| if hasattr(self, 'conn'):
|
| self.conn.close()
|
|
|
| class Activation:
|
| """
|
| Optimized activation function implementation with support for:
|
| - Hardware acceleration
|
| - Mixed precision
|
| - Memory optimization
|
| - Computation caching
|
| - Fused operations
|
| """
|
|
|
| def __init__(
|
| self,
|
| config: ActivationConfig,
|
| driver = None
|
| ):
|
| """
|
| Initialize activation function.
|
|
|
| Args:
|
| config: Activation configuration
|
| driver: Optional hardware driver for optimized computation
|
| """
|
| self.config = config
|
| self.driver = driver
|
| self.cache = DuckDBCache(config.cache_size)
|
| self._setup_implementation()
|
|
|
| def _setup_implementation(self):
|
| """Setup the appropriate implementation based on configuration"""
|
| implementations = {
|
| ActivationType.RELU: self._relu,
|
| ActivationType.GELU: self._gelu,
|
| ActivationType.TANH: self._tanh,
|
| ActivationType.SIGMOID: self._sigmoid,
|
| ActivationType.SWISH: self._swish,
|
| ActivationType.MISH: self._mish,
|
| ActivationType.RELU6: self._relu6,
|
| ActivationType.ELU: self._elu,
|
| ActivationType.SELU: self._selu,
|
| ActivationType.LEAKY_RELU: self._leaky_relu
|
| }
|
| self._impl = implementations[self.config.type]
|
|
|
| @staticmethod
|
| @lru_cache(maxsize=128)
|
| def _calculate_constants(dtype: np.dtype) -> Dict[str, float]:
|
| """Calculate and cache constants used in activation functions"""
|
| return {
|
| 'sqrt_2_pi': np.sqrt(2 / np.pi).astype(dtype),
|
| 'alpha_gelu': np.float32(0.044715),
|
| 'selu_alpha': np.float32(1.6732632423543772),
|
| 'selu_scale': np.float32(1.0507009873554805)
|
| }
|
|
|
| def _validate_input(self, x: np.ndarray):
|
| """Validate input tensor"""
|
| if not isinstance(x, np.ndarray):
|
| raise TypeError(f"Expected numpy.ndarray, got {type(x)}")
|
|
|
| def _prepare_input(self, x: np.ndarray) -> np.ndarray:
|
| """Prepare input for computation"""
|
| if x.dtype != self.config.dtype:
|
| x = x.astype(self.config.dtype)
|
| return x if not self.config.inplace else x.copy()
|
|
|
| def _try_driver_implementation(
|
| self,
|
| x: np.ndarray,
|
| func_name: str
|
| ) -> Optional[np.ndarray]:
|
| """Try to use driver implementation if available"""
|
| if self.driver and hasattr(self.driver, func_name):
|
| return getattr(self.driver, func_name)(x)
|
| return None
|
|
|
| def _relu(self, x: np.ndarray) -> np.ndarray:
|
| """Optimized ReLU implementation"""
|
| result = self._try_driver_implementation(x, 'relu')
|
| if result is not None:
|
| return result
|
| return np.maximum(x, 0, out=x if self.config.inplace else None)
|
|
|
| def _gelu(self, x: np.ndarray) -> np.ndarray:
|
| """Optimized GELU implementation"""
|
| result = self._try_driver_implementation(x, 'gelu')
|
| if result is not None:
|
| return result
|
|
|
| constants = self._calculate_constants(x.dtype)
|
| if self.config.approximate:
|
|
|
| cdf = x + constants['alpha_gelu'] * np.power(x, 3)
|
| cdf *= constants['sqrt_2_pi']
|
| return 0.5 * x * (1 + np.tanh(cdf))
|
| else:
|
|
|
| return 0.5 * x * (1 + math.erf(x / np.sqrt(2)))
|
|
|
| def _tanh(self, x: np.ndarray) -> np.ndarray:
|
| """Optimized tanh implementation"""
|
| result = self._try_driver_implementation(x, 'tanh')
|
| if result is not None:
|
| return result
|
| return np.tanh(x, out=x if self.config.inplace else None)
|
|
|
| def _sigmoid(self, x: np.ndarray) -> np.ndarray:
|
| """Optimized sigmoid implementation"""
|
| result = self._try_driver_implementation(x, 'sigmoid')
|
| if result is not None:
|
| return result
|
| return 1 / (1 + np.exp(-x, out=x if self.config.inplace else None))
|
|
|
| def _swish(self, x: np.ndarray) -> np.ndarray:
|
| """Optimized Swish implementation (x * sigmoid(x))"""
|
| result = self._try_driver_implementation(x, 'swish')
|
| if result is not None:
|
| return result
|
| return x * self._sigmoid(x)
|
|
|
| def _mish(self, x: np.ndarray) -> np.ndarray:
|
| """Optimized Mish implementation (x * tanh(softplus(x)))"""
|
| result = self._try_driver_implementation(x, 'mish')
|
| if result is not None:
|
| return result
|
| return x * np.tanh(np.log1p(np.exp(x)))
|
|
|
| def _relu6(self, x: np.ndarray) -> np.ndarray:
|
| """ReLU6 implementation (min(max(0, x), 6))"""
|
| result = self._try_driver_implementation(x, 'relu6')
|
| if result is not None:
|
| return result
|
| return np.clip(x, 0, 6, out=x if self.config.inplace else None)
|
|
|
| def _elu(self, x: np.ndarray) -> np.ndarray:
|
| """ELU implementation"""
|
| result = self._try_driver_implementation(x, 'elu')
|
| if result is not None:
|
| return result
|
| return np.where(x > 0, x, self.config.alpha * (np.exp(x) - 1))
|
|
|
| def _selu(self, x: np.ndarray) -> np.ndarray:
|
| """SELU implementation"""
|
| result = self._try_driver_implementation(x, 'selu')
|
| if result is not None:
|
| return result
|
| constants = self._calculate_constants(x.dtype)
|
| return constants['selu_scale'] * np.where(
|
| x > 0,
|
| x,
|
| constants['selu_alpha'] * (np.exp(x) - 1)
|
| )
|
|
|
| def _leaky_relu(self, x: np.ndarray) -> np.ndarray:
|
| """Leaky ReLU implementation"""
|
| result = self._try_driver_implementation(x, 'leaky_relu')
|
| if result is not None:
|
| return result
|
| return np.where(x > 0, x, self.config.alpha * x)
|
|
|
| def __call__(self, x: np.ndarray) -> np.ndarray:
|
| """
|
| Apply activation function to input tensor.
|
|
|
| Args:
|
| x: Input tensor
|
|
|
| Returns:
|
| Output tensor after activation
|
| """
|
| self._validate_input(x)
|
| x = self._prepare_input(x)
|
| return self._impl(x)
|
|
|
|
|
|
|
| def relu(x: np.ndarray, driver=None) -> np.ndarray:
|
| """Legacy ReLU interface"""
|
| config = ActivationConfig(type=ActivationType.RELU, dtype=x.dtype)
|
| return Activation(config, driver)(x)
|
|
|
| def gelu(x: np.ndarray, driver=None, approximate: bool = True) -> np.ndarray:
|
| """Legacy GELU interface"""
|
| config = ActivationConfig(type=ActivationType.GELU, dtype=x.dtype, approximate=approximate)
|
| return Activation(config, driver)(x)
|
|
|
| def tanh(x: np.ndarray, driver=None) -> np.ndarray:
|
| """Legacy tanh interface"""
|
| config = ActivationConfig(type=ActivationType.TANH, dtype=x.dtype)
|
| return Activation(config, driver)(x)
|
|
|
| def sigmoid(x: np.ndarray, driver=None) -> np.ndarray:
|
| """Legacy sigmoid interface"""
|
| config = ActivationConfig(type=ActivationType.SIGMOID, dtype=x.dtype)
|
| return Activation(config, driver)(x)
|
|
|