""" Modern graphics pipeline state manager for Virtual GPU """ from typing import Dict, List, Optional, Union import numpy as np from .pipeline_db import PipelineStateDB from .graphics_types import ( ShaderType, PrimitiveType, GraphicsPipelineState, Viewport, Scissor, RasterizationState, DepthState, StencilState, BlendState, ColorMask, VertexAttribute, ShaderResource ) class GraphicsPipelineManager: def __init__(self, driver, db_path: str = None): self.driver = driver self.current_state = GraphicsPipelineState() self.db = PipelineStateDB(db_path or ":memory:") # Use persistent or in-memory DB def create_pipeline(self, state: GraphicsPipelineState) -> str: """Create a new pipeline with given state""" # Hash pipeline state for caching state_hash = self._hash_pipeline_state(state) # Convert state to dict for DB storage state_dict = { "shaders": {k.value: v for k,v in state.shader_stages.items()}, "vertex_attributes": [attr._asdict() for attr in state.vertex_attributes], "shader_resources": [res._asdict() for res in state.shader_resources], "viewport": state.viewport._asdict() if state.viewport else None, "scissor": state.scissor._asdict() if state.scissor else None, "rasterization": state.rasterization._asdict(), "depth": state.depth._asdict(), "stencil": state.stencil._asdict(), "blend": state.blend._asdict(), "color_mask": state.color_mask._asdict(), "primitive_type": state.primitive_type.value, "patch_control_points": state.patch_control_points } # Store in database self.db.store_pipeline(state_hash, state_dict) return state_hash def bind_pipeline(self, pipeline_hash: str): """Bind pipeline for rendering""" state_dict = self.db.get_pipeline(pipeline_hash) if not state_dict: raise ValueError(f"Invalid pipeline hash: {pipeline_hash}") # Reconstruct GraphicsPipelineState from DB data state = GraphicsPipelineState() state.shader_stages = {ShaderType(k): v for k,v in state_dict['shaders'].items()} state.vertex_attributes = [VertexAttribute(**attr) for attr in state_dict['vertex_attributes']] state.shader_resources = [ShaderResource(**res) for res in state_dict['shader_resources']] state.viewport = Viewport(**state_dict['viewport']) if state_dict['viewport'] else None state.scissor = Scissor(**state_dict['scissor']) if state_dict['scissor'] else None state.rasterization = RasterizationState(**state_dict['rasterization']) state.depth = DepthState(**state_dict['depth']) state.stencil = StencilState(**state_dict['stencil']) state.blend = BlendState(**state_dict['blend']) state.color_mask = ColorMask(**state_dict['color_mask']) state.primitive_type = PrimitiveType(state_dict['primitive_type']) state.patch_control_points = state_dict['patch_control_points'] self.current_state = state def set_viewport(self, viewport: Viewport): """Set viewport state""" self.current_state.viewport = viewport def set_scissor(self, scissor: Scissor): """Set scissor state""" self.current_state.scissor = scissor def set_vertex_attributes(self, attributes: List[VertexAttribute]): """Set vertex input attributes""" self.current_state.vertex_attributes = attributes def set_shader_resources(self, resources: List[ShaderResource]): """Set shader resource bindings""" self.current_state.shader_resources = resources def set_rasterization_state(self, state: RasterizationState): """Set rasterization state""" self.current_state.rasterization = state def set_depth_state(self, state: DepthState): """Set depth state""" self.current_state.depth = state def set_stencil_state(self, state: StencilState): """Set stencil state""" self.current_state.stencil = state def set_blend_state(self, state: BlendState): """Set blend state""" self.current_state.blend = state def set_color_mask(self, mask: ColorMask): """Set color write mask""" self.current_state.color_mask = mask def _hash_pipeline_state(self, state: GraphicsPipelineState) -> str: """Create unique hash for pipeline state""" import hashlib import json # Convert state to JSON-serializable dict state_dict = { "shaders": {k.value: v for k,v in state.shader_stages.items()}, "vertex_attributes": [attr._asdict() for attr in state.vertex_attributes], "shader_resources": [res._asdict() for res in state.shader_resources], "viewport": state.viewport._asdict() if state.viewport else None, "scissor": state.scissor._asdict() if state.scissor else None, "rasterization": state.rasterization._asdict(), "depth": state.depth._asdict(), "stencil": state.stencil._asdict(), "blend": state.blend._asdict(), "color_mask": state.color_mask._asdict(), "primitive_type": state.primitive_type.value, "patch_control_points": state.patch_control_points } # Create hash state_str = json.dumps(state_dict, sort_keys=True) return hashlib.sha256(state_str.encode()).hexdigest() def get_current_state(self) -> GraphicsPipelineState: """Get copy of current pipeline state""" return self.current_state def validate_state(self) -> List[str]: """Validate current pipeline state""" errors = [] # Check required shaders if ShaderType.VERTEX not in self.current_state.shader_stages: errors.append("Missing vertex shader") # Check vertex attributes match shader inputs # TODO: Add validation against shader reflection data # Check primitive type compatibility if (self.current_state.primitive_type == PrimitiveType.PATCHES and ShaderType.TESSELLATION_CONTROL not in self.current_state.shader_stages): errors.append("Tessellation shaders required for patch primitives") return errors