from enum import Enum import re from typing import Dict, List, Optional, Union import hashlib import duckdb import json from config import get_db_url, get_hf_token_cached import logging class ShaderType(Enum): VERTEX = "vertex" FRAGMENT = "fragment" COMPUTE = "compute" class ShaderError(Exception): pass class ShaderCompilerDB: DB_URL = "hf://datasets/Fred808/helium/storage.json" def __init__(self, db_url: Optional[str] = None): """Initialize shader compiler database""" self.db_url = db_url or self.DB_URL self.max_retries = 3 self._connect_with_retries() def _connect_with_retries(self): """Establish database connection with retry logic""" for attempt in range(self.max_retries): try: self._init_db_connection() self._setup_database() return except Exception as e: if attempt == self.max_retries - 1: raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}") logging.warning(f"Database connection attempt {attempt + 1} failed, retrying...") def _init_db_connection(self): """Initialize database connection with HuggingFace configuration""" # Convert HF URL to S3 path and connect directly _, _, owner, dataset, db_file = self.db_url.split('/', 4) db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}" # Connect directly to remote database self.conn = duckdb.connect(db_path) self.conn.execute(""" INSTALL httpfs; LOAD httpfs; SET s3_region='us-east-1'; SET s3_endpoint='s3.us-east-1.amazonaws.com'; SET s3_url_style='path'; SET s3_access_key_id='none'; SET s3_secret_access_key=?; """, [self.HF_TOKEN]) def ensure_connection(self): """Ensure database connection is active and valid""" try: self.conn.execute("SELECT 1") except: logging.warning("Database connection lost, attempting to reconnect...") self._connect_with_retries() def _setup_database(self): """Set up shader compiler tables""" self.ensure_connection() # Shader storage self.conn.execute(""" CREATE TABLE IF NOT EXISTS shaders ( id VARCHAR PRIMARY KEY, type VARCHAR, source TEXT, variables JSON, instructions JSON ) """) # Program storage (linked shaders) self.conn.execute(""" CREATE TABLE IF NOT EXISTS programs ( id VARCHAR PRIMARY KEY, vertex_shader_id VARCHAR, fragment_shader_id VARCHAR, uniforms JSON, attributes JSON, varyings JSON, linked BOOLEAN ) """) # Register allocation table self.conn.execute(""" CREATE TABLE IF NOT EXISTS registers ( shader_id VARCHAR, var_name VARCHAR, register_name VARCHAR, PRIMARY KEY (shader_id, var_name) ) """) # Instructions table with optimization metadata self.conn.execute(""" CREATE TABLE IF NOT EXISTS instructions ( id INTEGER PRIMARY KEY, shader_id VARCHAR, opcode VARCHAR, args JSON, result VARCHAR, is_dead BOOLEAN DEFAULT FALSE, depends_on JSON ) """) self.conn.commit() class Instruction: def __init__(self, opcode: str, args: List[str], result: Optional[str] = None): self.opcode = opcode self.args = args self.result = result self.id = None # Will be set when stored in DB def __str__(self): result_str = f"{self.result} = " if self.result else "" return f"{result_str}{self.opcode} {', '.join(self.args)}" def to_dict(self): return { "opcode": self.opcode, "args": self.args, "result": self.result } @classmethod def from_dict(cls, data): return cls(data["opcode"], data["args"], data["result"]) class Variable: def __init__(self, name: str, var_type: str, is_input: bool = False, is_output: bool = False): self.name = name self.type = var_type self.is_input = is_input self.is_output = is_output self.register = None # Will be assigned during register allocation def to_dict(self): return { "name": self.name, "type": self.type, "is_input": self.is_input, "is_output": self.is_output, "register": self.register } @classmethod def from_dict(cls, data): var = cls(data["name"], data["type"], data["is_input"], data["is_output"]) var.register = data["register"] return var class ShaderCompiler: def __init__(self, db_url: Optional[str] = None): # Initialize database self.db = ShaderCompilerDB(db_url) self.temp_counter = 0 # Initialize operation mappings self.vector_ops = { '+': 'add', '-': 'sub', '*': 'mul', '/': 'div', 'dot': 'dot', 'cross': 'cross', 'normalize': 'normalize' } # Built-in functions and their implementations self.built_ins = { 'texture2D': self._compile_texture2D, 'normalize': self._compile_normalize, 'dot': self._compile_dot, 'mix': self._compile_mix, 'clamp': self._compile_clamp } # Current compilation context self.current_shader_id = None def compile_shader(self, shader_source: str, shader_type: Union[str, ShaderType]) -> dict: """Compile a shader from source code into virtual GPU instructions.""" if isinstance(shader_type, str): shader_type = ShaderType(shader_type) try: # Generate shader ID and set as current self.current_shader_id = self._generate_shader_id(shader_source) self.temp_counter = 0 # Parse input/output variables variables = self._parse_interface_variables(shader_source, shader_type) # Parse and compile the main function instructions = self._parse_main_function(shader_source) # Store initial instructions in DB self._store_instructions(instructions) # Perform optimizations self._optimize_instructions() # Allocate registers self._allocate_registers(variables) # Store shader in database self.db.conn.execute(""" INSERT INTO shaders (id, type, source, variables, instructions) VALUES (?, ?, ?, ?, ?) """, ( self.current_shader_id, shader_type.value, shader_source, json.dumps({name: var.to_dict() for name, var in variables.items()}), json.dumps([instr.to_dict() for instr in instructions]) )) # Fetch the complete compiled program result = self.db.conn.execute(""" SELECT type, source, variables, instructions FROM shaders WHERE id = ? """, [self.current_shader_id]).fetchone() compiled_program = { "id": self.current_shader_id, "type": result[0], "source": result[1], "variables": json.loads(result[2]), "instructions": json.loads(result[3]) } self.db.conn.commit() return compiled_program except Exception as e: raise ShaderError(f"Compilation failed: {str(e)}") def _parse_interface_variables(self, source: str, shader_type: ShaderType): """Parse input and output variable declarations.""" # Match input/output variable declarations input_pattern = r'in\s+(\w+)\s+(\w+)\s*;' output_pattern = r'out\s+(\w+)\s+(\w+)\s*;' for match in re.finditer(input_pattern, source): var_type, var_name = match.groups() self.variables[var_name] = Variable(var_name, var_type, is_input=True) for match in re.finditer(output_pattern, source): var_type, var_name = match.groups() self.variables[var_name] = Variable(var_name, var_type, is_output=True) # Add built-in variables based on shader type if shader_type == ShaderType.VERTEX: self.variables['gl_Position'] = Variable('gl_Position', 'vec4', is_output=True) elif shader_type == ShaderType.FRAGMENT: self.variables['gl_FragColor'] = Variable('gl_FragColor', 'vec4', is_output=True) def _parse_main_function(self, source: str): """Parse and compile the main function body.""" # Extract main function body main_pattern = r'void\s+main\s*\(\s*\)\s*{([^}]*)}' main_match = re.search(main_pattern, source) if not main_match: raise ShaderError("No main function found") main_body = main_match.group(1) # Split into statements statements = [s.strip() for s in main_body.split(';') if s.strip()] # Compile each statement for stmt in statements: self._compile_statement(stmt) def _compile_statement(self, statement: str): """Compile a single statement into instructions.""" # Handle assignments if '=' in statement: target, expr = [s.strip() for s in statement.split('=')] result = self._compile_expression(expr) self.instructions.append(Instruction('mov', [result], target)) return # Handle function calls if '(' in statement: self._compile_expression(statement) return raise ShaderError(f"Unsupported statement: {statement}") def _compile_expression(self, expr: str) -> str: """Compile an expression and return the register/variable containing the result.""" # Handle parentheses first if '(' in expr: # Handle function calls if any(builtin in expr for builtin in self.built_ins): for builtin, compiler in self.built_ins.items(): if builtin in expr: return compiler(expr) # Handle parenthesized expressions inner = self._extract_parenthesized(expr) result = self._compile_expression(inner) return result # Handle basic arithmetic for op in self.vector_ops: if op in expr: left, right = [s.strip() for s in expr.split(op)] left_reg = self._compile_expression(left) right_reg = self._compile_expression(right) result = self._new_temp() self.instructions.append(Instruction( self.vector_ops[op], [left_reg, right_reg], result )) return result # Must be a variable or literal return expr def _compile_texture2D(self, expr: str) -> str: """Compile a texture2D builtin function call.""" args = self._extract_args(expr) if len(args) != 2: raise ShaderError("texture2D requires 2 arguments") sampler = self._compile_expression(args[0]) coords = self._compile_expression(args[1]) result = self._new_temp() self.instructions.append(Instruction( 'texture2D', [sampler, coords], result )) return result def _optimize_instructions(self): """Perform basic optimizations on the instruction stream.""" # Constant folding self._fold_constants() # Dead code elimination self._eliminate_dead_code() # Common subexpression elimination self._eliminate_common_subexpressions() # Update optimized instructions in DB self.db.conn.commit() def _allocate_registers(self, variables: Dict[str, Variable]): """Allocate hardware registers to variables.""" used_registers = set() # Get existing register allocations result = self.db.conn.execute(""" SELECT var_name, register_name FROM registers WHERE shader_id = ? """, [self.current_shader_id]).fetchall() existing_registers = {r[0]: r[1] for r in result} used_registers.update(existing_registers.values()) # Allocate input/output variables first for var in variables.values(): if var.is_input or var.is_output: if var.name not in existing_registers: reg = self._find_free_register(used_registers) var.register = reg used_registers.add(reg) # Store in DB self.db.ensure_connection() self.db.conn.execute(""" INSERT INTO registers (shader_id, var_name, register_name) VALUES (?, ?, ?) """, (self.current_shader_id, var.name, reg)) else: var.register = existing_registers[var.name] # Allocate temporaries self.db.ensure_connection() result = self.db.conn.execute(""" SELECT DISTINCT result FROM instructions WHERE shader_id = ? AND result IS NOT NULL """, [self.current_shader_id]).fetchall() for (temp_var,) in result: if temp_var not in existing_registers: reg = self._find_free_register(used_registers) # Store in DB self.db.ensure_connection() self.db.conn.execute(""" INSERT INTO registers (shader_id, var_name, register_name) VALUES (?, ?, ?) """, (self.current_shader_id, temp_var, reg)) used_registers.add(reg) self.db.conn.commit() def link_program(self, vertex_shader: dict, fragment_shader: dict) -> dict: """Link vertex and fragment shaders into a complete program.""" # Verify shader types if vertex_shader['type'] != 'vertex' or fragment_shader['type'] != 'fragment': raise ShaderError("Invalid shader types for linking") # Check interface compatibility self._verify_interface_compatibility(vertex_shader, fragment_shader) # Generate program ID program_id = self._generate_program_id(vertex_shader, fragment_shader) # Create linked program in database self.db.ensure_connection() self.db.conn.execute(""" INSERT INTO programs ( id, vertex_shader_id, fragment_shader_id, uniforms, attributes, varyings, linked ) VALUES (?, ?, ?, ?, ?, ?, ?) """, ( program_id, vertex_shader['id'], fragment_shader['id'], json.dumps(self._collect_uniforms(vertex_shader, fragment_shader)), json.dumps(self._collect_attributes(vertex_shader)), json.dumps(self._collect_varyings(vertex_shader, fragment_shader)), True )) # Fetch complete program result = self.db.conn.execute(""" SELECT p.*, vs.source as vertex_source, vs.instructions as vertex_instructions, fs.source as fragment_source, fs.instructions as fragment_instructions FROM programs p JOIN shaders vs ON p.vertex_shader_id = vs.id JOIN shaders fs ON p.fragment_shader_id = fs.id WHERE p.id = ? """, [program_id]).fetchone() program = { "id": result[0], "vertex_shader": { "id": result[1], "source": result[7], "instructions": json.loads(result[8]) }, "fragment_shader": { "id": result[2], "source": result[9], "instructions": json.loads(result[10]) }, "uniforms": json.loads(result[3]), "attributes": json.loads(result[4]), "varyings": json.loads(result[5]), "linked": result[6] } self.db.conn.commit() return program def validate_program(self, program: dict) -> bool: """Validate a linked program.""" try: # Check required components if not all(k in program for k in ['vertex_shader', 'fragment_shader', 'linked']): return False # Verify shader validity for shader in [program['vertex_shader'], program['fragment_shader']]: if not all(k in shader for k in ['type', 'instructions', 'variables']): return False # Check interface compatibility vertex_outputs = { name for name, var in program['vertex_shader']['variables'].items() if var['is_output'] } fragment_inputs = { name for name, var in program['fragment_shader']['variables'].items() if var['is_input'] } if not fragment_inputs.issubset(vertex_outputs): return False return True except Exception: return False def _new_temp(self) -> str: """Generate a new temporary variable name.""" self.temp_counter += 1 return f"temp_{self.temp_counter}" def _extract_parenthesized(self, expr: str) -> str: """Extract content between outermost parentheses.""" start = expr.index('(') count = 1 for i, c in enumerate(expr[start + 1:], start + 1): if c == '(': count += 1 elif c == ')': count -= 1 if count == 0: return expr[start + 1:i] raise ShaderError("Mismatched parentheses") def _extract_args(self, expr: str) -> List[str]: """Extract function arguments.""" args_str = self._extract_parenthesized(expr) return [arg.strip() for arg in args_str.split(',')] def _generate_shader_id(self, source: str) -> str: """Generate a unique shader ID.""" return f"shader_{hashlib.md5(source.encode()).hexdigest()[:8]}" def _generate_program_id(self, vertex_shader: dict, fragment_shader: dict) -> str: """Generate a unique program ID.""" combined = vertex_shader['id'] + fragment_shader['id'] return f"program_{hashlib.md5(combined.encode()).hexdigest()[:8]}" def _compile_normalize(self, expr: str) -> str: args = self._extract_args(expr) if len(args) != 1: raise ShaderError("normalize requires 1 argument") vec = self._compile_expression(args[0]) result = self._new_temp() self.instructions.append(Instruction('normalize', [vec], result)) return result def _compile_dot(self, expr: str) -> str: args = self._extract_args(expr) if len(args) != 2: raise ShaderError("dot requires 2 arguments") vec1 = self._compile_expression(args[0]) vec2 = self._compile_expression(args[1]) result = self._new_temp() self.instructions.append(Instruction('dot', [vec1, vec2], result)) return result def _compile_mix(self, expr: str) -> str: args = self._extract_args(expr) if len(args) != 3: raise ShaderError("mix requires 3 arguments") x = self._compile_expression(args[0]) y = self._compile_expression(args[1]) a = self._compile_expression(args[2]) result = self._new_temp() self.instructions.append(Instruction('mix', [x, y, a], result)) return result def _compile_clamp(self, expr: str) -> str: args = self._extract_args(expr) if len(args) != 3: raise ShaderError("clamp requires 3 arguments") x = self._compile_expression(args[0]) min_val = self._compile_expression(args[1]) max_val = self._compile_expression(args[2]) result = self._new_temp() self.instructions.append(Instruction('clamp', [x, min_val, max_val], result)) return result def _fold_constants(self): """Perform constant folding optimization.""" # Get constant expressions self.db.conn.execute(""" UPDATE instructions SET is_dead = TRUE WHERE shader_id = ? AND opcode IN ('add', 'sub', 'mul', 'div') AND args[0] LIKE '%[0-9]+%' AND args[1] LIKE '%[0-9]+%' """, [self.current_shader_id]) def _eliminate_dead_code(self): """Perform dead code elimination.""" # Get output variables outputs = self.db.conn.execute(""" SELECT var_name FROM registers r JOIN shaders s ON r.shader_id = s.id JOIN json_each(s.variables) v ON v.key = r.var_name WHERE s.id = ? AND json_extract(v.value, '$.is_output') = true """, [self.current_shader_id]).fetchall() # Mark used variables recursively used_vars = set(r[0] for r in outputs) while True: new_used = self.db.conn.execute(""" SELECT DISTINCT a.value::VARCHAR FROM instructions i, json_array_elements_text(i.args) a WHERE i.shader_id = ? AND i.result IN ? AND NOT i.is_dead AND a.value NOT IN ? """, [self.current_shader_id, tuple(used_vars), tuple(used_vars)]).fetchall() if not new_used: break used_vars.update(r[0] for r in new_used) # Mark unused instructions as dead self.db.conn.execute(""" UPDATE instructions SET is_dead = TRUE WHERE shader_id = ? AND (result IS NULL OR result NOT IN ?) """, [self.current_shader_id, tuple(used_vars)]) def _eliminate_common_subexpressions(self): """Perform common subexpression elimination.""" # Find duplicate expressions self.db.conn.execute(""" WITH expr_groups AS ( SELECT opcode, args, MIN(id) as first_id, array_agg(id) as duplicate_ids FROM instructions WHERE shader_id = ? AND NOT is_dead GROUP BY opcode, args HAVING COUNT(*) > 1 ) UPDATE instructions i SET is_dead = TRUE WHERE id IN ( SELECT unnest(duplicate_ids[2:]) FROM expr_groups ) AND shader_id = ? """, [self.current_shader_id, self.current_shader_id]) # Add move instructions for the duplicates duplicates = self.db.conn.execute(""" WITH expr_groups AS ( SELECT opcode, args, result, MIN(id) as first_id, array_agg(id) as duplicate_ids, array_agg(result) as results FROM instructions WHERE shader_id = ? GROUP BY opcode, args HAVING COUNT(*) > 1 ) SELECT first_id, results FROM expr_groups """, [self.current_shader_id]).fetchall() for first_id, results in duplicates: results = json.loads(results) original_result = results[0] for result in results[1:]: self.db.conn.execute(""" INSERT INTO instructions (shader_id, opcode, args, result) VALUES (?, 'mov', ?, ?) """, [self.current_shader_id, json.dumps([original_result]), result]) def _find_free_register(self, used_registers: set) -> str: """Find an unused register.""" i = 0 while f"r{i}" in used_registers: i += 1 return f"r{i}" def _verify_interface_compatibility(self, vertex_shader: dict, fragment_shader: dict): """Verify that shader interfaces are compatible.""" vertex_outputs = { name: var for name, var in vertex_shader['variables'].items() if var['is_output'] } fragment_inputs = { name: var for name, var in fragment_shader['variables'].items() if var['is_input'] } # Check that all fragment inputs have matching vertex outputs for name, var in fragment_inputs.items(): if name not in vertex_outputs: raise ShaderError(f"Fragment shader input '{name}' has no matching vertex output") if vertex_outputs[name]['type'] != var['type']: raise ShaderError(f"Type mismatch for varying '{name}'") def _collect_uniforms(self, vertex_shader: dict, fragment_shader: dict) -> dict: """Collect all uniform variables from both shaders.""" uniforms = {} for shader in [vertex_shader, fragment_shader]: for name, var in shader['variables'].items(): if 'uniform' in var.get('qualifiers', []): uniforms[name] = var return uniforms def _collect_attributes(self, vertex_shader: dict) -> dict: """Collect vertex attributes.""" return { name: var for name, var in vertex_shader['variables'].items() if var['is_input'] } def _collect_varyings(self, vertex_shader: dict, fragment_shader: dict) -> dict: """Collect varying variables (vertex outputs / fragment inputs).""" return { name: var for name, var in vertex_shader['variables'].items() if var['is_output'] and name in fragment_shader['variables'] }