INV / helium /core /db_manager.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Database manager for Helium components using DuckDB
"""
from typing import Optional, Dict, Any, Union
import os
import duckdb
import json
import pickle
import numpy as np
from pathlib import Path
from datetime import datetime
import hashlib
from dotenv import load_dotenv
import warnings
# Initialize HuggingFace token from environment
HF_TOKEN = os.getenv("HF_TOKEN")
# Load environment variables
load_dotenv()
class HeliumDBManager:
"""Centralized database manager for Helium components"""
_instance = None
@classmethod
def get_instance(cls):
"""Singleton pattern to ensure one database connection"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
def __init__(self):
"""Initialize database connection and tables"""
self.db_url = os.getenv('HELIUM_DB_URL', 'hf://datasets/Fred808/helium/storage.json
self.db_file = Path(self.db_url.replace('hf://datasets/', ''))
self._connect_db()
self._init_tables()
def _connect_db(self):
"""Connect to DuckDB database"""
self.conn = duckdb.connect(str(self.db_file))
def _init_tables(self):
"""Initialize all required tables"""
# Activation cache table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS activation_cache (
key VARCHAR PRIMARY KEY,
value BLOB,
metadata JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_accessed TIMESTAMP
)
""")
# Layer normalization cache table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS layer_norm_cache (
key VARCHAR PRIMARY KEY,
mean BLOB,
var BLOB,
normalized BLOB,
metadata JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_accessed TIMESTAMP
)
""")
# Encoder state cache table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS encoder_cache (
key VARCHAR PRIMARY KEY,
key_states BLOB,
value_states BLOB,
metadata JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_accessed TIMESTAMP
)
""")
# Decoder state cache table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS decoder_cache (
key VARCHAR PRIMARY KEY,
self_attn_states BLOB,
cross_attn_states BLOB,
metadata JSON,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_accessed TIMESTAMP
)
""")
# Create indices for faster lookups
for table in ['activation_cache', 'layer_norm_cache', 'encoder_cache', 'decoder_cache']:
self.conn.execute(f"""
CREATE INDEX IF NOT EXISTS idx_{table}_key
ON {table}(key)
""")
def _compute_key(self, data: Union[np.ndarray, bytes], component_type: str, extra_data: str = "") -> str:
"""Compute cache key based on input data and component type"""
hasher = hashlib.sha256()
if isinstance(data, np.ndarray):
hasher.update(data.tobytes())
else:
hasher.update(data)
hasher.update(component_type.encode())
if extra_data:
hasher.update(extra_data.encode())
return hasher.hexdigest()
def get_activation(self, key: str) -> Optional[np.ndarray]:
"""Get cached activation result"""
result = self.conn.execute("""
SELECT value, metadata FROM activation_cache
WHERE key = ?
""", [key]).fetchone()
if result:
self._update_access_time('activation_cache', key)
return pickle.loads(result[0])
return None
def set_activation(self, key: str, value: np.ndarray, metadata: Dict[str, Any]):
"""Cache activation result"""
self.conn.execute("""
INSERT OR REPLACE INTO activation_cache (key, value, metadata)
VALUES (?, ?, ?)
""", [key, pickle.dumps(value), json.dumps(metadata)])
def get_layer_norm(self, key: str) -> Optional[Dict[str, np.ndarray]]:
"""Get cached layer normalization result"""
result = self.conn.execute("""
SELECT mean, var, normalized, metadata
FROM layer_norm_cache
WHERE key = ?
""", [key]).fetchone()
if result:
self._update_access_time('layer_norm_cache', key)
return {
'mean': pickle.loads(result[0]),
'var': pickle.loads(result[1]),
'normalized': pickle.loads(result[2])
}
return None
def set_layer_norm(self, key: str, mean: np.ndarray, var: np.ndarray,
normalized: np.ndarray, metadata: Dict[str, Any]):
"""Cache layer normalization result"""
self.conn.execute("""
INSERT OR REPLACE INTO layer_norm_cache
(key, mean, var, normalized, metadata)
VALUES (?, ?, ?, ?, ?)
""", [
key,
pickle.dumps(mean),
pickle.dumps(var),
pickle.dumps(normalized),
json.dumps(metadata)
])
def get_encoder_state(self, key: str) -> Optional[Dict[str, np.ndarray]]:
"""Get cached encoder state"""
result = self.conn.execute("""
SELECT key_states, value_states, metadata
FROM encoder_cache
WHERE key = ?
""", [key]).fetchone()
if result:
self._update_access_time('encoder_cache', key)
return {
'key_states': pickle.loads(result[0]),
'value_states': pickle.loads(result[1])
}
return None
def set_encoder_state(self, key: str, key_states: np.ndarray,
value_states: np.ndarray, metadata: Dict[str, Any]):
"""Cache encoder state"""
self.conn.execute("""
INSERT OR REPLACE INTO encoder_cache
(key, key_states, value_states, metadata)
VALUES (?, ?, ?, ?)
""", [
key,
pickle.dumps(key_states),
pickle.dumps(value_states),
json.dumps(metadata)
])
def get_decoder_state(self, key: str) -> Optional[Dict[str, np.ndarray]]:
"""Get cached decoder state"""
result = self.conn.execute("""
SELECT self_attn_states, cross_attn_states, metadata
FROM decoder_cache
WHERE key = ?
""", [key]).fetchone()
if result:
self._update_access_time('decoder_cache', key)
return {
'self_attn_states': pickle.loads(result[0]),
'cross_attn_states': pickle.loads(result[1])
}
return None
def set_decoder_state(self, key: str, self_attn_states: np.ndarray,
cross_attn_states: np.ndarray, metadata: Dict[str, Any]):
"""Cache decoder state"""
self.conn.execute("""
INSERT OR REPLACE INTO decoder_cache
(key, self_attn_states, cross_attn_states, metadata)
VALUES (?, ?, ?, ?)
""", [
key,
pickle.dumps(self_attn_states),
pickle.dumps(cross_attn_states),
json.dumps(metadata)
])
def _update_access_time(self, table: str, key: str):
"""Update last accessed timestamp"""
self.conn.execute(f"""
UPDATE {table}
SET last_accessed = CURRENT_TIMESTAMP
WHERE key = ?
""", [key])
def cleanup_old_entries(self, max_age_days: int = 30):
"""Remove entries older than specified days from all tables"""
for table in ['activation_cache', 'layer_norm_cache', 'encoder_cache', 'decoder_cache']:
self.conn.execute(f"""
DELETE FROM {table}
WHERE last_accessed < DATEADD(day, ?, CURRENT_TIMESTAMP)
""", [-max_age_days])
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics for all tables"""
stats = {}
for table in ['activation_cache', 'layer_norm_cache', 'encoder_cache', 'decoder_cache']:
table_stats = self.conn.execute(f"""
SELECT
COUNT(*) as total_entries,
SUM(LENGTH(value)) as total_size_bytes,
MIN(created_at) as oldest_entry,
MAX(last_accessed) as last_accessed
FROM {table}
""").fetchone()
stats[table] = {
'total_entries': table_stats[0],
'total_size_mb': table_stats[1] / (1024 * 1024) if table_stats[1] else 0,
'oldest_entry': table_stats[2],
'last_accessed': table_stats[3]
}
return stats
def __del__(self):
"""Close database connection on cleanup"""
if hasattr(self, 'conn'):
self.conn.close()