|
|
"""
|
|
|
Enhanced Shader System for Virtual GPU
|
|
|
Integrates with DuckDB for shader management and state tracking
|
|
|
"""
|
|
|
|
|
|
import duckdb
|
|
|
import json
|
|
|
import time
|
|
|
from typing import Dict, List, Optional, Union
|
|
|
from enum import Enum
|
|
|
import logging
|
|
|
import hashlib
|
|
|
from config import get_db_url, get_hf_token_cached
|
|
|
from huggingface_hub import HfApi, HfFileSystem
|
|
|
|
|
|
class ShaderType(Enum):
|
|
|
VERTEX = "vertex"
|
|
|
FRAGMENT = "fragment"
|
|
|
COMPUTE = "compute"
|
|
|
GEOMETRY = "geometry"
|
|
|
|
|
|
class ShaderProgram:
|
|
|
"""Represents a shader program that can contain multiple shader stages"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.shaders = []
|
|
|
self.uniforms = {}
|
|
|
self.attributes = {}
|
|
|
self.is_linked = False
|
|
|
self.link_error = None
|
|
|
|
|
|
def attach_shader(self, shader):
|
|
|
"""Attach a shader to this program"""
|
|
|
if shader not in self.shaders:
|
|
|
self.shaders.append(shader)
|
|
|
self.is_linked = False
|
|
|
|
|
|
def detach_shader(self, shader):
|
|
|
"""Detach a shader from this program"""
|
|
|
if shader in self.shaders:
|
|
|
self.shaders.remove(shader)
|
|
|
self.is_linked = False
|
|
|
|
|
|
def link(self):
|
|
|
"""Link the shader program"""
|
|
|
try:
|
|
|
|
|
|
shader_types = {shader.type for shader in self.shaders}
|
|
|
if ShaderType.VERTEX not in shader_types:
|
|
|
raise ValueError("Shader program must have a vertex shader")
|
|
|
if ShaderType.FRAGMENT not in shader_types:
|
|
|
raise ValueError("Shader program must have a fragment shader")
|
|
|
|
|
|
|
|
|
self.is_linked = True
|
|
|
self.link_error = None
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
self.is_linked = False
|
|
|
self.link_error = str(e)
|
|
|
return False
|
|
|
|
|
|
def use(self):
|
|
|
"""Activate this shader program"""
|
|
|
if not self.is_linked:
|
|
|
raise RuntimeError("Cannot use unlinked shader program")
|
|
|
|
|
|
|
|
|
def set_uniform(self, name: str, value: Union[float, int, list]):
|
|
|
"""Set a uniform value"""
|
|
|
self.uniforms[name] = value
|
|
|
|
|
|
def set_attribute(self, name: str, value: Union[float, int, list]):
|
|
|
"""Set an attribute value"""
|
|
|
self.attributes[name] = value
|
|
|
|
|
|
class ShaderError(Exception):
|
|
|
pass
|
|
|
|
|
|
class ShaderSystem:
|
|
|
DB_URL = "hf://datasets/Fred808/helium/storage.json"
|
|
|
|
|
|
def __init__(self, hal, db_url: Optional[str] = None):
|
|
|
"""Initialize shader system with remote database connection"""
|
|
|
self.hal = hal
|
|
|
self.db_url = db_url or self.DB_URL
|
|
|
self.max_retries = 3
|
|
|
self._connect_with_retries()
|
|
|
|
|
|
def _connect_with_retries(self):
|
|
|
"""Establish database connection with retry logic"""
|
|
|
for attempt in range(self.max_retries):
|
|
|
try:
|
|
|
self.conn = self._init_db_connection()
|
|
|
self._setup_database()
|
|
|
return
|
|
|
except Exception as e:
|
|
|
if attempt == self.max_retries - 1:
|
|
|
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
|
|
|
time.sleep(1)
|
|
|
|
|
|
def _init_db_connection(self) -> duckdb.DuckDBPyConnection:
|
|
|
"""Initialize database connection with HuggingFace configuration"""
|
|
|
|
|
|
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
|
|
|
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
|
|
|
|
|
|
|
|
|
conn = duckdb.connect(db_path)
|
|
|
conn.execute("INSTALL httpfs;")
|
|
|
conn.execute("LOAD httpfs;")
|
|
|
conn.execute("SET s3_endpoint='s3.us-east-1.amazonaws.com';")
|
|
|
conn.execute("SET s3_use_ssl=true;")
|
|
|
conn.execute("SET s3_url_style='path';")
|
|
|
conn.execute(f"SET s3_access_key_id='{self.HF_TOKEN}';")
|
|
|
conn.execute(f"SET s3_secret_access_key='{self.HF_TOKEN}';")
|
|
|
return conn
|
|
|
|
|
|
def _setup_database(self):
|
|
|
"""Initialize database tables for shader management"""
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS shader_programs (
|
|
|
program_id VARCHAR PRIMARY KEY,
|
|
|
name VARCHAR,
|
|
|
vertex_shader_id VARCHAR,
|
|
|
fragment_shader_id VARCHAR,
|
|
|
geometry_shader_id VARCHAR,
|
|
|
compute_shader_id VARCHAR,
|
|
|
uniforms JSON,
|
|
|
state_json JSON
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS shaders (
|
|
|
shader_id VARCHAR PRIMARY KEY,
|
|
|
type VARCHAR,
|
|
|
source TEXT,
|
|
|
compiled_code JSON,
|
|
|
metadata JSON
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS uniform_values (
|
|
|
program_id VARCHAR,
|
|
|
uniform_name VARCHAR,
|
|
|
value_type VARCHAR,
|
|
|
value JSON,
|
|
|
PRIMARY KEY (program_id, uniform_name)
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS execution_state (
|
|
|
program_id VARCHAR,
|
|
|
chip_id INTEGER,
|
|
|
sm_id INTEGER,
|
|
|
unit_id INTEGER,
|
|
|
execution_count INTEGER DEFAULT 0,
|
|
|
last_execution TIMESTAMP,
|
|
|
performance_stats JSON,
|
|
|
PRIMARY KEY (program_id, chip_id, sm_id, unit_id)
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
def create_shader(self, source: str, shader_type: Union[str, ShaderType]) -> str:
|
|
|
"""Create a new shader from source code"""
|
|
|
if isinstance(shader_type, str):
|
|
|
shader_type = ShaderType(shader_type)
|
|
|
|
|
|
|
|
|
shader_id = f"shader_{hashlib.md5(source.encode()).hexdigest()[:16]}"
|
|
|
|
|
|
|
|
|
try:
|
|
|
compiled_code = self._compile_shader(source, shader_type)
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO shaders (shader_id, type, source, compiled_code, metadata)
|
|
|
VALUES (?, ?, ?, ?, ?)
|
|
|
""", [
|
|
|
shader_id,
|
|
|
shader_type.value,
|
|
|
source,
|
|
|
json.dumps(compiled_code),
|
|
|
json.dumps({
|
|
|
"created_at": "NOW",
|
|
|
"version": "1.0",
|
|
|
"compiler_optimizations": True
|
|
|
})
|
|
|
])
|
|
|
|
|
|
self.conn.commit()
|
|
|
return shader_id
|
|
|
|
|
|
except Exception as e:
|
|
|
raise ShaderError(f"Failed to compile shader: {str(e)}")
|
|
|
|
|
|
def create_program(self, name: str) -> str:
|
|
|
"""Create a new shader program"""
|
|
|
program_id = f"program_{hashlib.md5(name.encode()).hexdigest()[:16]}"
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO shader_programs (
|
|
|
program_id, name, state_json
|
|
|
) VALUES (?, ?, ?)
|
|
|
""", [
|
|
|
program_id,
|
|
|
name,
|
|
|
json.dumps({"status": "created", "linked": False})
|
|
|
])
|
|
|
|
|
|
self.conn.commit()
|
|
|
return program_id
|
|
|
|
|
|
def attach_shader(self, program_id: str, shader_id: str):
|
|
|
"""Attach a shader to a program"""
|
|
|
|
|
|
shader = self.conn.execute("""
|
|
|
SELECT type FROM shaders WHERE shader_id = ?
|
|
|
""", [shader_id]).fetchone()
|
|
|
|
|
|
if not shader:
|
|
|
raise ShaderError(f"Shader {shader_id} not found")
|
|
|
|
|
|
|
|
|
shader_type = shader[0]
|
|
|
column = f"{shader_type}_shader_id"
|
|
|
|
|
|
self.conn.execute(f"""
|
|
|
UPDATE shader_programs
|
|
|
SET {column} = ?,
|
|
|
state_json = json_set(state_json::json, '$.linked', 'false')::json
|
|
|
WHERE program_id = ?
|
|
|
""", [shader_id, program_id])
|
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
def set_uniform(self, program_id: str, name: str, value: Union[float, int, List[float], List[int]]):
|
|
|
"""Set a uniform value for a shader program"""
|
|
|
value_type = type(value).__name__
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO uniform_values (program_id, uniform_name, value_type, value)
|
|
|
VALUES (?, ?, ?, ?)
|
|
|
ON CONFLICT (program_id, uniform_name) DO UPDATE SET
|
|
|
value_type = excluded.value_type,
|
|
|
value = excluded.value
|
|
|
""", [program_id, name, value_type, json.dumps(value)])
|
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
def link_program(self, program_id: str) -> bool:
|
|
|
"""Link a shader program"""
|
|
|
program = self.conn.execute("""
|
|
|
SELECT vertex_shader_id, fragment_shader_id
|
|
|
FROM shader_programs WHERE program_id = ?
|
|
|
""", [program_id]).fetchone()
|
|
|
|
|
|
if not program:
|
|
|
raise ShaderError(f"Program {program_id} not found")
|
|
|
|
|
|
if not program[0] or not program[1]:
|
|
|
raise ShaderError("Program must have both vertex and fragment shaders")
|
|
|
|
|
|
try:
|
|
|
|
|
|
self._verify_shader_interface(program[0], program[1])
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
UPDATE shader_programs
|
|
|
SET state_json = json_set(state_json::json,
|
|
|
'$.linked', 'true',
|
|
|
'$.link_time', extract(epoch from current_timestamp)
|
|
|
)::json
|
|
|
WHERE program_id = ?
|
|
|
""", [program_id])
|
|
|
|
|
|
self.conn.commit()
|
|
|
return True
|
|
|
|
|
|
except Exception as e:
|
|
|
self.conn.execute("""
|
|
|
UPDATE shader_programs
|
|
|
SET state_json = json_set(state_json::json,
|
|
|
'$.linked', 'false',
|
|
|
'$.link_error', ?
|
|
|
)::json
|
|
|
WHERE program_id = ?
|
|
|
""", [str(e), program_id])
|
|
|
|
|
|
self.conn.commit()
|
|
|
raise ShaderError(f"Link failed: {str(e)}")
|
|
|
|
|
|
def execute_program(self, program_id: str, input_data: Dict) -> Dict:
|
|
|
"""Execute a shader program"""
|
|
|
|
|
|
program = self.conn.execute("""
|
|
|
SELECT state_json, vertex_shader_id, fragment_shader_id
|
|
|
FROM shader_programs WHERE program_id = ?
|
|
|
""", [program_id]).fetchone()
|
|
|
|
|
|
if not program:
|
|
|
raise ShaderError(f"Program {program_id} not found")
|
|
|
|
|
|
state = json.loads(program[0])
|
|
|
if not state.get("linked"):
|
|
|
raise ShaderError("Program is not linked")
|
|
|
|
|
|
try:
|
|
|
|
|
|
vertex_shader = self.conn.execute("""
|
|
|
SELECT compiled_code FROM shaders WHERE shader_id = ?
|
|
|
""", [program[1]]).fetchone()
|
|
|
|
|
|
fragment_shader = self.conn.execute("""
|
|
|
SELECT compiled_code FROM shaders WHERE shader_id = ?
|
|
|
""", [program[2]]).fetchone()
|
|
|
|
|
|
|
|
|
uniforms = self.conn.execute("""
|
|
|
SELECT uniform_name, value_type, value
|
|
|
FROM uniform_values WHERE program_id = ?
|
|
|
""", [program_id]).fetchall()
|
|
|
|
|
|
uniform_data = {
|
|
|
u[0]: json.loads(u[2])
|
|
|
for u in uniforms
|
|
|
}
|
|
|
|
|
|
|
|
|
vertex_output = self._execute_vertex_shader(
|
|
|
json.loads(vertex_shader[0]),
|
|
|
input_data,
|
|
|
uniform_data
|
|
|
)
|
|
|
|
|
|
|
|
|
final_output = self._execute_fragment_shader(
|
|
|
json.loads(fragment_shader[0]),
|
|
|
vertex_output,
|
|
|
uniform_data
|
|
|
)
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO execution_state (
|
|
|
program_id, chip_id, sm_id, unit_id,
|
|
|
execution_count, last_execution, performance_stats
|
|
|
) VALUES (?, ?, ?, ?, 1, NOW(), ?)
|
|
|
ON CONFLICT (program_id, chip_id, sm_id, unit_id) DO UPDATE SET
|
|
|
execution_count = execution_state.execution_count + 1,
|
|
|
last_execution = NOW(),
|
|
|
performance_stats = json_set(
|
|
|
execution_state.performance_stats::json,
|
|
|
'$.last_execution_time',
|
|
|
extract(epoch from current_timestamp)
|
|
|
)::json
|
|
|
""", [program_id, 0, 0, 0, json.dumps({"status": "success"})])
|
|
|
|
|
|
self.conn.commit()
|
|
|
return final_output
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
self.conn.execute("""
|
|
|
INSERT INTO execution_state (
|
|
|
program_id, chip_id, sm_id, unit_id,
|
|
|
execution_count, last_execution, performance_stats
|
|
|
) VALUES (?, ?, ?, ?, 1, NOW(), ?)
|
|
|
ON CONFLICT (program_id, chip_id, sm_id, unit_id) DO UPDATE SET
|
|
|
execution_count = execution_state.execution_count + 1,
|
|
|
last_execution = NOW(),
|
|
|
performance_stats = ?
|
|
|
""", [
|
|
|
program_id, 0, 0, 0,
|
|
|
json.dumps({"status": "error", "error": str(e)}),
|
|
|
json.dumps({"status": "error", "error": str(e)})
|
|
|
])
|
|
|
|
|
|
self.conn.commit()
|
|
|
raise ShaderError(f"Execution failed: {str(e)}")
|
|
|
|
|
|
def _compile_shader(self, source: str, shader_type: ShaderType) -> Dict:
|
|
|
"""Compile shader source to intermediate representation"""
|
|
|
|
|
|
|
|
|
return {
|
|
|
"type": shader_type.value,
|
|
|
"instructions": self._parse_shader(source),
|
|
|
"variables": self._extract_variables(source)
|
|
|
}
|
|
|
|
|
|
def _verify_shader_interface(self, vertex_id: str, fragment_id: str):
|
|
|
"""Verify that vertex and fragment shaders have compatible interfaces"""
|
|
|
vertex = self.conn.execute("""
|
|
|
SELECT compiled_code FROM shaders WHERE shader_id = ?
|
|
|
""", [vertex_id]).fetchone()
|
|
|
|
|
|
fragment = self.conn.execute("""
|
|
|
SELECT compiled_code FROM shaders WHERE shader_id = ?
|
|
|
""", [fragment_id]).fetchone()
|
|
|
|
|
|
if not vertex or not fragment:
|
|
|
raise ShaderError("Shader not found")
|
|
|
|
|
|
vertex_code = json.loads(vertex[0])
|
|
|
fragment_code = json.loads(fragment[0])
|
|
|
|
|
|
|
|
|
vertex_outputs = set(v["name"] for v in vertex_code["variables"] if v.get("output"))
|
|
|
fragment_inputs = set(v["name"] for v in fragment_code["variables"] if v.get("input"))
|
|
|
|
|
|
missing_inputs = fragment_inputs - vertex_outputs
|
|
|
if missing_inputs:
|
|
|
raise ShaderError(f"Fragment shader requires inputs not provided by vertex shader: {missing_inputs}")
|
|
|
|
|
|
def _execute_vertex_shader(self, shader_code: Dict, input_data: Dict, uniforms: Dict) -> Dict:
|
|
|
"""Execute compiled vertex shader code"""
|
|
|
|
|
|
|
|
|
return {
|
|
|
"position": [0, 0, 0, 1],
|
|
|
"varying": {}
|
|
|
}
|
|
|
|
|
|
def _execute_fragment_shader(self, shader_code: Dict, vertex_data: Dict, uniforms: Dict) -> Dict:
|
|
|
"""Execute compiled fragment shader code"""
|
|
|
|
|
|
|
|
|
return {
|
|
|
"color": [1, 1, 1, 1]
|
|
|
}
|
|
|
|
|
|
def _parse_shader(self, source: str) -> List:
|
|
|
"""Parse shader source into instructions"""
|
|
|
|
|
|
|
|
|
return []
|
|
|
|
|
|
def _extract_variables(self, source: str) -> List:
|
|
|
"""Extract variable declarations from shader source"""
|
|
|
|
|
|
|
|
|
return []
|
|
|
|
|
|
def get_program_stats(self, program_id: str) -> Dict:
|
|
|
"""Get execution statistics for a shader program"""
|
|
|
result = self.conn.execute("""
|
|
|
SELECT
|
|
|
sp.name,
|
|
|
sp.state_json,
|
|
|
COUNT(DISTINCT es.chip_id || ':' || es.sm_id || ':' || es.unit_id) as execution_units,
|
|
|
SUM(es.execution_count) as total_executions,
|
|
|
MAX(es.last_execution) as last_execution
|
|
|
FROM shader_programs sp
|
|
|
LEFT JOIN execution_state es ON sp.program_id = es.program_id
|
|
|
WHERE sp.program_id = ?
|
|
|
GROUP BY sp.program_id, sp.name, sp.state_json
|
|
|
""", [program_id]).fetchone()
|
|
|
|
|
|
if not result:
|
|
|
raise ShaderError(f"Program {program_id} not found")
|
|
|
|
|
|
return {
|
|
|
"name": result[0],
|
|
|
"state": json.loads(result[1]),
|
|
|
"execution_units": result[2],
|
|
|
"total_executions": result[3],
|
|
|
"last_execution": result[4]
|
|
|
}
|
|
|
|
|
|
def run_compute(self, *args, **kwargs):
|
|
|
if 'compute' in self.shaders:
|
|
|
return self.shaders['compute'].run(*args, **kwargs)
|
|
|
return None
|
|
|
|
|
|
class ShaderManager:
|
|
|
def __init__(self):
|
|
|
self.programs = {}
|
|
|
self.next_id = 1
|
|
|
def create_program(self):
|
|
|
pid = self.next_id
|
|
|
self.next_id += 1
|
|
|
self.programs[pid] = ShaderProgram()
|
|
|
return pid
|
|
|
def get_program(self, pid):
|
|
|
return self.programs[pid]
|
|
|
def attach_shader(self, pid, shader):
|
|
|
self.programs[pid].attach_shader(shader)
|
|
|
def link_program(self, pid):
|
|
|
return self.programs[pid].link()
|
|
|
def run_stage(self, pid, stage, *args, **kwargs):
|
|
|
prog = self.programs[pid]
|
|
|
if stage == 'vertex':
|
|
|
return prog.run_vertex(*args, **kwargs)
|
|
|
elif stage == 'fragment':
|
|
|
return prog.run_fragment(*args, **kwargs)
|
|
|
elif stage == 'geometry':
|
|
|
return prog.run_geometry(*args, **kwargs)
|
|
|
elif stage == 'compute':
|
|
|
return prog.run_compute(*args, **kwargs)
|
|
|
else:
|
|
|
raise ValueError(f"Unknown shader stage: {stage}")
|
|
|
|