| use crate::node_registry; |
| use dyn_any::StaticType; |
| use graph_craft::Type; |
| use graph_craft::document::NodeId; |
| use graph_craft::document::value::{TaggedValue, UpcastAsRefNode, UpcastNode}; |
| use graph_craft::graphene_compiler::Executor; |
| use graph_craft::proto::{ConstructionArgs, GraphError, LocalFuture, NodeContainer, ProtoNetwork, ProtoNode, SharedNodeContainer, TypeErasedBox, TypingContext}; |
| use graph_craft::proto::{GraphErrorType, GraphErrors}; |
| use std::collections::{HashMap, HashSet}; |
| use std::error::Error; |
| use std::sync::Arc; |
|
|
| |
| #[derive(Clone)] |
| pub struct DynamicExecutor { |
| output: NodeId, |
| |
| tree: BorrowTree, |
| |
| typing_context: TypingContext, |
| |
| orphaned_nodes: HashSet<NodeId>, |
| } |
|
|
| impl Default for DynamicExecutor { |
| fn default() -> Self { |
| Self { |
| output: Default::default(), |
| tree: Default::default(), |
| typing_context: TypingContext::new(&node_registry::NODE_REGISTRY), |
| orphaned_nodes: HashSet::new(), |
| } |
| } |
| } |
|
|
| #[derive(PartialEq, Clone, Debug, Default, serde::Serialize, serde::Deserialize)] |
| pub struct NodeTypes { |
| pub inputs: Vec<Type>, |
| pub output: Type, |
| } |
|
|
| #[derive(PartialEq, Clone, Debug, Default, serde::Serialize, serde::Deserialize)] |
| pub struct ResolvedDocumentNodeTypes { |
| pub types: HashMap<Vec<NodeId>, NodeTypes>, |
| } |
|
|
| type Path = Box<[NodeId]>; |
|
|
| #[derive(PartialEq, Clone, Debug, Default, serde::Serialize, serde::Deserialize)] |
| pub struct ResolvedDocumentNodeTypesDelta { |
| pub add: Vec<(Path, NodeTypes)>, |
| pub remove: Vec<Path>, |
| } |
|
|
| impl DynamicExecutor { |
| pub async fn new(proto_network: ProtoNetwork) -> Result<Self, GraphErrors> { |
| let mut typing_context = TypingContext::new(&node_registry::NODE_REGISTRY); |
| typing_context.update(&proto_network)?; |
| let output = proto_network.output; |
| let tree = BorrowTree::new(proto_network, &typing_context).await?; |
|
|
| Ok(Self { |
| tree, |
| output, |
| typing_context, |
| orphaned_nodes: HashSet::new(), |
| }) |
| } |
|
|
| |
| #[cfg_attr(debug_assertions, inline(never))] |
| pub async fn update(&mut self, proto_network: ProtoNetwork) -> Result<ResolvedDocumentNodeTypesDelta, GraphErrors> { |
| self.output = proto_network.output; |
| self.typing_context.update(&proto_network)?; |
| let (add, orphaned) = self.tree.update(proto_network, &self.typing_context).await?; |
| let old_to_remove = core::mem::replace(&mut self.orphaned_nodes, orphaned); |
| let mut remove = Vec::with_capacity(old_to_remove.len() - self.orphaned_nodes.len().min(old_to_remove.len())); |
| for node_id in old_to_remove { |
| if self.orphaned_nodes.contains(&node_id) { |
| let path = self.tree.free_node(node_id); |
| self.typing_context.remove_inference(node_id); |
| if let Some(path) = path { |
| remove.push(path); |
| } |
| } |
| } |
| let add = self.document_node_types(add.into_iter()).collect(); |
| Ok(ResolvedDocumentNodeTypesDelta { add, remove }) |
| } |
|
|
| |
| pub fn introspect(&self, node_path: &[NodeId]) -> Result<Arc<dyn std::any::Any + Send + Sync + 'static>, IntrospectError> { |
| self.tree.introspect(node_path) |
| } |
|
|
| pub fn input_type(&self) -> Option<Type> { |
| self.typing_context.type_of(self.output).map(|node_io| node_io.call_argument.clone()) |
| } |
|
|
| pub fn tree(&self) -> &BorrowTree { |
| &self.tree |
| } |
|
|
| pub fn output(&self) -> NodeId { |
| self.output |
| } |
|
|
| pub fn output_type(&self) -> Option<Type> { |
| self.typing_context.type_of(self.output).map(|node_io| node_io.return_value.clone()) |
| } |
|
|
| pub fn document_node_types<'a>(&'a self, nodes: impl Iterator<Item = Path> + 'a) -> impl Iterator<Item = (Path, NodeTypes)> + 'a { |
| nodes.flat_map(|id| self.tree.source_map().get(&id).map(|(_, b)| (id, b.clone()))) |
| |
| |
| } |
| } |
|
|
| impl<I> Executor<I, TaggedValue> for &DynamicExecutor |
| where |
| I: StaticType + 'static + Send + Sync + std::panic::UnwindSafe, |
| { |
| fn execute(&self, input: I) -> LocalFuture<'_, Result<TaggedValue, Box<dyn Error>>> { |
| Box::pin(async move { |
| use futures::FutureExt; |
|
|
| let result = self.tree.eval_tagged_value(self.output, input); |
| let wrapped_result = std::panic::AssertUnwindSafe(result).catch_unwind().await; |
|
|
| match wrapped_result { |
| Ok(result) => result.map_err(|e| e.into()), |
| Err(e) => { |
| Box::leak(e); |
| Err("Node graph execution panicked".into()) |
| } |
| } |
| }) |
| } |
| } |
| pub struct InputMapping {} |
|
|
| #[derive(Debug, Clone, PartialEq, Eq, Hash)] |
| pub enum IntrospectError { |
| PathNotFound(Vec<NodeId>), |
| ProtoNodeNotFound(NodeId), |
| NoData, |
| RuntimeNotReady, |
| } |
|
|
| impl std::fmt::Display for IntrospectError { |
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| match self { |
| IntrospectError::PathNotFound(path) => write!(f, "Path not found: {:?}", path), |
| IntrospectError::ProtoNodeNotFound(id) => write!(f, "ProtoNode not found: {:?}", id), |
| IntrospectError::NoData => write!(f, "No data found for this node"), |
| IntrospectError::RuntimeNotReady => write!(f, "Node runtime is not ready"), |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| #[derive(Default, Clone)] |
| pub struct BorrowTree { |
| |
| nodes: HashMap<NodeId, (SharedNodeContainer, Path)>, |
| |
| source_map: HashMap<Path, (NodeId, NodeTypes)>, |
| } |
|
|
| impl BorrowTree { |
| pub async fn new(proto_network: ProtoNetwork, typing_context: &TypingContext) -> Result<BorrowTree, GraphErrors> { |
| let mut nodes = BorrowTree::default(); |
| for (id, node) in proto_network.nodes { |
| nodes.push_node(id, node, typing_context).await? |
| } |
| Ok(nodes) |
| } |
|
|
| |
| pub async fn update(&mut self, proto_network: ProtoNetwork, typing_context: &TypingContext) -> Result<(Vec<Path>, HashSet<NodeId>), GraphErrors> { |
| let mut old_nodes: HashSet<_> = self.nodes.keys().copied().collect(); |
| let mut new_nodes: Vec<_> = Vec::new(); |
| |
| for (id, node) in proto_network.nodes { |
| if !self.nodes.contains_key(&id) { |
| new_nodes.push(node.original_location.path.clone().unwrap_or_default().into()); |
| self.push_node(id, node, typing_context).await?; |
| } else if self.update_source_map(id, typing_context, &node) { |
| new_nodes.push(node.original_location.path.clone().unwrap_or_default().into()); |
| } |
| old_nodes.remove(&id); |
| } |
| Ok((new_nodes, old_nodes)) |
| } |
|
|
| fn node_deps(&self, nodes: &[NodeId]) -> Vec<SharedNodeContainer> { |
| nodes.iter().map(|node| self.nodes.get(node).unwrap().0.clone()).collect() |
| } |
|
|
| fn store_node(&mut self, node: SharedNodeContainer, id: NodeId, path: Path) { |
| self.nodes.insert(id, (node, path)); |
| } |
|
|
| |
| pub fn introspect(&self, node_path: &[NodeId]) -> Result<Arc<dyn std::any::Any + Send + Sync + 'static>, IntrospectError> { |
| let (id, _) = self.source_map.get(node_path).ok_or_else(|| IntrospectError::PathNotFound(node_path.to_vec()))?; |
| let (node, _path) = self.nodes.get(id).ok_or(IntrospectError::ProtoNodeNotFound(*id))?; |
| node.serialize().ok_or(IntrospectError::NoData) |
| } |
|
|
| pub fn get(&self, id: NodeId) -> Option<SharedNodeContainer> { |
| self.nodes.get(&id).map(|(node, _)| node.clone()) |
| } |
|
|
| |
| pub async fn eval<'i, I, O>(&'i self, id: NodeId, input: I) -> Option<O> |
| where |
| I: StaticType + 'i + Send + Sync, |
| O: StaticType + 'i, |
| { |
| let (node, _path) = self.nodes.get(&id).cloned()?; |
| let output = node.eval(Box::new(input)); |
| dyn_any::downcast::<O>(output.await).ok().map(|o| *o) |
| } |
| |
| |
| pub async fn eval_tagged_value<I>(&self, id: NodeId, input: I) -> Result<TaggedValue, String> |
| where |
| I: StaticType + 'static + Send + Sync, |
| { |
| let (node, _path) = self.nodes.get(&id).cloned().ok_or("Output node not found in executor")?; |
| let output = node.eval(Box::new(input)); |
| TaggedValue::try_from_any(output.await) |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pub fn free_node(&mut self, id: NodeId) -> Option<Path> { |
| let (_, path) = self.nodes.remove(&id)?; |
| if self.source_map.get(&path)?.0 == id { |
| self.source_map.remove(&path); |
| return Some(path); |
| } |
| None |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| fn update_source_map(&mut self, id: NodeId, typing_context: &TypingContext, proto_node: &ProtoNode) -> bool { |
| let Some(node_io) = typing_context.type_of(id) else { |
| log::warn!("did not find type"); |
| return false; |
| }; |
| let inputs = [&node_io.call_argument].into_iter().chain(&node_io.inputs).cloned().collect(); |
|
|
| let node_path = &proto_node.original_location.path.as_ref().unwrap_or(const { &vec![] }); |
|
|
| let entry = self.source_map.entry(node_path.to_vec().into()).or_default(); |
|
|
| let update = ( |
| id, |
| NodeTypes { |
| inputs, |
| output: node_io.return_value.clone(), |
| }, |
| ); |
| let modified = *entry != update; |
| *entry = update; |
| modified |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| async fn push_node(&mut self, id: NodeId, proto_node: ProtoNode, typing_context: &TypingContext) -> Result<(), GraphErrors> { |
| self.update_source_map(id, typing_context, &proto_node); |
| let path = proto_node.original_location.path.clone().unwrap_or_default(); |
|
|
| match &proto_node.construction_args { |
| ConstructionArgs::Value(value) => { |
| let node = if let TaggedValue::EditorApi(api) = &**value { |
| let editor_api = UpcastAsRefNode::new(api.clone()); |
| let node = Box::new(editor_api) as TypeErasedBox<'_>; |
| NodeContainer::new(node) |
| } else { |
| let upcasted = UpcastNode::new(value.to_owned()); |
| let node = Box::new(upcasted) as TypeErasedBox<'_>; |
| NodeContainer::new(node) |
| }; |
| self.store_node(node, id, path.into()); |
| } |
| ConstructionArgs::Inline(_) => unimplemented!("Inline nodes are not supported yet"), |
| ConstructionArgs::Nodes(ids) => { |
| let ids: Vec<_> = ids.iter().map(|(id, _)| *id).collect(); |
| let construction_nodes = self.node_deps(&ids); |
| let constructor = typing_context.constructor(id).ok_or_else(|| vec![GraphError::new(&proto_node, GraphErrorType::NoConstructor)])?; |
| let node = constructor(construction_nodes).await; |
| let node = NodeContainer::new(node); |
| self.store_node(node, id, path.into()); |
| } |
| }; |
| Ok(()) |
| } |
|
|
| |
| pub fn source_map(&self) -> &HashMap<Path, (NodeId, NodeTypes)> { |
| &self.source_map |
| } |
| } |
|
|
| #[cfg(test)] |
| mod test { |
| use super::*; |
| use graph_craft::document::value::TaggedValue; |
|
|
| #[test] |
| fn push_node_sync() { |
| let mut tree = BorrowTree::default(); |
| let val_1_protonode = ProtoNode::value(ConstructionArgs::Value(TaggedValue::U32(2u32).into()), vec![]); |
| let context = TypingContext::default(); |
| let future = tree.push_node(NodeId(0), val_1_protonode, &context); |
| futures::executor::block_on(future).unwrap(); |
| let _node = tree.get(NodeId(0)).unwrap(); |
| let result = futures::executor::block_on(tree.eval(NodeId(0), ())); |
| assert_eq!(result, Some(2u32)); |
| } |
| } |
|
|