Labexperiment / server /causal_world.py
Sbhimraj's picture
Add application file
aab0192
Raw
History Blame Contribute Delete
17.3 kB
"""
server/causal_world.py -- The hidden causal world.
Every episode a new CausalGraph is generated. The agent never sees this
object -- it can only probe it via experiments.
Design:
- Variables are nodes; directed edges carry one of 8+ rule types.
- Multi-parent interaction rules make some effects depend on >1 cause.
- Hidden confounders can inject correlated noise across variables.
- Gaussian noise is added to every observation.
- The graph is a DAG (no cycles) so causality is well-defined.
- Domains: system_alpha | system_beta | system_gamma | system_delta
Each domain provides a different narrative prompt but uses abstract
variable names (Greek letters, V1/V2/V3...) to prevent LLM agents
from leveraging pretrained real-world knowledge instead of reasoning
from experimental evidence.
"""
from __future__ import annotations
import math
import random
from dataclasses import dataclass, field
from typing import Any, Optional
import numpy as np
import networkx as nx
ABSTRACT_VAR_POOLS: list[list[str]] = [
["Alpha", "Beta", "Gamma", "Delta", "Epsilon"],
["Zeta", "Eta", "Theta", "Iota", "Kappa"],
["V1", "V2", "V3", "V4", "V5"],
["Rho", "Sigma", "Tau", "Upsilon", "Phi"],
["Mu", "Nu", "Xi", "Omicron", "Pi"],
["Quant_A", "Quant_B", "Quant_C", "Quant_D", "Quant_E"],
]
DOMAIN_LABELS: dict[str, dict] = {
"system_alpha": {
"context": "You are studying an unknown dynamical system. Variables have hidden causal relationships you must discover through experiments.",
"unit": "units",
},
"system_beta": {
"context": "You are investigating a black-box system with interacting quantities. Design experiments to uncover the governing equations.",
"unit": "units",
},
"system_gamma": {
"context": "You are analysing an opaque process with measurable outputs. Run controlled experiments to determine how variables influence each other.",
"unit": "units",
},
"system_delta": {
"context": "You are probing a simulated environment with coupled variables. The underlying rules are unknown -- discover them.",
"unit": "units",
},
}
DOMAINS = list(DOMAIN_LABELS.keys())
RULE_TYPES = [
"linear",
"threshold",
"inverse",
"quadratic",
"exponential",
"logarithmic",
"saturating",
"piecewise_linear",
]
@dataclass
class CausalRule:
"""A single edge rule in the causal graph."""
cause: str
effect: str
rule_type: str
params: dict[str, float] = field(default_factory=dict)
description: str = ""
def evaluate(self, x: float) -> float:
if self.rule_type == "linear":
a = self.params.get("a", 1.0)
b = self.params.get("b", 0.0)
return a * x + b
elif self.rule_type == "threshold":
threshold = self.params.get("threshold", 5.0)
high = self.params.get("high", 10.0)
low = self.params.get("low", 2.0)
return high if x > threshold else low
elif self.rule_type == "inverse":
a = self.params.get("a", 10.0)
if abs(x) < 1e-9:
return float("nan")
return a / x
elif self.rule_type == "quadratic":
a = self.params.get("a", 0.5)
b = self.params.get("b", 0.0)
c = self.params.get("c", 0.0)
return a * x * x + b * x + c
elif self.rule_type == "exponential":
a = self.params.get("a", 1.0)
k = self.params.get("k", 0.3)
x_clamped = max(-20.0, min(20.0, k * x))
return a * math.exp(x_clamped)
elif self.rule_type == "logarithmic":
a = self.params.get("a", 3.0)
b = self.params.get("b", 0.0)
if x <= 0:
return float("nan")
return a * math.log(x) + b
elif self.rule_type == "saturating":
v_max = self.params.get("v_max", 10.0)
k_m = self.params.get("k_m", 3.0)
if x < 0:
return 0.0
return v_max * x / (k_m + x)
elif self.rule_type == "piecewise_linear":
knot = self.params.get("knot", 5.0)
a1 = self.params.get("a1", 1.0)
a2 = self.params.get("a2", -0.5)
b = self.params.get("b", 0.0)
if x <= knot:
return a1 * x + b
else:
y_knot = a1 * knot + b
return y_knot + a2 * (x - knot)
return 0.0
@dataclass
class InteractionRule:
"""
A multi-parent rule: effect = f(cause1, cause2).
These cannot be discovered by varying one variable at a time --
the agent must realise two parents jointly determine the effect.
"""
cause1: str
cause2: str
effect: str
interaction_type: str # "additive", "multiplicative", "min", "max"
params: dict[str, float] = field(default_factory=dict)
description: str = ""
def evaluate(self, x1: float, x2: float) -> float:
if self.interaction_type == "additive":
a = self.params.get("a", 1.0)
b = self.params.get("b", 1.0)
c = self.params.get("c", 0.0)
return a * x1 + b * x2 + c
elif self.interaction_type == "multiplicative":
a = self.params.get("a", 0.5)
return a * x1 * x2
elif self.interaction_type == "min":
return min(x1, x2)
elif self.interaction_type == "max":
return max(x1, x2)
return 0.0
@dataclass
class CausalWorld:
"""
The hidden world the agent must discover.
Contains variables, single-parent rules, multi-parent interaction rules,
and optional hidden confounders.
"""
domain: str
variables: list[str]
units: dict[str, str]
rules: list[CausalRule]
default_values: dict[str, float]
rng: np.random.Generator
interactions: list[InteractionRule] = field(default_factory=list)
confounder_sigma: float = 0.0
def _compute_value(
self, target: str, interventions: Optional[dict[str, float]] = None
) -> float:
"""Compute the true (noiseless) value of target given interventions."""
vals = dict(self.default_values)
if interventions:
vals.update(interventions)
for rule in self.rules:
if rule.effect in (interventions or {}):
continue
if rule.cause in vals:
result = rule.evaluate(vals[rule.cause])
if not math.isnan(result):
vals[rule.effect] = result
for inter in self.interactions:
if inter.effect in (interventions or {}):
continue
if inter.cause1 in vals and inter.cause2 in vals:
result = inter.evaluate(vals[inter.cause1], vals[inter.cause2])
vals[inter.effect] = result
return vals.get(target, 0.0)
def _confounder_noise(self) -> float:
"""Hidden confounder adds correlated noise the agent can't explain."""
if self.confounder_sigma <= 0:
return 0.0
return float(self.rng.normal(0, self.confounder_sigma))
def query_intervention(
self, cause: str, value: float, effect: str, sigma: float
) -> float:
true_val = self._compute_value(effect, {cause: value})
return true_val + self.rng.normal(0, sigma) + self._confounder_noise()
def query_correlation(
self,
cause: str,
control_range: list[float],
effect: str,
sigma: float,
) -> list[tuple[float, float]]:
lo = control_range[0] if len(control_range) > 0 else 1.0
hi = control_range[1] if len(control_range) > 1 else 10.0
n = int(control_range[2]) if len(control_range) > 2 else 5
xs = np.linspace(lo, hi, n)
conf = self._confounder_noise()
pairs = []
for x in xs:
y = self._compute_value(effect, {cause: float(x)})
y_noisy = y + self.rng.normal(0, sigma) + conf
pairs.append((round(float(x), 4), round(float(y_noisy), 4)))
return pairs
def query_counterfactual(
self, cause: str, delta: float, effect: str, sigma: float
) -> dict[str, Any]:
baseline_x = self.default_values.get(cause, 5.0)
cf_x = baseline_x + delta
baseline_y = self._compute_value(effect, {cause: baseline_x})
cf_y = self._compute_value(effect, {cause: cf_x})
conf = self._confounder_noise()
baseline_y_noisy = baseline_y + self.rng.normal(0, sigma) + conf
cf_y_noisy = cf_y + self.rng.normal(0, sigma) + conf
direction = "increases" if cf_y > baseline_y else "decreases" if cf_y < baseline_y else "unchanged"
return {
"baseline_x": round(baseline_x, 4),
"baseline_y_noisy": round(float(baseline_y_noisy), 4),
"counterfactual_x": round(cf_x, 4),
"counterfactual_y_noisy": round(float(cf_y_noisy), 4),
"direction": direction,
}
def query_passive(self, target: str, sigma: float) -> float:
true_val = self._compute_value(target)
return true_val + self.rng.normal(0, sigma) + self._confounder_noise()
def ground_truth_summary(self) -> str:
lines = [f"Domain: {self.domain}"]
effects_covered: set[str] = set()
for rule in self.rules:
lines.append(f" {rule.description}")
effects_covered.add(rule.effect)
for inter in self.interactions:
lines.append(f" {inter.description}")
effects_covered.add(inter.effect)
for v in self.variables:
if v not in effects_covered:
lines.append(f" {v} = {self.default_values.get(v, '?')} (root)")
if self.confounder_sigma > 0:
lines.append(f" [hidden confounder with sigma={self.confounder_sigma:.2f}]")
return "\n".join(lines)
def _random_rule(
cause: str, effect: str, rng: random.Random
) -> CausalRule:
"""Generate a random causal rule for one edge."""
rule_type = rng.choices(
RULE_TYPES,
weights=[0.30, 0.15, 0.10, 0.12, 0.08, 0.08, 0.10, 0.07],
)[0]
if rule_type == "linear":
a = round(rng.uniform(0.5, 3.5) * rng.choice([-1, 1]), 2)
b = round(rng.uniform(-5.0, 5.0), 2)
sign = "+" if b >= 0 else "-"
desc = f"{effect} = {a} * {cause} {sign} {abs(b)}"
return CausalRule(cause, effect, "linear", {"a": a, "b": b}, desc)
elif rule_type == "threshold":
threshold = round(rng.uniform(3.0, 8.0), 2)
high = round(rng.uniform(6.0, 12.0), 2)
low = round(rng.uniform(0.5, 4.0), 2)
desc = f"{effect} = {high} if {cause} > {threshold} else {low}"
return CausalRule(
cause, effect, "threshold",
{"threshold": threshold, "high": high, "low": low}, desc,
)
elif rule_type == "inverse":
a = round(rng.uniform(5.0, 30.0), 2)
desc = f"{effect} = {a} / {cause}"
return CausalRule(cause, effect, "inverse", {"a": a}, desc)
elif rule_type == "quadratic":
a = round(rng.uniform(0.1, 1.0) * rng.choice([-1, 1]), 2)
b = round(rng.uniform(-2.0, 2.0), 2)
c = round(rng.uniform(-3.0, 3.0), 2)
desc = f"{effect} = {a}*{cause}^2 + {b}*{cause} + {c}"
return CausalRule(cause, effect, "quadratic", {"a": a, "b": b, "c": c}, desc)
elif rule_type == "exponential":
a = round(rng.uniform(0.5, 3.0), 2)
k = round(rng.uniform(0.1, 0.5) * rng.choice([-1, 1]), 2)
desc = f"{effect} = {a} * exp({k} * {cause})"
return CausalRule(cause, effect, "exponential", {"a": a, "k": k}, desc)
elif rule_type == "logarithmic":
a = round(rng.uniform(1.0, 5.0) * rng.choice([-1, 1]), 2)
b = round(rng.uniform(-3.0, 3.0), 2)
sign = "+" if b >= 0 else "-"
desc = f"{effect} = {a} * ln({cause}) {sign} {abs(b)}"
return CausalRule(cause, effect, "logarithmic", {"a": a, "b": b}, desc)
elif rule_type == "saturating":
v_max = round(rng.uniform(5.0, 15.0), 2)
k_m = round(rng.uniform(1.0, 6.0), 2)
desc = f"{effect} = {v_max} * {cause} / ({k_m} + {cause})"
return CausalRule(cause, effect, "saturating", {"v_max": v_max, "k_m": k_m}, desc)
else: # piecewise_linear
knot = round(rng.uniform(3.0, 7.0), 2)
a1 = round(rng.uniform(0.5, 3.0) * rng.choice([-1, 1]), 2)
a2 = round(rng.uniform(0.5, 3.0) * rng.choice([-1, 1]), 2)
b = round(rng.uniform(-3.0, 3.0), 2)
desc = (
f"{effect} = {a1}*{cause} + {b} (if {cause} <= {knot}), "
f"then slope changes to {a2}"
)
return CausalRule(
cause, effect, "piecewise_linear",
{"knot": knot, "a1": a1, "a2": a2, "b": b}, desc,
)
def _random_interaction(
cause1: str, cause2: str, effect: str, rng: random.Random
) -> InteractionRule:
"""Generate a random interaction rule where effect depends on two parents."""
itype = rng.choices(
["additive", "multiplicative", "min", "max"],
weights=[0.35, 0.35, 0.15, 0.15],
)[0]
if itype == "additive":
a = round(rng.uniform(0.5, 2.0) * rng.choice([-1, 1]), 2)
b = round(rng.uniform(0.5, 2.0) * rng.choice([-1, 1]), 2)
c = round(rng.uniform(-2.0, 2.0), 2)
desc = f"{effect} = {a}*{cause1} + {b}*{cause2} + {c}"
return InteractionRule(cause1, cause2, effect, itype, {"a": a, "b": b, "c": c}, desc)
elif itype == "multiplicative":
a = round(rng.uniform(0.1, 0.8), 2)
desc = f"{effect} = {a} * {cause1} * {cause2}"
return InteractionRule(cause1, cause2, effect, itype, {"a": a}, desc)
elif itype == "min":
desc = f"{effect} = min({cause1}, {cause2})"
return InteractionRule(cause1, cause2, effect, itype, {}, desc)
else:
desc = f"{effect} = max({cause1}, {cause2})"
return InteractionRule(cause1, cause2, effect, itype, {}, desc)
def generate_world(
n_variables: int = 3,
domain: Optional[str] = None,
seed: Optional[int] = None,
) -> CausalWorld:
"""
Generate a fresh hidden causal world.
The world may contain:
- Single-parent rules (8 types: linear, threshold, inverse, quadratic,
exponential, logarithmic, saturating, piecewise_linear)
- Multi-parent interaction rules (additive, multiplicative, min, max)
- Hidden confounders that add unexplainable correlated noise
Args:
n_variables: How many variables (2-5).
domain: One of "system_alpha", "system_beta", "system_gamma", "system_delta", or None for random.
seed: Random seed for reproducibility.
Returns:
A CausalWorld instance ready for agent probing.
"""
py_rng = random.Random(seed)
np_rng = np.random.default_rng(seed)
if domain is None or domain not in DOMAIN_LABELS:
domain = py_rng.choice(DOMAINS)
labels = DOMAIN_LABELS[domain]
var_pool = py_rng.choice(ABSTRACT_VAR_POOLS)
n = min(n_variables, len(var_pool))
chosen_vars = py_rng.sample(var_pool, n)
unit_label = labels.get("unit", "units")
units = {v: unit_label for v in chosen_vars}
rules: list[CausalRule] = []
for i in range(len(chosen_vars) - 1):
parent = chosen_vars[i]
child = chosen_vars[i + 1]
rules.append(_random_rule(parent, child, py_rng))
for i, parent in enumerate(chosen_vars):
for j, child in enumerate(chosen_vars):
if j <= i + 1:
continue
if py_rng.random() < 0.30:
rules.append(_random_rule(parent, child, py_rng))
interactions: list[InteractionRule] = []
if n >= 3 and py_rng.random() < 0.40:
roots_and_mid = [v for v in chosen_vars[:n - 1]]
if len(roots_and_mid) >= 2:
c1, c2 = py_rng.sample(roots_and_mid, 2)
target = chosen_vars[-1]
interaction = _random_interaction(c1, c2, target, py_rng)
interactions.append(interaction)
rules = [r for r in rules if r.effect != target]
confounder_sigma = 0.0
if n >= 3 and py_rng.random() < 0.30:
confounder_sigma = round(py_rng.uniform(0.05, 0.25), 3)
all_effects = {r.effect for r in rules} | {i.effect for i in interactions}
default_values: dict[str, float] = {}
for v in chosen_vars:
is_root = v not in all_effects
default_values[v] = round(py_rng.uniform(2.0, 10.0), 2) if is_root else 0.0
display_order = list(chosen_vars)
py_rng.shuffle(display_order)
world = CausalWorld(
domain=domain,
variables=display_order,
units=units,
rules=rules,
default_values=default_values,
rng=np_rng,
interactions=interactions,
confounder_sigma=confounder_sigma,
)
for v in chosen_vars:
if default_values[v] == 0.0:
computed = world._compute_value(v)
if not math.isnan(computed):
default_values[v] = round(computed, 4)
else:
default_values[v] = round(py_rng.uniform(2.0, 10.0), 2)
return world