""" server/causal_world.py -- The hidden causal world. Every episode a new CausalGraph is generated. The agent never sees this object -- it can only probe it via experiments. Design: - Variables are nodes; directed edges carry one of 8+ rule types. - Multi-parent interaction rules make some effects depend on >1 cause. - Hidden confounders can inject correlated noise across variables. - Gaussian noise is added to every observation. - The graph is a DAG (no cycles) so causality is well-defined. - Domains: system_alpha | system_beta | system_gamma | system_delta Each domain provides a different narrative prompt but uses abstract variable names (Greek letters, V1/V2/V3...) to prevent LLM agents from leveraging pretrained real-world knowledge instead of reasoning from experimental evidence. """ from __future__ import annotations import math import random from dataclasses import dataclass, field from typing import Any, Optional import numpy as np import networkx as nx ABSTRACT_VAR_POOLS: list[list[str]] = [ ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"], ["Zeta", "Eta", "Theta", "Iota", "Kappa"], ["V1", "V2", "V3", "V4", "V5"], ["Rho", "Sigma", "Tau", "Upsilon", "Phi"], ["Mu", "Nu", "Xi", "Omicron", "Pi"], ["Quant_A", "Quant_B", "Quant_C", "Quant_D", "Quant_E"], ] DOMAIN_LABELS: dict[str, dict] = { "system_alpha": { "context": "You are studying an unknown dynamical system. Variables have hidden causal relationships you must discover through experiments.", "unit": "units", }, "system_beta": { "context": "You are investigating a black-box system with interacting quantities. Design experiments to uncover the governing equations.", "unit": "units", }, "system_gamma": { "context": "You are analysing an opaque process with measurable outputs. Run controlled experiments to determine how variables influence each other.", "unit": "units", }, "system_delta": { "context": "You are probing a simulated environment with coupled variables. The underlying rules are unknown -- discover them.", "unit": "units", }, } DOMAINS = list(DOMAIN_LABELS.keys()) RULE_TYPES = [ "linear", "threshold", "inverse", "quadratic", "exponential", "logarithmic", "saturating", "piecewise_linear", ] @dataclass class CausalRule: """A single edge rule in the causal graph.""" cause: str effect: str rule_type: str params: dict[str, float] = field(default_factory=dict) description: str = "" def evaluate(self, x: float) -> float: if self.rule_type == "linear": a = self.params.get("a", 1.0) b = self.params.get("b", 0.0) return a * x + b elif self.rule_type == "threshold": threshold = self.params.get("threshold", 5.0) high = self.params.get("high", 10.0) low = self.params.get("low", 2.0) return high if x > threshold else low elif self.rule_type == "inverse": a = self.params.get("a", 10.0) if abs(x) < 1e-9: return float("nan") return a / x elif self.rule_type == "quadratic": a = self.params.get("a", 0.5) b = self.params.get("b", 0.0) c = self.params.get("c", 0.0) return a * x * x + b * x + c elif self.rule_type == "exponential": a = self.params.get("a", 1.0) k = self.params.get("k", 0.3) x_clamped = max(-20.0, min(20.0, k * x)) return a * math.exp(x_clamped) elif self.rule_type == "logarithmic": a = self.params.get("a", 3.0) b = self.params.get("b", 0.0) if x <= 0: return float("nan") return a * math.log(x) + b elif self.rule_type == "saturating": v_max = self.params.get("v_max", 10.0) k_m = self.params.get("k_m", 3.0) if x < 0: return 0.0 return v_max * x / (k_m + x) elif self.rule_type == "piecewise_linear": knot = self.params.get("knot", 5.0) a1 = self.params.get("a1", 1.0) a2 = self.params.get("a2", -0.5) b = self.params.get("b", 0.0) if x <= knot: return a1 * x + b else: y_knot = a1 * knot + b return y_knot + a2 * (x - knot) return 0.0 @dataclass class InteractionRule: """ A multi-parent rule: effect = f(cause1, cause2). These cannot be discovered by varying one variable at a time -- the agent must realise two parents jointly determine the effect. """ cause1: str cause2: str effect: str interaction_type: str # "additive", "multiplicative", "min", "max" params: dict[str, float] = field(default_factory=dict) description: str = "" def evaluate(self, x1: float, x2: float) -> float: if self.interaction_type == "additive": a = self.params.get("a", 1.0) b = self.params.get("b", 1.0) c = self.params.get("c", 0.0) return a * x1 + b * x2 + c elif self.interaction_type == "multiplicative": a = self.params.get("a", 0.5) return a * x1 * x2 elif self.interaction_type == "min": return min(x1, x2) elif self.interaction_type == "max": return max(x1, x2) return 0.0 @dataclass class CausalWorld: """ The hidden world the agent must discover. Contains variables, single-parent rules, multi-parent interaction rules, and optional hidden confounders. """ domain: str variables: list[str] units: dict[str, str] rules: list[CausalRule] default_values: dict[str, float] rng: np.random.Generator interactions: list[InteractionRule] = field(default_factory=list) confounder_sigma: float = 0.0 def _compute_value( self, target: str, interventions: Optional[dict[str, float]] = None ) -> float: """Compute the true (noiseless) value of target given interventions.""" vals = dict(self.default_values) if interventions: vals.update(interventions) for rule in self.rules: if rule.effect in (interventions or {}): continue if rule.cause in vals: result = rule.evaluate(vals[rule.cause]) if not math.isnan(result): vals[rule.effect] = result for inter in self.interactions: if inter.effect in (interventions or {}): continue if inter.cause1 in vals and inter.cause2 in vals: result = inter.evaluate(vals[inter.cause1], vals[inter.cause2]) vals[inter.effect] = result return vals.get(target, 0.0) def _confounder_noise(self) -> float: """Hidden confounder adds correlated noise the agent can't explain.""" if self.confounder_sigma <= 0: return 0.0 return float(self.rng.normal(0, self.confounder_sigma)) def query_intervention( self, cause: str, value: float, effect: str, sigma: float ) -> float: true_val = self._compute_value(effect, {cause: value}) return true_val + self.rng.normal(0, sigma) + self._confounder_noise() def query_correlation( self, cause: str, control_range: list[float], effect: str, sigma: float, ) -> list[tuple[float, float]]: lo = control_range[0] if len(control_range) > 0 else 1.0 hi = control_range[1] if len(control_range) > 1 else 10.0 n = int(control_range[2]) if len(control_range) > 2 else 5 xs = np.linspace(lo, hi, n) conf = self._confounder_noise() pairs = [] for x in xs: y = self._compute_value(effect, {cause: float(x)}) y_noisy = y + self.rng.normal(0, sigma) + conf pairs.append((round(float(x), 4), round(float(y_noisy), 4))) return pairs def query_counterfactual( self, cause: str, delta: float, effect: str, sigma: float ) -> dict[str, Any]: baseline_x = self.default_values.get(cause, 5.0) cf_x = baseline_x + delta baseline_y = self._compute_value(effect, {cause: baseline_x}) cf_y = self._compute_value(effect, {cause: cf_x}) conf = self._confounder_noise() baseline_y_noisy = baseline_y + self.rng.normal(0, sigma) + conf cf_y_noisy = cf_y + self.rng.normal(0, sigma) + conf direction = "increases" if cf_y > baseline_y else "decreases" if cf_y < baseline_y else "unchanged" return { "baseline_x": round(baseline_x, 4), "baseline_y_noisy": round(float(baseline_y_noisy), 4), "counterfactual_x": round(cf_x, 4), "counterfactual_y_noisy": round(float(cf_y_noisy), 4), "direction": direction, } def query_passive(self, target: str, sigma: float) -> float: true_val = self._compute_value(target) return true_val + self.rng.normal(0, sigma) + self._confounder_noise() def ground_truth_summary(self) -> str: lines = [f"Domain: {self.domain}"] effects_covered: set[str] = set() for rule in self.rules: lines.append(f" {rule.description}") effects_covered.add(rule.effect) for inter in self.interactions: lines.append(f" {inter.description}") effects_covered.add(inter.effect) for v in self.variables: if v not in effects_covered: lines.append(f" {v} = {self.default_values.get(v, '?')} (root)") if self.confounder_sigma > 0: lines.append(f" [hidden confounder with sigma={self.confounder_sigma:.2f}]") return "\n".join(lines) def _random_rule( cause: str, effect: str, rng: random.Random ) -> CausalRule: """Generate a random causal rule for one edge.""" rule_type = rng.choices( RULE_TYPES, weights=[0.30, 0.15, 0.10, 0.12, 0.08, 0.08, 0.10, 0.07], )[0] if rule_type == "linear": a = round(rng.uniform(0.5, 3.5) * rng.choice([-1, 1]), 2) b = round(rng.uniform(-5.0, 5.0), 2) sign = "+" if b >= 0 else "-" desc = f"{effect} = {a} * {cause} {sign} {abs(b)}" return CausalRule(cause, effect, "linear", {"a": a, "b": b}, desc) elif rule_type == "threshold": threshold = round(rng.uniform(3.0, 8.0), 2) high = round(rng.uniform(6.0, 12.0), 2) low = round(rng.uniform(0.5, 4.0), 2) desc = f"{effect} = {high} if {cause} > {threshold} else {low}" return CausalRule( cause, effect, "threshold", {"threshold": threshold, "high": high, "low": low}, desc, ) elif rule_type == "inverse": a = round(rng.uniform(5.0, 30.0), 2) desc = f"{effect} = {a} / {cause}" return CausalRule(cause, effect, "inverse", {"a": a}, desc) elif rule_type == "quadratic": a = round(rng.uniform(0.1, 1.0) * rng.choice([-1, 1]), 2) b = round(rng.uniform(-2.0, 2.0), 2) c = round(rng.uniform(-3.0, 3.0), 2) desc = f"{effect} = {a}*{cause}^2 + {b}*{cause} + {c}" return CausalRule(cause, effect, "quadratic", {"a": a, "b": b, "c": c}, desc) elif rule_type == "exponential": a = round(rng.uniform(0.5, 3.0), 2) k = round(rng.uniform(0.1, 0.5) * rng.choice([-1, 1]), 2) desc = f"{effect} = {a} * exp({k} * {cause})" return CausalRule(cause, effect, "exponential", {"a": a, "k": k}, desc) elif rule_type == "logarithmic": a = round(rng.uniform(1.0, 5.0) * rng.choice([-1, 1]), 2) b = round(rng.uniform(-3.0, 3.0), 2) sign = "+" if b >= 0 else "-" desc = f"{effect} = {a} * ln({cause}) {sign} {abs(b)}" return CausalRule(cause, effect, "logarithmic", {"a": a, "b": b}, desc) elif rule_type == "saturating": v_max = round(rng.uniform(5.0, 15.0), 2) k_m = round(rng.uniform(1.0, 6.0), 2) desc = f"{effect} = {v_max} * {cause} / ({k_m} + {cause})" return CausalRule(cause, effect, "saturating", {"v_max": v_max, "k_m": k_m}, desc) else: # piecewise_linear knot = round(rng.uniform(3.0, 7.0), 2) a1 = round(rng.uniform(0.5, 3.0) * rng.choice([-1, 1]), 2) a2 = round(rng.uniform(0.5, 3.0) * rng.choice([-1, 1]), 2) b = round(rng.uniform(-3.0, 3.0), 2) desc = ( f"{effect} = {a1}*{cause} + {b} (if {cause} <= {knot}), " f"then slope changes to {a2}" ) return CausalRule( cause, effect, "piecewise_linear", {"knot": knot, "a1": a1, "a2": a2, "b": b}, desc, ) def _random_interaction( cause1: str, cause2: str, effect: str, rng: random.Random ) -> InteractionRule: """Generate a random interaction rule where effect depends on two parents.""" itype = rng.choices( ["additive", "multiplicative", "min", "max"], weights=[0.35, 0.35, 0.15, 0.15], )[0] if itype == "additive": a = round(rng.uniform(0.5, 2.0) * rng.choice([-1, 1]), 2) b = round(rng.uniform(0.5, 2.0) * rng.choice([-1, 1]), 2) c = round(rng.uniform(-2.0, 2.0), 2) desc = f"{effect} = {a}*{cause1} + {b}*{cause2} + {c}" return InteractionRule(cause1, cause2, effect, itype, {"a": a, "b": b, "c": c}, desc) elif itype == "multiplicative": a = round(rng.uniform(0.1, 0.8), 2) desc = f"{effect} = {a} * {cause1} * {cause2}" return InteractionRule(cause1, cause2, effect, itype, {"a": a}, desc) elif itype == "min": desc = f"{effect} = min({cause1}, {cause2})" return InteractionRule(cause1, cause2, effect, itype, {}, desc) else: desc = f"{effect} = max({cause1}, {cause2})" return InteractionRule(cause1, cause2, effect, itype, {}, desc) def generate_world( n_variables: int = 3, domain: Optional[str] = None, seed: Optional[int] = None, ) -> CausalWorld: """ Generate a fresh hidden causal world. The world may contain: - Single-parent rules (8 types: linear, threshold, inverse, quadratic, exponential, logarithmic, saturating, piecewise_linear) - Multi-parent interaction rules (additive, multiplicative, min, max) - Hidden confounders that add unexplainable correlated noise Args: n_variables: How many variables (2-5). domain: One of "system_alpha", "system_beta", "system_gamma", "system_delta", or None for random. seed: Random seed for reproducibility. Returns: A CausalWorld instance ready for agent probing. """ py_rng = random.Random(seed) np_rng = np.random.default_rng(seed) if domain is None or domain not in DOMAIN_LABELS: domain = py_rng.choice(DOMAINS) labels = DOMAIN_LABELS[domain] var_pool = py_rng.choice(ABSTRACT_VAR_POOLS) n = min(n_variables, len(var_pool)) chosen_vars = py_rng.sample(var_pool, n) unit_label = labels.get("unit", "units") units = {v: unit_label for v in chosen_vars} rules: list[CausalRule] = [] for i in range(len(chosen_vars) - 1): parent = chosen_vars[i] child = chosen_vars[i + 1] rules.append(_random_rule(parent, child, py_rng)) for i, parent in enumerate(chosen_vars): for j, child in enumerate(chosen_vars): if j <= i + 1: continue if py_rng.random() < 0.30: rules.append(_random_rule(parent, child, py_rng)) interactions: list[InteractionRule] = [] if n >= 3 and py_rng.random() < 0.40: roots_and_mid = [v for v in chosen_vars[:n - 1]] if len(roots_and_mid) >= 2: c1, c2 = py_rng.sample(roots_and_mid, 2) target = chosen_vars[-1] interaction = _random_interaction(c1, c2, target, py_rng) interactions.append(interaction) rules = [r for r in rules if r.effect != target] confounder_sigma = 0.0 if n >= 3 and py_rng.random() < 0.30: confounder_sigma = round(py_rng.uniform(0.05, 0.25), 3) all_effects = {r.effect for r in rules} | {i.effect for i in interactions} default_values: dict[str, float] = {} for v in chosen_vars: is_root = v not in all_effects default_values[v] = round(py_rng.uniform(2.0, 10.0), 2) if is_root else 0.0 display_order = list(chosen_vars) py_rng.shuffle(display_order) world = CausalWorld( domain=domain, variables=display_order, units=units, rules=rules, default_values=default_values, rng=np_rng, interactions=interactions, confounder_sigma=confounder_sigma, ) for v in chosen_vars: if default_values[v] == 0.0: computed = world._compute_value(v) if not math.isnan(computed): default_values[v] = round(computed, 4) else: default_values[v] = round(py_rng.uniform(2.0, 10.0), 2) return world