""" Database manager for graphics pipeline state persistence """ import duckdb import json import time import logging from typing import Dict, List, Optional from pathlib import Path from huggingface_hub import HfApi, HfFileSystem from config import get_hf_token_cached # Initialize token from .env class PipelineStateDB: DB_URL = "hf://datasets/Fred808/helium/storage.json" def __init__(self, db_path: str = None, max_retries: int = 3): self.db_path = db_path or self.DB_URL self.max_retries = max_retries self._connect_with_retries() def _connect_with_retries(self): """Establish database connection with retry logic""" for attempt in range(self.max_retries): try: self.conn = self._init_db_connection() self._init_tables() return except Exception as e: if attempt == self.max_retries - 1: raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}") time.sleep(1) def _init_db_connection(self) -> duckdb.DuckDBPyConnection: """Initialize database connection with HuggingFace configuration""" con = duckdb.connect(self.db_path) # Configure HuggingFace access con.execute("INSTALL httpfs;") con.execute("LOAD httpfs;") con.execute("SET s3_endpoint='hf.co';") con.execute("SET s3_use_ssl=true;") con.execute("SET s3_url_style='path';") con.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';") con.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';") return con def _init_tables(self): """Initialize database tables""" # Pipeline state table self.conn.execute(""" CREATE TABLE IF NOT EXISTS pipeline_states ( hash VARCHAR PRIMARY KEY, shader_stages JSON, vertex_attributes JSON, shader_resources JSON, viewport JSON, scissor JSON, rasterization JSON, depth JSON, stencil JSON, blend JSON, color_mask JSON, primitive_type VARCHAR, patch_control_points INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Resource bindings table self.conn.execute(""" CREATE TABLE IF NOT EXISTS resource_bindings ( pipeline_hash VARCHAR, binding_point INTEGER, resource_type VARCHAR, resource_data JSON, FOREIGN KEY (pipeline_hash) REFERENCES pipeline_states(hash), PRIMARY KEY (pipeline_hash, binding_point) ) """) # Cache statistics self.conn.execute(""" CREATE TABLE IF NOT EXISTS cache_stats ( pipeline_hash VARCHAR PRIMARY KEY, hit_count INTEGER DEFAULT 0, last_used TIMESTAMP, FOREIGN KEY (pipeline_hash) REFERENCES pipeline_states(hash) ) """) def ensure_connection(self): """Ensure database connection is active and valid""" try: self.conn.execute("SELECT 1") except: logging.warning("Database connection lost, attempting to reconnect...") self._connect_with_retries() def store_pipeline(self, hash: str, state_dict: Dict): """Store pipeline state in database""" self.ensure_connection() # Convert complex objects to JSON state = { 'hash': hash, 'shader_stages': json.dumps(state_dict['shaders']), 'vertex_attributes': json.dumps(state_dict['vertex_attributes']), 'shader_resources': json.dumps(state_dict['shader_resources']), 'viewport': json.dumps(state_dict['viewport']), 'scissor': json.dumps(state_dict['scissor']), 'rasterization': json.dumps(state_dict['rasterization']), 'depth': json.dumps(state_dict['depth']), 'stencil': json.dumps(state_dict['stencil']), 'blend': json.dumps(state_dict['blend']), 'color_mask': json.dumps(state_dict['color_mask']), 'primitive_type': state_dict['primitive_type'], 'patch_control_points': state_dict['patch_control_points'] } # Insert/update pipeline state self.conn.execute(""" INSERT INTO pipeline_states (hash, shader_stages, vertex_attributes, shader_resources, viewport, scissor, rasterization, depth, stencil, blend, color_mask, primitive_type, patch_control_points) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(hash) DO UPDATE SET shader_stages=excluded.shader_stages, vertex_attributes=excluded.vertex_attributes, shader_resources=excluded.shader_resources, viewport=excluded.viewport, scissor=excluded.scissor, rasterization=excluded.rasterization, depth=excluded.depth, stencil=excluded.stencil, blend=excluded.blend, color_mask=excluded.color_mask, primitive_type=excluded.primitive_type, patch_control_points=excluded.patch_control_points """, [state[k] for k in state.keys()]) def get_pipeline(self, hash: str) -> Optional[Dict]: """Retrieve pipeline state from database""" result = self.conn.execute(""" SELECT * FROM pipeline_states WHERE hash = ? """, [hash]).fetchone() if not result: return None # Update cache statistics self.conn.execute(""" INSERT INTO cache_stats (pipeline_hash, hit_count, last_used) VALUES (?, 1, CURRENT_TIMESTAMP) ON CONFLICT(pipeline_hash) DO UPDATE SET hit_count = cache_stats.hit_count + 1, last_used = CURRENT_TIMESTAMP """, [hash]) # Convert JSON back to Python objects state = dict(zip(result.keys(), result)) for k in ['shader_stages', 'vertex_attributes', 'shader_resources', 'viewport', 'scissor', 'rasterization', 'depth', 'stencil', 'blend', 'color_mask']: if state[k]: state[k] = json.loads(state[k]) return state def prune_cache(self, max_size: int = 1000): """Remove least recently used pipeline states""" self.conn.execute(""" WITH old_states AS ( SELECT pipeline_hash FROM cache_stats ORDER BY last_used ASC LIMIT (SELECT COUNT(*) - ? FROM pipeline_states) ) DELETE FROM pipeline_states WHERE hash IN (SELECT pipeline_hash FROM old_states) """, [max_size]) def get_cache_stats(self) -> List[Dict]: """Get cache usage statistics""" return self.conn.execute(""" SELECT ps.hash, cs.hit_count, cs.last_used, ps.created_at FROM pipeline_states ps LEFT JOIN cache_stats cs ON ps.hash = cs.pipeline_hash ORDER BY cs.hit_count DESC """).fetchall()