INV / helium /activations.py
Fred808's picture
Upload 256 files
7a0c684 verified
from typing import Optional, Union, Callable, Dict, Type, Any
import numpy as np
from enum import Enum
from dataclasses import dataclass
import warnings
from functools import lru_cache
import math
import os
import duckdb
import json
from pathlib import Path
from dotenv import load_dotenv
import hashlib
import pickle
from datetime import datetime
# Load environment variables
load_dotenv()
# Get database URL from environment
DB_URL = os.getenv('HELIUM_DB_URL', 'hf://datasets/Fred808/helium/storage.json')
DB_FILE = Path(DB_URL.replace('hf://datasets/', ''))
import numpy as np
from enum import Enum
from dataclasses import dataclass
import warnings
from functools import lru_cache
import math
import sqlite3
import pickle
import hashlib
import threading
import time
import os
from pathlib import Path
# Initialize HuggingFace token from environment
HF_TOKEN = os.getenv("HF_TOKEN")
class ActivationType(Enum):
"""Supported activation function types"""
RELU = "relu"
GELU = "gelu"
TANH = "tanh"
SIGMOID = "sigmoid"
SWISH = "swish"
MISH = "mish"
RELU6 = "relu6"
ELU = "elu"
SELU = "selu"
LEAKY_RELU = "leaky_relu"
@dataclass
class ActivationConfig:
"""Configuration for activation functions"""
type: ActivationType
dtype: np.dtype = np.float32
inplace: bool = False
alpha: float = 0.01 # For LeakyReLU, ELU
approximate: bool = True # For GELU
cache_size: int = 1024 # For lookup table optimization
class DuckDBCache:
"""Database-backed cache for activation function computations"""
def __init__(self, size: int = 1024):
self.size = size
self._connect_db()
self._init_tables()
def _connect_db(self):
"""Connect to DuckDB database"""
db_path = str(DB_FILE)
self.conn = duckdb.connect(db_path)
def _init_tables(self):
"""Initialize database tables"""
self.conn.execute("""
CREATE TABLE IF NOT EXISTS activation_cache (
key VARCHAR PRIMARY KEY,
value BLOB,
metadata JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_accessed TIMESTAMP
)
""")
# Create index for faster lookups
self.conn.execute("""
CREATE INDEX IF NOT EXISTS idx_activation_cache_key
ON activation_cache(key)
""")
def _compute_key(self, data: np.ndarray, activation_type: str) -> str:
"""Compute cache key based on input data and activation type"""
# Hash the input data and metadata
hasher = hashlib.sha256()
hasher.update(data.tobytes())
hasher.update(activation_type.encode())
return hasher.hexdigest()
def clear(self):
"""Clear all cached computations"""
self.conn.execute("DELETE FROM activation_cache")
def get(self, data: np.ndarray, activation_type: str) -> Optional[np.ndarray]:
"""Get cached computation result"""
key = self._compute_key(data, activation_type)
result = self.conn.execute("""
SELECT value, metadata FROM activation_cache
WHERE key = ?
""", [key]).fetchone()
if result:
value_blob, metadata = result
# Update last accessed timestamp
self.conn.execute("""
UPDATE activation_cache
SET last_accessed = CURRENT_TIMESTAMP
WHERE key = ?
""", [key])
# Deserialize the value
try:
value = pickle.loads(value_blob)
return value
except Exception as e:
warnings.warn(f"Failed to deserialize cached value: {e}")
return None
return None
def set(self, data: np.ndarray, activation_type: str, value: np.ndarray):
"""Cache computation result"""
key = self._compute_key(data, activation_type)
# Prepare metadata
metadata = {
'shape': data.shape,
'dtype': str(data.dtype),
'activation_type': activation_type,
'timestamp': datetime.now().isoformat()
}
# Serialize the value
try:
value_blob = pickle.dumps(value)
except Exception as e:
warnings.warn(f"Failed to serialize value for caching: {e}")
return
# Check cache size and remove old entries if needed
self.conn.execute("""
DELETE FROM activation_cache
WHERE key IN (
SELECT key FROM activation_cache
ORDER BY last_accessed ASC
LIMIT MAX(0, (SELECT COUNT(*) - ?) FROM activation_cache)
)
""", [self.size - 1])
# Insert new value
self.conn.execute("""
INSERT OR REPLACE INTO activation_cache (key, value, metadata)
VALUES (?, ?, ?)
""", [key, value_blob, json.dumps(metadata)])
def cleanup_old_entries(self, max_age_days: int = 30):
"""Remove entries older than specified days"""
self.conn.execute("""
DELETE FROM activation_cache
WHERE last_accessed < DATEADD(day, ?, CURRENT_TIMESTAMP)
""", [-max_age_days])
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics"""
stats = self.conn.execute("""
SELECT
COUNT(*) as total_entries,
SUM(LENGTH(value)) as total_size_bytes,
MIN(created_at) as oldest_entry,
MAX(last_accessed) as last_accessed
FROM activation_cache
""").fetchone()
return {
'total_entries': stats[0],
'total_size_mb': stats[1] / (1024 * 1024) if stats[1] else 0,
'oldest_entry': stats[2],
'last_accessed': stats[3]
}
def __del__(self):
"""Close database connection on cleanup"""
if hasattr(self, 'conn'):
self.conn.close()
class Activation:
"""
Optimized activation function implementation with support for:
- Hardware acceleration
- Mixed precision
- Memory optimization
- Computation caching
- Fused operations
"""
def __init__(
self,
config: ActivationConfig,
driver = None
):
"""
Initialize activation function.
Args:
config: Activation configuration
driver: Optional hardware driver for optimized computation
"""
self.config = config
self.driver = driver
self.cache = DuckDBCache(config.cache_size)
self._setup_implementation()
def _setup_implementation(self):
"""Setup the appropriate implementation based on configuration"""
implementations = {
ActivationType.RELU: self._relu,
ActivationType.GELU: self._gelu,
ActivationType.TANH: self._tanh,
ActivationType.SIGMOID: self._sigmoid,
ActivationType.SWISH: self._swish,
ActivationType.MISH: self._mish,
ActivationType.RELU6: self._relu6,
ActivationType.ELU: self._elu,
ActivationType.SELU: self._selu,
ActivationType.LEAKY_RELU: self._leaky_relu
}
self._impl = implementations[self.config.type]
@staticmethod
@lru_cache(maxsize=128)
def _calculate_constants(dtype: np.dtype) -> Dict[str, float]:
"""Calculate and cache constants used in activation functions"""
return {
'sqrt_2_pi': np.sqrt(2 / np.pi).astype(dtype),
'alpha_gelu': np.float32(0.044715),
'selu_alpha': np.float32(1.6732632423543772),
'selu_scale': np.float32(1.0507009873554805)
}
def _validate_input(self, x: np.ndarray):
"""Validate input tensor"""
if not isinstance(x, np.ndarray):
raise TypeError(f"Expected numpy.ndarray, got {type(x)}")
def _prepare_input(self, x: np.ndarray) -> np.ndarray:
"""Prepare input for computation"""
if x.dtype != self.config.dtype:
x = x.astype(self.config.dtype)
return x if not self.config.inplace else x.copy()
def _try_driver_implementation(
self,
x: np.ndarray,
func_name: str
) -> Optional[np.ndarray]:
"""Try to use driver implementation if available"""
if self.driver and hasattr(self.driver, func_name):
return getattr(self.driver, func_name)(x)
return None
def _relu(self, x: np.ndarray) -> np.ndarray:
"""Optimized ReLU implementation"""
result = self._try_driver_implementation(x, 'relu')
if result is not None:
return result
return np.maximum(x, 0, out=x if self.config.inplace else None)
def _gelu(self, x: np.ndarray) -> np.ndarray:
"""Optimized GELU implementation"""
result = self._try_driver_implementation(x, 'gelu')
if result is not None:
return result
constants = self._calculate_constants(x.dtype)
if self.config.approximate:
# Fast approximation
cdf = x + constants['alpha_gelu'] * np.power(x, 3)
cdf *= constants['sqrt_2_pi']
return 0.5 * x * (1 + np.tanh(cdf))
else:
# Exact computation using error function
return 0.5 * x * (1 + math.erf(x / np.sqrt(2)))
def _tanh(self, x: np.ndarray) -> np.ndarray:
"""Optimized tanh implementation"""
result = self._try_driver_implementation(x, 'tanh')
if result is not None:
return result
return np.tanh(x, out=x if self.config.inplace else None)
def _sigmoid(self, x: np.ndarray) -> np.ndarray:
"""Optimized sigmoid implementation"""
result = self._try_driver_implementation(x, 'sigmoid')
if result is not None:
return result
return 1 / (1 + np.exp(-x, out=x if self.config.inplace else None))
def _swish(self, x: np.ndarray) -> np.ndarray:
"""Optimized Swish implementation (x * sigmoid(x))"""
result = self._try_driver_implementation(x, 'swish')
if result is not None:
return result
return x * self._sigmoid(x)
def _mish(self, x: np.ndarray) -> np.ndarray:
"""Optimized Mish implementation (x * tanh(softplus(x)))"""
result = self._try_driver_implementation(x, 'mish')
if result is not None:
return result
return x * np.tanh(np.log1p(np.exp(x)))
def _relu6(self, x: np.ndarray) -> np.ndarray:
"""ReLU6 implementation (min(max(0, x), 6))"""
result = self._try_driver_implementation(x, 'relu6')
if result is not None:
return result
return np.clip(x, 0, 6, out=x if self.config.inplace else None)
def _elu(self, x: np.ndarray) -> np.ndarray:
"""ELU implementation"""
result = self._try_driver_implementation(x, 'elu')
if result is not None:
return result
return np.where(x > 0, x, self.config.alpha * (np.exp(x) - 1))
def _selu(self, x: np.ndarray) -> np.ndarray:
"""SELU implementation"""
result = self._try_driver_implementation(x, 'selu')
if result is not None:
return result
constants = self._calculate_constants(x.dtype)
return constants['selu_scale'] * np.where(
x > 0,
x,
constants['selu_alpha'] * (np.exp(x) - 1)
)
def _leaky_relu(self, x: np.ndarray) -> np.ndarray:
"""Leaky ReLU implementation"""
result = self._try_driver_implementation(x, 'leaky_relu')
if result is not None:
return result
return np.where(x > 0, x, self.config.alpha * x)
def __call__(self, x: np.ndarray) -> np.ndarray:
"""
Apply activation function to input tensor.
Args:
x: Input tensor
Returns:
Output tensor after activation
"""
self._validate_input(x)
x = self._prepare_input(x)
return self._impl(x)
# Legacy function interfaces for backward compatibility
def relu(x: np.ndarray, driver=None) -> np.ndarray:
"""Legacy ReLU interface"""
config = ActivationConfig(type=ActivationType.RELU, dtype=x.dtype)
return Activation(config, driver)(x)
def gelu(x: np.ndarray, driver=None, approximate: bool = True) -> np.ndarray:
"""Legacy GELU interface"""
config = ActivationConfig(type=ActivationType.GELU, dtype=x.dtype, approximate=approximate)
return Activation(config, driver)(x)
def tanh(x: np.ndarray, driver=None) -> np.ndarray:
"""Legacy tanh interface"""
config = ActivationConfig(type=ActivationType.TANH, dtype=x.dtype)
return Activation(config, driver)(x)
def sigmoid(x: np.ndarray, driver=None) -> np.ndarray:
"""Legacy sigmoid interface"""
config = ActivationConfig(type=ActivationType.SIGMOID, dtype=x.dtype)
return Activation(config, driver)(x)