| """
|
| Database manager for Helium components using DuckDB
|
| """
|
| from typing import Optional, Dict, Any, Union
|
| import os
|
| import duckdb
|
| import json
|
| import pickle
|
| import numpy as np
|
| from pathlib import Path
|
| from datetime import datetime
|
| import hashlib
|
| from dotenv import load_dotenv
|
| import warnings
|
|
|
|
|
| HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
|
|
|
| load_dotenv()
|
|
|
| class HeliumDBManager:
|
| """Centralized database manager for Helium components"""
|
|
|
| _instance = None
|
|
|
| @classmethod
|
| def get_instance(cls):
|
| """Singleton pattern to ensure one database connection"""
|
| if cls._instance is None:
|
| cls._instance = cls()
|
| return cls._instance
|
|
|
| def __init__(self):
|
| """Initialize database connection and tables"""
|
| self.db_url = os.getenv('HELIUM_DB_URL', 'hf://datasets/Fred808/helium/storage.json
|
| self.db_file = Path(self.db_url.replace('hf://datasets/', ''))
|
| self._connect_db()
|
| self._init_tables()
|
|
|
| def _connect_db(self):
|
| """Connect to DuckDB database"""
|
| self.conn = duckdb.connect(str(self.db_file))
|
|
|
| def _init_tables(self):
|
| """Initialize all required tables"""
|
| # Activation cache table
|
| self.conn.execute("""
|
| CREATE TABLE IF NOT EXISTS activation_cache (
|
| key VARCHAR PRIMARY KEY,
|
| value BLOB,
|
| metadata JSON,
|
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| last_accessed TIMESTAMP
|
| )
|
| """)
|
|
|
| # Layer normalization cache table
|
| self.conn.execute("""
|
| CREATE TABLE IF NOT EXISTS layer_norm_cache (
|
| key VARCHAR PRIMARY KEY,
|
| mean BLOB,
|
| var BLOB,
|
| normalized BLOB,
|
| metadata JSON,
|
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| last_accessed TIMESTAMP
|
| )
|
| """)
|
|
|
| # Encoder state cache table
|
| self.conn.execute("""
|
| CREATE TABLE IF NOT EXISTS encoder_cache (
|
| key VARCHAR PRIMARY KEY,
|
| key_states BLOB,
|
| value_states BLOB,
|
| metadata JSON,
|
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| last_accessed TIMESTAMP
|
| )
|
| """)
|
|
|
| # Decoder state cache table
|
| self.conn.execute("""
|
| CREATE TABLE IF NOT EXISTS decoder_cache (
|
| key VARCHAR PRIMARY KEY,
|
| self_attn_states BLOB,
|
| cross_attn_states BLOB,
|
| metadata JSON,
|
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| last_accessed TIMESTAMP
|
| )
|
| """)
|
|
|
| # Create indices for faster lookups
|
| for table in ['activation_cache', 'layer_norm_cache', 'encoder_cache', 'decoder_cache']:
|
| self.conn.execute(f"""
|
| CREATE INDEX IF NOT EXISTS idx_{table}_key
|
| ON {table}(key)
|
| """)
|
|
|
| def _compute_key(self, data: Union[np.ndarray, bytes], component_type: str, extra_data: str = "") -> str:
|
| """Compute cache key based on input data and component type"""
|
| hasher = hashlib.sha256()
|
| if isinstance(data, np.ndarray):
|
| hasher.update(data.tobytes())
|
| else:
|
| hasher.update(data)
|
| hasher.update(component_type.encode())
|
| if extra_data:
|
| hasher.update(extra_data.encode())
|
| return hasher.hexdigest()
|
|
|
| def get_activation(self, key: str) -> Optional[np.ndarray]:
|
| """Get cached activation result"""
|
| result = self.conn.execute("""
|
| SELECT value, metadata FROM activation_cache
|
| WHERE key = ?
|
| """, [key]).fetchone()
|
|
|
| if result:
|
| self._update_access_time('activation_cache', key)
|
| return pickle.loads(result[0])
|
| return None
|
|
|
| def set_activation(self, key: str, value: np.ndarray, metadata: Dict[str, Any]):
|
| """Cache activation result"""
|
| self.conn.execute("""
|
| INSERT OR REPLACE INTO activation_cache (key, value, metadata)
|
| VALUES (?, ?, ?)
|
| """, [key, pickle.dumps(value), json.dumps(metadata)])
|
|
|
| def get_layer_norm(self, key: str) -> Optional[Dict[str, np.ndarray]]:
|
| """Get cached layer normalization result"""
|
| result = self.conn.execute("""
|
| SELECT mean, var, normalized, metadata
|
| FROM layer_norm_cache
|
| WHERE key = ?
|
| """, [key]).fetchone()
|
|
|
| if result:
|
| self._update_access_time('layer_norm_cache', key)
|
| return {
|
| 'mean': pickle.loads(result[0]),
|
| 'var': pickle.loads(result[1]),
|
| 'normalized': pickle.loads(result[2])
|
| }
|
| return None
|
|
|
| def set_layer_norm(self, key: str, mean: np.ndarray, var: np.ndarray,
|
| normalized: np.ndarray, metadata: Dict[str, Any]):
|
| """Cache layer normalization result"""
|
| self.conn.execute("""
|
| INSERT OR REPLACE INTO layer_norm_cache
|
| (key, mean, var, normalized, metadata)
|
| VALUES (?, ?, ?, ?, ?)
|
| """, [
|
| key,
|
| pickle.dumps(mean),
|
| pickle.dumps(var),
|
| pickle.dumps(normalized),
|
| json.dumps(metadata)
|
| ])
|
|
|
| def get_encoder_state(self, key: str) -> Optional[Dict[str, np.ndarray]]:
|
| """Get cached encoder state"""
|
| result = self.conn.execute("""
|
| SELECT key_states, value_states, metadata
|
| FROM encoder_cache
|
| WHERE key = ?
|
| """, [key]).fetchone()
|
|
|
| if result:
|
| self._update_access_time('encoder_cache', key)
|
| return {
|
| 'key_states': pickle.loads(result[0]),
|
| 'value_states': pickle.loads(result[1])
|
| }
|
| return None
|
|
|
| def set_encoder_state(self, key: str, key_states: np.ndarray,
|
| value_states: np.ndarray, metadata: Dict[str, Any]):
|
| """Cache encoder state"""
|
| self.conn.execute("""
|
| INSERT OR REPLACE INTO encoder_cache
|
| (key, key_states, value_states, metadata)
|
| VALUES (?, ?, ?, ?)
|
| """, [
|
| key,
|
| pickle.dumps(key_states),
|
| pickle.dumps(value_states),
|
| json.dumps(metadata)
|
| ])
|
|
|
| def get_decoder_state(self, key: str) -> Optional[Dict[str, np.ndarray]]:
|
| """Get cached decoder state"""
|
| result = self.conn.execute("""
|
| SELECT self_attn_states, cross_attn_states, metadata
|
| FROM decoder_cache
|
| WHERE key = ?
|
| """, [key]).fetchone()
|
|
|
| if result:
|
| self._update_access_time('decoder_cache', key)
|
| return {
|
| 'self_attn_states': pickle.loads(result[0]),
|
| 'cross_attn_states': pickle.loads(result[1])
|
| }
|
| return None
|
|
|
| def set_decoder_state(self, key: str, self_attn_states: np.ndarray,
|
| cross_attn_states: np.ndarray, metadata: Dict[str, Any]):
|
| """Cache decoder state"""
|
| self.conn.execute("""
|
| INSERT OR REPLACE INTO decoder_cache
|
| (key, self_attn_states, cross_attn_states, metadata)
|
| VALUES (?, ?, ?, ?)
|
| """, [
|
| key,
|
| pickle.dumps(self_attn_states),
|
| pickle.dumps(cross_attn_states),
|
| json.dumps(metadata)
|
| ])
|
|
|
| def _update_access_time(self, table: str, key: str):
|
| """Update last accessed timestamp"""
|
| self.conn.execute(f"""
|
| UPDATE {table}
|
| SET last_accessed = CURRENT_TIMESTAMP
|
| WHERE key = ?
|
| """, [key])
|
|
|
| def cleanup_old_entries(self, max_age_days: int = 30):
|
| """Remove entries older than specified days from all tables"""
|
| for table in ['activation_cache', 'layer_norm_cache', 'encoder_cache', 'decoder_cache']:
|
| self.conn.execute(f"""
|
| DELETE FROM {table}
|
| WHERE last_accessed < DATEADD(day, ?, CURRENT_TIMESTAMP)
|
| """, [-max_age_days])
|
|
|
| def get_stats(self) -> Dict[str, Any]:
|
| """Get cache statistics for all tables"""
|
| stats = {}
|
| for table in ['activation_cache', 'layer_norm_cache', 'encoder_cache', 'decoder_cache']:
|
| table_stats = self.conn.execute(f"""
|
| SELECT
|
| COUNT(*) as total_entries,
|
| SUM(LENGTH(value)) as total_size_bytes,
|
| MIN(created_at) as oldest_entry,
|
| MAX(last_accessed) as last_accessed
|
| FROM {table}
|
| """).fetchone()
|
|
|
| stats[table] = {
|
| 'total_entries': table_stats[0],
|
| 'total_size_mb': table_stats[1] / (1024 * 1024) if table_stats[1] else 0,
|
| 'oldest_entry': table_stats[2],
|
| 'last_accessed': table_stats[3]
|
| }
|
| return stats
|
|
|
| def __del__(self):
|
| """Close database connection on cleanup"""
|
| if hasattr(self, 'conn'):
|
| self.conn.close()
|
| |