datacentric-env / server /environment.py
Aswini-Kumar's picture
Upload server/environment.py with huggingface_hub
f22507c verified
"""
server/environment.py β€” Session-aware environment (v0.5).
New in v0.5:
1. Rollback action β€” undo last apply (real data engineers do this)
2. Episode reasoning trace β€” running history of what the agent tried + effects
3. Feature importance β€” returned after every apply so agent sees what the model learned
4. Regression explanation β€” when accuracy drops, explains the likely cause
5. Baseline comparison β€” agent always knows how far ahead of majority-class predictor it is
"""
import threading
import pandas as pd
import numpy as np
from server.dataset_registry import DatasetRegistry
from server.evaluator import Evaluator
from server.reward import compute, compute_stats
from server.anti_exploit import AntiExploit, ExploitDetected
from server.config import cfg
from server.logger import get_logger, log_event
from server.specialist_agents import (
CleanerAgent, AugmenterAgent, BalancerAgent, ValidatorAgent, AnalystAgent
)
logger = get_logger("environment")
QUERY_COSTS = {
"query_cleaner": 1,
"query_augmenter": 1,
"query_balancer": 1,
"query_validator": 2,
"query_analyst": 2,
}
QUERY_ACTIONS = set(QUERY_COSTS.keys())
_registry = DatasetRegistry()
class DataCentricEnvironment:
def __init__(self, session_id: str, episode_count: int = 0):
self.session_id = session_id
self._episode_count = episode_count
self.agents = {
"cleaner": CleanerAgent(),
"augmenter": AugmenterAgent(),
"balancer": BalancerAgent(),
"validator": ValidatorAgent(),
"analyst": AnalystAgent(),
}
self.anti_exploit = AntiExploit()
self._lock = threading.Lock()
self._reset_state()
def _reset_state(self):
self.train_df: pd.DataFrame = None
self.holdout_df: pd.DataFrame = None
self.domain_metadata: dict = {}
self.evaluator: Evaluator = None
self.target_accuracy: float = None
self.initial_row_count: int = 0
self.baseline_accuracy: float = 0.0 # majority-class predictor on holdout
self.starting_accuracy: float = 0.0 # accuracy before ANY agent action
self.budget: int = cfg.MAX_BUDGET
self.current_accuracy: float = 0.0
self.episode_step: int = 0
self.done: bool = False
self.difficulty: str = "easy"
self.pending_recs: dict = {}
self.applied_rec_ids: set = set()
self.last_query_result: dict = {}
self.last_feature_importance: dict = {}
self.anti_exploit.reset()
self.accuracy_history: list = []
self.reward_history: list = []
# Rollback: stack of (df_snapshot, accuracy) β€” last 3 states
self._state_stack: list[tuple] = []
# Reasoning trace: running log of every step
self._episode_trace: list[dict] = []
# ── Public API ─────────────────────────────────────────────────────────────
def reset(self, difficulty: str = None, seed: int = None) -> dict:
with self._lock:
self._episode_count += 1
self._reset_state()
self.difficulty = difficulty or self._curriculum_difficulty()
self.train_df, self.holdout_df, self.domain_metadata = _registry.get(
difficulty=self.difficulty, seed=seed
)
self.initial_row_count = len(self.train_df)
self.evaluator = Evaluator(self.holdout_df)
pub_baseline = self.domain_metadata.get("published_baseline", 0.80)
self.target_accuracy = round(pub_baseline * 0.97, 4)
self.baseline_accuracy = self.evaluator.baseline_accuracy()
self.current_accuracy = self.evaluator.evaluate(self._clean_df(self.train_df))
self.starting_accuracy = self.current_accuracy
self.accuracy_history.append(self.current_accuracy)
self._episode_trace.append({
"step": 0,
"type": "reset",
"dataset": self.domain_metadata.get("display_name"),
"accuracy": round(self.current_accuracy, 4),
"baseline_accuracy": self.baseline_accuracy,
"target_accuracy": self.target_accuracy,
})
log_event(logger, "episode_reset",
session_id=self.session_id,
dataset=self.domain_metadata.get("display_name"),
difficulty=self.difficulty,
initial_accuracy=round(self.current_accuracy, 4),
target_accuracy=self.target_accuracy,
baseline_accuracy=self.baseline_accuracy,
published_baseline=pub_baseline,
n_train=len(self.train_df),
n_holdout=len(self.holdout_df))
return self._observation()
def step(self, action: dict) -> dict:
with self._lock:
if self.done:
return self._error("Episode done. Call /reset.")
if self.train_df is None:
return self._error("Not initialized. Call /reset first.")
# Rollback action β€” no anti-exploit check needed
action_type = action.get("action", "")
if action_type == "rollback":
return self._handle_rollback()
try:
self.anti_exploit.check(
action=action,
budget_remaining=self.budget,
pending_recs=self.pending_recs,
applied_rec_ids=self.applied_rec_ids,
)
except ExploitDetected as e:
log_event(logger, "exploit_detected", session_id=self.session_id,
rule=e.rule, detail=e.detail)
self.episode_step += 1
self.budget = max(0, self.budget - 1)
self.done = self.budget <= 0
self._episode_trace.append({
"step": self.episode_step,
"type": "exploit_blocked",
"rule": e.rule,
"detail": e.detail,
})
return {
"observation": self._observation(),
"reward": 0.001,
"done": self.done,
"exploit_detected": True,
"error": f"[{e.rule}] {e.detail}",
"info": {"episode_step": self.episode_step, "budget_remaining": self.budget},
}
if action_type in QUERY_ACTIONS:
return self._handle_query(action_type, action)
elif action_type == "apply":
return self._handle_apply(action)
else:
return self._error(f"Unknown action '{action_type}'. Valid: {list(QUERY_ACTIONS) + ['apply', 'rollback']}")
def state(self) -> dict:
with self._lock:
return self._observation()
# ── Rollback ───────────────────────────────────────────────────────────────
def _handle_rollback(self) -> dict:
"""Undo the last apply operation. Costs 1 budget. Max 3 rollbacks per episode."""
rollbacks_used = sum(1 for e in self._episode_trace if e["type"] == "rollback")
if rollbacks_used >= 3:
return self._error("Maximum 3 rollbacks per episode reached.")
if not self._state_stack:
return self._error("Nothing to roll back. No apply operations have been made yet.")
prev_df, prev_accuracy = self._state_stack.pop()
self.train_df = prev_df
self.current_accuracy = prev_accuracy
self.accuracy_history.append(self.current_accuracy)
self.budget = max(0, self.budget - 1)
self.episode_step += 1
self.done = self.budget <= 0
self._episode_trace.append({
"step": self.episode_step,
"type": "rollback",
"accuracy_after_rollback": round(self.current_accuracy, 4),
"note": "Last apply undone. Dataset restored to previous state.",
})
log_event(logger, "rollback", session_id=self.session_id,
accuracy_after=round(self.current_accuracy, 4))
return {
"observation": self._observation(),
"reward": 0.3, # small penalty for indecision, but not zero
"done": self.done,
"rollback": True,
"accuracy_after_rollback": round(self.current_accuracy, 4),
"info": {
"episode_step": self.episode_step,
"budget_remaining": self.budget,
"rollbacks_remaining": 3 - rollbacks_used - 1,
"note": "Dataset restored to state before last apply.",
},
}
# ── Query handler ──────────────────────────────────────────────────────────
def _handle_query(self, action_type: str, action: dict) -> dict:
cost = QUERY_COSTS[action_type]
prev_stats = compute_stats(self.train_df)
clean = self._clean_df(self.train_df)
meta = self.domain_metadata
if action_type == "query_cleaner":
result = self.agents["cleaner"].query(clean, meta)
elif action_type == "query_augmenter":
result = self.agents["augmenter"].query(clean, action.get("target_class"), meta)
elif action_type == "query_balancer":
result = self.agents["balancer"].query(clean, meta)
elif action_type == "query_validator":
result = self.agents["validator"].query(clean, meta)
elif action_type == "query_analyst":
result = self.agents["analyst"].query(clean, meta)
else:
result = {}
new_rec_ids = []
for rec in result.get("recommendations", []):
rid = rec["id"]
self.pending_recs[rid] = {"rec": rec, "agent": result.get("agent", "unknown")}
new_rec_ids.append(rid)
self.last_query_result = result
self.budget = max(0, self.budget - cost)
self.episode_step += 1
new_stats = compute_stats(self.train_df)
reward, decomp = compute(
prev_accuracy=self.current_accuracy,
new_accuracy=self.current_accuracy,
prev_stats=prev_stats,
new_stats=new_stats,
action=action,
steps_taken=self.episode_step,
max_steps=cfg.MAX_BUDGET,
budget_remaining=self.budget,
target_accuracy=self.target_accuracy,
step_type="query",
n_recs_returned=len(new_rec_ids),
)
self.reward_history.append(reward)
self.done = self.budget <= 0
agent_name = action_type.replace("query_", "")
self._episode_trace.append({
"step": self.episode_step,
"type": "query",
"agent": agent_name,
"n_recs": len(new_rec_ids),
"budget_cost": cost,
"budget_remaining": self.budget,
"reward": reward,
"rec_ids": new_rec_ids,
})
log_event(logger, "query_step", session_id=self.session_id,
action=action_type, n_recs=len(new_rec_ids),
budget=self.budget, reward=reward)
return {
"observation": self._observation(),
"reward": reward,
"reward_decomposition": decomp,
"done": self.done,
"query_result": result,
"new_recommendation_ids": new_rec_ids,
"info": {
"action_type": "query",
"agent_queried": agent_name,
"budget_cost": cost,
"budget_remaining": self.budget,
"n_recommendations": len(new_rec_ids),
"episode_step": self.episode_step,
"domain": self.domain_metadata.get("display_name"),
},
}
# ── Apply handler ──────────────────────────────────────────────────────────
def _handle_apply(self, action: dict) -> dict:
rec_id = action.get("rec_id", "")
entry = self.pending_recs[rec_id]
agent_name = entry["agent"]
rec = entry["rec"]
prev_accuracy = self.current_accuracy
prev_stats = compute_stats(self.train_df)
prev_rows = len(self.train_df)
meta = self.domain_metadata
# Save state for rollback BEFORE applying
self._state_stack.append((self.train_df.copy(), prev_accuracy))
if len(self._state_stack) > 3:
self._state_stack.pop(0) # keep at most last 3
result_holder: dict = {}
error_holder: dict = {}
def _run():
try:
clean = self._clean_df(self.train_df)
df_out, log_msg = self.agents[agent_name].apply(clean, rec, meta)
result_holder["df"] = df_out
result_holder["log"] = log_msg
except Exception as e:
error_holder["error"] = str(e)
t = threading.Thread(target=_run, daemon=True)
t.start()
t.join(timeout=cfg.STEP_TIMEOUT_SECONDS)
if t.is_alive():
self._state_stack.pop() # failed β€” don't keep stale snapshot
return self._error("Apply timed out.")
if "error" in error_holder:
self._state_stack.pop()
return self._error(f"Apply error: {error_holder['error']}")
new_df = result_holder["df"]
tool_log = result_holder["log"]
# Data integrity constraint: cannot delete more than 10% of rows
new_rows = len(new_df)
deletion_pct = max(0, (prev_rows - new_rows) / max(prev_rows, 1))
if deletion_pct > 0.10:
self._state_stack.pop()
return self._error(
f"Data integrity violation: would delete {deletion_pct:.1%} of training rows "
f"(limit: 10%). Use targeted imputation instead of drop_rows."
)
self.train_df = new_df
self.applied_rec_ids.add(rec_id)
self.episode_step += 1
# Full evaluation with feature importance + regression explanation
eval_result = self.evaluator.evaluate_with_details(
self._clean_df(self.train_df), prev_accuracy
)
self.current_accuracy = eval_result["accuracy"]
self.last_feature_importance = eval_result.get("feature_importance", {})
regression_explanation = eval_result.get("regression_explanation")
self.accuracy_history.append(self.current_accuracy)
new_stats = compute_stats(self.train_df)
reward, decomp = compute(
prev_accuracy=prev_accuracy,
new_accuracy=self.current_accuracy,
prev_stats=prev_stats,
new_stats=new_stats,
action=action,
steps_taken=self.episode_step,
max_steps=cfg.MAX_BUDGET,
budget_remaining=self.budget,
target_accuracy=self.target_accuracy,
step_type="apply",
)
self.reward_history.append(reward)
self.done = (self.current_accuracy >= self.target_accuracy) or (self.budget <= 0)
acc_delta = round(self.current_accuracy - prev_accuracy, 4)
self._episode_trace.append({
"step": self.episode_step,
"type": "apply",
"agent": agent_name,
"rec_type": rec.get("type", "?"),
"rec_id": rec_id,
"accuracy_before": round(prev_accuracy, 4),
"accuracy_after": round(self.current_accuracy, 4),
"accuracy_delta": acc_delta,
"effect": "improved" if acc_delta > 0.001 else ("hurt" if acc_delta < -0.001 else "neutral"),
"reward": reward,
"rows_before": prev_rows,
"rows_after": new_rows,
})
log_event(logger, "apply_step", session_id=self.session_id,
rec_id=rec_id, agent=agent_name,
prev_acc=round(prev_accuracy, 4),
new_acc=round(self.current_accuracy, 4),
target=self.target_accuracy,
reward=reward,
success=self.current_accuracy >= self.target_accuracy)
response = {
"observation": self._observation(),
"reward": reward,
"reward_decomposition": decomp,
"done": self.done,
"tool_log": tool_log,
"feature_importance": self.last_feature_importance,
"info": {
"action_type": "apply",
"rec_id": rec_id,
"agent": agent_name,
"rec_type": rec.get("type", "?"),
"prev_accuracy": round(prev_accuracy, 4),
"new_accuracy": round(self.current_accuracy, 4),
"accuracy_delta": acc_delta,
"target_accuracy": self.target_accuracy,
"published_baseline": self.domain_metadata.get("published_baseline"),
"improvement_over_start": round(self.current_accuracy - self.starting_accuracy, 4),
"improvement_over_majority_baseline": round(self.current_accuracy - self.baseline_accuracy, 4),
"budget_remaining": self.budget,
"episode_step": self.episode_step,
"success": self.current_accuracy >= self.target_accuracy,
"rollbacks_available": max(0, 3 - sum(1 for e in self._episode_trace if e["type"] == "rollback")),
"data_integrity": {
"rows_before": prev_rows,
"rows_after": new_rows,
"deletion_pct": round(deletion_pct, 4),
},
},
}
# Only include regression explanation when accuracy dropped
if regression_explanation:
response["regression_explanation"] = regression_explanation
return response
# ── Helpers ────────────────────────────────────────────────────────────────
def _clean_df(self, df):
drop_cols = [c for c in df.columns if c.startswith("_")]
return df.drop(columns=drop_cols) if drop_cols else df
def _observation(self) -> dict:
stats = compute_stats(self.train_df) if self.train_df is not None else {}
pending_summary = {
rid: {
"agent": entry["agent"],
"type": entry["rec"].get("type", "?"),
"priority": entry["rec"].get("priority", "?"),
"reason": entry["rec"].get("reason", ""),
"domain_informed": entry["rec"].get("domain_informed", False),
}
for rid, entry in self.pending_recs.items()
if rid not in self.applied_rec_ids
}
meta = self.domain_metadata
# Compact trace β€” last 5 steps for context without overwhelming the prompt
recent_trace = self._episode_trace[-5:] if self._episode_trace else []
return {
"session_id": self.session_id,
# What the agent is working on
"dataset": {
"name": meta.get("display_name", "Unknown"),
"domain": meta.get("domain", "generic"),
"description": meta.get("description", ""),
"known_issues": meta.get("known_issues", []),
"published_baseline": meta.get("published_baseline"),
},
# Current state
"current_accuracy": round(self.current_accuracy, 4),
"target_accuracy": self.target_accuracy,
"accuracy_gap": round(max(0, self.target_accuracy - self.current_accuracy), 4),
"budget_remaining": self.budget,
"difficulty": self.difficulty,
# Comparisons β€” what does this number actually mean?
"benchmarks": {
"majority_class_baseline": self.baseline_accuracy,
"starting_accuracy": round(self.starting_accuracy, 4),
"improvement_over_start": round(self.current_accuracy - self.starting_accuracy, 4),
"improvement_over_baseline": round(self.current_accuracy - self.baseline_accuracy, 4),
"published_baseline": meta.get("published_baseline"),
},
"dataset_stats": {
"n_train_rows": len(self.train_df) if self.train_df is not None else 0,
"n_holdout_rows": len(self.holdout_df) if self.holdout_df is not None else 0,
"n_cols": len(self.train_df.columns) if self.train_df is not None else 0,
"missing_pct": round(stats.get("missing_pct", 0), 4),
"balance_ratio": round(stats.get("balance_ratio", 0), 4),
},
# Feature importance from last evaluation
"feature_importance": self.last_feature_importance,
# Episodic memory β€” what has the agent tried so far?
"episode_trace": recent_trace,
"pending_recommendations": pending_summary,
"last_query_result": self.last_query_result,
"available_actions": (
"query_cleaner | query_augmenter | query_balancer | "
"query_validator (cost 2) | query_analyst (cost 2) | "
"apply {rec_id} | rollback (undo last apply, max 3/episode)"
),
}
def _error(self, msg: str) -> dict:
return {"error": msg, "session_id": self.session_id}
def _curriculum_difficulty(self) -> str:
if self._episode_count < cfg.CURRICULUM_MEDIUM_AFTER:
return "easy"
elif self._episode_count < cfg.CURRICULUM_HARD_AFTER:
return "medium"
return "hard"
def episode_summary(self) -> dict:
return {
"session_id": self.session_id,
"episode_count": self._episode_count,
"accuracy_history": [round(a, 4) for a in self.accuracy_history],
"reward_history": [round(r, 4) for r in self.reward_history],
"mean_reward": round(sum(self.reward_history) / max(len(self.reward_history), 1), 4),
"full_trace": self._episode_trace,
}