Spaces:
Running
Running
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Callable, Iterable | |
| from .enums import DataPortState, NodeKind, NodeState | |
| from .ports import PortSchema, PortState | |
| class NodeType: | |
| """Static node type metadata.""" | |
| kind: NodeKind | |
| display_name: str | |
| inputs: list[PortSchema] | |
| outputs: list[PortSchema] | |
| class NodeInstance: | |
| def __init__( | |
| self, | |
| node_id: str, | |
| node_type: NodeType, | |
| auto_process: bool = False, | |
| x: float = 0.0, | |
| y: float = 0.0, | |
| width: float = 250, | |
| height: float = 200, | |
| inputs: dict[str, PortState] | None = None, | |
| outputs: dict[str, PortState] | None = None, | |
| on_process: Callable[[NodeInstance], None] | None = None, | |
| state: NodeState = NodeState.IDLE, | |
| ) -> None: | |
| self.node_id = node_id | |
| self.node_type = node_type | |
| self.auto_process = auto_process | |
| self.x = x | |
| self.y = y | |
| self.width = width | |
| self.height = height | |
| self.state = state | |
| # if not provided, build from the node_type schemas | |
| if inputs is None: | |
| self.inputs = {p.name: PortState(p) for p in self.node_type.inputs} | |
| else: | |
| self.inputs = inputs | |
| if outputs is None: | |
| self.outputs = {p.name: PortState(p) for p in self.node_type.outputs} | |
| else: | |
| self.outputs = outputs | |
| self.on_process = on_process | |
| def __post_init__(self) -> None: | |
| if not self.inputs: | |
| self.inputs = {p.name: PortState(p) for p in self.node_type.inputs} | |
| if not self.outputs: | |
| self.outputs = {p.name: PortState(p) for p in self.node_type.outputs} | |
| def all_inputs(self) -> Iterable[PortState]: | |
| return self.inputs.values() | |
| def all_outputs(self) -> Iterable[PortState]: | |
| return self.outputs.values() | |
| def mark_dirty(self) -> None: | |
| self.state = NodeState.IDLE | |
| for p in self.all_outputs(): | |
| p.state = DataPortState.DIRTY | |
| async def process(self) -> None: | |
| if self.on_process: | |
| import inspect | |
| if inspect.iscoroutinefunction(self.on_process): | |
| await self.on_process(self) | |
| else: | |
| self.on_process(self) | |
| else: | |
| pass | |
| def reset_node(self) -> None: | |
| pass | |
| def reset_outputs(self) -> None: | |
| self.state = NodeState.IDLE | |
| # clear output node values | |
| for name, out in self.outputs.items(): | |
| out.value = None | |
| out.state = DataPortState.DIRTY | |
| def has_dirty_outputs(self, connected_only: bool = False) -> bool: | |
| for port in self.all_outputs(): | |
| # only look at connected ports | |
| if connected_only and not port.is_connected: | |
| continue | |
| if port.state != DataPortState.CLEAN: | |
| return True | |
| return False | |