kratadata's picture
Upload folder via script
0f8b3a0 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Callable, Iterable
from .enums import DataPortState, NodeKind, NodeState
from .ports import PortSchema, PortState
@dataclass(slots=True)
class NodeType:
"""Static node type metadata."""
kind: NodeKind
display_name: str
inputs: list[PortSchema]
outputs: list[PortSchema]
class NodeInstance:
def __init__(
self,
node_id: str,
node_type: NodeType,
auto_process: bool = False,
x: float = 0.0,
y: float = 0.0,
width: float = 250,
height: float = 200,
inputs: dict[str, PortState] | None = None,
outputs: dict[str, PortState] | None = None,
on_process: Callable[[NodeInstance], None] | None = None,
state: NodeState = NodeState.IDLE,
) -> None:
self.node_id = node_id
self.node_type = node_type
self.auto_process = auto_process
self.x = x
self.y = y
self.width = width
self.height = height
self.state = state
# if not provided, build from the node_type schemas
if inputs is None:
self.inputs = {p.name: PortState(p) for p in self.node_type.inputs}
else:
self.inputs = inputs
if outputs is None:
self.outputs = {p.name: PortState(p) for p in self.node_type.outputs}
else:
self.outputs = outputs
self.on_process = on_process
def __post_init__(self) -> None:
if not self.inputs:
self.inputs = {p.name: PortState(p) for p in self.node_type.inputs}
if not self.outputs:
self.outputs = {p.name: PortState(p) for p in self.node_type.outputs}
def all_inputs(self) -> Iterable[PortState]:
return self.inputs.values()
def all_outputs(self) -> Iterable[PortState]:
return self.outputs.values()
def mark_dirty(self) -> None:
self.state = NodeState.IDLE
for p in self.all_outputs():
p.state = DataPortState.DIRTY
async def process(self) -> None:
if self.on_process:
import inspect
if inspect.iscoroutinefunction(self.on_process):
await self.on_process(self)
else:
self.on_process(self)
else:
pass
def reset_node(self) -> None:
pass
def reset_outputs(self) -> None:
self.state = NodeState.IDLE
# clear output node values
for name, out in self.outputs.items():
out.value = None
out.state = DataPortState.DIRTY
def has_dirty_outputs(self, connected_only: bool = False) -> bool:
for port in self.all_outputs():
# only look at connected ports
if connected_only and not port.is_connected:
continue
if port.state != DataPortState.CLEAN:
return True
return False