Spaces:
Sleeping
Sleeping
| """ | |
| server/environment.py β Session-aware environment (v0.5). | |
| New in v0.5: | |
| 1. Rollback action β undo last apply (real data engineers do this) | |
| 2. Episode reasoning trace β running history of what the agent tried + effects | |
| 3. Feature importance β returned after every apply so agent sees what the model learned | |
| 4. Regression explanation β when accuracy drops, explains the likely cause | |
| 5. Baseline comparison β agent always knows how far ahead of majority-class predictor it is | |
| """ | |
| import threading | |
| import pandas as pd | |
| import numpy as np | |
| from server.dataset_registry import DatasetRegistry | |
| from server.evaluator import Evaluator | |
| from server.reward import compute, compute_stats | |
| from server.anti_exploit import AntiExploit, ExploitDetected | |
| from server.config import cfg | |
| from server.logger import get_logger, log_event | |
| from server.specialist_agents import ( | |
| CleanerAgent, AugmenterAgent, BalancerAgent, ValidatorAgent, AnalystAgent | |
| ) | |
| logger = get_logger("environment") | |
| QUERY_COSTS = { | |
| "query_cleaner": 1, | |
| "query_augmenter": 1, | |
| "query_balancer": 1, | |
| "query_validator": 2, | |
| "query_analyst": 2, | |
| } | |
| QUERY_ACTIONS = set(QUERY_COSTS.keys()) | |
| _registry = DatasetRegistry() | |
| class DataCentricEnvironment: | |
| def __init__(self, session_id: str, episode_count: int = 0): | |
| self.session_id = session_id | |
| self._episode_count = episode_count | |
| self.agents = { | |
| "cleaner": CleanerAgent(), | |
| "augmenter": AugmenterAgent(), | |
| "balancer": BalancerAgent(), | |
| "validator": ValidatorAgent(), | |
| "analyst": AnalystAgent(), | |
| } | |
| self.anti_exploit = AntiExploit() | |
| self._lock = threading.Lock() | |
| self._reset_state() | |
| def _reset_state(self): | |
| self.train_df: pd.DataFrame = None | |
| self.holdout_df: pd.DataFrame = None | |
| self.domain_metadata: dict = {} | |
| self.evaluator: Evaluator = None | |
| self.target_accuracy: float = None | |
| self.initial_row_count: int = 0 | |
| self.baseline_accuracy: float = 0.0 # majority-class predictor on holdout | |
| self.starting_accuracy: float = 0.0 # accuracy before ANY agent action | |
| self.budget: int = cfg.MAX_BUDGET | |
| self.current_accuracy: float = 0.0 | |
| self.episode_step: int = 0 | |
| self.done: bool = False | |
| self.difficulty: str = "easy" | |
| self.pending_recs: dict = {} | |
| self.applied_rec_ids: set = set() | |
| self.last_query_result: dict = {} | |
| self.last_feature_importance: dict = {} | |
| self.anti_exploit.reset() | |
| self.accuracy_history: list = [] | |
| self.reward_history: list = [] | |
| # Rollback: stack of (df_snapshot, accuracy) β last 3 states | |
| self._state_stack: list[tuple] = [] | |
| # Reasoning trace: running log of every step | |
| self._episode_trace: list[dict] = [] | |
| # ββ Public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, difficulty: str = None, seed: int = None) -> dict: | |
| with self._lock: | |
| self._episode_count += 1 | |
| self._reset_state() | |
| self.difficulty = difficulty or self._curriculum_difficulty() | |
| self.train_df, self.holdout_df, self.domain_metadata = _registry.get( | |
| difficulty=self.difficulty, seed=seed | |
| ) | |
| self.initial_row_count = len(self.train_df) | |
| self.evaluator = Evaluator(self.holdout_df) | |
| pub_baseline = self.domain_metadata.get("published_baseline", 0.80) | |
| self.target_accuracy = round(pub_baseline * 0.97, 4) | |
| self.baseline_accuracy = self.evaluator.baseline_accuracy() | |
| self.current_accuracy = self.evaluator.evaluate(self._clean_df(self.train_df)) | |
| self.starting_accuracy = self.current_accuracy | |
| self.accuracy_history.append(self.current_accuracy) | |
| self._episode_trace.append({ | |
| "step": 0, | |
| "type": "reset", | |
| "dataset": self.domain_metadata.get("display_name"), | |
| "accuracy": round(self.current_accuracy, 4), | |
| "baseline_accuracy": self.baseline_accuracy, | |
| "target_accuracy": self.target_accuracy, | |
| }) | |
| log_event(logger, "episode_reset", | |
| session_id=self.session_id, | |
| dataset=self.domain_metadata.get("display_name"), | |
| difficulty=self.difficulty, | |
| initial_accuracy=round(self.current_accuracy, 4), | |
| target_accuracy=self.target_accuracy, | |
| baseline_accuracy=self.baseline_accuracy, | |
| published_baseline=pub_baseline, | |
| n_train=len(self.train_df), | |
| n_holdout=len(self.holdout_df)) | |
| return self._observation() | |
| def step(self, action: dict) -> dict: | |
| with self._lock: | |
| if self.done: | |
| return self._error("Episode done. Call /reset.") | |
| if self.train_df is None: | |
| return self._error("Not initialized. Call /reset first.") | |
| # Rollback action β no anti-exploit check needed | |
| action_type = action.get("action", "") | |
| if action_type == "rollback": | |
| return self._handle_rollback() | |
| try: | |
| self.anti_exploit.check( | |
| action=action, | |
| budget_remaining=self.budget, | |
| pending_recs=self.pending_recs, | |
| applied_rec_ids=self.applied_rec_ids, | |
| ) | |
| except ExploitDetected as e: | |
| log_event(logger, "exploit_detected", session_id=self.session_id, | |
| rule=e.rule, detail=e.detail) | |
| self.episode_step += 1 | |
| self.budget = max(0, self.budget - 1) | |
| self.done = self.budget <= 0 | |
| self._episode_trace.append({ | |
| "step": self.episode_step, | |
| "type": "exploit_blocked", | |
| "rule": e.rule, | |
| "detail": e.detail, | |
| }) | |
| return { | |
| "observation": self._observation(), | |
| "reward": 0.001, | |
| "done": self.done, | |
| "exploit_detected": True, | |
| "error": f"[{e.rule}] {e.detail}", | |
| "info": {"episode_step": self.episode_step, "budget_remaining": self.budget}, | |
| } | |
| if action_type in QUERY_ACTIONS: | |
| return self._handle_query(action_type, action) | |
| elif action_type == "apply": | |
| return self._handle_apply(action) | |
| else: | |
| return self._error(f"Unknown action '{action_type}'. Valid: {list(QUERY_ACTIONS) + ['apply', 'rollback']}") | |
| def state(self) -> dict: | |
| with self._lock: | |
| return self._observation() | |
| # ββ Rollback βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _handle_rollback(self) -> dict: | |
| """Undo the last apply operation. Costs 1 budget. Max 3 rollbacks per episode.""" | |
| rollbacks_used = sum(1 for e in self._episode_trace if e["type"] == "rollback") | |
| if rollbacks_used >= 3: | |
| return self._error("Maximum 3 rollbacks per episode reached.") | |
| if not self._state_stack: | |
| return self._error("Nothing to roll back. No apply operations have been made yet.") | |
| prev_df, prev_accuracy = self._state_stack.pop() | |
| self.train_df = prev_df | |
| self.current_accuracy = prev_accuracy | |
| self.accuracy_history.append(self.current_accuracy) | |
| self.budget = max(0, self.budget - 1) | |
| self.episode_step += 1 | |
| self.done = self.budget <= 0 | |
| self._episode_trace.append({ | |
| "step": self.episode_step, | |
| "type": "rollback", | |
| "accuracy_after_rollback": round(self.current_accuracy, 4), | |
| "note": "Last apply undone. Dataset restored to previous state.", | |
| }) | |
| log_event(logger, "rollback", session_id=self.session_id, | |
| accuracy_after=round(self.current_accuracy, 4)) | |
| return { | |
| "observation": self._observation(), | |
| "reward": 0.3, # small penalty for indecision, but not zero | |
| "done": self.done, | |
| "rollback": True, | |
| "accuracy_after_rollback": round(self.current_accuracy, 4), | |
| "info": { | |
| "episode_step": self.episode_step, | |
| "budget_remaining": self.budget, | |
| "rollbacks_remaining": 3 - rollbacks_used - 1, | |
| "note": "Dataset restored to state before last apply.", | |
| }, | |
| } | |
| # ββ Query handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _handle_query(self, action_type: str, action: dict) -> dict: | |
| cost = QUERY_COSTS[action_type] | |
| prev_stats = compute_stats(self.train_df) | |
| clean = self._clean_df(self.train_df) | |
| meta = self.domain_metadata | |
| if action_type == "query_cleaner": | |
| result = self.agents["cleaner"].query(clean, meta) | |
| elif action_type == "query_augmenter": | |
| result = self.agents["augmenter"].query(clean, action.get("target_class"), meta) | |
| elif action_type == "query_balancer": | |
| result = self.agents["balancer"].query(clean, meta) | |
| elif action_type == "query_validator": | |
| result = self.agents["validator"].query(clean, meta) | |
| elif action_type == "query_analyst": | |
| result = self.agents["analyst"].query(clean, meta) | |
| else: | |
| result = {} | |
| new_rec_ids = [] | |
| for rec in result.get("recommendations", []): | |
| rid = rec["id"] | |
| self.pending_recs[rid] = {"rec": rec, "agent": result.get("agent", "unknown")} | |
| new_rec_ids.append(rid) | |
| self.last_query_result = result | |
| self.budget = max(0, self.budget - cost) | |
| self.episode_step += 1 | |
| new_stats = compute_stats(self.train_df) | |
| reward, decomp = compute( | |
| prev_accuracy=self.current_accuracy, | |
| new_accuracy=self.current_accuracy, | |
| prev_stats=prev_stats, | |
| new_stats=new_stats, | |
| action=action, | |
| steps_taken=self.episode_step, | |
| max_steps=cfg.MAX_BUDGET, | |
| budget_remaining=self.budget, | |
| target_accuracy=self.target_accuracy, | |
| step_type="query", | |
| n_recs_returned=len(new_rec_ids), | |
| ) | |
| self.reward_history.append(reward) | |
| self.done = self.budget <= 0 | |
| agent_name = action_type.replace("query_", "") | |
| self._episode_trace.append({ | |
| "step": self.episode_step, | |
| "type": "query", | |
| "agent": agent_name, | |
| "n_recs": len(new_rec_ids), | |
| "budget_cost": cost, | |
| "budget_remaining": self.budget, | |
| "reward": reward, | |
| "rec_ids": new_rec_ids, | |
| }) | |
| log_event(logger, "query_step", session_id=self.session_id, | |
| action=action_type, n_recs=len(new_rec_ids), | |
| budget=self.budget, reward=reward) | |
| return { | |
| "observation": self._observation(), | |
| "reward": reward, | |
| "reward_decomposition": decomp, | |
| "done": self.done, | |
| "query_result": result, | |
| "new_recommendation_ids": new_rec_ids, | |
| "info": { | |
| "action_type": "query", | |
| "agent_queried": agent_name, | |
| "budget_cost": cost, | |
| "budget_remaining": self.budget, | |
| "n_recommendations": len(new_rec_ids), | |
| "episode_step": self.episode_step, | |
| "domain": self.domain_metadata.get("display_name"), | |
| }, | |
| } | |
| # ββ Apply handler ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _handle_apply(self, action: dict) -> dict: | |
| rec_id = action.get("rec_id", "") | |
| entry = self.pending_recs[rec_id] | |
| agent_name = entry["agent"] | |
| rec = entry["rec"] | |
| prev_accuracy = self.current_accuracy | |
| prev_stats = compute_stats(self.train_df) | |
| prev_rows = len(self.train_df) | |
| meta = self.domain_metadata | |
| # Save state for rollback BEFORE applying | |
| self._state_stack.append((self.train_df.copy(), prev_accuracy)) | |
| if len(self._state_stack) > 3: | |
| self._state_stack.pop(0) # keep at most last 3 | |
| result_holder: dict = {} | |
| error_holder: dict = {} | |
| def _run(): | |
| try: | |
| clean = self._clean_df(self.train_df) | |
| df_out, log_msg = self.agents[agent_name].apply(clean, rec, meta) | |
| result_holder["df"] = df_out | |
| result_holder["log"] = log_msg | |
| except Exception as e: | |
| error_holder["error"] = str(e) | |
| t = threading.Thread(target=_run, daemon=True) | |
| t.start() | |
| t.join(timeout=cfg.STEP_TIMEOUT_SECONDS) | |
| if t.is_alive(): | |
| self._state_stack.pop() # failed β don't keep stale snapshot | |
| return self._error("Apply timed out.") | |
| if "error" in error_holder: | |
| self._state_stack.pop() | |
| return self._error(f"Apply error: {error_holder['error']}") | |
| new_df = result_holder["df"] | |
| tool_log = result_holder["log"] | |
| # Data integrity constraint: cannot delete more than 10% of rows | |
| new_rows = len(new_df) | |
| deletion_pct = max(0, (prev_rows - new_rows) / max(prev_rows, 1)) | |
| if deletion_pct > 0.10: | |
| self._state_stack.pop() | |
| return self._error( | |
| f"Data integrity violation: would delete {deletion_pct:.1%} of training rows " | |
| f"(limit: 10%). Use targeted imputation instead of drop_rows." | |
| ) | |
| self.train_df = new_df | |
| self.applied_rec_ids.add(rec_id) | |
| self.episode_step += 1 | |
| # Full evaluation with feature importance + regression explanation | |
| eval_result = self.evaluator.evaluate_with_details( | |
| self._clean_df(self.train_df), prev_accuracy | |
| ) | |
| self.current_accuracy = eval_result["accuracy"] | |
| self.last_feature_importance = eval_result.get("feature_importance", {}) | |
| regression_explanation = eval_result.get("regression_explanation") | |
| self.accuracy_history.append(self.current_accuracy) | |
| new_stats = compute_stats(self.train_df) | |
| reward, decomp = compute( | |
| prev_accuracy=prev_accuracy, | |
| new_accuracy=self.current_accuracy, | |
| prev_stats=prev_stats, | |
| new_stats=new_stats, | |
| action=action, | |
| steps_taken=self.episode_step, | |
| max_steps=cfg.MAX_BUDGET, | |
| budget_remaining=self.budget, | |
| target_accuracy=self.target_accuracy, | |
| step_type="apply", | |
| ) | |
| self.reward_history.append(reward) | |
| self.done = (self.current_accuracy >= self.target_accuracy) or (self.budget <= 0) | |
| acc_delta = round(self.current_accuracy - prev_accuracy, 4) | |
| self._episode_trace.append({ | |
| "step": self.episode_step, | |
| "type": "apply", | |
| "agent": agent_name, | |
| "rec_type": rec.get("type", "?"), | |
| "rec_id": rec_id, | |
| "accuracy_before": round(prev_accuracy, 4), | |
| "accuracy_after": round(self.current_accuracy, 4), | |
| "accuracy_delta": acc_delta, | |
| "effect": "improved" if acc_delta > 0.001 else ("hurt" if acc_delta < -0.001 else "neutral"), | |
| "reward": reward, | |
| "rows_before": prev_rows, | |
| "rows_after": new_rows, | |
| }) | |
| log_event(logger, "apply_step", session_id=self.session_id, | |
| rec_id=rec_id, agent=agent_name, | |
| prev_acc=round(prev_accuracy, 4), | |
| new_acc=round(self.current_accuracy, 4), | |
| target=self.target_accuracy, | |
| reward=reward, | |
| success=self.current_accuracy >= self.target_accuracy) | |
| response = { | |
| "observation": self._observation(), | |
| "reward": reward, | |
| "reward_decomposition": decomp, | |
| "done": self.done, | |
| "tool_log": tool_log, | |
| "feature_importance": self.last_feature_importance, | |
| "info": { | |
| "action_type": "apply", | |
| "rec_id": rec_id, | |
| "agent": agent_name, | |
| "rec_type": rec.get("type", "?"), | |
| "prev_accuracy": round(prev_accuracy, 4), | |
| "new_accuracy": round(self.current_accuracy, 4), | |
| "accuracy_delta": acc_delta, | |
| "target_accuracy": self.target_accuracy, | |
| "published_baseline": self.domain_metadata.get("published_baseline"), | |
| "improvement_over_start": round(self.current_accuracy - self.starting_accuracy, 4), | |
| "improvement_over_majority_baseline": round(self.current_accuracy - self.baseline_accuracy, 4), | |
| "budget_remaining": self.budget, | |
| "episode_step": self.episode_step, | |
| "success": self.current_accuracy >= self.target_accuracy, | |
| "rollbacks_available": max(0, 3 - sum(1 for e in self._episode_trace if e["type"] == "rollback")), | |
| "data_integrity": { | |
| "rows_before": prev_rows, | |
| "rows_after": new_rows, | |
| "deletion_pct": round(deletion_pct, 4), | |
| }, | |
| }, | |
| } | |
| # Only include regression explanation when accuracy dropped | |
| if regression_explanation: | |
| response["regression_explanation"] = regression_explanation | |
| return response | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _clean_df(self, df): | |
| drop_cols = [c for c in df.columns if c.startswith("_")] | |
| return df.drop(columns=drop_cols) if drop_cols else df | |
| def _observation(self) -> dict: | |
| stats = compute_stats(self.train_df) if self.train_df is not None else {} | |
| pending_summary = { | |
| rid: { | |
| "agent": entry["agent"], | |
| "type": entry["rec"].get("type", "?"), | |
| "priority": entry["rec"].get("priority", "?"), | |
| "reason": entry["rec"].get("reason", ""), | |
| "domain_informed": entry["rec"].get("domain_informed", False), | |
| } | |
| for rid, entry in self.pending_recs.items() | |
| if rid not in self.applied_rec_ids | |
| } | |
| meta = self.domain_metadata | |
| # Compact trace β last 5 steps for context without overwhelming the prompt | |
| recent_trace = self._episode_trace[-5:] if self._episode_trace else [] | |
| return { | |
| "session_id": self.session_id, | |
| # What the agent is working on | |
| "dataset": { | |
| "name": meta.get("display_name", "Unknown"), | |
| "domain": meta.get("domain", "generic"), | |
| "description": meta.get("description", ""), | |
| "known_issues": meta.get("known_issues", []), | |
| "published_baseline": meta.get("published_baseline"), | |
| }, | |
| # Current state | |
| "current_accuracy": round(self.current_accuracy, 4), | |
| "target_accuracy": self.target_accuracy, | |
| "accuracy_gap": round(max(0, self.target_accuracy - self.current_accuracy), 4), | |
| "budget_remaining": self.budget, | |
| "difficulty": self.difficulty, | |
| # Comparisons β what does this number actually mean? | |
| "benchmarks": { | |
| "majority_class_baseline": self.baseline_accuracy, | |
| "starting_accuracy": round(self.starting_accuracy, 4), | |
| "improvement_over_start": round(self.current_accuracy - self.starting_accuracy, 4), | |
| "improvement_over_baseline": round(self.current_accuracy - self.baseline_accuracy, 4), | |
| "published_baseline": meta.get("published_baseline"), | |
| }, | |
| "dataset_stats": { | |
| "n_train_rows": len(self.train_df) if self.train_df is not None else 0, | |
| "n_holdout_rows": len(self.holdout_df) if self.holdout_df is not None else 0, | |
| "n_cols": len(self.train_df.columns) if self.train_df is not None else 0, | |
| "missing_pct": round(stats.get("missing_pct", 0), 4), | |
| "balance_ratio": round(stats.get("balance_ratio", 0), 4), | |
| }, | |
| # Feature importance from last evaluation | |
| "feature_importance": self.last_feature_importance, | |
| # Episodic memory β what has the agent tried so far? | |
| "episode_trace": recent_trace, | |
| "pending_recommendations": pending_summary, | |
| "last_query_result": self.last_query_result, | |
| "available_actions": ( | |
| "query_cleaner | query_augmenter | query_balancer | " | |
| "query_validator (cost 2) | query_analyst (cost 2) | " | |
| "apply {rec_id} | rollback (undo last apply, max 3/episode)" | |
| ), | |
| } | |
| def _error(self, msg: str) -> dict: | |
| return {"error": msg, "session_id": self.session_id} | |
| def _curriculum_difficulty(self) -> str: | |
| if self._episode_count < cfg.CURRICULUM_MEDIUM_AFTER: | |
| return "easy" | |
| elif self._episode_count < cfg.CURRICULUM_HARD_AFTER: | |
| return "medium" | |
| return "hard" | |
| def episode_summary(self) -> dict: | |
| return { | |
| "session_id": self.session_id, | |
| "episode_count": self._episode_count, | |
| "accuracy_history": [round(a, 4) for a in self.accuracy_history], | |
| "reward_history": [round(r, 4) for r in self.reward_history], | |
| "mean_reward": round(sum(self.reward_history) / max(len(self.reward_history), 1), 4), | |
| "full_trace": self._episode_trace, | |
| } | |