INV / virtual_gpu_driver /src /graphics /shader_compiler.py
Fred808's picture
Upload 256 files
7a0c684 verified
from enum import Enum
import re
from typing import Dict, List, Optional, Union
import hashlib
import duckdb
import json
from config import get_db_url, get_hf_token_cached
import logging
class ShaderType(Enum):
VERTEX = "vertex"
FRAGMENT = "fragment"
COMPUTE = "compute"
class ShaderError(Exception):
pass
class ShaderCompilerDB:
DB_URL = "hf://datasets/Fred808/helium/storage.json"
def __init__(self, db_url: Optional[str] = None):
"""Initialize shader compiler database"""
self.db_url = db_url or self.DB_URL
self.max_retries = 3
self._connect_with_retries()
def _connect_with_retries(self):
"""Establish database connection with retry logic"""
for attempt in range(self.max_retries):
try:
self._init_db_connection()
self._setup_database()
return
except Exception as e:
if attempt == self.max_retries - 1:
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
logging.warning(f"Database connection attempt {attempt + 1} failed, retrying...")
def _init_db_connection(self):
"""Initialize database connection with HuggingFace configuration"""
# Convert HF URL to S3 path and connect directly
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
# Connect directly to remote database
self.conn = duckdb.connect(db_path)
self.conn.execute("""
INSTALL httpfs;
LOAD httpfs;
SET s3_region='us-east-1';
SET s3_endpoint='s3.us-east-1.amazonaws.com';
SET s3_url_style='path';
SET s3_access_key_id='none';
SET s3_secret_access_key=?;
""", [self.HF_TOKEN])
def ensure_connection(self):
"""Ensure database connection is active and valid"""
try:
self.conn.execute("SELECT 1")
except:
logging.warning("Database connection lost, attempting to reconnect...")
self._connect_with_retries()
def _setup_database(self):
"""Set up shader compiler tables"""
self.ensure_connection()
# Shader storage
self.conn.execute("""
CREATE TABLE IF NOT EXISTS shaders (
id VARCHAR PRIMARY KEY,
type VARCHAR,
source TEXT,
variables JSON,
instructions JSON
)
""")
# Program storage (linked shaders)
self.conn.execute("""
CREATE TABLE IF NOT EXISTS programs (
id VARCHAR PRIMARY KEY,
vertex_shader_id VARCHAR,
fragment_shader_id VARCHAR,
uniforms JSON,
attributes JSON,
varyings JSON,
linked BOOLEAN
)
""")
# Register allocation table
self.conn.execute("""
CREATE TABLE IF NOT EXISTS registers (
shader_id VARCHAR,
var_name VARCHAR,
register_name VARCHAR,
PRIMARY KEY (shader_id, var_name)
)
""")
# Instructions table with optimization metadata
self.conn.execute("""
CREATE TABLE IF NOT EXISTS instructions (
id INTEGER PRIMARY KEY,
shader_id VARCHAR,
opcode VARCHAR,
args JSON,
result VARCHAR,
is_dead BOOLEAN DEFAULT FALSE,
depends_on JSON
)
""")
self.conn.commit()
class Instruction:
def __init__(self, opcode: str, args: List[str], result: Optional[str] = None):
self.opcode = opcode
self.args = args
self.result = result
self.id = None # Will be set when stored in DB
def __str__(self):
result_str = f"{self.result} = " if self.result else ""
return f"{result_str}{self.opcode} {', '.join(self.args)}"
def to_dict(self):
return {
"opcode": self.opcode,
"args": self.args,
"result": self.result
}
@classmethod
def from_dict(cls, data):
return cls(data["opcode"], data["args"], data["result"])
class Variable:
def __init__(self, name: str, var_type: str, is_input: bool = False, is_output: bool = False):
self.name = name
self.type = var_type
self.is_input = is_input
self.is_output = is_output
self.register = None # Will be assigned during register allocation
def to_dict(self):
return {
"name": self.name,
"type": self.type,
"is_input": self.is_input,
"is_output": self.is_output,
"register": self.register
}
@classmethod
def from_dict(cls, data):
var = cls(data["name"], data["type"], data["is_input"], data["is_output"])
var.register = data["register"]
return var
class ShaderCompiler:
def __init__(self, db_url: Optional[str] = None):
# Initialize database
self.db = ShaderCompilerDB(db_url)
self.temp_counter = 0
# Initialize operation mappings
self.vector_ops = {
'+': 'add', '-': 'sub', '*': 'mul', '/': 'div',
'dot': 'dot', 'cross': 'cross', 'normalize': 'normalize'
}
# Built-in functions and their implementations
self.built_ins = {
'texture2D': self._compile_texture2D,
'normalize': self._compile_normalize,
'dot': self._compile_dot,
'mix': self._compile_mix,
'clamp': self._compile_clamp
}
# Current compilation context
self.current_shader_id = None
def compile_shader(self, shader_source: str, shader_type: Union[str, ShaderType]) -> dict:
"""Compile a shader from source code into virtual GPU instructions."""
if isinstance(shader_type, str):
shader_type = ShaderType(shader_type)
try:
# Generate shader ID and set as current
self.current_shader_id = self._generate_shader_id(shader_source)
self.temp_counter = 0
# Parse input/output variables
variables = self._parse_interface_variables(shader_source, shader_type)
# Parse and compile the main function
instructions = self._parse_main_function(shader_source)
# Store initial instructions in DB
self._store_instructions(instructions)
# Perform optimizations
self._optimize_instructions()
# Allocate registers
self._allocate_registers(variables)
# Store shader in database
self.db.conn.execute("""
INSERT INTO shaders (id, type, source, variables, instructions)
VALUES (?, ?, ?, ?, ?)
""", (
self.current_shader_id,
shader_type.value,
shader_source,
json.dumps({name: var.to_dict() for name, var in variables.items()}),
json.dumps([instr.to_dict() for instr in instructions])
))
# Fetch the complete compiled program
result = self.db.conn.execute("""
SELECT type, source, variables, instructions
FROM shaders WHERE id = ?
""", [self.current_shader_id]).fetchone()
compiled_program = {
"id": self.current_shader_id,
"type": result[0],
"source": result[1],
"variables": json.loads(result[2]),
"instructions": json.loads(result[3])
}
self.db.conn.commit()
return compiled_program
except Exception as e:
raise ShaderError(f"Compilation failed: {str(e)}")
def _parse_interface_variables(self, source: str, shader_type: ShaderType):
"""Parse input and output variable declarations."""
# Match input/output variable declarations
input_pattern = r'in\s+(\w+)\s+(\w+)\s*;'
output_pattern = r'out\s+(\w+)\s+(\w+)\s*;'
for match in re.finditer(input_pattern, source):
var_type, var_name = match.groups()
self.variables[var_name] = Variable(var_name, var_type, is_input=True)
for match in re.finditer(output_pattern, source):
var_type, var_name = match.groups()
self.variables[var_name] = Variable(var_name, var_type, is_output=True)
# Add built-in variables based on shader type
if shader_type == ShaderType.VERTEX:
self.variables['gl_Position'] = Variable('gl_Position', 'vec4', is_output=True)
elif shader_type == ShaderType.FRAGMENT:
self.variables['gl_FragColor'] = Variable('gl_FragColor', 'vec4', is_output=True)
def _parse_main_function(self, source: str):
"""Parse and compile the main function body."""
# Extract main function body
main_pattern = r'void\s+main\s*\(\s*\)\s*{([^}]*)}'
main_match = re.search(main_pattern, source)
if not main_match:
raise ShaderError("No main function found")
main_body = main_match.group(1)
# Split into statements
statements = [s.strip() for s in main_body.split(';') if s.strip()]
# Compile each statement
for stmt in statements:
self._compile_statement(stmt)
def _compile_statement(self, statement: str):
"""Compile a single statement into instructions."""
# Handle assignments
if '=' in statement:
target, expr = [s.strip() for s in statement.split('=')]
result = self._compile_expression(expr)
self.instructions.append(Instruction('mov', [result], target))
return
# Handle function calls
if '(' in statement:
self._compile_expression(statement)
return
raise ShaderError(f"Unsupported statement: {statement}")
def _compile_expression(self, expr: str) -> str:
"""Compile an expression and return the register/variable containing the result."""
# Handle parentheses first
if '(' in expr:
# Handle function calls
if any(builtin in expr for builtin in self.built_ins):
for builtin, compiler in self.built_ins.items():
if builtin in expr:
return compiler(expr)
# Handle parenthesized expressions
inner = self._extract_parenthesized(expr)
result = self._compile_expression(inner)
return result
# Handle basic arithmetic
for op in self.vector_ops:
if op in expr:
left, right = [s.strip() for s in expr.split(op)]
left_reg = self._compile_expression(left)
right_reg = self._compile_expression(right)
result = self._new_temp()
self.instructions.append(Instruction(
self.vector_ops[op],
[left_reg, right_reg],
result
))
return result
# Must be a variable or literal
return expr
def _compile_texture2D(self, expr: str) -> str:
"""Compile a texture2D builtin function call."""
args = self._extract_args(expr)
if len(args) != 2:
raise ShaderError("texture2D requires 2 arguments")
sampler = self._compile_expression(args[0])
coords = self._compile_expression(args[1])
result = self._new_temp()
self.instructions.append(Instruction(
'texture2D',
[sampler, coords],
result
))
return result
def _optimize_instructions(self):
"""Perform basic optimizations on the instruction stream."""
# Constant folding
self._fold_constants()
# Dead code elimination
self._eliminate_dead_code()
# Common subexpression elimination
self._eliminate_common_subexpressions()
# Update optimized instructions in DB
self.db.conn.commit()
def _allocate_registers(self, variables: Dict[str, Variable]):
"""Allocate hardware registers to variables."""
used_registers = set()
# Get existing register allocations
result = self.db.conn.execute("""
SELECT var_name, register_name FROM registers
WHERE shader_id = ?
""", [self.current_shader_id]).fetchall()
existing_registers = {r[0]: r[1] for r in result}
used_registers.update(existing_registers.values())
# Allocate input/output variables first
for var in variables.values():
if var.is_input or var.is_output:
if var.name not in existing_registers:
reg = self._find_free_register(used_registers)
var.register = reg
used_registers.add(reg)
# Store in DB
self.db.ensure_connection()
self.db.conn.execute("""
INSERT INTO registers (shader_id, var_name, register_name)
VALUES (?, ?, ?)
""", (self.current_shader_id, var.name, reg))
else:
var.register = existing_registers[var.name]
# Allocate temporaries
self.db.ensure_connection()
result = self.db.conn.execute("""
SELECT DISTINCT result FROM instructions
WHERE shader_id = ? AND result IS NOT NULL
""", [self.current_shader_id]).fetchall()
for (temp_var,) in result:
if temp_var not in existing_registers:
reg = self._find_free_register(used_registers)
# Store in DB
self.db.ensure_connection()
self.db.conn.execute("""
INSERT INTO registers (shader_id, var_name, register_name)
VALUES (?, ?, ?)
""", (self.current_shader_id, temp_var, reg))
used_registers.add(reg)
self.db.conn.commit()
def link_program(self, vertex_shader: dict, fragment_shader: dict) -> dict:
"""Link vertex and fragment shaders into a complete program."""
# Verify shader types
if vertex_shader['type'] != 'vertex' or fragment_shader['type'] != 'fragment':
raise ShaderError("Invalid shader types for linking")
# Check interface compatibility
self._verify_interface_compatibility(vertex_shader, fragment_shader)
# Generate program ID
program_id = self._generate_program_id(vertex_shader, fragment_shader)
# Create linked program in database
self.db.ensure_connection()
self.db.conn.execute("""
INSERT INTO programs (
id, vertex_shader_id, fragment_shader_id,
uniforms, attributes, varyings, linked
) VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
program_id,
vertex_shader['id'],
fragment_shader['id'],
json.dumps(self._collect_uniforms(vertex_shader, fragment_shader)),
json.dumps(self._collect_attributes(vertex_shader)),
json.dumps(self._collect_varyings(vertex_shader, fragment_shader)),
True
))
# Fetch complete program
result = self.db.conn.execute("""
SELECT p.*,
vs.source as vertex_source,
vs.instructions as vertex_instructions,
fs.source as fragment_source,
fs.instructions as fragment_instructions
FROM programs p
JOIN shaders vs ON p.vertex_shader_id = vs.id
JOIN shaders fs ON p.fragment_shader_id = fs.id
WHERE p.id = ?
""", [program_id]).fetchone()
program = {
"id": result[0],
"vertex_shader": {
"id": result[1],
"source": result[7],
"instructions": json.loads(result[8])
},
"fragment_shader": {
"id": result[2],
"source": result[9],
"instructions": json.loads(result[10])
},
"uniforms": json.loads(result[3]),
"attributes": json.loads(result[4]),
"varyings": json.loads(result[5]),
"linked": result[6]
}
self.db.conn.commit()
return program
def validate_program(self, program: dict) -> bool:
"""Validate a linked program."""
try:
# Check required components
if not all(k in program for k in ['vertex_shader', 'fragment_shader', 'linked']):
return False
# Verify shader validity
for shader in [program['vertex_shader'], program['fragment_shader']]:
if not all(k in shader for k in ['type', 'instructions', 'variables']):
return False
# Check interface compatibility
vertex_outputs = {
name for name, var in program['vertex_shader']['variables'].items()
if var['is_output']
}
fragment_inputs = {
name for name, var in program['fragment_shader']['variables'].items()
if var['is_input']
}
if not fragment_inputs.issubset(vertex_outputs):
return False
return True
except Exception:
return False
def _new_temp(self) -> str:
"""Generate a new temporary variable name."""
self.temp_counter += 1
return f"temp_{self.temp_counter}"
def _extract_parenthesized(self, expr: str) -> str:
"""Extract content between outermost parentheses."""
start = expr.index('(')
count = 1
for i, c in enumerate(expr[start + 1:], start + 1):
if c == '(':
count += 1
elif c == ')':
count -= 1
if count == 0:
return expr[start + 1:i]
raise ShaderError("Mismatched parentheses")
def _extract_args(self, expr: str) -> List[str]:
"""Extract function arguments."""
args_str = self._extract_parenthesized(expr)
return [arg.strip() for arg in args_str.split(',')]
def _generate_shader_id(self, source: str) -> str:
"""Generate a unique shader ID."""
return f"shader_{hashlib.md5(source.encode()).hexdigest()[:8]}"
def _generate_program_id(self, vertex_shader: dict, fragment_shader: dict) -> str:
"""Generate a unique program ID."""
combined = vertex_shader['id'] + fragment_shader['id']
return f"program_{hashlib.md5(combined.encode()).hexdigest()[:8]}"
def _compile_normalize(self, expr: str) -> str:
args = self._extract_args(expr)
if len(args) != 1:
raise ShaderError("normalize requires 1 argument")
vec = self._compile_expression(args[0])
result = self._new_temp()
self.instructions.append(Instruction('normalize', [vec], result))
return result
def _compile_dot(self, expr: str) -> str:
args = self._extract_args(expr)
if len(args) != 2:
raise ShaderError("dot requires 2 arguments")
vec1 = self._compile_expression(args[0])
vec2 = self._compile_expression(args[1])
result = self._new_temp()
self.instructions.append(Instruction('dot', [vec1, vec2], result))
return result
def _compile_mix(self, expr: str) -> str:
args = self._extract_args(expr)
if len(args) != 3:
raise ShaderError("mix requires 3 arguments")
x = self._compile_expression(args[0])
y = self._compile_expression(args[1])
a = self._compile_expression(args[2])
result = self._new_temp()
self.instructions.append(Instruction('mix', [x, y, a], result))
return result
def _compile_clamp(self, expr: str) -> str:
args = self._extract_args(expr)
if len(args) != 3:
raise ShaderError("clamp requires 3 arguments")
x = self._compile_expression(args[0])
min_val = self._compile_expression(args[1])
max_val = self._compile_expression(args[2])
result = self._new_temp()
self.instructions.append(Instruction('clamp', [x, min_val, max_val], result))
return result
def _fold_constants(self):
"""Perform constant folding optimization."""
# Get constant expressions
self.db.conn.execute("""
UPDATE instructions
SET is_dead = TRUE
WHERE shader_id = ? AND opcode IN ('add', 'sub', 'mul', 'div')
AND args[0] LIKE '%[0-9]+%' AND args[1] LIKE '%[0-9]+%'
""", [self.current_shader_id])
def _eliminate_dead_code(self):
"""Perform dead code elimination."""
# Get output variables
outputs = self.db.conn.execute("""
SELECT var_name FROM registers r
JOIN shaders s ON r.shader_id = s.id
JOIN json_each(s.variables) v ON v.key = r.var_name
WHERE s.id = ? AND json_extract(v.value, '$.is_output') = true
""", [self.current_shader_id]).fetchall()
# Mark used variables recursively
used_vars = set(r[0] for r in outputs)
while True:
new_used = self.db.conn.execute("""
SELECT DISTINCT a.value::VARCHAR
FROM instructions i,
json_array_elements_text(i.args) a
WHERE i.shader_id = ?
AND i.result IN ?
AND NOT i.is_dead
AND a.value NOT IN ?
""", [self.current_shader_id, tuple(used_vars), tuple(used_vars)]).fetchall()
if not new_used:
break
used_vars.update(r[0] for r in new_used)
# Mark unused instructions as dead
self.db.conn.execute("""
UPDATE instructions
SET is_dead = TRUE
WHERE shader_id = ?
AND (result IS NULL OR result NOT IN ?)
""", [self.current_shader_id, tuple(used_vars)])
def _eliminate_common_subexpressions(self):
"""Perform common subexpression elimination."""
# Find duplicate expressions
self.db.conn.execute("""
WITH expr_groups AS (
SELECT opcode, args, MIN(id) as first_id,
array_agg(id) as duplicate_ids
FROM instructions
WHERE shader_id = ? AND NOT is_dead
GROUP BY opcode, args
HAVING COUNT(*) > 1
)
UPDATE instructions i
SET is_dead = TRUE
WHERE id IN (
SELECT unnest(duplicate_ids[2:])
FROM expr_groups
)
AND shader_id = ?
""", [self.current_shader_id, self.current_shader_id])
# Add move instructions for the duplicates
duplicates = self.db.conn.execute("""
WITH expr_groups AS (
SELECT opcode, args, result,
MIN(id) as first_id,
array_agg(id) as duplicate_ids,
array_agg(result) as results
FROM instructions
WHERE shader_id = ?
GROUP BY opcode, args
HAVING COUNT(*) > 1
)
SELECT first_id, results
FROM expr_groups
""", [self.current_shader_id]).fetchall()
for first_id, results in duplicates:
results = json.loads(results)
original_result = results[0]
for result in results[1:]:
self.db.conn.execute("""
INSERT INTO instructions (shader_id, opcode, args, result)
VALUES (?, 'mov', ?, ?)
""", [self.current_shader_id, json.dumps([original_result]), result])
def _find_free_register(self, used_registers: set) -> str:
"""Find an unused register."""
i = 0
while f"r{i}" in used_registers:
i += 1
return f"r{i}"
def _verify_interface_compatibility(self, vertex_shader: dict, fragment_shader: dict):
"""Verify that shader interfaces are compatible."""
vertex_outputs = {
name: var for name, var in vertex_shader['variables'].items()
if var['is_output']
}
fragment_inputs = {
name: var for name, var in fragment_shader['variables'].items()
if var['is_input']
}
# Check that all fragment inputs have matching vertex outputs
for name, var in fragment_inputs.items():
if name not in vertex_outputs:
raise ShaderError(f"Fragment shader input '{name}' has no matching vertex output")
if vertex_outputs[name]['type'] != var['type']:
raise ShaderError(f"Type mismatch for varying '{name}'")
def _collect_uniforms(self, vertex_shader: dict, fragment_shader: dict) -> dict:
"""Collect all uniform variables from both shaders."""
uniforms = {}
for shader in [vertex_shader, fragment_shader]:
for name, var in shader['variables'].items():
if 'uniform' in var.get('qualifiers', []):
uniforms[name] = var
return uniforms
def _collect_attributes(self, vertex_shader: dict) -> dict:
"""Collect vertex attributes."""
return {
name: var for name, var in vertex_shader['variables'].items()
if var['is_input']
}
def _collect_varyings(self, vertex_shader: dict, fragment_shader: dict) -> dict:
"""Collect varying variables (vertex outputs / fragment inputs)."""
return {
name: var for name, var in vertex_shader['variables'].items()
if var['is_output'] and name in fragment_shader['variables']
}