Labexperiment / server /hypothesis_lab_environment.py
Sbhimraj's picture
Add application file
aab0192
Raw
History Blame Contribute Delete
13.2 kB
"""
server/hypothesis_lab_environment.py -- OpenEnv Environment implementation.
Implements the OpenEnv server-side Environment interface:
reset() -> initial observation
step() -> execute one agent action, return observation
state -> return episode metadata (no hidden info leaked)
This class is what the FastAPI server wraps via create_app().
"""
from __future__ import annotations
import random
from typing import Any, Optional
from uuid import uuid4
try:
from openenv.core.env_server.interfaces import Environment
except ImportError:
from abc import ABC, abstractmethod
class Environment(ABC): # type: ignore[no-redef]
def __init__(self, **kwargs: Any):
pass
@abstractmethod
def reset(self, **kwargs: Any) -> Any:
pass
@abstractmethod
def step(self, action: Any, **kwargs: Any) -> Any:
pass
@property
@abstractmethod
def state(self) -> Any:
pass
try:
from models import (
ActionType,
ExperimentType,
HypLabAction,
HypLabObservation,
HypLabState,
NoiseLevelTag,
)
except ImportError:
from ..models import (
ActionType,
ExperimentType,
HypLabAction,
HypLabObservation,
HypLabState,
NoiseLevelTag,
)
from .causal_world import CausalWorld, generate_world
from .rubric import InfoGainTracker, RubricResult, score_hypothesis
NOISE_SCHEDULE: dict[NoiseLevelTag, float] = {
NoiseLevelTag.LOW: 0.05,
NoiseLevelTag.MEDIUM: 0.20,
NoiseLevelTag.HIGH: 0.50,
}
BUDGET_SCHEDULE: dict[NoiseLevelTag, int] = {
NoiseLevelTag.LOW: 12,
NoiseLevelTag.MEDIUM: 10,
NoiseLevelTag.HIGH: 8,
}
N_VARIABLES_SCHEDULE: dict[NoiseLevelTag, int] = {
NoiseLevelTag.LOW: 2,
NoiseLevelTag.MEDIUM: 3,
NoiseLevelTag.HIGH: 4,
}
DOMAINS = ["system_alpha", "system_beta", "system_gamma", "system_delta"]
class HypothesisLabEnvironment(Environment):
"""
Scientific Hypothesis Lab -- OpenEnv Environment.
Each episode presents the agent with a new randomised causal world.
The agent must discover the hidden rules through experiments and
submit a hypothesis before running out of budget.
"""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._episode_id: str = ""
self._world: Optional[CausalWorld] = None
self._tracker: Optional[InfoGainTracker] = None
self._step_count: int = 0
self._budget_total: int = 10
self._budget_remaining: int = 0
self._done: bool = True
self._history: list[dict] = []
self._cumulative_reward: float = 0.0
self._noise_level: NoiseLevelTag = NoiseLevelTag.MEDIUM
self._sigma: float = 0.20
self._domain: str = "unknown"
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> HypLabObservation:
noise_level_str = kwargs.get("noise_level", "medium")
noise_level = NoiseLevelTag(noise_level_str) if isinstance(noise_level_str, str) else noise_level_str
domain = kwargs.get("domain", None) or random.choice(DOMAINS)
sigma = NOISE_SCHEDULE[noise_level]
budget = BUDGET_SCHEDULE[noise_level]
n_vars = N_VARIABLES_SCHEDULE[noise_level]
self._world = generate_world(n_variables=n_vars, domain=domain, seed=seed)
self._tracker = InfoGainTracker()
self._episode_id = episode_id or str(uuid4())
self._step_count = 0
self._budget_total = budget
self._budget_remaining = budget
self._done = False
self._history = []
self._cumulative_reward = 0.0
self._noise_level = noise_level
self._sigma = sigma
self._domain = domain
system_msg = (
f"New episode started. Domain: {domain.upper()}.\n"
f"You have {n_vars} unknown variables: {', '.join(self._world.variables)}.\n"
f"Budget: {budget} experiment steps.\n"
f"Run experiments to discover the hidden causal rules, then SUBMIT your hypothesis.\n"
f"Noise level: {noise_level.value}.\n\n"
f"Available experiment types:\n"
f" INTERVENTION -- set one variable to a value, observe another\n"
f" CORRELATION -- sweep one variable across a range, observe another\n"
f" COUNTERFACTUAL-- ask 'what if variable changes by delta?'\n"
f" PASSIVE -- observe one variable in its default state\n"
f" SUBMIT -- submit your hypothesis (ends episode)"
)
return HypLabObservation(
system_message=system_msg,
available_variables=self._world.variables,
budget_remaining=self._budget_remaining,
done=False,
reward=0.0,
)
def step(
self,
action: HypLabAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> HypLabObservation:
if self._world is None:
return HypLabObservation(
system_message="Error: No active episode. Call reset() before step().",
done=True,
reward=-1.0,
)
if self._done:
return HypLabObservation(
system_message="Error: Episode is already done. Call reset() to start a new episode.",
available_variables=self._world.variables,
budget_remaining=self._budget_remaining,
done=True,
reward=0.0,
)
self._step_count += 1
if action.action_type == ActionType.EXPERIMENT:
return self._handle_experiment(action)
elif action.action_type == ActionType.SUBMIT:
return self._handle_submit(action)
else:
return self._error_obs(
f"Unknown action_type: {action.action_type}. Use 'experiment' or 'submit'.",
deduct_budget=True,
)
@property
def state(self) -> HypLabState:
return HypLabState(
episode_id=self._episode_id,
step_count=self._step_count,
budget_total=self._budget_total,
budget_remaining=self._budget_remaining,
noise_level=self._noise_level,
noise_sigma=self._sigma,
domain=self._domain,
n_variables=len(self._world.variables) if self._world else 0,
experiment_history=self._history,
cumulative_info_gain=self._tracker.cumulative_gain if self._tracker else 0.0,
redundant_experiment_count=self._tracker.redundant_count if self._tracker else 0,
)
def _handle_experiment(self, action: HypLabAction) -> HypLabObservation:
world = self._world
sigma = self._sigma
tracker = self._tracker
cause = action.control_variable or ""
effect = action.target_variable or ""
all_vars = world.variables
if cause not in all_vars:
return self._error_obs(
f"Unknown control variable '{cause}'. Available: {all_vars}",
deduct_budget=True,
)
if effect not in all_vars:
return self._error_obs(
f"Unknown target variable '{effect}'. Available: {all_vars}",
deduct_budget=True,
)
exp_type = action.experiment_type or ExperimentType.INTERVENTION
result_value = None
if exp_type == ExperimentType.INTERVENTION:
val = action.control_value if action.control_value is not None else 5.0
result_value = world.query_intervention(cause, val, effect, sigma)
result_str = f"{effect} = {result_value:.4f} (sigma={sigma}, set {cause}={val})"
elif exp_type == ExperimentType.CORRELATION:
cr = action.control_range or [1.0, 10.0, 5.0]
pairs = world.query_correlation(cause, cr, effect, sigma)
result_value = pairs
result_str = (
f"Correlation sweep {cause} -> {effect}:\n"
+ "\n".join(f" {cause}={x:.2f} -> {effect}={y:.4f}" for x, y in pairs)
)
elif exp_type == ExperimentType.COUNTERFACTUAL:
delta = action.control_value or 1.0
cf = world.query_counterfactual(cause, delta, effect, sigma)
result_value = cf
result_str = (
f"Counterfactual: if {cause} changes by {delta:+.2f}:\n"
f" Baseline: {cause}={cf['baseline_x']:.2f} -> {effect}={cf['baseline_y_noisy']:.4f}\n"
f" After: {cause}={cf['counterfactual_x']:.2f} -> {effect}={cf['counterfactual_y_noisy']:.4f}\n"
f" Direction: {effect} {cf['direction']}"
)
elif exp_type == ExperimentType.PASSIVE:
result_value = world.query_passive(effect, sigma)
result_str = f"Passive observation: {effect} = {result_value:.4f} (sigma={sigma})"
else:
return self._error_obs(f"Unknown experiment type: {exp_type}")
info_gain, is_redundant = tracker.record_and_score(
cause, effect, exp_type.value, result_value
)
self._budget_remaining -= 1
budget_done = self._budget_remaining <= 0
self._cumulative_reward += info_gain
self._history.append({
"step": self._step_count,
"exp_type": exp_type.value,
"cause": cause,
"effect": effect,
"reward": round(info_gain, 4),
"redundant": is_redundant,
})
msg = f"[Step {self._step_count}] {result_str}"
if is_redundant:
msg += "\nRedundant experiment -- reward penalty applied."
if budget_done:
msg += "\nBudget exhausted. Submit your hypothesis now."
self._done = True
return HypLabObservation(
system_message=msg,
available_variables=world.variables,
budget_remaining=self._budget_remaining,
experiment_type_run=exp_type,
control_variable_used=cause,
control_value_used=(
action.control_value
if exp_type != ExperimentType.CORRELATION
else action.control_range
),
target_variable_observed=effect,
result_value=result_value,
noise_sigma=sigma,
is_redundant=is_redundant,
info_gain_reward=round(info_gain, 4),
reward=info_gain,
done=self._done,
)
def _handle_submit(self, action: HypLabAction) -> HypLabObservation:
self._done = True
rubric: RubricResult = score_hypothesis(
hypothesis_text=action.hypothesis_text or "",
hypothesis_equations=action.hypothesis_equations,
confidence=action.confidence,
world=self._world,
budget_remaining=self._budget_remaining,
budget_total=self._budget_total,
)
total_reward = rubric.total
self._cumulative_reward += total_reward
msg = (
f"[Episode End -- Step {self._step_count}]\n"
f"Hypothesis received. Evaluating against ground truth...\n\n"
f"RUBRIC BREAKDOWN:\n"
f" Accuracy score: {rubric.accuracy_score:+.4f}\n"
f" Precision bonus: {rubric.precision_bonus:+.4f}\n"
f" Calibration score: {rubric.calibration_score:+.4f}\n"
f" Efficiency bonus: {rubric.efficiency_bonus:+.4f}\n"
f" Contradiction penalty: {rubric.contradiction_penalty:+.4f}\n"
f" ────────────────────────────\n"
f" TOTAL EPISODE REWARD: {rubric.total:+.4f}\n\n"
f"FEEDBACK: {rubric.feedback}\n\n"
f"GROUND TRUTH:\n{rubric.ground_truth}"
)
return HypLabObservation(
system_message=msg,
available_variables=self._world.variables,
budget_remaining=self._budget_remaining,
accuracy_score=rubric.accuracy_score,
precision_bonus=rubric.precision_bonus,
calibration_score=rubric.calibration_score,
efficiency_bonus=rubric.efficiency_bonus,
contradiction_penalty=rubric.contradiction_penalty,
total_episode_reward=rubric.total,
ground_truth_revealed=rubric.ground_truth,
reward=total_reward,
done=True,
)
def _error_obs(
self, msg: str, deduct_budget: bool = False
) -> HypLabObservation:
if deduct_budget:
self._budget_remaining -= 1
return HypLabObservation(
system_message=f"Error: {msg}",
available_variables=self._world.variables if self._world else [],
budget_remaining=self._budget_remaining,
reward=-0.05,
done=False,
)