|
|
"""
|
|
|
Database manager for graphics pipeline state persistence
|
|
|
"""
|
|
|
import duckdb
|
|
|
import json
|
|
|
import time
|
|
|
import logging
|
|
|
from typing import Dict, List, Optional
|
|
|
from pathlib import Path
|
|
|
from huggingface_hub import HfApi, HfFileSystem
|
|
|
from config import get_hf_token_cached
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PipelineStateDB:
|
|
|
DB_URL = "hf://datasets/Fred808/helium/storage.json"
|
|
|
|
|
|
def __init__(self, db_path: str = None, max_retries: int = 3):
|
|
|
self.db_path = db_path or self.DB_URL
|
|
|
self.max_retries = max_retries
|
|
|
self._connect_with_retries()
|
|
|
|
|
|
def _connect_with_retries(self):
|
|
|
"""Establish database connection with retry logic"""
|
|
|
for attempt in range(self.max_retries):
|
|
|
try:
|
|
|
self.conn = self._init_db_connection()
|
|
|
self._init_tables()
|
|
|
return
|
|
|
except Exception as e:
|
|
|
if attempt == self.max_retries - 1:
|
|
|
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
|
|
|
time.sleep(1)
|
|
|
|
|
|
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
|
|
|
"""Initialize database connection with HuggingFace configuration"""
|
|
|
con = duckdb.connect(self.db_path)
|
|
|
|
|
|
|
|
|
con.execute("INSTALL httpfs;")
|
|
|
con.execute("LOAD httpfs;")
|
|
|
con.execute("SET s3_endpoint='hf.co';")
|
|
|
con.execute("SET s3_use_ssl=true;")
|
|
|
con.execute("SET s3_url_style='path';")
|
|
|
con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
|
|
|
con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
|
|
|
|
|
|
return con
|
|
|
|
|
|
def _init_tables(self):
|
|
|
"""Initialize database tables"""
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS pipeline_states (
|
|
|
hash VARCHAR PRIMARY KEY,
|
|
|
shader_stages JSON,
|
|
|
vertex_attributes JSON,
|
|
|
shader_resources JSON,
|
|
|
viewport JSON,
|
|
|
scissor JSON,
|
|
|
rasterization JSON,
|
|
|
depth JSON,
|
|
|
stencil JSON,
|
|
|
blend JSON,
|
|
|
color_mask JSON,
|
|
|
primitive_type VARCHAR,
|
|
|
patch_control_points INTEGER,
|
|
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS resource_bindings (
|
|
|
pipeline_hash VARCHAR,
|
|
|
binding_point INTEGER,
|
|
|
resource_type VARCHAR,
|
|
|
resource_data JSON,
|
|
|
FOREIGN KEY (pipeline_hash) REFERENCES pipeline_states(hash),
|
|
|
PRIMARY KEY (pipeline_hash, binding_point)
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS cache_stats (
|
|
|
pipeline_hash VARCHAR PRIMARY KEY,
|
|
|
hit_count INTEGER DEFAULT 0,
|
|
|
last_used TIMESTAMP,
|
|
|
FOREIGN KEY (pipeline_hash) REFERENCES pipeline_states(hash)
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
def ensure_connection(self):
|
|
|
"""Ensure database connection is active and valid"""
|
|
|
try:
|
|
|
self.conn.execute("SELECT 1")
|
|
|
except:
|
|
|
logging.warning("Database connection lost, attempting to reconnect...")
|
|
|
self._connect_with_retries()
|
|
|
|
|
|
def store_pipeline(self, hash: str, state_dict: Dict):
|
|
|
"""Store pipeline state in database"""
|
|
|
self.ensure_connection()
|
|
|
|
|
|
state = {
|
|
|
'hash': hash,
|
|
|
'shader_stages': json.dumps(state_dict['shaders']),
|
|
|
'vertex_attributes': json.dumps(state_dict['vertex_attributes']),
|
|
|
'shader_resources': json.dumps(state_dict['shader_resources']),
|
|
|
'viewport': json.dumps(state_dict['viewport']),
|
|
|
'scissor': json.dumps(state_dict['scissor']),
|
|
|
'rasterization': json.dumps(state_dict['rasterization']),
|
|
|
'depth': json.dumps(state_dict['depth']),
|
|
|
'stencil': json.dumps(state_dict['stencil']),
|
|
|
'blend': json.dumps(state_dict['blend']),
|
|
|
'color_mask': json.dumps(state_dict['color_mask']),
|
|
|
'primitive_type': state_dict['primitive_type'],
|
|
|
'patch_control_points': state_dict['patch_control_points']
|
|
|
}
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO pipeline_states
|
|
|
(hash, shader_stages, vertex_attributes, shader_resources,
|
|
|
viewport, scissor, rasterization, depth, stencil, blend,
|
|
|
color_mask, primitive_type, patch_control_points)
|
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
|
ON CONFLICT(hash) DO UPDATE SET
|
|
|
shader_stages=excluded.shader_stages,
|
|
|
vertex_attributes=excluded.vertex_attributes,
|
|
|
shader_resources=excluded.shader_resources,
|
|
|
viewport=excluded.viewport,
|
|
|
scissor=excluded.scissor,
|
|
|
rasterization=excluded.rasterization,
|
|
|
depth=excluded.depth,
|
|
|
stencil=excluded.stencil,
|
|
|
blend=excluded.blend,
|
|
|
color_mask=excluded.color_mask,
|
|
|
primitive_type=excluded.primitive_type,
|
|
|
patch_control_points=excluded.patch_control_points
|
|
|
""", [state[k] for k in state.keys()])
|
|
|
|
|
|
def get_pipeline(self, hash: str) -> Optional[Dict]:
|
|
|
"""Retrieve pipeline state from database"""
|
|
|
result = self.conn.execute("""
|
|
|
SELECT * FROM pipeline_states WHERE hash = ?
|
|
|
""", [hash]).fetchone()
|
|
|
|
|
|
if not result:
|
|
|
return None
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO cache_stats (pipeline_hash, hit_count, last_used)
|
|
|
VALUES (?, 1, CURRENT_TIMESTAMP)
|
|
|
ON CONFLICT(pipeline_hash) DO UPDATE SET
|
|
|
hit_count = cache_stats.hit_count + 1,
|
|
|
last_used = CURRENT_TIMESTAMP
|
|
|
""", [hash])
|
|
|
|
|
|
|
|
|
state = dict(zip(result.keys(), result))
|
|
|
for k in ['shader_stages', 'vertex_attributes', 'shader_resources',
|
|
|
'viewport', 'scissor', 'rasterization', 'depth',
|
|
|
'stencil', 'blend', 'color_mask']:
|
|
|
if state[k]:
|
|
|
state[k] = json.loads(state[k])
|
|
|
|
|
|
return state
|
|
|
|
|
|
def prune_cache(self, max_size: int = 1000):
|
|
|
"""Remove least recently used pipeline states"""
|
|
|
self.conn.execute("""
|
|
|
WITH old_states AS (
|
|
|
SELECT pipeline_hash
|
|
|
FROM cache_stats
|
|
|
ORDER BY last_used ASC
|
|
|
LIMIT (SELECT COUNT(*) - ? FROM pipeline_states)
|
|
|
)
|
|
|
DELETE FROM pipeline_states
|
|
|
WHERE hash IN (SELECT pipeline_hash FROM old_states)
|
|
|
""", [max_size])
|
|
|
|
|
|
def get_cache_stats(self) -> List[Dict]:
|
|
|
"""Get cache usage statistics"""
|
|
|
return self.conn.execute("""
|
|
|
SELECT ps.hash, cs.hit_count, cs.last_used,
|
|
|
ps.created_at
|
|
|
FROM pipeline_states ps
|
|
|
LEFT JOIN cache_stats cs ON ps.hash = cs.pipeline_hash
|
|
|
ORDER BY cs.hit_count DESC
|
|
|
""").fetchall()
|
|
|
|