velai / dataflow /registry.py
cansik's picture
Upload folder via script
d868fac verified
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Callable
from .enums import NodeKind
from .types import DataType
from .nodes_base import NodeType, NodeInstance
@dataclass(slots=True)
class Registry:
datatypes: dict[DataType, DataType] = field(default_factory=dict)
node_types: dict[NodeKind, NodeType] = field(default_factory=dict)
factories: dict[NodeKind, Callable[[str], NodeInstance]] = field(default_factory=dict)
def register_datatype(self, dt: DataType) -> None:
self.datatypes[dt] = dt
def register_node_type(self, t: NodeType, factory: Callable[[str], NodeInstance]) -> None:
self.node_types[t.kind] = t
self.factories[t.kind] = factory
def create(self, type_ref: NodeType | NodeKind | str, node_id: str) -> NodeInstance:
"""Create node by NodeType, NodeKind, or name. Prefer NodeType or NodeKind."""
if isinstance(type_ref, NodeType):
kind = type_ref.kind
elif isinstance(type_ref, NodeKind):
kind = type_ref
elif isinstance(type_ref, str):
# last resort lookup by display or kind value
match = next((k for k, t in self.node_types.items()
if t.display_name == type_ref or k.value == type_ref), None)
if match is None:
raise KeyError(f"Unknown node type name {type_ref}")
kind = match
else:
raise TypeError("Unsupported type_ref")
return self.factories[kind](node_id)