| |
|
| | from __future__ import annotations
|
| | from dataclasses import dataclass, field
|
| | from typing import Any, Dict, List, Optional, Tuple, Set
|
| | from datetime import date, datetime
|
| | import math
|
| | import yaml
|
| | import ast
|
| |
|
| |
|
| | class SafeEvalError(Exception):
|
| | pass
|
| |
|
| | class SafeExpr:
|
| | """
|
| | Very small arithmetic evaluator over a dict of variables.
|
| | Supports + - * / // % **, parentheses, numbers, names, and
|
| | simple calls to min, max, abs, round with at most 2 args.
|
| | """
|
| | ALLOWED_FUNCS = {"min": min, "max": max, "abs": abs, "round": round}
|
| | ALLOWED_NODES = (
|
| | ast.Expression, ast.BinOp, ast.UnaryOp, ast.Num, ast.Name,
|
| | ast.Load, ast.Add, ast.Sub, ast.Mult, ast.Div, ast.FloorDiv, ast.Mod, ast.Pow,
|
| | ast.USub, ast.UAdd, ast.Call, ast.Tuple, ast.Constant, ast.Compare,
|
| | ast.Lt, ast.Gt, ast.LtE, ast.GtE, ast.Eq, ast.NotEq, ast.BoolOp, ast.And, ast.Or,
|
| | ast.IfExp, ast.Subscript, ast.Index, ast.Dict, ast.List
|
| | )
|
| |
|
| | @classmethod
|
| | def eval(cls, expr: str, variables: Dict[str, Any]) -> Any:
|
| | try:
|
| | tree = ast.parse(expr, mode="eval")
|
| | except Exception as e:
|
| | raise SafeEvalError(f"Parse error: {e}") from e
|
| | if not all(isinstance(n, cls.ALLOWED_NODES) for n in ast.walk(tree)):
|
| | raise SafeEvalError("Disallowed syntax in expression")
|
| | return cls._eval_node(tree.body, variables)
|
| |
|
| | @classmethod
|
| | def _eval_node(cls, node, vars):
|
| | if isinstance(node, ast.Constant):
|
| | return node.value
|
| | if isinstance(node, ast.Num):
|
| | return node.n
|
| | if isinstance(node, ast.Name):
|
| | try:
|
| | return vars[node.id]
|
| | except KeyError:
|
| | raise SafeEvalError(f"Unknown variable '{node.id}'")
|
| | if isinstance(node, ast.UnaryOp):
|
| | val = cls._eval_node(node.operand, vars)
|
| | if isinstance(node.op, ast.UAdd):
|
| | return +val
|
| | if isinstance(node.op, ast.USub):
|
| | return -val
|
| | raise SafeEvalError("Unsupported unary op")
|
| | if isinstance(node, ast.BinOp):
|
| | l = cls._eval_node(node.left, vars)
|
| | r = cls._eval_node(node.right, vars)
|
| | if isinstance(node.op, ast.Add): return l + r
|
| | if isinstance(node.op, ast.Sub): return l - r
|
| | if isinstance(node.op, ast.Mult): return l * r
|
| | if isinstance(node.op, ast.Div): return l / r
|
| | if isinstance(node.op, ast.FloorDiv): return l // r
|
| | if isinstance(node.op, ast.Mod): return l % r
|
| | if isinstance(node.op, ast.Pow): return l ** r
|
| | raise SafeEvalError("Unsupported binary op")
|
| | if isinstance(node, ast.Compare):
|
| | left = cls._eval_node(node.left, vars)
|
| | result = True
|
| | cur = left
|
| | for op, comparator in zip(node.ops, node.comparators):
|
| | right = cls._eval_node(comparator, vars)
|
| | if isinstance(op, ast.Lt): ok = cur < right
|
| | elif isinstance(op, ast.Gt): ok = cur > right
|
| | elif isinstance(op, ast.LtE): ok = cur <= right
|
| | elif isinstance(op, ast.GtE): ok = cur >= right
|
| | elif isinstance(op, ast.Eq): ok = cur == right
|
| | elif isinstance(op, ast.NotEq): ok = cur != right
|
| | else: raise SafeEvalError("Unsupported comparator")
|
| | result = result and ok
|
| | cur = right
|
| | return result
|
| | if isinstance(node, ast.BoolOp):
|
| | vals = [cls._eval_node(v, vars) for v in node.values]
|
| | if isinstance(node.op, ast.And):
|
| | out = True
|
| | for v in vals:
|
| | out = out and bool(v)
|
| | return out
|
| | if isinstance(node.op, ast.Or):
|
| | out = False
|
| | for v in vals:
|
| | out = out or bool(v)
|
| | return out
|
| | raise SafeEvalError("Unsupported bool op")
|
| | if isinstance(node, ast.IfExp):
|
| | cond = cls._eval_node(node.test, vars)
|
| | return cls._eval_node(node.body if cond else node.orelse, vars)
|
| | if isinstance(node, ast.Call):
|
| | if not isinstance(node.func, ast.Name):
|
| | raise SafeEvalError("Only simple function calls allowed")
|
| | fname = node.func.id
|
| | if fname not in cls.ALLOWED_FUNCS:
|
| | raise SafeEvalError(f"Function '{fname}' not allowed")
|
| | args = [cls._eval_node(a, vars) for a in node.args]
|
| | if len(args) > 2:
|
| | raise SafeEvalError("Too many args")
|
| | return cls.ALLOWED_FUNCS[fname](*args)
|
| | if isinstance(node, (ast.List, ast.Tuple)):
|
| | return [cls._eval_node(e, vars) for e in node.elts]
|
| | if isinstance(node, ast.Dict):
|
| | return {cls._eval_node(k, vars): cls._eval_node(v, vars) for k, v in zip(node.keys, node.values)}
|
| | if isinstance(node, ast.Subscript):
|
| | container = cls._eval_node(node.value, vars)
|
| | idx = cls._eval_node(node.slice.value if hasattr(node.slice, "value") else node.slice, vars)
|
| | return container[idx]
|
| | raise SafeEvalError(f"Unsupported node: {type(node).__name__}")
|
| |
|
| |
|
| | @dataclass
|
| | class AuthorityRef:
|
| | doc: str
|
| | section: Optional[str] = None
|
| | subsection: Optional[str] = None
|
| | page: Optional[str] = None
|
| | url_anchor: Optional[str] = None
|
| |
|
| | @dataclass
|
| | class RuleAtom:
|
| | id: str
|
| | title: str
|
| | description: str
|
| | tax_type: str
|
| | jurisdiction_level: str
|
| | formula_type: str
|
| | inputs: List[str]
|
| | output: str
|
| | parameters: Dict[str, Any] = field(default_factory=dict)
|
| | ordering_constraints: Dict[str, List[str]] = field(default_factory=dict)
|
| | effective_from: str = "1900-01-01"
|
| | effective_to: Optional[str] = None
|
| | authority: List[AuthorityRef] = field(default_factory=list)
|
| | notes: Optional[str] = None
|
| | status: str = "approved"
|
| |
|
| | def is_active_on(self, on_date: date) -> bool:
|
| |
|
| | if isinstance(self.effective_from, str):
|
| | start = datetime.strptime(self.effective_from, "%Y-%m-%d").date()
|
| | else:
|
| | start = self.effective_from
|
| |
|
| | if self.effective_to is None:
|
| | end = datetime.max.date()
|
| | elif isinstance(self.effective_to, str):
|
| | end = datetime.strptime(self.effective_to, "%Y-%m-%d").date()
|
| | else:
|
| | end = self.effective_to
|
| |
|
| | return start <= on_date <= end
|
| |
|
| |
|
| | class RuleCatalog:
|
| | def __init__(self, atoms: List[RuleAtom]):
|
| | self.atoms = atoms
|
| | self._by_id = {a.id: a for a in atoms}
|
| |
|
| | @classmethod
|
| | def from_yaml_files(cls, paths: List[str]) -> "RuleCatalog":
|
| | atoms: List[RuleAtom] = []
|
| | for p in paths:
|
| | with open(p, "r", encoding="utf-8") as f:
|
| | data = yaml.safe_load(f)
|
| | if isinstance(data, dict):
|
| | data = [data]
|
| | for item in data:
|
| | auth = [AuthorityRef(**r) for r in item.get("authority", [])]
|
| | atoms.append(RuleAtom(**{**item, "authority": auth}))
|
| | return cls(atoms)
|
| |
|
| | def select(self, *, tax_type: str, on_date: date, jurisdiction: Optional[str] = None) -> List[RuleAtom]:
|
| | out = []
|
| | for a in self.atoms:
|
| | if a.tax_type != tax_type:
|
| | continue
|
| | if jurisdiction and a.jurisdiction_level != jurisdiction:
|
| | continue
|
| | if not a.is_active_on(on_date):
|
| | continue
|
| | if a.status == "deprecated":
|
| | continue
|
| | out.append(a)
|
| | return out
|
| |
|
| | class CalculationResult:
|
| | def __init__(self):
|
| | self.values: Dict[str, float] = {}
|
| | self.lines: List[Dict[str, Any]] = []
|
| |
|
| | def set_value(self, key: str, val: float):
|
| | self.values[key] = float(val)
|
| |
|
| | def get(self, key: str, default: float = 0.0) -> float:
|
| | return float(self.values.get(key, default))
|
| |
|
| | class TaxEngine:
|
| | def __init__(self, catalog: RuleCatalog, rounding_mode: str = "half_up"):
|
| | self.catalog = catalog
|
| | self.rounding_mode = rounding_mode
|
| |
|
| |
|
| | def _toposort(self, rules: List[RuleAtom]) -> List[RuleAtom]:
|
| | after_map: Dict[str, Set[str]] = {}
|
| | indeg: Dict[str, int] = {}
|
| | id_map = {r.id: r for r in rules}
|
| | for r in rules:
|
| | deps = set(r.ordering_constraints.get("applies_after", []))
|
| | after_map[r.id] = {d for d in deps if d in id_map}
|
| | for r in rules:
|
| | indeg[r.id] = 0
|
| | for r, deps in after_map.items():
|
| | for d in deps:
|
| | indeg[r] += 1
|
| | queue = [rid for rid, deg in indeg.items() if deg == 0]
|
| | ordered: List[RuleAtom] = []
|
| | while queue:
|
| | rid = queue.pop(0)
|
| | ordered.append(id_map[rid])
|
| | for nid, deps in after_map.items():
|
| | if rid in deps:
|
| | indeg[nid] -= 1
|
| | if indeg[nid] == 0:
|
| | queue.append(nid)
|
| | if len(ordered) != len(rules):
|
| |
|
| | raise ValueError("Dependency cycle or missing rule id in applies_after")
|
| | return ordered
|
| |
|
| | def _round(self, x: float) -> float:
|
| | if self.rounding_mode == "half_up":
|
| | return float(int(x + 0.5)) if x >= 0 else -float(int(abs(x) + 0.5))
|
| | return round(x)
|
| |
|
| | def _evaluate_rule(self, r: RuleAtom, ctx: CalculationResult) -> Tuple[str, float, Dict[str, Any]]:
|
| | v = ctx.values
|
| |
|
| | def ex(expr: str) -> float:
|
| | return float(SafeExpr.eval(expr, v))
|
| |
|
| | details: Dict[str, Any] = {}
|
| |
|
| | if r.formula_type == "fixed_amount":
|
| | amt = ex(r.parameters.get("amount_expr", "0"))
|
| | elif r.formula_type == "rate_on_base":
|
| | base = ex(r.parameters.get("base_expr", "0"))
|
| | rate = float(r.parameters.get("rate", 0))
|
| | amt = base * rate
|
| | details.update({"base": base, "rate": rate})
|
| | elif r.formula_type == "capped_percentage":
|
| | base = ex(r.parameters.get("base_expr", "0"))
|
| | cap_rate = float(r.parameters.get("cap_rate", 0))
|
| | amt = min(base, base * cap_rate)
|
| | details.update({"base": base, "cap_rate": cap_rate})
|
| | elif r.formula_type == "max_of_plus":
|
| | base_opts = [ex(opt.get("expr", "0")) for opt in r.parameters.get("base_options", [])]
|
| | plus_expr = r.parameters.get("plus_expr", "0")
|
| | plus = ex(plus_expr) if plus_expr else 0.0
|
| | amt = max(base_opts) + plus if base_opts else plus
|
| | details.update({"base_options": base_opts, "plus": plus})
|
| | elif r.formula_type == "piecewise_bands":
|
| | taxable = ex(r.parameters.get("base_expr", "0"))
|
| | bands = r.parameters.get("bands", [])
|
| | remaining = taxable
|
| | tax = 0.0
|
| | calc_steps = []
|
| | prev_upper = 0.0
|
| | for b in bands:
|
| | upper = float("inf") if b.get("up_to") is None else float(b["up_to"])
|
| | rate = float(b["rate"])
|
| | chunk = max(0.0, min(remaining, upper - prev_upper))
|
| | if chunk > 0:
|
| | part = chunk * rate
|
| | tax += part
|
| | calc_steps.append({"range": [prev_upper, upper], "chunk": chunk, "rate": rate, "tax": part})
|
| | remaining -= chunk
|
| | prev_upper = upper
|
| | if remaining <= 0:
|
| | break
|
| | amt = tax
|
| | details.update({"base": taxable, "bands_applied": calc_steps})
|
| | elif r.formula_type == "conditional_min":
|
| | computed = ex(r.parameters.get("computed_expr", "computed_tax"))
|
| | min_amount = ex(r.parameters.get("min_amount_expr", "0"))
|
| | amt = max(computed, min_amount)
|
| | details.update({"computed": computed, "minimum": min_amount})
|
| | else:
|
| | raise ValueError(f"Unknown formula_type: {r.formula_type}")
|
| |
|
| | amt = self._round(amt) if r.parameters.get("round", False) else amt
|
| | return r.output, amt, details
|
| |
|
| | def run(
|
| | self,
|
| | *,
|
| | tax_type: str,
|
| | as_of: date,
|
| | jurisdiction: Optional[str],
|
| | inputs: Dict[str, float],
|
| | rule_ids_whitelist: Optional[List[str]] = None
|
| | ) -> CalculationResult:
|
| | active = self.catalog.select(tax_type=tax_type, on_date=as_of, jurisdiction=jurisdiction)
|
| | if rule_ids_whitelist:
|
| | idset = set(rule_ids_whitelist)
|
| | active = [r for r in active if r.id in idset]
|
| |
|
| | ordered = self._toposort(active)
|
| | ctx = CalculationResult()
|
| |
|
| | for k, v in inputs.items():
|
| | ctx.set_value(k, float(v))
|
| |
|
| | for r in ordered:
|
| |
|
| | guard = r.parameters.get("applicability_expr")
|
| | if guard:
|
| | try:
|
| | applies = bool(SafeExpr.eval(guard, ctx.values))
|
| | except Exception as e:
|
| | raise SafeEvalError(f"Guard error in {r.id}: {e}")
|
| | if not applies:
|
| | continue
|
| |
|
| | out_key, amount, details = self._evaluate_rule(r, ctx)
|
| | ctx.set_value(out_key, amount)
|
| | ctx.lines.append({
|
| | "rule_id": r.id,
|
| | "title": r.title,
|
| | "amount": amount,
|
| | "output": out_key,
|
| | "details": details,
|
| | "authority": [a.__dict__ for a in r.authority],
|
| | })
|
| | return ctx
|
| |
|