Spaces:
Sleeping
Sleeping
| """ | |
| server/causal_world.py -- The hidden causal world. | |
| Every episode a new CausalGraph is generated. The agent never sees this | |
| object -- it can only probe it via experiments. | |
| Design: | |
| - Variables are nodes; directed edges carry one of 8+ rule types. | |
| - Multi-parent interaction rules make some effects depend on >1 cause. | |
| - Hidden confounders can inject correlated noise across variables. | |
| - Gaussian noise is added to every observation. | |
| - The graph is a DAG (no cycles) so causality is well-defined. | |
| - Domains: system_alpha | system_beta | system_gamma | system_delta | |
| Each domain provides a different narrative prompt but uses abstract | |
| variable names (Greek letters, V1/V2/V3...) to prevent LLM agents | |
| from leveraging pretrained real-world knowledge instead of reasoning | |
| from experimental evidence. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import random | |
| from dataclasses import dataclass, field | |
| from typing import Any, Optional | |
| import numpy as np | |
| import networkx as nx | |
| ABSTRACT_VAR_POOLS: list[list[str]] = [ | |
| ["Alpha", "Beta", "Gamma", "Delta", "Epsilon"], | |
| ["Zeta", "Eta", "Theta", "Iota", "Kappa"], | |
| ["V1", "V2", "V3", "V4", "V5"], | |
| ["Rho", "Sigma", "Tau", "Upsilon", "Phi"], | |
| ["Mu", "Nu", "Xi", "Omicron", "Pi"], | |
| ["Quant_A", "Quant_B", "Quant_C", "Quant_D", "Quant_E"], | |
| ] | |
| DOMAIN_LABELS: dict[str, dict] = { | |
| "system_alpha": { | |
| "context": "You are studying an unknown dynamical system. Variables have hidden causal relationships you must discover through experiments.", | |
| "unit": "units", | |
| }, | |
| "system_beta": { | |
| "context": "You are investigating a black-box system with interacting quantities. Design experiments to uncover the governing equations.", | |
| "unit": "units", | |
| }, | |
| "system_gamma": { | |
| "context": "You are analysing an opaque process with measurable outputs. Run controlled experiments to determine how variables influence each other.", | |
| "unit": "units", | |
| }, | |
| "system_delta": { | |
| "context": "You are probing a simulated environment with coupled variables. The underlying rules are unknown -- discover them.", | |
| "unit": "units", | |
| }, | |
| } | |
| DOMAINS = list(DOMAIN_LABELS.keys()) | |
| RULE_TYPES = [ | |
| "linear", | |
| "threshold", | |
| "inverse", | |
| "quadratic", | |
| "exponential", | |
| "logarithmic", | |
| "saturating", | |
| "piecewise_linear", | |
| ] | |
| class CausalRule: | |
| """A single edge rule in the causal graph.""" | |
| cause: str | |
| effect: str | |
| rule_type: str | |
| params: dict[str, float] = field(default_factory=dict) | |
| description: str = "" | |
| def evaluate(self, x: float) -> float: | |
| if self.rule_type == "linear": | |
| a = self.params.get("a", 1.0) | |
| b = self.params.get("b", 0.0) | |
| return a * x + b | |
| elif self.rule_type == "threshold": | |
| threshold = self.params.get("threshold", 5.0) | |
| high = self.params.get("high", 10.0) | |
| low = self.params.get("low", 2.0) | |
| return high if x > threshold else low | |
| elif self.rule_type == "inverse": | |
| a = self.params.get("a", 10.0) | |
| if abs(x) < 1e-9: | |
| return float("nan") | |
| return a / x | |
| elif self.rule_type == "quadratic": | |
| a = self.params.get("a", 0.5) | |
| b = self.params.get("b", 0.0) | |
| c = self.params.get("c", 0.0) | |
| return a * x * x + b * x + c | |
| elif self.rule_type == "exponential": | |
| a = self.params.get("a", 1.0) | |
| k = self.params.get("k", 0.3) | |
| x_clamped = max(-20.0, min(20.0, k * x)) | |
| return a * math.exp(x_clamped) | |
| elif self.rule_type == "logarithmic": | |
| a = self.params.get("a", 3.0) | |
| b = self.params.get("b", 0.0) | |
| if x <= 0: | |
| return float("nan") | |
| return a * math.log(x) + b | |
| elif self.rule_type == "saturating": | |
| v_max = self.params.get("v_max", 10.0) | |
| k_m = self.params.get("k_m", 3.0) | |
| if x < 0: | |
| return 0.0 | |
| return v_max * x / (k_m + x) | |
| elif self.rule_type == "piecewise_linear": | |
| knot = self.params.get("knot", 5.0) | |
| a1 = self.params.get("a1", 1.0) | |
| a2 = self.params.get("a2", -0.5) | |
| b = self.params.get("b", 0.0) | |
| if x <= knot: | |
| return a1 * x + b | |
| else: | |
| y_knot = a1 * knot + b | |
| return y_knot + a2 * (x - knot) | |
| return 0.0 | |
| class InteractionRule: | |
| """ | |
| A multi-parent rule: effect = f(cause1, cause2). | |
| These cannot be discovered by varying one variable at a time -- | |
| the agent must realise two parents jointly determine the effect. | |
| """ | |
| cause1: str | |
| cause2: str | |
| effect: str | |
| interaction_type: str # "additive", "multiplicative", "min", "max" | |
| params: dict[str, float] = field(default_factory=dict) | |
| description: str = "" | |
| def evaluate(self, x1: float, x2: float) -> float: | |
| if self.interaction_type == "additive": | |
| a = self.params.get("a", 1.0) | |
| b = self.params.get("b", 1.0) | |
| c = self.params.get("c", 0.0) | |
| return a * x1 + b * x2 + c | |
| elif self.interaction_type == "multiplicative": | |
| a = self.params.get("a", 0.5) | |
| return a * x1 * x2 | |
| elif self.interaction_type == "min": | |
| return min(x1, x2) | |
| elif self.interaction_type == "max": | |
| return max(x1, x2) | |
| return 0.0 | |
| class CausalWorld: | |
| """ | |
| The hidden world the agent must discover. | |
| Contains variables, single-parent rules, multi-parent interaction rules, | |
| and optional hidden confounders. | |
| """ | |
| domain: str | |
| variables: list[str] | |
| units: dict[str, str] | |
| rules: list[CausalRule] | |
| default_values: dict[str, float] | |
| rng: np.random.Generator | |
| interactions: list[InteractionRule] = field(default_factory=list) | |
| confounder_sigma: float = 0.0 | |
| def _compute_value( | |
| self, target: str, interventions: Optional[dict[str, float]] = None | |
| ) -> float: | |
| """Compute the true (noiseless) value of target given interventions.""" | |
| vals = dict(self.default_values) | |
| if interventions: | |
| vals.update(interventions) | |
| for rule in self.rules: | |
| if rule.effect in (interventions or {}): | |
| continue | |
| if rule.cause in vals: | |
| result = rule.evaluate(vals[rule.cause]) | |
| if not math.isnan(result): | |
| vals[rule.effect] = result | |
| for inter in self.interactions: | |
| if inter.effect in (interventions or {}): | |
| continue | |
| if inter.cause1 in vals and inter.cause2 in vals: | |
| result = inter.evaluate(vals[inter.cause1], vals[inter.cause2]) | |
| vals[inter.effect] = result | |
| return vals.get(target, 0.0) | |
| def _confounder_noise(self) -> float: | |
| """Hidden confounder adds correlated noise the agent can't explain.""" | |
| if self.confounder_sigma <= 0: | |
| return 0.0 | |
| return float(self.rng.normal(0, self.confounder_sigma)) | |
| def query_intervention( | |
| self, cause: str, value: float, effect: str, sigma: float | |
| ) -> float: | |
| true_val = self._compute_value(effect, {cause: value}) | |
| return true_val + self.rng.normal(0, sigma) + self._confounder_noise() | |
| def query_correlation( | |
| self, | |
| cause: str, | |
| control_range: list[float], | |
| effect: str, | |
| sigma: float, | |
| ) -> list[tuple[float, float]]: | |
| lo = control_range[0] if len(control_range) > 0 else 1.0 | |
| hi = control_range[1] if len(control_range) > 1 else 10.0 | |
| n = int(control_range[2]) if len(control_range) > 2 else 5 | |
| xs = np.linspace(lo, hi, n) | |
| conf = self._confounder_noise() | |
| pairs = [] | |
| for x in xs: | |
| y = self._compute_value(effect, {cause: float(x)}) | |
| y_noisy = y + self.rng.normal(0, sigma) + conf | |
| pairs.append((round(float(x), 4), round(float(y_noisy), 4))) | |
| return pairs | |
| def query_counterfactual( | |
| self, cause: str, delta: float, effect: str, sigma: float | |
| ) -> dict[str, Any]: | |
| baseline_x = self.default_values.get(cause, 5.0) | |
| cf_x = baseline_x + delta | |
| baseline_y = self._compute_value(effect, {cause: baseline_x}) | |
| cf_y = self._compute_value(effect, {cause: cf_x}) | |
| conf = self._confounder_noise() | |
| baseline_y_noisy = baseline_y + self.rng.normal(0, sigma) + conf | |
| cf_y_noisy = cf_y + self.rng.normal(0, sigma) + conf | |
| direction = "increases" if cf_y > baseline_y else "decreases" if cf_y < baseline_y else "unchanged" | |
| return { | |
| "baseline_x": round(baseline_x, 4), | |
| "baseline_y_noisy": round(float(baseline_y_noisy), 4), | |
| "counterfactual_x": round(cf_x, 4), | |
| "counterfactual_y_noisy": round(float(cf_y_noisy), 4), | |
| "direction": direction, | |
| } | |
| def query_passive(self, target: str, sigma: float) -> float: | |
| true_val = self._compute_value(target) | |
| return true_val + self.rng.normal(0, sigma) + self._confounder_noise() | |
| def ground_truth_summary(self) -> str: | |
| lines = [f"Domain: {self.domain}"] | |
| effects_covered: set[str] = set() | |
| for rule in self.rules: | |
| lines.append(f" {rule.description}") | |
| effects_covered.add(rule.effect) | |
| for inter in self.interactions: | |
| lines.append(f" {inter.description}") | |
| effects_covered.add(inter.effect) | |
| for v in self.variables: | |
| if v not in effects_covered: | |
| lines.append(f" {v} = {self.default_values.get(v, '?')} (root)") | |
| if self.confounder_sigma > 0: | |
| lines.append(f" [hidden confounder with sigma={self.confounder_sigma:.2f}]") | |
| return "\n".join(lines) | |
| def _random_rule( | |
| cause: str, effect: str, rng: random.Random | |
| ) -> CausalRule: | |
| """Generate a random causal rule for one edge.""" | |
| rule_type = rng.choices( | |
| RULE_TYPES, | |
| weights=[0.30, 0.15, 0.10, 0.12, 0.08, 0.08, 0.10, 0.07], | |
| )[0] | |
| if rule_type == "linear": | |
| a = round(rng.uniform(0.5, 3.5) * rng.choice([-1, 1]), 2) | |
| b = round(rng.uniform(-5.0, 5.0), 2) | |
| sign = "+" if b >= 0 else "-" | |
| desc = f"{effect} = {a} * {cause} {sign} {abs(b)}" | |
| return CausalRule(cause, effect, "linear", {"a": a, "b": b}, desc) | |
| elif rule_type == "threshold": | |
| threshold = round(rng.uniform(3.0, 8.0), 2) | |
| high = round(rng.uniform(6.0, 12.0), 2) | |
| low = round(rng.uniform(0.5, 4.0), 2) | |
| desc = f"{effect} = {high} if {cause} > {threshold} else {low}" | |
| return CausalRule( | |
| cause, effect, "threshold", | |
| {"threshold": threshold, "high": high, "low": low}, desc, | |
| ) | |
| elif rule_type == "inverse": | |
| a = round(rng.uniform(5.0, 30.0), 2) | |
| desc = f"{effect} = {a} / {cause}" | |
| return CausalRule(cause, effect, "inverse", {"a": a}, desc) | |
| elif rule_type == "quadratic": | |
| a = round(rng.uniform(0.1, 1.0) * rng.choice([-1, 1]), 2) | |
| b = round(rng.uniform(-2.0, 2.0), 2) | |
| c = round(rng.uniform(-3.0, 3.0), 2) | |
| desc = f"{effect} = {a}*{cause}^2 + {b}*{cause} + {c}" | |
| return CausalRule(cause, effect, "quadratic", {"a": a, "b": b, "c": c}, desc) | |
| elif rule_type == "exponential": | |
| a = round(rng.uniform(0.5, 3.0), 2) | |
| k = round(rng.uniform(0.1, 0.5) * rng.choice([-1, 1]), 2) | |
| desc = f"{effect} = {a} * exp({k} * {cause})" | |
| return CausalRule(cause, effect, "exponential", {"a": a, "k": k}, desc) | |
| elif rule_type == "logarithmic": | |
| a = round(rng.uniform(1.0, 5.0) * rng.choice([-1, 1]), 2) | |
| b = round(rng.uniform(-3.0, 3.0), 2) | |
| sign = "+" if b >= 0 else "-" | |
| desc = f"{effect} = {a} * ln({cause}) {sign} {abs(b)}" | |
| return CausalRule(cause, effect, "logarithmic", {"a": a, "b": b}, desc) | |
| elif rule_type == "saturating": | |
| v_max = round(rng.uniform(5.0, 15.0), 2) | |
| k_m = round(rng.uniform(1.0, 6.0), 2) | |
| desc = f"{effect} = {v_max} * {cause} / ({k_m} + {cause})" | |
| return CausalRule(cause, effect, "saturating", {"v_max": v_max, "k_m": k_m}, desc) | |
| else: # piecewise_linear | |
| knot = round(rng.uniform(3.0, 7.0), 2) | |
| a1 = round(rng.uniform(0.5, 3.0) * rng.choice([-1, 1]), 2) | |
| a2 = round(rng.uniform(0.5, 3.0) * rng.choice([-1, 1]), 2) | |
| b = round(rng.uniform(-3.0, 3.0), 2) | |
| desc = ( | |
| f"{effect} = {a1}*{cause} + {b} (if {cause} <= {knot}), " | |
| f"then slope changes to {a2}" | |
| ) | |
| return CausalRule( | |
| cause, effect, "piecewise_linear", | |
| {"knot": knot, "a1": a1, "a2": a2, "b": b}, desc, | |
| ) | |
| def _random_interaction( | |
| cause1: str, cause2: str, effect: str, rng: random.Random | |
| ) -> InteractionRule: | |
| """Generate a random interaction rule where effect depends on two parents.""" | |
| itype = rng.choices( | |
| ["additive", "multiplicative", "min", "max"], | |
| weights=[0.35, 0.35, 0.15, 0.15], | |
| )[0] | |
| if itype == "additive": | |
| a = round(rng.uniform(0.5, 2.0) * rng.choice([-1, 1]), 2) | |
| b = round(rng.uniform(0.5, 2.0) * rng.choice([-1, 1]), 2) | |
| c = round(rng.uniform(-2.0, 2.0), 2) | |
| desc = f"{effect} = {a}*{cause1} + {b}*{cause2} + {c}" | |
| return InteractionRule(cause1, cause2, effect, itype, {"a": a, "b": b, "c": c}, desc) | |
| elif itype == "multiplicative": | |
| a = round(rng.uniform(0.1, 0.8), 2) | |
| desc = f"{effect} = {a} * {cause1} * {cause2}" | |
| return InteractionRule(cause1, cause2, effect, itype, {"a": a}, desc) | |
| elif itype == "min": | |
| desc = f"{effect} = min({cause1}, {cause2})" | |
| return InteractionRule(cause1, cause2, effect, itype, {}, desc) | |
| else: | |
| desc = f"{effect} = max({cause1}, {cause2})" | |
| return InteractionRule(cause1, cause2, effect, itype, {}, desc) | |
| def generate_world( | |
| n_variables: int = 3, | |
| domain: Optional[str] = None, | |
| seed: Optional[int] = None, | |
| ) -> CausalWorld: | |
| """ | |
| Generate a fresh hidden causal world. | |
| The world may contain: | |
| - Single-parent rules (8 types: linear, threshold, inverse, quadratic, | |
| exponential, logarithmic, saturating, piecewise_linear) | |
| - Multi-parent interaction rules (additive, multiplicative, min, max) | |
| - Hidden confounders that add unexplainable correlated noise | |
| Args: | |
| n_variables: How many variables (2-5). | |
| domain: One of "system_alpha", "system_beta", "system_gamma", "system_delta", or None for random. | |
| seed: Random seed for reproducibility. | |
| Returns: | |
| A CausalWorld instance ready for agent probing. | |
| """ | |
| py_rng = random.Random(seed) | |
| np_rng = np.random.default_rng(seed) | |
| if domain is None or domain not in DOMAIN_LABELS: | |
| domain = py_rng.choice(DOMAINS) | |
| labels = DOMAIN_LABELS[domain] | |
| var_pool = py_rng.choice(ABSTRACT_VAR_POOLS) | |
| n = min(n_variables, len(var_pool)) | |
| chosen_vars = py_rng.sample(var_pool, n) | |
| unit_label = labels.get("unit", "units") | |
| units = {v: unit_label for v in chosen_vars} | |
| rules: list[CausalRule] = [] | |
| for i in range(len(chosen_vars) - 1): | |
| parent = chosen_vars[i] | |
| child = chosen_vars[i + 1] | |
| rules.append(_random_rule(parent, child, py_rng)) | |
| for i, parent in enumerate(chosen_vars): | |
| for j, child in enumerate(chosen_vars): | |
| if j <= i + 1: | |
| continue | |
| if py_rng.random() < 0.30: | |
| rules.append(_random_rule(parent, child, py_rng)) | |
| interactions: list[InteractionRule] = [] | |
| if n >= 3 and py_rng.random() < 0.40: | |
| roots_and_mid = [v for v in chosen_vars[:n - 1]] | |
| if len(roots_and_mid) >= 2: | |
| c1, c2 = py_rng.sample(roots_and_mid, 2) | |
| target = chosen_vars[-1] | |
| interaction = _random_interaction(c1, c2, target, py_rng) | |
| interactions.append(interaction) | |
| rules = [r for r in rules if r.effect != target] | |
| confounder_sigma = 0.0 | |
| if n >= 3 and py_rng.random() < 0.30: | |
| confounder_sigma = round(py_rng.uniform(0.05, 0.25), 3) | |
| all_effects = {r.effect for r in rules} | {i.effect for i in interactions} | |
| default_values: dict[str, float] = {} | |
| for v in chosen_vars: | |
| is_root = v not in all_effects | |
| default_values[v] = round(py_rng.uniform(2.0, 10.0), 2) if is_root else 0.0 | |
| display_order = list(chosen_vars) | |
| py_rng.shuffle(display_order) | |
| world = CausalWorld( | |
| domain=domain, | |
| variables=display_order, | |
| units=units, | |
| rules=rules, | |
| default_values=default_values, | |
| rng=np_rng, | |
| interactions=interactions, | |
| confounder_sigma=confounder_sigma, | |
| ) | |
| for v in chosen_vars: | |
| if default_values[v] == 0.0: | |
| computed = world._compute_value(v) | |
| if not math.isnan(computed): | |
| default_values[v] = round(computed, 4) | |
| else: | |
| default_values[v] = round(py_rng.uniform(2.0, 10.0), 2) | |
| return world | |