| use crate::ast::{Literal, Node}; |
| use crate::constants::DEFAULT_FUNCTIONS; |
| use crate::context::{EvalContext, FunctionProvider, ValueProvider}; |
| use crate::value::{Number, Value}; |
| use thiserror::Error; |
|
|
| #[derive(Debug, Error)] |
| pub enum EvalError { |
| #[error("Missing value: {0}")] |
| MissingValue(String), |
|
|
| #[error("Missing function: {0}")] |
| MissingFunction(String), |
| #[error("Wrong type for function call")] |
| TypeError, |
| } |
|
|
| impl Node { |
| pub fn eval<V: ValueProvider, F: FunctionProvider>(&self, context: &EvalContext<V, F>) -> Result<Value, EvalError> { |
| match self { |
| Node::Lit(lit) => match lit { |
| Literal::Float(num) => Ok(Value::from_f64(*num)), |
| Literal::Complex(num) => Ok(Value::Number(Number::Complex(*num))), |
| }, |
|
|
| Node::BinOp { lhs, op, rhs } => match (lhs.eval(context)?, rhs.eval(context)?) { |
| (Value::Number(lhs), Value::Number(rhs)) => Ok(Value::Number(lhs.binary_op(*op, rhs))), |
| }, |
| Node::UnaryOp { expr, op } => match expr.eval(context)? { |
| Value::Number(num) => Ok(Value::Number(num.unary_op(*op))), |
| }, |
| Node::Var(name) => context.get_value(name).ok_or_else(|| EvalError::MissingValue(name.clone())), |
| Node::FnCall { name, expr } => { |
| let values = expr.iter().map(|expr| expr.eval(context)).collect::<Result<Vec<Value>, EvalError>>()?; |
| if let Some(function) = DEFAULT_FUNCTIONS.get(&name.as_str()) { |
| function(&values).ok_or(EvalError::TypeError) |
| } else if let Some(val) = context.run_function(name, &values) { |
| Ok(val) |
| } else { |
| context.get_value(name).ok_or_else(|| EvalError::MissingFunction(name.to_string())) |
| } |
| } |
| } |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use crate::ast::{BinaryOp, Literal, Node, UnaryOp}; |
| use crate::context::{EvalContext, ValueMap}; |
| use crate::value::Value; |
|
|
| macro_rules! eval_tests { |
| ($($name:ident: $expected:expr_2021 => $expr:expr_2021),* $(,)?) => { |
| $( |
| #[test] |
| fn $name() { |
| let result = $expr.eval(&EvalContext::default()).unwrap(); |
| assert_eq!(result, $expected); |
| } |
| )* |
| }; |
| } |
|
|
| eval_tests! { |
| test_addition: Value::from_f64(7.0) => Node::BinOp { |
| lhs: Box::new(Node::Lit(Literal::Float(3.0))), |
| op: BinaryOp::Add, |
| rhs: Box::new(Node::Lit(Literal::Float(4.0))), |
| }, |
| test_subtraction: Value::from_f64(1.0) => Node::BinOp { |
| lhs: Box::new(Node::Lit(Literal::Float(5.0))), |
| op: BinaryOp::Sub, |
| rhs: Box::new(Node::Lit(Literal::Float(4.0))), |
| }, |
| test_multiplication: Value::from_f64(12.0) => Node::BinOp { |
| lhs: Box::new(Node::Lit(Literal::Float(3.0))), |
| op: BinaryOp::Mul, |
| rhs: Box::new(Node::Lit(Literal::Float(4.0))), |
| }, |
| test_division: Value::from_f64(2.5) => Node::BinOp { |
| lhs: Box::new(Node::Lit(Literal::Float(5.0))), |
| op: BinaryOp::Div, |
| rhs: Box::new(Node::Lit(Literal::Float(2.0))), |
| }, |
| test_negation: Value::from_f64(-3.0) => Node::UnaryOp { |
| expr: Box::new(Node::Lit(Literal::Float(3.0))), |
| op: UnaryOp::Neg, |
| }, |
| test_sqrt: Value::from_f64(2.0) => Node::UnaryOp { |
| expr: Box::new(Node::Lit(Literal::Float(4.0))), |
| op: UnaryOp::Sqrt, |
| }, |
| test_power: Value::from_f64(8.0) => Node::BinOp { |
| lhs: Box::new(Node::Lit(Literal::Float(2.0))), |
| op: BinaryOp::Pow, |
| rhs: Box::new(Node::Lit(Literal::Float(3.0))), |
| }, |
| } |
| } |
|
|