|
|
"""
|
|
|
Modern graphics pipeline state manager for Virtual GPU
|
|
|
"""
|
|
|
from typing import Dict, List, Optional, Union
|
|
|
import numpy as np
|
|
|
|
|
|
from .pipeline_db import PipelineStateDB
|
|
|
from .graphics_types import (
|
|
|
ShaderType, PrimitiveType, GraphicsPipelineState,
|
|
|
Viewport, Scissor, RasterizationState, DepthState,
|
|
|
StencilState, BlendState, ColorMask, VertexAttribute,
|
|
|
ShaderResource
|
|
|
)
|
|
|
|
|
|
class GraphicsPipelineManager:
|
|
|
def __init__(self, driver, db_path: str = None):
|
|
|
self.driver = driver
|
|
|
self.current_state = GraphicsPipelineState()
|
|
|
self.db = PipelineStateDB(db_path or ":memory:")
|
|
|
|
|
|
def create_pipeline(self, state: GraphicsPipelineState) -> str:
|
|
|
"""Create a new pipeline with given state"""
|
|
|
|
|
|
state_hash = self._hash_pipeline_state(state)
|
|
|
|
|
|
|
|
|
state_dict = {
|
|
|
"shaders": {k.value: v for k,v in state.shader_stages.items()},
|
|
|
"vertex_attributes": [attr._asdict() for attr in state.vertex_attributes],
|
|
|
"shader_resources": [res._asdict() for res in state.shader_resources],
|
|
|
"viewport": state.viewport._asdict() if state.viewport else None,
|
|
|
"scissor": state.scissor._asdict() if state.scissor else None,
|
|
|
"rasterization": state.rasterization._asdict(),
|
|
|
"depth": state.depth._asdict(),
|
|
|
"stencil": state.stencil._asdict(),
|
|
|
"blend": state.blend._asdict(),
|
|
|
"color_mask": state.color_mask._asdict(),
|
|
|
"primitive_type": state.primitive_type.value,
|
|
|
"patch_control_points": state.patch_control_points
|
|
|
}
|
|
|
|
|
|
|
|
|
self.db.store_pipeline(state_hash, state_dict)
|
|
|
return state_hash
|
|
|
|
|
|
def bind_pipeline(self, pipeline_hash: str):
|
|
|
"""Bind pipeline for rendering"""
|
|
|
state_dict = self.db.get_pipeline(pipeline_hash)
|
|
|
if not state_dict:
|
|
|
raise ValueError(f"Invalid pipeline hash: {pipeline_hash}")
|
|
|
|
|
|
|
|
|
state = GraphicsPipelineState()
|
|
|
state.shader_stages = {ShaderType(k): v for k,v in state_dict['shaders'].items()}
|
|
|
state.vertex_attributes = [VertexAttribute(**attr) for attr in state_dict['vertex_attributes']]
|
|
|
state.shader_resources = [ShaderResource(**res) for res in state_dict['shader_resources']]
|
|
|
state.viewport = Viewport(**state_dict['viewport']) if state_dict['viewport'] else None
|
|
|
state.scissor = Scissor(**state_dict['scissor']) if state_dict['scissor'] else None
|
|
|
state.rasterization = RasterizationState(**state_dict['rasterization'])
|
|
|
state.depth = DepthState(**state_dict['depth'])
|
|
|
state.stencil = StencilState(**state_dict['stencil'])
|
|
|
state.blend = BlendState(**state_dict['blend'])
|
|
|
state.color_mask = ColorMask(**state_dict['color_mask'])
|
|
|
state.primitive_type = PrimitiveType(state_dict['primitive_type'])
|
|
|
state.patch_control_points = state_dict['patch_control_points']
|
|
|
|
|
|
self.current_state = state
|
|
|
|
|
|
def set_viewport(self, viewport: Viewport):
|
|
|
"""Set viewport state"""
|
|
|
self.current_state.viewport = viewport
|
|
|
|
|
|
def set_scissor(self, scissor: Scissor):
|
|
|
"""Set scissor state"""
|
|
|
self.current_state.scissor = scissor
|
|
|
|
|
|
def set_vertex_attributes(self, attributes: List[VertexAttribute]):
|
|
|
"""Set vertex input attributes"""
|
|
|
self.current_state.vertex_attributes = attributes
|
|
|
|
|
|
def set_shader_resources(self, resources: List[ShaderResource]):
|
|
|
"""Set shader resource bindings"""
|
|
|
self.current_state.shader_resources = resources
|
|
|
|
|
|
def set_rasterization_state(self, state: RasterizationState):
|
|
|
"""Set rasterization state"""
|
|
|
self.current_state.rasterization = state
|
|
|
|
|
|
def set_depth_state(self, state: DepthState):
|
|
|
"""Set depth state"""
|
|
|
self.current_state.depth = state
|
|
|
|
|
|
def set_stencil_state(self, state: StencilState):
|
|
|
"""Set stencil state"""
|
|
|
self.current_state.stencil = state
|
|
|
|
|
|
def set_blend_state(self, state: BlendState):
|
|
|
"""Set blend state"""
|
|
|
self.current_state.blend = state
|
|
|
|
|
|
def set_color_mask(self, mask: ColorMask):
|
|
|
"""Set color write mask"""
|
|
|
self.current_state.color_mask = mask
|
|
|
|
|
|
def _hash_pipeline_state(self, state: GraphicsPipelineState) -> str:
|
|
|
"""Create unique hash for pipeline state"""
|
|
|
import hashlib
|
|
|
import json
|
|
|
|
|
|
|
|
|
state_dict = {
|
|
|
"shaders": {k.value: v for k,v in state.shader_stages.items()},
|
|
|
"vertex_attributes": [attr._asdict() for attr in state.vertex_attributes],
|
|
|
"shader_resources": [res._asdict() for res in state.shader_resources],
|
|
|
"viewport": state.viewport._asdict() if state.viewport else None,
|
|
|
"scissor": state.scissor._asdict() if state.scissor else None,
|
|
|
"rasterization": state.rasterization._asdict(),
|
|
|
"depth": state.depth._asdict(),
|
|
|
"stencil": state.stencil._asdict(),
|
|
|
"blend": state.blend._asdict(),
|
|
|
"color_mask": state.color_mask._asdict(),
|
|
|
"primitive_type": state.primitive_type.value,
|
|
|
"patch_control_points": state.patch_control_points
|
|
|
}
|
|
|
|
|
|
|
|
|
state_str = json.dumps(state_dict, sort_keys=True)
|
|
|
return hashlib.sha256(state_str.encode()).hexdigest()
|
|
|
|
|
|
def get_current_state(self) -> GraphicsPipelineState:
|
|
|
"""Get copy of current pipeline state"""
|
|
|
return self.current_state
|
|
|
|
|
|
def validate_state(self) -> List[str]:
|
|
|
"""Validate current pipeline state"""
|
|
|
errors = []
|
|
|
|
|
|
|
|
|
if ShaderType.VERTEX not in self.current_state.shader_stages:
|
|
|
errors.append("Missing vertex shader")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (self.current_state.primitive_type == PrimitiveType.PATCHES and
|
|
|
ShaderType.TESSELLATION_CONTROL not in self.current_state.shader_stages):
|
|
|
errors.append("Tessellation shaders required for patch primitives")
|
|
|
|
|
|
return errors
|
|
|
|