| use crate::document::{InlineRust, value}; |
| use crate::document::{NodeId, OriginalLocation}; |
| pub use graphene_core::registry::*; |
| use graphene_core::*; |
| use rustc_hash::FxHashMap; |
| use std::borrow::Cow; |
| use std::collections::{HashMap, HashSet}; |
| use std::fmt::Debug; |
| use std::hash::Hash; |
|
|
| #[derive(Debug, Default, PartialEq, Clone, Hash, Eq, serde::Serialize, serde::Deserialize)] |
| |
| pub struct ProtoNetwork { |
| |
| |
| pub inputs: Vec<NodeId>, |
| |
| pub output: NodeId, |
| |
| pub nodes: Vec<(NodeId, ProtoNode)>, |
| } |
|
|
| impl core::fmt::Display for ProtoNetwork { |
| fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { |
| f.write_str("Proto Network with nodes: ")?; |
| fn write_node(f: &mut core::fmt::Formatter<'_>, network: &ProtoNetwork, id: NodeId, indent: usize) -> core::fmt::Result { |
| f.write_str(&"\t".repeat(indent))?; |
| let Some((_, node)) = network.nodes.iter().find(|(node_id, _)| *node_id == id) else { |
| return f.write_str("{{Unknown Node}}"); |
| }; |
| f.write_str("Node: ")?; |
| f.write_str(&node.identifier.name)?; |
|
|
| f.write_str("\n")?; |
| f.write_str(&"\t".repeat(indent))?; |
| f.write_str("{\n")?; |
|
|
| f.write_str(&"\t".repeat(indent + 1))?; |
| f.write_str("Input: ")?; |
| match &node.input { |
| ProtoNodeInput::None => f.write_str("None")?, |
| ProtoNodeInput::ManualComposition(ty) => f.write_fmt(format_args!("Manual Composition (type = {ty:?})"))?, |
| ProtoNodeInput::Node(_) => f.write_str("Node")?, |
| ProtoNodeInput::NodeLambda(_) => f.write_str("Lambda Node")?, |
| } |
| f.write_str("\n")?; |
|
|
| match &node.construction_args { |
| ConstructionArgs::Value(value) => { |
| f.write_str(&"\t".repeat(indent + 1))?; |
| f.write_fmt(format_args!("Value construction argument: {value:?}"))? |
| } |
| ConstructionArgs::Nodes(nodes) => { |
| for id in nodes { |
| write_node(f, network, id.0, indent + 1)?; |
| } |
| } |
| ConstructionArgs::Inline(inline) => { |
| f.write_str(&"\t".repeat(indent + 1))?; |
| f.write_fmt(format_args!("Inline construction argument: {inline:?}"))? |
| } |
| } |
| f.write_str(&"\t".repeat(indent))?; |
| f.write_str("}\n")?; |
| Ok(()) |
| } |
|
|
| let id = self.output; |
| write_node(f, self, id, 0) |
| } |
| } |
|
|
| #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] |
| |
| pub enum ConstructionArgs { |
| |
| Value(MemoHash<value::TaggedValue>), |
| |
| |
| |
| Nodes(Vec<(NodeId, bool)>), |
| |
| Inline(InlineRust), |
| } |
|
|
| impl Eq for ConstructionArgs {} |
|
|
| impl PartialEq for ConstructionArgs { |
| fn eq(&self, other: &Self) -> bool { |
| match (&self, &other) { |
| (Self::Nodes(n1), Self::Nodes(n2)) => n1 == n2, |
| (Self::Value(v1), Self::Value(v2)) => v1 == v2, |
| _ => { |
| use std::hash::Hasher; |
| let hash = |input: &Self| { |
| let mut hasher = rustc_hash::FxHasher::default(); |
| input.hash(&mut hasher); |
| hasher.finish() |
| }; |
| hash(self) == hash(other) |
| } |
| } |
| } |
| } |
|
|
| impl Hash for ConstructionArgs { |
| fn hash<H: std::hash::Hasher>(&self, state: &mut H) { |
| core::mem::discriminant(self).hash(state); |
| match self { |
| Self::Nodes(nodes) => { |
| for node in nodes { |
| node.hash(state); |
| } |
| } |
| Self::Value(value) => value.hash(state), |
| Self::Inline(inline) => inline.hash(state), |
| } |
| } |
| } |
|
|
| impl ConstructionArgs { |
| |
| pub fn new_function_args(&self) -> Vec<String> { |
| match self { |
| ConstructionArgs::Nodes(nodes) => nodes.iter().map(|(n, _)| format!("n{:0x}", n.0)).collect(), |
| ConstructionArgs::Value(value) => vec![value.to_primitive_string()], |
| ConstructionArgs::Inline(inline) => vec![inline.expr.clone()], |
| } |
| } |
| } |
|
|
| #[derive(Debug, Clone, PartialEq, Hash, Eq, serde::Serialize, serde::Deserialize)] |
| |
| |
| pub struct ProtoNode { |
| pub construction_args: ConstructionArgs, |
| pub input: ProtoNodeInput, |
| pub identifier: ProtoNodeIdentifier, |
| pub original_location: OriginalLocation, |
| pub skip_deduplication: bool, |
| } |
|
|
| impl Default for ProtoNode { |
| fn default() -> Self { |
| Self { |
| identifier: ProtoNodeIdentifier::new("graphene_core::ops::IdentityNode"), |
| construction_args: ConstructionArgs::Value(value::TaggedValue::U32(0).into()), |
| input: ProtoNodeInput::None, |
| original_location: OriginalLocation::default(), |
| skip_deduplication: false, |
| } |
| } |
| } |
|
|
| |
| #[derive(Debug, PartialEq, Eq, Clone, Hash, serde::Serialize, serde::Deserialize)] |
| pub enum ProtoNodeInput { |
| |
| None, |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ManualComposition(Type), |
| |
| |
| |
| |
| |
| |
| Node(NodeId), |
| |
| |
| |
| |
| |
| |
| NodeLambda(NodeId), |
| } |
|
|
| impl ProtoNode { |
| |
| |
| pub fn stable_node_id(&self) -> Option<NodeId> { |
| use std::hash::Hasher; |
| let mut hasher = rustc_hash::FxHasher::default(); |
|
|
| self.identifier.name.hash(&mut hasher); |
| self.construction_args.hash(&mut hasher); |
| if self.skip_deduplication { |
| self.original_location.path.hash(&mut hasher); |
| } |
|
|
| std::mem::discriminant(&self.input).hash(&mut hasher); |
| match self.input { |
| ProtoNodeInput::None => (), |
| ProtoNodeInput::ManualComposition(ref ty) => { |
| ty.hash(&mut hasher); |
| } |
| ProtoNodeInput::Node(id) => (id, false).hash(&mut hasher), |
| ProtoNodeInput::NodeLambda(id) => (id, true).hash(&mut hasher), |
| }; |
|
|
| Some(NodeId(hasher.finish())) |
| } |
|
|
| |
| pub fn value(value: ConstructionArgs, path: Vec<NodeId>) -> Self { |
| let inputs_exposed = match &value { |
| ConstructionArgs::Nodes(nodes) => nodes.len() + 1, |
| _ => 2, |
| }; |
| Self { |
| identifier: ProtoNodeIdentifier::new("graphene_core::value::ClonedNode"), |
| construction_args: value, |
| input: ProtoNodeInput::ManualComposition(concrete!(Context)), |
| original_location: OriginalLocation { |
| path: Some(path), |
| inputs_exposed: vec![false; inputs_exposed], |
| ..Default::default() |
| }, |
| skip_deduplication: false, |
| } |
| } |
|
|
| |
| |
| pub fn map_ids(&mut self, f: impl Fn(NodeId) -> NodeId, skip_lambdas: bool) { |
| match self.input { |
| ProtoNodeInput::Node(id) => self.input = ProtoNodeInput::Node(f(id)), |
| ProtoNodeInput::NodeLambda(id) => { |
| if !skip_lambdas { |
| self.input = ProtoNodeInput::NodeLambda(f(id)) |
| } |
| } |
| _ => (), |
| } |
|
|
| if let ConstructionArgs::Nodes(ids) = &mut self.construction_args { |
| ids.iter_mut().filter(|(_, lambda)| !(skip_lambdas && *lambda)).for_each(|(id, _)| *id = f(*id)); |
| } |
| } |
|
|
| pub fn unwrap_construction_nodes(&self) -> Vec<(NodeId, bool)> { |
| match &self.construction_args { |
| ConstructionArgs::Nodes(nodes) => nodes.clone(), |
| _ => panic!("tried to unwrap nodes from non node construction args \n node: {self:#?}"), |
| } |
| } |
| } |
|
|
| #[derive(Clone, Copy, PartialEq)] |
| enum NodeState { |
| Unvisited, |
| Visiting, |
| Visited, |
| } |
|
|
| impl ProtoNetwork { |
| fn check_ref(&self, ref_id: &NodeId, id: &NodeId) { |
| debug_assert!( |
| self.nodes.iter().any(|(check_id, _)| check_id == ref_id), |
| "Node id:{id} has a reference which uses node id:{ref_id} which doesn't exist in network {self:#?}" |
| ); |
| } |
|
|
| #[cfg(debug_assertions)] |
| pub fn example() -> (Self, NodeId, ProtoNode) { |
| let node_id = NodeId(1); |
| let proto_node = ProtoNode::default(); |
| let proto_network = ProtoNetwork { |
| inputs: vec![node_id], |
| output: node_id, |
| nodes: vec![(node_id, proto_node.clone())], |
| }; |
| (proto_network, node_id, proto_node) |
| } |
|
|
| |
| pub fn collect_outwards_edges(&self) -> HashMap<NodeId, Vec<NodeId>> { |
| let mut edges: HashMap<NodeId, Vec<NodeId>> = HashMap::new(); |
| for (id, node) in &self.nodes { |
| match &node.input { |
| ProtoNodeInput::Node(ref_id) | ProtoNodeInput::NodeLambda(ref_id) => { |
| self.check_ref(ref_id, id); |
| edges.entry(*ref_id).or_default().push(*id) |
| } |
| _ => (), |
| } |
|
|
| if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args { |
| for (ref_id, _) in ref_nodes { |
| self.check_ref(ref_id, id); |
| edges.entry(*ref_id).or_default().push(*id) |
| } |
| } |
| } |
| edges |
| } |
|
|
| |
| |
| pub fn generate_stable_node_ids(&mut self) { |
| debug_assert!(self.is_topologically_sorted()); |
| let outwards_edges = self.collect_outwards_edges(); |
|
|
| for index in 0..self.nodes.len() { |
| let Some(sni) = self.nodes[index].1.stable_node_id() else { |
| panic!("failed to generate stable node id for node {:#?}", self.nodes[index].1); |
| }; |
| self.replace_node_id(&outwards_edges, NodeId(index as u64), sni, false); |
| self.nodes[index].0 = sni; |
| } |
| } |
|
|
| |
| |
| pub fn collect_inwards_edges(&self) -> HashMap<NodeId, Vec<NodeId>> { |
| let mut edges: HashMap<NodeId, Vec<NodeId>> = HashMap::new(); |
| for (id, node) in &self.nodes { |
| match &node.input { |
| ProtoNodeInput::Node(ref_id) | ProtoNodeInput::NodeLambda(ref_id) => { |
| self.check_ref(ref_id, id); |
| edges.entry(*id).or_default().push(*ref_id) |
| } |
| _ => (), |
| } |
|
|
| if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args { |
| for (ref_id, _) in ref_nodes { |
| self.check_ref(ref_id, id); |
| edges.entry(*id).or_default().push(*ref_id) |
| } |
| } |
| } |
| edges |
| } |
|
|
| fn collect_inwards_edges_with_mapping(&self) -> (Vec<Vec<usize>>, FxHashMap<NodeId, usize>) { |
| let id_map: FxHashMap<_, _> = self.nodes.iter().enumerate().map(|(idx, (id, _))| (*id, idx)).collect(); |
|
|
| |
| let mut inwards_edges = vec![Vec::new(); self.nodes.len()]; |
| for (node_id, node) in &self.nodes { |
| let node_index = id_map[node_id]; |
| match &node.input { |
| ProtoNodeInput::Node(ref_id) | ProtoNodeInput::NodeLambda(ref_id) => { |
| self.check_ref(ref_id, &NodeId(node_index as u64)); |
| inwards_edges[node_index].push(id_map[ref_id]); |
| } |
| _ => {} |
| } |
|
|
| if let ConstructionArgs::Nodes(ref_nodes) = &node.construction_args { |
| for (ref_id, _) in ref_nodes { |
| self.check_ref(ref_id, &NodeId(node_index as u64)); |
| inwards_edges[node_index].push(id_map[ref_id]); |
| } |
| } |
| } |
|
|
| (inwards_edges, id_map) |
| } |
|
|
| |
| pub fn resolve_inputs(&mut self) -> Result<(), String> { |
| |
| self.reorder_ids()?; |
|
|
| let max_id = self.nodes.len() as u64 - 1; |
|
|
| |
| let outwards_edges = self.collect_outwards_edges(); |
|
|
| |
| for node_id in 0..=max_id { |
| let node_id = NodeId(node_id); |
|
|
| let (_, node) = &mut self.nodes[node_id.0 as usize]; |
|
|
| if let ProtoNodeInput::Node(input_node_id) = node.input { |
| |
| let compose_node_id = NodeId(self.nodes.len() as u64); |
|
|
| let (_, input_node_id_proto) = &self.nodes[input_node_id.0 as usize]; |
|
|
| let input = input_node_id_proto.input.clone(); |
|
|
| let mut path = input_node_id_proto.original_location.path.clone(); |
| if let Some(path) = &mut path { |
| path.push(node_id); |
| } |
|
|
| self.nodes.push(( |
| compose_node_id, |
| ProtoNode { |
| identifier: ProtoNodeIdentifier::new("graphene_core::structural::ComposeNode"), |
| construction_args: ConstructionArgs::Nodes(vec![(input_node_id, false), (node_id, true)]), |
| input, |
| original_location: OriginalLocation { path, ..Default::default() }, |
| skip_deduplication: false, |
| }, |
| )); |
|
|
| self.replace_node_id(&outwards_edges, node_id, compose_node_id, true); |
| } |
| } |
| self.reorder_ids()?; |
| Ok(()) |
| } |
|
|
| |
| fn replace_node_id(&mut self, outwards_edges: &HashMap<NodeId, Vec<NodeId>>, node_id: NodeId, compose_node_id: NodeId, skip_lambdas: bool) { |
| |
| if let Some(referring_nodes) = outwards_edges.get(&node_id) { |
| for &referring_node_id in referring_nodes { |
| let (_, referring_node) = &mut self.nodes[referring_node_id.0 as usize]; |
| referring_node.map_ids(|id| if id == node_id { compose_node_id } else { id }, skip_lambdas) |
| } |
| } |
|
|
| if self.output == node_id { |
| self.output = compose_node_id; |
| } |
|
|
| self.inputs.iter_mut().for_each(|id| { |
| if *id == node_id { |
| *id = compose_node_id; |
| } |
| }); |
| } |
|
|
| |
| |
| pub fn topological_sort(&self) -> Result<(Vec<NodeId>, FxHashMap<NodeId, usize>), String> { |
| let (inwards_edges, id_map) = self.collect_inwards_edges_with_mapping(); |
| let mut sorted = Vec::with_capacity(self.nodes.len()); |
| let mut stack = vec![id_map[&self.output]]; |
| let mut state = vec![NodeState::Unvisited; self.nodes.len()]; |
|
|
| while let Some(&node_index) = stack.last() { |
| match state[node_index] { |
| NodeState::Unvisited => { |
| state[node_index] = NodeState::Visiting; |
| for &dep_index in inwards_edges[node_index].iter().rev() { |
| match state[dep_index] { |
| NodeState::Visiting => { |
| return Err(format!("Cycle detected involving node {}", self.nodes[dep_index].0)); |
| } |
| NodeState::Unvisited => { |
| stack.push(dep_index); |
| } |
| NodeState::Visited => {} |
| } |
| } |
| } |
| NodeState::Visiting => { |
| stack.pop(); |
| state[node_index] = NodeState::Visited; |
| sorted.push(NodeId(node_index as u64)); |
| } |
| NodeState::Visited => { |
| stack.pop(); |
| } |
| } |
| } |
|
|
| Ok((sorted, id_map)) |
| } |
|
|
| fn is_topologically_sorted(&self) -> bool { |
| let mut visited = HashSet::new(); |
|
|
| let inwards_edges = self.collect_inwards_edges(); |
| for (id, _) in &self.nodes { |
| for &dependency in inwards_edges.get(id).unwrap_or(&Vec::new()) { |
| if !visited.contains(&dependency) { |
| dbg!(id, dependency); |
| dbg!(&visited); |
| dbg!(&self.nodes); |
| return false; |
| } |
| } |
| visited.insert(*id); |
| } |
| true |
| } |
|
|
| |
| fn reorder_ids(&mut self) -> Result<(), String> { |
| let (order, _id_map) = self.topological_sort()?; |
|
|
| |
| |
|
|
| |
| let new_positions: FxHashMap<_, _> = order.iter().enumerate().map(|(pos, id)| (self.nodes[id.0 as usize].0, pos)).collect(); |
| |
|
|
| |
|
|
| let mut new_nodes = Vec::with_capacity(order.len()); |
| for (index, &id) in order.iter().enumerate() { |
| let mut node = std::mem::take(&mut self.nodes[id.0 as usize].1); |
| |
| node.map_ids(|id| NodeId(*new_positions.get(&id).expect("node not found in lookup table") as u64), false); |
| new_nodes.push((NodeId(index as u64), node)); |
| } |
|
|
| |
| |
| |
| |
|
|
| |
| self.nodes = new_nodes; |
| self.inputs = self.inputs.iter().filter_map(|id| new_positions.get(id).map(|x| NodeId(*x as u64))).collect(); |
| self.output = NodeId(*new_positions.get(&self.output).unwrap() as u64); |
|
|
| assert_eq!(order.len(), self.nodes.len()); |
| Ok(()) |
| } |
| } |
| #[derive(Clone, PartialEq, serde::Serialize, serde::Deserialize)] |
| pub enum GraphErrorType { |
| NodeNotFound(NodeId), |
| InputNodeNotFound(NodeId), |
| UnexpectedGenerics { index: usize, inputs: Vec<Type> }, |
| NoImplementations, |
| NoConstructor, |
| InvalidImplementations { inputs: String, error_inputs: Vec<Vec<(usize, (Type, Type))>> }, |
| MultipleImplementations { inputs: String, valid: Vec<NodeIOTypes> }, |
| } |
| impl Debug for GraphErrorType { |
| |
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| match self { |
| GraphErrorType::NodeNotFound(id) => write!(f, "Input node {id} is not present in the typing context"), |
| GraphErrorType::InputNodeNotFound(id) => write!(f, "Input node {id} is not present in the typing context"), |
| GraphErrorType::UnexpectedGenerics { index, inputs } => write!(f, "Generic inputs should not exist but found at {index}: {inputs:?}"), |
| GraphErrorType::NoImplementations => write!(f, "No implementations found"), |
| GraphErrorType::NoConstructor => write!(f, "No construct found for node"), |
| GraphErrorType::InvalidImplementations { inputs, error_inputs } => { |
| let format_error = |(index, (found, expected)): &(usize, (Type, Type))| { |
| let index = index + 1; |
| format!( |
| "\ |
| • Input {index}:\n\ |
| …found: {found}\n\ |
| …expected: {expected}\ |
| " |
| ) |
| }; |
| let format_error_list = |errors: &Vec<(usize, (Type, Type))>| errors.iter().map(format_error).collect::<Vec<_>>().join("\n"); |
| let mut errors = error_inputs.iter().map(format_error_list).collect::<Vec<_>>(); |
| errors.sort(); |
| let errors = errors.join("\n"); |
| let incompatibility = if errors.chars().filter(|&c| c == '•').count() == 1 { |
| "This input type is incompatible:" |
| } else { |
| "These input types are incompatible:" |
| }; |
|
|
| write!( |
| f, |
| "\ |
| {incompatibility}\n\ |
| {errors}\n\ |
| \n\ |
| The node is currently receiving all of the following input types:\n\ |
| {inputs}\n\ |
| This is not a supported arrangement of types for the node.\ |
| " |
| ) |
| } |
| GraphErrorType::MultipleImplementations { inputs, valid } => write!(f, "Multiple implementations found ({inputs}):\n{valid:#?}"), |
| } |
| } |
| } |
| #[derive(Clone, PartialEq, serde::Serialize, serde::Deserialize)] |
| pub struct GraphError { |
| pub node_path: Vec<NodeId>, |
| pub identifier: Cow<'static, str>, |
| pub error: GraphErrorType, |
| } |
| impl GraphError { |
| pub fn new(node: &ProtoNode, text: impl Into<GraphErrorType>) -> Self { |
| Self { |
| node_path: node.original_location.path.clone().unwrap_or_default(), |
| identifier: node.identifier.name.clone(), |
| error: text.into(), |
| } |
| } |
| } |
| impl Debug for GraphError { |
| fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |
| f.debug_struct("NodeGraphError") |
| .field("path", &self.node_path.iter().map(|id| id.0).collect::<Vec<_>>()) |
| .field("identifier", &self.identifier.to_string()) |
| .field("error", &self.error) |
| .finish() |
| } |
| } |
| pub type GraphErrors = Vec<GraphError>; |
|
|
| |
| #[derive(Default, Clone, dyn_any::DynAny)] |
| pub struct TypingContext { |
| lookup: Cow<'static, HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>>>, |
| inferred: HashMap<NodeId, NodeIOTypes>, |
| constructor: HashMap<NodeId, NodeConstructor>, |
| } |
|
|
| impl TypingContext { |
| |
| pub fn new(lookup: &'static HashMap<ProtoNodeIdentifier, HashMap<NodeIOTypes, NodeConstructor>>) -> Self { |
| Self { |
| lookup: Cow::Borrowed(lookup), |
| ..Default::default() |
| } |
| } |
|
|
| |
| |
| |
| pub fn update(&mut self, network: &ProtoNetwork) -> Result<(), GraphErrors> { |
| for (id, node) in network.nodes.iter() { |
| self.infer(*id, node)?; |
| } |
|
|
| Ok(()) |
| } |
|
|
| pub fn remove_inference(&mut self, node_id: NodeId) -> Option<NodeIOTypes> { |
| self.constructor.remove(&node_id); |
| self.inferred.remove(&node_id) |
| } |
|
|
| |
| pub fn constructor(&self, node_id: NodeId) -> Option<NodeConstructor> { |
| self.constructor.get(&node_id).copied() |
| } |
|
|
| |
| pub fn type_of(&self, node_id: NodeId) -> Option<&NodeIOTypes> { |
| self.inferred.get(&node_id) |
| } |
|
|
| |
| pub fn infer(&mut self, node_id: NodeId, node: &ProtoNode) -> Result<NodeIOTypes, GraphErrors> { |
| |
| if let Some(inferred) = self.inferred.get(&node_id) { |
| return Ok(inferred.clone()); |
| } |
|
|
| let inputs = match node.construction_args { |
| |
| ConstructionArgs::Value(ref v) => { |
| assert!(matches!(node.input, ProtoNodeInput::None) || matches!(node.input, ProtoNodeInput::ManualComposition(ref x) if x == &concrete!(Context))); |
| |
| let types = NodeIOTypes::new(concrete!(Context), Type::Future(Box::new(v.ty())), vec![]); |
| self.inferred.insert(node_id, types.clone()); |
| return Ok(types); |
| } |
| |
| ConstructionArgs::Nodes(ref nodes) => nodes |
| .iter() |
| .map(|(id, _)| { |
| self.inferred |
| .get(id) |
| .ok_or_else(|| vec![GraphError::new(node, GraphErrorType::NodeNotFound(*id))]) |
| .map(|node| node.ty()) |
| }) |
| .collect::<Result<Vec<Type>, GraphErrors>>()?, |
| ConstructionArgs::Inline(ref inline) => vec![inline.ty.clone()], |
| }; |
|
|
| |
| |
| let primary_input_or_call_argument = match node.input { |
| ProtoNodeInput::None => concrete!(()), |
| ProtoNodeInput::ManualComposition(ref ty) => ty.clone(), |
| ProtoNodeInput::Node(id) | ProtoNodeInput::NodeLambda(id) => { |
| let input = self.inferred.get(&id).ok_or_else(|| vec![GraphError::new(node, GraphErrorType::InputNodeNotFound(id))])?; |
| input.return_value.clone() |
| } |
| }; |
| let using_manual_composition = matches!(node.input, ProtoNodeInput::ManualComposition(_) | ProtoNodeInput::None); |
| let impls = self.lookup.get(&node.identifier).ok_or_else(|| vec![GraphError::new(node, GraphErrorType::NoImplementations)])?; |
|
|
| if let Some(index) = inputs.iter().position(|p| { |
| matches!(p, |
| Type::Fn(_, b) if matches!(b.as_ref(), Type::Generic(_))) |
| }) { |
| return Err(vec![GraphError::new(node, GraphErrorType::UnexpectedGenerics { index, inputs })]); |
| } |
|
|
| |
| |
| fn valid_type(from: &Type, to: &Type) -> bool { |
| match (from, to) { |
| |
| (Type::Concrete(type1), Type::Concrete(type2)) => type1 == type2, |
| |
| (Type::Future(type1), Type::Future(type2)) => valid_type(type1, type2), |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| (Type::Fn(in1, out1), Type::Fn(in2, out2)) => valid_type(out2, out1) && valid_type(in1, in2), |
| |
| |
| (Type::Generic(_), _) | (_, Type::Generic(_)) => true, |
| |
| _ => false, |
| } |
| } |
|
|
| |
| let valid_output_types = impls |
| .keys() |
| .filter(|node_io| valid_type(&node_io.call_argument, &primary_input_or_call_argument) && inputs.iter().zip(node_io.inputs.iter()).all(|(p1, p2)| valid_type(p1, p2))) |
| .collect::<Vec<_>>(); |
|
|
| |
| let substitution_results = valid_output_types |
| .iter() |
| .map(|node_io| { |
| let generics_lookup: Result<HashMap<_, _>, _> = collect_generics(node_io) |
| .iter() |
| .map(|generic| check_generic(node_io, &primary_input_or_call_argument, &inputs, generic).map(|x| (generic.to_string(), x))) |
| .collect(); |
|
|
| generics_lookup.map(|generics_lookup| { |
| let orig_node_io = (*node_io).clone(); |
| let mut new_node_io = orig_node_io.clone(); |
| replace_generics(&mut new_node_io, &generics_lookup); |
| (new_node_io, orig_node_io) |
| }) |
| }) |
| .collect::<Vec<_>>(); |
|
|
| |
| let valid_impls = substitution_results.iter().filter_map(|result| result.as_ref().ok()).collect::<Vec<_>>(); |
|
|
| match valid_impls.as_slice() { |
| [] => { |
| let mut best_errors = usize::MAX; |
| let mut error_inputs = Vec::new(); |
| for node_io in impls.keys() { |
| let current_errors = [&primary_input_or_call_argument] |
| .into_iter() |
| .chain(&inputs) |
| .cloned() |
| .zip([&node_io.call_argument].into_iter().chain(&node_io.inputs).cloned()) |
| .enumerate() |
| .filter(|(_, (p1, p2))| !valid_type(p1, p2)) |
| .map(|(index, ty)| { |
| let i = node.original_location.inputs(index).min_by_key(|s| s.node.len()).map(|s| s.index).unwrap_or(index); |
| let i = if using_manual_composition { i } else { i + 1 }; |
| (i, ty) |
| }) |
| .collect::<Vec<_>>(); |
| if current_errors.len() < best_errors { |
| best_errors = current_errors.len(); |
| error_inputs.clear(); |
| } |
| if current_errors.len() <= best_errors { |
| error_inputs.push(current_errors); |
| } |
| } |
| let inputs = [&primary_input_or_call_argument] |
| .into_iter() |
| .chain(&inputs) |
| .enumerate() |
| |
| .filter_map(|(i, t)| { |
| let i = if using_manual_composition { i } else { i + 1 }; |
| if i == 0 { None } else { Some(format!("• Input {i}: {t}")) } |
| }) |
| .collect::<Vec<_>>() |
| .join("\n"); |
| Err(vec![GraphError::new(node, GraphErrorType::InvalidImplementations { inputs, error_inputs })]) |
| } |
| [(node_io, org_nio)] => { |
| let node_io = node_io.clone(); |
|
|
| |
| self.inferred.insert(node_id, node_io.clone()); |
| self.constructor.insert(node_id, impls[org_nio]); |
| Ok(node_io) |
| } |
| |
| [first, second] => { |
| if first.0.call_argument != second.0.call_argument { |
| for (node_io, orig_nio) in [first, second] { |
| if node_io.call_argument != concrete!(()) { |
| continue; |
| } |
|
|
| |
| self.inferred.insert(node_id, node_io.clone()); |
| self.constructor.insert(node_id, impls[orig_nio]); |
| return Ok(node_io.clone()); |
| } |
| } |
| let inputs = [&primary_input_or_call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::<Vec<_>>().join(", "); |
| let valid = valid_output_types.into_iter().cloned().collect(); |
| Err(vec![GraphError::new(node, GraphErrorType::MultipleImplementations { inputs, valid })]) |
| } |
|
|
| _ => { |
| let inputs = [&primary_input_or_call_argument].into_iter().chain(&inputs).map(|t| t.to_string()).collect::<Vec<_>>().join(", "); |
| let valid = valid_output_types.into_iter().cloned().collect(); |
| Err(vec![GraphError::new(node, GraphErrorType::MultipleImplementations { inputs, valid })]) |
| } |
| } |
| } |
| } |
|
|
| |
| fn collect_generics(types: &NodeIOTypes) -> Vec<Cow<'static, str>> { |
| let inputs = [&types.call_argument].into_iter().chain(types.inputs.iter().map(|x| x.nested_type())); |
| let mut generics = inputs |
| .filter_map(|t| match t { |
| Type::Generic(out) => Some(out.clone()), |
| _ => None, |
| }) |
| .collect::<Vec<_>>(); |
| if let Type::Generic(out) = &types.return_value { |
| generics.push(out.clone()); |
| } |
| generics.dedup(); |
| generics |
| } |
|
|
| |
| fn check_generic(types: &NodeIOTypes, input: &Type, parameters: &[Type], generic: &str) -> Result<Type, String> { |
| let inputs = [(Some(&types.call_argument), Some(input))] |
| .into_iter() |
| .chain(types.inputs.iter().map(|x| x.fn_input()).zip(parameters.iter().map(|x| x.fn_input()))) |
| .chain(types.inputs.iter().map(|x| x.fn_output()).zip(parameters.iter().map(|x| x.fn_output()))); |
| let concrete_inputs = inputs.filter(|(ni, _)| matches!(ni, Some(Type::Generic(input)) if generic == input)); |
| let mut outputs = concrete_inputs.flat_map(|(_, out)| out); |
| let out_ty = outputs |
| .next() |
| .ok_or_else(|| format!("Generic output type {generic} is not dependent on input {input:?} or parameters {parameters:?}",))?; |
| if outputs.any(|ty| ty != out_ty) { |
| return Err(format!("Generic output type {generic} is dependent on multiple inputs or parameters",)); |
| } |
| Ok(out_ty.clone()) |
| } |
|
|
| |
| fn replace_generics(types: &mut NodeIOTypes, lookup: &HashMap<String, Type>) { |
| let replace = |ty: &Type| { |
| let Type::Generic(ident) = ty else { |
| return None; |
| }; |
| lookup.get(ident.as_ref()).cloned() |
| }; |
| types.call_argument.replace_nested(replace); |
| types.return_value.replace_nested(replace); |
| for input in &mut types.inputs { |
| input.replace_nested(replace); |
| } |
| } |
|
|
| #[cfg(test)] |
| mod test { |
| use super::*; |
| use crate::proto::{ConstructionArgs, ProtoNetwork, ProtoNode, ProtoNodeInput}; |
|
|
| #[test] |
| fn topological_sort() { |
| let construction_network = test_network(); |
| let (sorted, _) = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network."); |
| let sorted: Vec<_> = sorted.iter().map(|x| construction_network.nodes[x.0 as usize].0).collect(); |
| println!("{sorted:#?}"); |
| assert_eq!(sorted, vec![NodeId(14), NodeId(10), NodeId(11), NodeId(1)]); |
| } |
|
|
| #[test] |
| fn topological_sort_with_cycles() { |
| let construction_network = test_network_with_cycles(); |
| let sorted = construction_network.topological_sort(); |
|
|
| assert!(sorted.is_err()) |
| } |
|
|
| #[test] |
| fn id_reordering() { |
| let mut construction_network = test_network(); |
| construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network."); |
| let (sorted, _) = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network."); |
| let sorted: Vec<_> = sorted.iter().map(|x| construction_network.nodes[x.0 as usize].0).collect(); |
| println!("nodes: {:#?}", construction_network.nodes); |
| assert_eq!(sorted, vec![NodeId(0), NodeId(1), NodeId(2), NodeId(3)]); |
| let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); |
| println!("{ids:#?}"); |
| println!("nodes: {:#?}", construction_network.nodes); |
| assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); |
| assert_eq!(ids, vec![NodeId(0), NodeId(1), NodeId(2), NodeId(3)]); |
| } |
|
|
| #[test] |
| fn id_reordering_idempotent() { |
| let mut construction_network = test_network(); |
| construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network."); |
| construction_network.reorder_ids().expect("Error when calling 'reorder_ids' on 'construction_network."); |
| let (sorted, _) = construction_network.topological_sort().expect("Error when calling 'topological_sort' on 'construction_network."); |
| assert_eq!(sorted, vec![NodeId(0), NodeId(1), NodeId(2), NodeId(3)]); |
| let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); |
| println!("{ids:#?}"); |
| assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); |
| assert_eq!(ids, vec![NodeId(0), NodeId(1), NodeId(2), NodeId(3)]); |
| } |
|
|
| #[test] |
| fn input_resolution() { |
| let mut construction_network = test_network(); |
| construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network."); |
| println!("{construction_network:#?}"); |
| assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); |
| assert_eq!(construction_network.nodes.len(), 6); |
| assert_eq!(construction_network.nodes[5].1.construction_args, ConstructionArgs::Nodes(vec![(NodeId(3), false), (NodeId(4), true)])); |
| } |
|
|
| #[test] |
| fn stable_node_id_generation() { |
| let mut construction_network = test_network(); |
| construction_network.resolve_inputs().expect("Error when calling 'resolve_inputs' on 'construction_network."); |
| construction_network.generate_stable_node_ids(); |
| assert_eq!(construction_network.nodes[0].1.identifier.name.as_ref(), "value"); |
| let ids: Vec<_> = construction_network.nodes.iter().map(|(id, _)| *id).collect(); |
| assert_eq!( |
| ids, |
| vec![ |
| NodeId(16997244687192517417), |
| NodeId(12226224850522777131), |
| NodeId(9162113827627229771), |
| NodeId(12793582657066318419), |
| NodeId(16945623684036608820), |
| NodeId(2640415155091892458) |
| ] |
| ); |
| } |
|
|
| fn test_network() -> ProtoNetwork { |
| ProtoNetwork { |
| inputs: vec![NodeId(10)], |
| output: NodeId(1), |
| nodes: [ |
| ( |
| NodeId(7), |
| ProtoNode { |
| identifier: "id".into(), |
| input: ProtoNodeInput::Node(NodeId(11)), |
| construction_args: ConstructionArgs::Nodes(vec![]), |
| ..Default::default() |
| }, |
| ), |
| ( |
| NodeId(1), |
| ProtoNode { |
| identifier: "id".into(), |
| input: ProtoNodeInput::Node(NodeId(11)), |
| construction_args: ConstructionArgs::Nodes(vec![]), |
| ..Default::default() |
| }, |
| ), |
| ( |
| NodeId(10), |
| ProtoNode { |
| identifier: "cons".into(), |
| input: ProtoNodeInput::ManualComposition(concrete!(u32)), |
| construction_args: ConstructionArgs::Nodes(vec![(NodeId(14), false)]), |
| ..Default::default() |
| }, |
| ), |
| ( |
| NodeId(11), |
| ProtoNode { |
| identifier: "add".into(), |
| input: ProtoNodeInput::Node(NodeId(10)), |
| construction_args: ConstructionArgs::Nodes(vec![]), |
| ..Default::default() |
| }, |
| ), |
| ( |
| NodeId(14), |
| ProtoNode { |
| identifier: "value".into(), |
| input: ProtoNodeInput::None, |
| construction_args: ConstructionArgs::Value(value::TaggedValue::U32(2).into()), |
| ..Default::default() |
| }, |
| ), |
| ] |
| .into_iter() |
| .collect(), |
| } |
| } |
|
|
| fn test_network_with_cycles() -> ProtoNetwork { |
| ProtoNetwork { |
| inputs: vec![NodeId(1)], |
| output: NodeId(1), |
| nodes: [ |
| ( |
| NodeId(1), |
| ProtoNode { |
| identifier: "id".into(), |
| input: ProtoNodeInput::Node(NodeId(2)), |
| construction_args: ConstructionArgs::Nodes(vec![]), |
| ..Default::default() |
| }, |
| ), |
| ( |
| NodeId(2), |
| ProtoNode { |
| identifier: "id".into(), |
| input: ProtoNodeInput::Node(NodeId(1)), |
| construction_args: ConstructionArgs::Nodes(vec![]), |
| ..Default::default() |
| }, |
| ), |
| ] |
| .into_iter() |
| .collect(), |
| } |
| } |
| } |
|
|