INV / virtual_gpu_driver /src /graphics /pipeline_manager.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Modern graphics pipeline state manager for Virtual GPU
"""
from typing import Dict, List, Optional, Union
import numpy as np
from .pipeline_db import PipelineStateDB
from .graphics_types import (
ShaderType, PrimitiveType, GraphicsPipelineState,
Viewport, Scissor, RasterizationState, DepthState,
StencilState, BlendState, ColorMask, VertexAttribute,
ShaderResource
)
class GraphicsPipelineManager:
def __init__(self, driver, db_path: str = None):
self.driver = driver
self.current_state = GraphicsPipelineState()
self.db = PipelineStateDB(db_path or ":memory:") # Use persistent or in-memory DB
def create_pipeline(self, state: GraphicsPipelineState) -> str:
"""Create a new pipeline with given state"""
# Hash pipeline state for caching
state_hash = self._hash_pipeline_state(state)
# Convert state to dict for DB storage
state_dict = {
"shaders": {k.value: v for k,v in state.shader_stages.items()},
"vertex_attributes": [attr._asdict() for attr in state.vertex_attributes],
"shader_resources": [res._asdict() for res in state.shader_resources],
"viewport": state.viewport._asdict() if state.viewport else None,
"scissor": state.scissor._asdict() if state.scissor else None,
"rasterization": state.rasterization._asdict(),
"depth": state.depth._asdict(),
"stencil": state.stencil._asdict(),
"blend": state.blend._asdict(),
"color_mask": state.color_mask._asdict(),
"primitive_type": state.primitive_type.value,
"patch_control_points": state.patch_control_points
}
# Store in database
self.db.store_pipeline(state_hash, state_dict)
return state_hash
def bind_pipeline(self, pipeline_hash: str):
"""Bind pipeline for rendering"""
state_dict = self.db.get_pipeline(pipeline_hash)
if not state_dict:
raise ValueError(f"Invalid pipeline hash: {pipeline_hash}")
# Reconstruct GraphicsPipelineState from DB data
state = GraphicsPipelineState()
state.shader_stages = {ShaderType(k): v for k,v in state_dict['shaders'].items()}
state.vertex_attributes = [VertexAttribute(**attr) for attr in state_dict['vertex_attributes']]
state.shader_resources = [ShaderResource(**res) for res in state_dict['shader_resources']]
state.viewport = Viewport(**state_dict['viewport']) if state_dict['viewport'] else None
state.scissor = Scissor(**state_dict['scissor']) if state_dict['scissor'] else None
state.rasterization = RasterizationState(**state_dict['rasterization'])
state.depth = DepthState(**state_dict['depth'])
state.stencil = StencilState(**state_dict['stencil'])
state.blend = BlendState(**state_dict['blend'])
state.color_mask = ColorMask(**state_dict['color_mask'])
state.primitive_type = PrimitiveType(state_dict['primitive_type'])
state.patch_control_points = state_dict['patch_control_points']
self.current_state = state
def set_viewport(self, viewport: Viewport):
"""Set viewport state"""
self.current_state.viewport = viewport
def set_scissor(self, scissor: Scissor):
"""Set scissor state"""
self.current_state.scissor = scissor
def set_vertex_attributes(self, attributes: List[VertexAttribute]):
"""Set vertex input attributes"""
self.current_state.vertex_attributes = attributes
def set_shader_resources(self, resources: List[ShaderResource]):
"""Set shader resource bindings"""
self.current_state.shader_resources = resources
def set_rasterization_state(self, state: RasterizationState):
"""Set rasterization state"""
self.current_state.rasterization = state
def set_depth_state(self, state: DepthState):
"""Set depth state"""
self.current_state.depth = state
def set_stencil_state(self, state: StencilState):
"""Set stencil state"""
self.current_state.stencil = state
def set_blend_state(self, state: BlendState):
"""Set blend state"""
self.current_state.blend = state
def set_color_mask(self, mask: ColorMask):
"""Set color write mask"""
self.current_state.color_mask = mask
def _hash_pipeline_state(self, state: GraphicsPipelineState) -> str:
"""Create unique hash for pipeline state"""
import hashlib
import json
# Convert state to JSON-serializable dict
state_dict = {
"shaders": {k.value: v for k,v in state.shader_stages.items()},
"vertex_attributes": [attr._asdict() for attr in state.vertex_attributes],
"shader_resources": [res._asdict() for res in state.shader_resources],
"viewport": state.viewport._asdict() if state.viewport else None,
"scissor": state.scissor._asdict() if state.scissor else None,
"rasterization": state.rasterization._asdict(),
"depth": state.depth._asdict(),
"stencil": state.stencil._asdict(),
"blend": state.blend._asdict(),
"color_mask": state.color_mask._asdict(),
"primitive_type": state.primitive_type.value,
"patch_control_points": state.patch_control_points
}
# Create hash
state_str = json.dumps(state_dict, sort_keys=True)
return hashlib.sha256(state_str.encode()).hexdigest()
def get_current_state(self) -> GraphicsPipelineState:
"""Get copy of current pipeline state"""
return self.current_state
def validate_state(self) -> List[str]:
"""Validate current pipeline state"""
errors = []
# Check required shaders
if ShaderType.VERTEX not in self.current_state.shader_stages:
errors.append("Missing vertex shader")
# Check vertex attributes match shader inputs
# TODO: Add validation against shader reflection data
# Check primitive type compatibility
if (self.current_state.primitive_type == PrimitiveType.PATCHES and
ShaderType.TESSELLATION_CONTROL not in self.current_state.shader_stages):
errors.append("Tessellation shaders required for patch primitives")
return errors