|
|
from enum import Enum
|
|
|
import re
|
|
|
from typing import Dict, List, Optional, Union
|
|
|
import hashlib
|
|
|
import duckdb
|
|
|
import json
|
|
|
from config import get_db_url, get_hf_token_cached
|
|
|
import logging
|
|
|
|
|
|
class ShaderType(Enum):
|
|
|
VERTEX = "vertex"
|
|
|
FRAGMENT = "fragment"
|
|
|
COMPUTE = "compute"
|
|
|
|
|
|
class ShaderError(Exception):
|
|
|
pass
|
|
|
|
|
|
class ShaderCompilerDB:
|
|
|
DB_URL = "hf://datasets/Fred808/helium/storage.json"
|
|
|
|
|
|
def __init__(self, db_url: Optional[str] = None):
|
|
|
"""Initialize shader compiler database"""
|
|
|
self.db_url = db_url or self.DB_URL
|
|
|
self.max_retries = 3
|
|
|
self._connect_with_retries()
|
|
|
|
|
|
def _connect_with_retries(self):
|
|
|
"""Establish database connection with retry logic"""
|
|
|
for attempt in range(self.max_retries):
|
|
|
try:
|
|
|
self._init_db_connection()
|
|
|
self._setup_database()
|
|
|
return
|
|
|
except Exception as e:
|
|
|
if attempt == self.max_retries - 1:
|
|
|
raise RuntimeError(f"Failed to initialize database after {self.max_retries} attempts: {str(e)}")
|
|
|
logging.warning(f"Database connection attempt {attempt + 1} failed, retrying...")
|
|
|
|
|
|
def _init_db_connection(self):
|
|
|
"""Initialize database connection with HuggingFace configuration"""
|
|
|
|
|
|
_, _, owner, dataset, db_file = self.db_url.split('/', 4)
|
|
|
db_path = f"s3://datasets-cached/{owner}/{dataset}/{db_file}"
|
|
|
|
|
|
|
|
|
self.conn = duckdb.connect(db_path)
|
|
|
self.conn.execute("""
|
|
|
INSTALL httpfs;
|
|
|
LOAD httpfs;
|
|
|
SET s3_region='us-east-1';
|
|
|
SET s3_endpoint='s3.us-east-1.amazonaws.com';
|
|
|
SET s3_url_style='path';
|
|
|
SET s3_access_key_id='none';
|
|
|
SET s3_secret_access_key=?;
|
|
|
""", [self.HF_TOKEN])
|
|
|
|
|
|
def ensure_connection(self):
|
|
|
"""Ensure database connection is active and valid"""
|
|
|
try:
|
|
|
self.conn.execute("SELECT 1")
|
|
|
except:
|
|
|
logging.warning("Database connection lost, attempting to reconnect...")
|
|
|
self._connect_with_retries()
|
|
|
|
|
|
def _setup_database(self):
|
|
|
"""Set up shader compiler tables"""
|
|
|
self.ensure_connection()
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS shaders (
|
|
|
id VARCHAR PRIMARY KEY,
|
|
|
type VARCHAR,
|
|
|
source TEXT,
|
|
|
variables JSON,
|
|
|
instructions JSON
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS programs (
|
|
|
id VARCHAR PRIMARY KEY,
|
|
|
vertex_shader_id VARCHAR,
|
|
|
fragment_shader_id VARCHAR,
|
|
|
uniforms JSON,
|
|
|
attributes JSON,
|
|
|
varyings JSON,
|
|
|
linked BOOLEAN
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS registers (
|
|
|
shader_id VARCHAR,
|
|
|
var_name VARCHAR,
|
|
|
register_name VARCHAR,
|
|
|
PRIMARY KEY (shader_id, var_name)
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
|
|
|
self.conn.execute("""
|
|
|
CREATE TABLE IF NOT EXISTS instructions (
|
|
|
id INTEGER PRIMARY KEY,
|
|
|
shader_id VARCHAR,
|
|
|
opcode VARCHAR,
|
|
|
args JSON,
|
|
|
result VARCHAR,
|
|
|
is_dead BOOLEAN DEFAULT FALSE,
|
|
|
depends_on JSON
|
|
|
)
|
|
|
""")
|
|
|
|
|
|
self.conn.commit()
|
|
|
|
|
|
class Instruction:
|
|
|
def __init__(self, opcode: str, args: List[str], result: Optional[str] = None):
|
|
|
self.opcode = opcode
|
|
|
self.args = args
|
|
|
self.result = result
|
|
|
self.id = None
|
|
|
|
|
|
def __str__(self):
|
|
|
result_str = f"{self.result} = " if self.result else ""
|
|
|
return f"{result_str}{self.opcode} {', '.join(self.args)}"
|
|
|
|
|
|
def to_dict(self):
|
|
|
return {
|
|
|
"opcode": self.opcode,
|
|
|
"args": self.args,
|
|
|
"result": self.result
|
|
|
}
|
|
|
|
|
|
@classmethod
|
|
|
def from_dict(cls, data):
|
|
|
return cls(data["opcode"], data["args"], data["result"])
|
|
|
|
|
|
class Variable:
|
|
|
def __init__(self, name: str, var_type: str, is_input: bool = False, is_output: bool = False):
|
|
|
self.name = name
|
|
|
self.type = var_type
|
|
|
self.is_input = is_input
|
|
|
self.is_output = is_output
|
|
|
self.register = None
|
|
|
|
|
|
def to_dict(self):
|
|
|
return {
|
|
|
"name": self.name,
|
|
|
"type": self.type,
|
|
|
"is_input": self.is_input,
|
|
|
"is_output": self.is_output,
|
|
|
"register": self.register
|
|
|
}
|
|
|
|
|
|
@classmethod
|
|
|
def from_dict(cls, data):
|
|
|
var = cls(data["name"], data["type"], data["is_input"], data["is_output"])
|
|
|
var.register = data["register"]
|
|
|
return var
|
|
|
|
|
|
class ShaderCompiler:
|
|
|
def __init__(self, db_url: Optional[str] = None):
|
|
|
|
|
|
self.db = ShaderCompilerDB(db_url)
|
|
|
self.temp_counter = 0
|
|
|
|
|
|
|
|
|
self.vector_ops = {
|
|
|
'+': 'add', '-': 'sub', '*': 'mul', '/': 'div',
|
|
|
'dot': 'dot', 'cross': 'cross', 'normalize': 'normalize'
|
|
|
}
|
|
|
|
|
|
|
|
|
self.built_ins = {
|
|
|
'texture2D': self._compile_texture2D,
|
|
|
'normalize': self._compile_normalize,
|
|
|
'dot': self._compile_dot,
|
|
|
'mix': self._compile_mix,
|
|
|
'clamp': self._compile_clamp
|
|
|
}
|
|
|
|
|
|
|
|
|
self.current_shader_id = None
|
|
|
|
|
|
def compile_shader(self, shader_source: str, shader_type: Union[str, ShaderType]) -> dict:
|
|
|
"""Compile a shader from source code into virtual GPU instructions."""
|
|
|
if isinstance(shader_type, str):
|
|
|
shader_type = ShaderType(shader_type)
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.current_shader_id = self._generate_shader_id(shader_source)
|
|
|
self.temp_counter = 0
|
|
|
|
|
|
|
|
|
variables = self._parse_interface_variables(shader_source, shader_type)
|
|
|
|
|
|
|
|
|
instructions = self._parse_main_function(shader_source)
|
|
|
|
|
|
|
|
|
self._store_instructions(instructions)
|
|
|
|
|
|
|
|
|
self._optimize_instructions()
|
|
|
|
|
|
|
|
|
self._allocate_registers(variables)
|
|
|
|
|
|
|
|
|
self.db.conn.execute("""
|
|
|
INSERT INTO shaders (id, type, source, variables, instructions)
|
|
|
VALUES (?, ?, ?, ?, ?)
|
|
|
""", (
|
|
|
self.current_shader_id,
|
|
|
shader_type.value,
|
|
|
shader_source,
|
|
|
json.dumps({name: var.to_dict() for name, var in variables.items()}),
|
|
|
json.dumps([instr.to_dict() for instr in instructions])
|
|
|
))
|
|
|
|
|
|
|
|
|
result = self.db.conn.execute("""
|
|
|
SELECT type, source, variables, instructions
|
|
|
FROM shaders WHERE id = ?
|
|
|
""", [self.current_shader_id]).fetchone()
|
|
|
|
|
|
compiled_program = {
|
|
|
"id": self.current_shader_id,
|
|
|
"type": result[0],
|
|
|
"source": result[1],
|
|
|
"variables": json.loads(result[2]),
|
|
|
"instructions": json.loads(result[3])
|
|
|
}
|
|
|
|
|
|
self.db.conn.commit()
|
|
|
return compiled_program
|
|
|
|
|
|
except Exception as e:
|
|
|
raise ShaderError(f"Compilation failed: {str(e)}")
|
|
|
|
|
|
def _parse_interface_variables(self, source: str, shader_type: ShaderType):
|
|
|
"""Parse input and output variable declarations."""
|
|
|
|
|
|
input_pattern = r'in\s+(\w+)\s+(\w+)\s*;'
|
|
|
output_pattern = r'out\s+(\w+)\s+(\w+)\s*;'
|
|
|
|
|
|
for match in re.finditer(input_pattern, source):
|
|
|
var_type, var_name = match.groups()
|
|
|
self.variables[var_name] = Variable(var_name, var_type, is_input=True)
|
|
|
|
|
|
for match in re.finditer(output_pattern, source):
|
|
|
var_type, var_name = match.groups()
|
|
|
self.variables[var_name] = Variable(var_name, var_type, is_output=True)
|
|
|
|
|
|
|
|
|
if shader_type == ShaderType.VERTEX:
|
|
|
self.variables['gl_Position'] = Variable('gl_Position', 'vec4', is_output=True)
|
|
|
elif shader_type == ShaderType.FRAGMENT:
|
|
|
self.variables['gl_FragColor'] = Variable('gl_FragColor', 'vec4', is_output=True)
|
|
|
|
|
|
def _parse_main_function(self, source: str):
|
|
|
"""Parse and compile the main function body."""
|
|
|
|
|
|
main_pattern = r'void\s+main\s*\(\s*\)\s*{([^}]*)}'
|
|
|
main_match = re.search(main_pattern, source)
|
|
|
if not main_match:
|
|
|
raise ShaderError("No main function found")
|
|
|
|
|
|
main_body = main_match.group(1)
|
|
|
|
|
|
|
|
|
statements = [s.strip() for s in main_body.split(';') if s.strip()]
|
|
|
|
|
|
|
|
|
for stmt in statements:
|
|
|
self._compile_statement(stmt)
|
|
|
|
|
|
def _compile_statement(self, statement: str):
|
|
|
"""Compile a single statement into instructions."""
|
|
|
|
|
|
if '=' in statement:
|
|
|
target, expr = [s.strip() for s in statement.split('=')]
|
|
|
result = self._compile_expression(expr)
|
|
|
self.instructions.append(Instruction('mov', [result], target))
|
|
|
return
|
|
|
|
|
|
|
|
|
if '(' in statement:
|
|
|
self._compile_expression(statement)
|
|
|
return
|
|
|
|
|
|
raise ShaderError(f"Unsupported statement: {statement}")
|
|
|
|
|
|
def _compile_expression(self, expr: str) -> str:
|
|
|
"""Compile an expression and return the register/variable containing the result."""
|
|
|
|
|
|
if '(' in expr:
|
|
|
|
|
|
if any(builtin in expr for builtin in self.built_ins):
|
|
|
for builtin, compiler in self.built_ins.items():
|
|
|
if builtin in expr:
|
|
|
return compiler(expr)
|
|
|
|
|
|
|
|
|
inner = self._extract_parenthesized(expr)
|
|
|
result = self._compile_expression(inner)
|
|
|
return result
|
|
|
|
|
|
|
|
|
for op in self.vector_ops:
|
|
|
if op in expr:
|
|
|
left, right = [s.strip() for s in expr.split(op)]
|
|
|
left_reg = self._compile_expression(left)
|
|
|
right_reg = self._compile_expression(right)
|
|
|
result = self._new_temp()
|
|
|
self.instructions.append(Instruction(
|
|
|
self.vector_ops[op],
|
|
|
[left_reg, right_reg],
|
|
|
result
|
|
|
))
|
|
|
return result
|
|
|
|
|
|
|
|
|
return expr
|
|
|
|
|
|
def _compile_texture2D(self, expr: str) -> str:
|
|
|
"""Compile a texture2D builtin function call."""
|
|
|
args = self._extract_args(expr)
|
|
|
if len(args) != 2:
|
|
|
raise ShaderError("texture2D requires 2 arguments")
|
|
|
|
|
|
sampler = self._compile_expression(args[0])
|
|
|
coords = self._compile_expression(args[1])
|
|
|
result = self._new_temp()
|
|
|
|
|
|
self.instructions.append(Instruction(
|
|
|
'texture2D',
|
|
|
[sampler, coords],
|
|
|
result
|
|
|
))
|
|
|
return result
|
|
|
|
|
|
def _optimize_instructions(self):
|
|
|
"""Perform basic optimizations on the instruction stream."""
|
|
|
|
|
|
self._fold_constants()
|
|
|
|
|
|
|
|
|
self._eliminate_dead_code()
|
|
|
|
|
|
|
|
|
self._eliminate_common_subexpressions()
|
|
|
|
|
|
|
|
|
self.db.conn.commit()
|
|
|
|
|
|
def _allocate_registers(self, variables: Dict[str, Variable]):
|
|
|
"""Allocate hardware registers to variables."""
|
|
|
used_registers = set()
|
|
|
|
|
|
|
|
|
result = self.db.conn.execute("""
|
|
|
SELECT var_name, register_name FROM registers
|
|
|
WHERE shader_id = ?
|
|
|
""", [self.current_shader_id]).fetchall()
|
|
|
|
|
|
existing_registers = {r[0]: r[1] for r in result}
|
|
|
used_registers.update(existing_registers.values())
|
|
|
|
|
|
|
|
|
for var in variables.values():
|
|
|
if var.is_input or var.is_output:
|
|
|
if var.name not in existing_registers:
|
|
|
reg = self._find_free_register(used_registers)
|
|
|
var.register = reg
|
|
|
used_registers.add(reg)
|
|
|
|
|
|
self.db.ensure_connection()
|
|
|
self.db.conn.execute("""
|
|
|
INSERT INTO registers (shader_id, var_name, register_name)
|
|
|
VALUES (?, ?, ?)
|
|
|
""", (self.current_shader_id, var.name, reg))
|
|
|
else:
|
|
|
var.register = existing_registers[var.name]
|
|
|
|
|
|
|
|
|
self.db.ensure_connection()
|
|
|
result = self.db.conn.execute("""
|
|
|
SELECT DISTINCT result FROM instructions
|
|
|
WHERE shader_id = ? AND result IS NOT NULL
|
|
|
""", [self.current_shader_id]).fetchall()
|
|
|
|
|
|
for (temp_var,) in result:
|
|
|
if temp_var not in existing_registers:
|
|
|
reg = self._find_free_register(used_registers)
|
|
|
|
|
|
self.db.ensure_connection()
|
|
|
self.db.conn.execute("""
|
|
|
INSERT INTO registers (shader_id, var_name, register_name)
|
|
|
VALUES (?, ?, ?)
|
|
|
""", (self.current_shader_id, temp_var, reg))
|
|
|
used_registers.add(reg)
|
|
|
|
|
|
self.db.conn.commit()
|
|
|
|
|
|
def link_program(self, vertex_shader: dict, fragment_shader: dict) -> dict:
|
|
|
"""Link vertex and fragment shaders into a complete program."""
|
|
|
|
|
|
if vertex_shader['type'] != 'vertex' or fragment_shader['type'] != 'fragment':
|
|
|
raise ShaderError("Invalid shader types for linking")
|
|
|
|
|
|
|
|
|
self._verify_interface_compatibility(vertex_shader, fragment_shader)
|
|
|
|
|
|
|
|
|
program_id = self._generate_program_id(vertex_shader, fragment_shader)
|
|
|
|
|
|
|
|
|
self.db.ensure_connection()
|
|
|
self.db.conn.execute("""
|
|
|
INSERT INTO programs (
|
|
|
id, vertex_shader_id, fragment_shader_id,
|
|
|
uniforms, attributes, varyings, linked
|
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
|
""", (
|
|
|
program_id,
|
|
|
vertex_shader['id'],
|
|
|
fragment_shader['id'],
|
|
|
json.dumps(self._collect_uniforms(vertex_shader, fragment_shader)),
|
|
|
json.dumps(self._collect_attributes(vertex_shader)),
|
|
|
json.dumps(self._collect_varyings(vertex_shader, fragment_shader)),
|
|
|
True
|
|
|
))
|
|
|
|
|
|
|
|
|
result = self.db.conn.execute("""
|
|
|
SELECT p.*,
|
|
|
vs.source as vertex_source,
|
|
|
vs.instructions as vertex_instructions,
|
|
|
fs.source as fragment_source,
|
|
|
fs.instructions as fragment_instructions
|
|
|
FROM programs p
|
|
|
JOIN shaders vs ON p.vertex_shader_id = vs.id
|
|
|
JOIN shaders fs ON p.fragment_shader_id = fs.id
|
|
|
WHERE p.id = ?
|
|
|
""", [program_id]).fetchone()
|
|
|
|
|
|
program = {
|
|
|
"id": result[0],
|
|
|
"vertex_shader": {
|
|
|
"id": result[1],
|
|
|
"source": result[7],
|
|
|
"instructions": json.loads(result[8])
|
|
|
},
|
|
|
"fragment_shader": {
|
|
|
"id": result[2],
|
|
|
"source": result[9],
|
|
|
"instructions": json.loads(result[10])
|
|
|
},
|
|
|
"uniforms": json.loads(result[3]),
|
|
|
"attributes": json.loads(result[4]),
|
|
|
"varyings": json.loads(result[5]),
|
|
|
"linked": result[6]
|
|
|
}
|
|
|
|
|
|
self.db.conn.commit()
|
|
|
return program
|
|
|
|
|
|
def validate_program(self, program: dict) -> bool:
|
|
|
"""Validate a linked program."""
|
|
|
try:
|
|
|
|
|
|
if not all(k in program for k in ['vertex_shader', 'fragment_shader', 'linked']):
|
|
|
return False
|
|
|
|
|
|
|
|
|
for shader in [program['vertex_shader'], program['fragment_shader']]:
|
|
|
if not all(k in shader for k in ['type', 'instructions', 'variables']):
|
|
|
return False
|
|
|
|
|
|
|
|
|
vertex_outputs = {
|
|
|
name for name, var in program['vertex_shader']['variables'].items()
|
|
|
if var['is_output']
|
|
|
}
|
|
|
fragment_inputs = {
|
|
|
name for name, var in program['fragment_shader']['variables'].items()
|
|
|
if var['is_input']
|
|
|
}
|
|
|
|
|
|
if not fragment_inputs.issubset(vertex_outputs):
|
|
|
return False
|
|
|
|
|
|
return True
|
|
|
|
|
|
except Exception:
|
|
|
return False
|
|
|
|
|
|
def _new_temp(self) -> str:
|
|
|
"""Generate a new temporary variable name."""
|
|
|
self.temp_counter += 1
|
|
|
return f"temp_{self.temp_counter}"
|
|
|
|
|
|
def _extract_parenthesized(self, expr: str) -> str:
|
|
|
"""Extract content between outermost parentheses."""
|
|
|
start = expr.index('(')
|
|
|
count = 1
|
|
|
for i, c in enumerate(expr[start + 1:], start + 1):
|
|
|
if c == '(':
|
|
|
count += 1
|
|
|
elif c == ')':
|
|
|
count -= 1
|
|
|
if count == 0:
|
|
|
return expr[start + 1:i]
|
|
|
raise ShaderError("Mismatched parentheses")
|
|
|
|
|
|
def _extract_args(self, expr: str) -> List[str]:
|
|
|
"""Extract function arguments."""
|
|
|
args_str = self._extract_parenthesized(expr)
|
|
|
return [arg.strip() for arg in args_str.split(',')]
|
|
|
|
|
|
def _generate_shader_id(self, source: str) -> str:
|
|
|
"""Generate a unique shader ID."""
|
|
|
return f"shader_{hashlib.md5(source.encode()).hexdigest()[:8]}"
|
|
|
|
|
|
def _generate_program_id(self, vertex_shader: dict, fragment_shader: dict) -> str:
|
|
|
"""Generate a unique program ID."""
|
|
|
combined = vertex_shader['id'] + fragment_shader['id']
|
|
|
return f"program_{hashlib.md5(combined.encode()).hexdigest()[:8]}"
|
|
|
|
|
|
def _compile_normalize(self, expr: str) -> str:
|
|
|
args = self._extract_args(expr)
|
|
|
if len(args) != 1:
|
|
|
raise ShaderError("normalize requires 1 argument")
|
|
|
vec = self._compile_expression(args[0])
|
|
|
result = self._new_temp()
|
|
|
self.instructions.append(Instruction('normalize', [vec], result))
|
|
|
return result
|
|
|
|
|
|
def _compile_dot(self, expr: str) -> str:
|
|
|
args = self._extract_args(expr)
|
|
|
if len(args) != 2:
|
|
|
raise ShaderError("dot requires 2 arguments")
|
|
|
vec1 = self._compile_expression(args[0])
|
|
|
vec2 = self._compile_expression(args[1])
|
|
|
result = self._new_temp()
|
|
|
self.instructions.append(Instruction('dot', [vec1, vec2], result))
|
|
|
return result
|
|
|
|
|
|
def _compile_mix(self, expr: str) -> str:
|
|
|
args = self._extract_args(expr)
|
|
|
if len(args) != 3:
|
|
|
raise ShaderError("mix requires 3 arguments")
|
|
|
x = self._compile_expression(args[0])
|
|
|
y = self._compile_expression(args[1])
|
|
|
a = self._compile_expression(args[2])
|
|
|
result = self._new_temp()
|
|
|
self.instructions.append(Instruction('mix', [x, y, a], result))
|
|
|
return result
|
|
|
|
|
|
def _compile_clamp(self, expr: str) -> str:
|
|
|
args = self._extract_args(expr)
|
|
|
if len(args) != 3:
|
|
|
raise ShaderError("clamp requires 3 arguments")
|
|
|
x = self._compile_expression(args[0])
|
|
|
min_val = self._compile_expression(args[1])
|
|
|
max_val = self._compile_expression(args[2])
|
|
|
result = self._new_temp()
|
|
|
self.instructions.append(Instruction('clamp', [x, min_val, max_val], result))
|
|
|
return result
|
|
|
|
|
|
def _fold_constants(self):
|
|
|
"""Perform constant folding optimization."""
|
|
|
|
|
|
self.db.conn.execute("""
|
|
|
UPDATE instructions
|
|
|
SET is_dead = TRUE
|
|
|
WHERE shader_id = ? AND opcode IN ('add', 'sub', 'mul', 'div')
|
|
|
AND args[0] LIKE '%[0-9]+%' AND args[1] LIKE '%[0-9]+%'
|
|
|
""", [self.current_shader_id])
|
|
|
|
|
|
def _eliminate_dead_code(self):
|
|
|
"""Perform dead code elimination."""
|
|
|
|
|
|
outputs = self.db.conn.execute("""
|
|
|
SELECT var_name FROM registers r
|
|
|
JOIN shaders s ON r.shader_id = s.id
|
|
|
JOIN json_each(s.variables) v ON v.key = r.var_name
|
|
|
WHERE s.id = ? AND json_extract(v.value, '$.is_output') = true
|
|
|
""", [self.current_shader_id]).fetchall()
|
|
|
|
|
|
|
|
|
used_vars = set(r[0] for r in outputs)
|
|
|
while True:
|
|
|
new_used = self.db.conn.execute("""
|
|
|
SELECT DISTINCT a.value::VARCHAR
|
|
|
FROM instructions i,
|
|
|
json_array_elements_text(i.args) a
|
|
|
WHERE i.shader_id = ?
|
|
|
AND i.result IN ?
|
|
|
AND NOT i.is_dead
|
|
|
AND a.value NOT IN ?
|
|
|
""", [self.current_shader_id, tuple(used_vars), tuple(used_vars)]).fetchall()
|
|
|
|
|
|
if not new_used:
|
|
|
break
|
|
|
used_vars.update(r[0] for r in new_used)
|
|
|
|
|
|
|
|
|
self.db.conn.execute("""
|
|
|
UPDATE instructions
|
|
|
SET is_dead = TRUE
|
|
|
WHERE shader_id = ?
|
|
|
AND (result IS NULL OR result NOT IN ?)
|
|
|
""", [self.current_shader_id, tuple(used_vars)])
|
|
|
|
|
|
def _eliminate_common_subexpressions(self):
|
|
|
"""Perform common subexpression elimination."""
|
|
|
|
|
|
self.db.conn.execute("""
|
|
|
WITH expr_groups AS (
|
|
|
SELECT opcode, args, MIN(id) as first_id,
|
|
|
array_agg(id) as duplicate_ids
|
|
|
FROM instructions
|
|
|
WHERE shader_id = ? AND NOT is_dead
|
|
|
GROUP BY opcode, args
|
|
|
HAVING COUNT(*) > 1
|
|
|
)
|
|
|
UPDATE instructions i
|
|
|
SET is_dead = TRUE
|
|
|
WHERE id IN (
|
|
|
SELECT unnest(duplicate_ids[2:])
|
|
|
FROM expr_groups
|
|
|
)
|
|
|
AND shader_id = ?
|
|
|
""", [self.current_shader_id, self.current_shader_id])
|
|
|
|
|
|
|
|
|
duplicates = self.db.conn.execute("""
|
|
|
WITH expr_groups AS (
|
|
|
SELECT opcode, args, result,
|
|
|
MIN(id) as first_id,
|
|
|
array_agg(id) as duplicate_ids,
|
|
|
array_agg(result) as results
|
|
|
FROM instructions
|
|
|
WHERE shader_id = ?
|
|
|
GROUP BY opcode, args
|
|
|
HAVING COUNT(*) > 1
|
|
|
)
|
|
|
SELECT first_id, results
|
|
|
FROM expr_groups
|
|
|
""", [self.current_shader_id]).fetchall()
|
|
|
|
|
|
for first_id, results in duplicates:
|
|
|
results = json.loads(results)
|
|
|
original_result = results[0]
|
|
|
for result in results[1:]:
|
|
|
self.db.conn.execute("""
|
|
|
INSERT INTO instructions (shader_id, opcode, args, result)
|
|
|
VALUES (?, 'mov', ?, ?)
|
|
|
""", [self.current_shader_id, json.dumps([original_result]), result])
|
|
|
|
|
|
def _find_free_register(self, used_registers: set) -> str:
|
|
|
"""Find an unused register."""
|
|
|
i = 0
|
|
|
while f"r{i}" in used_registers:
|
|
|
i += 1
|
|
|
return f"r{i}"
|
|
|
|
|
|
def _verify_interface_compatibility(self, vertex_shader: dict, fragment_shader: dict):
|
|
|
"""Verify that shader interfaces are compatible."""
|
|
|
vertex_outputs = {
|
|
|
name: var for name, var in vertex_shader['variables'].items()
|
|
|
if var['is_output']
|
|
|
}
|
|
|
fragment_inputs = {
|
|
|
name: var for name, var in fragment_shader['variables'].items()
|
|
|
if var['is_input']
|
|
|
}
|
|
|
|
|
|
|
|
|
for name, var in fragment_inputs.items():
|
|
|
if name not in vertex_outputs:
|
|
|
raise ShaderError(f"Fragment shader input '{name}' has no matching vertex output")
|
|
|
if vertex_outputs[name]['type'] != var['type']:
|
|
|
raise ShaderError(f"Type mismatch for varying '{name}'")
|
|
|
|
|
|
def _collect_uniforms(self, vertex_shader: dict, fragment_shader: dict) -> dict:
|
|
|
"""Collect all uniform variables from both shaders."""
|
|
|
uniforms = {}
|
|
|
for shader in [vertex_shader, fragment_shader]:
|
|
|
for name, var in shader['variables'].items():
|
|
|
if 'uniform' in var.get('qualifiers', []):
|
|
|
uniforms[name] = var
|
|
|
return uniforms
|
|
|
|
|
|
def _collect_attributes(self, vertex_shader: dict) -> dict:
|
|
|
"""Collect vertex attributes."""
|
|
|
return {
|
|
|
name: var for name, var in vertex_shader['variables'].items()
|
|
|
if var['is_input']
|
|
|
}
|
|
|
|
|
|
def _collect_varyings(self, vertex_shader: dict, fragment_shader: dict) -> dict:
|
|
|
"""Collect varying variables (vertex outputs / fragment inputs)."""
|
|
|
return {
|
|
|
name: var for name, var in vertex_shader['variables'].items()
|
|
|
if var['is_output'] and name in fragment_shader['variables']
|
|
|
}
|
|
|
|
|
|
|