Spaces:
Sleeping
Sleeping
| """ | |
| server/main.py β Production FastAPI application (v0.5). | |
| Endpoints: | |
| POST /reset β Start new episode (returns session_id + full observation) | |
| POST /step β Take action (query | apply | rollback) | |
| GET /state/{session_id} β Current observation | |
| GET /trajectory/{session_id} β Full episode trace with all rewards and effects | |
| GET /health β Health check + version | |
| GET /metrics β Session counts + config | |
| GET /docs β Swagger UI (auto-generated) | |
| """ | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, field_validator | |
| from typing import Optional | |
| from server.environment import DataCentricEnvironment, _registry | |
| from server.session_manager import session_manager | |
| from server.config import cfg | |
| from server.logger import get_logger, log_event | |
| logger = get_logger("api") | |
| app = FastAPI( | |
| title="DataCentric-Env", | |
| version=cfg.ENV_VERSION, | |
| description=( | |
| "RL environment: an LLM acts as a data engineer. " | |
| "Given a real, messy tabular dataset (UCI Adult, Pima Diabetes, German Credit, etc.), " | |
| "the agent queries specialist agents for recommendations and applies them to fix the data " | |
| "until the frozen classifier hits the accuracy target. " | |
| "All scores compared against published academic baselines.\n\n" | |
| "**New in v0.5:** Rollback action, episode reasoning trace, feature importance, " | |
| "regression explanations, benchmark comparisons." | |
| ), | |
| docs_url="/docs", | |
| redoc_url="/redoc", | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def startup_event(): | |
| """Pre-load all 5 real datasets in a background thread so the first /reset is instant.""" | |
| _registry.warmup() | |
| VALID_ACTIONS = { | |
| "query_cleaner", "query_augmenter", "query_balancer", | |
| "query_validator", "query_analyst", "apply", "rollback", | |
| } | |
| # ββ Request models ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResetRequest(BaseModel): | |
| difficulty: Optional[str] = None | |
| seed: Optional[int] = None | |
| def validate_difficulty(cls, v): | |
| if v is not None and v not in ("easy", "medium", "hard"): | |
| raise ValueError("difficulty must be 'easy', 'medium', or 'hard'") | |
| return v | |
| class ActionRequest(BaseModel): | |
| session_id: str | |
| action: str | |
| rec_id: Optional[str] = None | |
| target_class: Optional[int] = None | |
| def validate_action(cls, v): | |
| if v not in VALID_ACTIONS: | |
| raise ValueError(f"Invalid action '{v}'. Valid: {sorted(VALID_ACTIONS)}") | |
| return v | |
| def validate_target_class(cls, v): | |
| if v is not None and v not in (0, 1): | |
| raise ValueError("target_class must be 0 or 1") | |
| return v | |
| # ββ Endpoints βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(body: ResetRequest = None): | |
| """ | |
| Creates a new episode on a real dataset. Returns `session_id` + full observation. | |
| The observation includes: | |
| - Dataset name, domain, and documented known quality issues | |
| - Current accuracy vs target vs published benchmark vs majority-class baseline | |
| - Dataset statistics (missing %, class balance ratio) | |
| - Feature importance (empty until first apply) | |
| - Episode trace (empty at start) | |
| - All pending recommendations (empty until first query) | |
| """ | |
| difficulty = body.difficulty if body else None | |
| seed = body.seed if body else None | |
| env = DataCentricEnvironment(session_id="pending", episode_count=0) | |
| session_id = session_manager.create_session(env) | |
| env.session_id = session_id | |
| obs = env.reset(difficulty=difficulty, seed=seed) | |
| log_event(logger, "api_reset", session_id=session_id, difficulty=obs.get("difficulty")) | |
| return obs | |
| def step(body: ActionRequest): | |
| """ | |
| Take one action in the environment. | |
| **Query actions** (cost 1-2 budget, return recommendations): | |
| - `query_cleaner` (cost 1) β missing value + zero-as-missing analysis, domain-aware | |
| - `query_augmenter` (cost 1) β minority class synthesis via SMOTE-like interpolation | |
| - `query_balancer` (cost 1) β class resampling with explicit tradeoff explanation | |
| - `query_validator` (cost 2) β duplicate + outlier detection (conservative IQR for medical) | |
| - `query_analyst` (cost 2) β holistic diagnosis + prioritized plan + published baseline | |
| **Apply action** (modifies dataset, no budget cost): | |
| - `apply` with `rec_id` β apply a recommendation by its ID from any previous query | |
| - Response includes: feature importance (LogReg coefs), regression explanation if accuracy drops | |
| **Rollback action** (cost 1 budget, max 3/episode): | |
| - `rollback` β undo the last apply and restore the previous dataset state | |
| """ | |
| env = session_manager.get_env(body.session_id) | |
| if env is None: | |
| raise HTTPException( | |
| status_code=404, | |
| detail=f"Session '{body.session_id}' not found or expired. Call /reset first." | |
| ) | |
| action_dict = {"action": body.action} | |
| if body.rec_id: | |
| action_dict["rec_id"] = body.rec_id | |
| if body.target_class is not None: | |
| action_dict["target_class"] = body.target_class | |
| result = env.step(action_dict) | |
| if "error" in result and "exploit" not in str(result): | |
| log_event(logger, "step_error", session_id=body.session_id, error=result["error"]) | |
| session_manager.increment_steps(body.session_id) | |
| return result | |
| def state(session_id: str): | |
| """Current full observation including episode trace, benchmarks, and feature importance.""" | |
| env = session_manager.get_env(session_id) | |
| if env is None: | |
| raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.") | |
| return env.state() | |
| def trajectory(session_id: str): | |
| """ | |
| Complete episode trace β every step with reward, accuracy delta, and effect label. | |
| Useful for: | |
| - Offline reward model training | |
| - Debugging agent decisions | |
| - Comparing strategy effectiveness across episodes | |
| """ | |
| env = session_manager.get_env(session_id) | |
| if env is None: | |
| raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.") | |
| return env.episode_summary() | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "version": cfg.ENV_VERSION, | |
| "active_sessions": session_manager.metrics()["active_sessions"], | |
| "real_datasets": [ | |
| "UCI Adult Census Income", | |
| "Pima Indians Diabetes", | |
| "Wisconsin Breast Cancer Diagnostic", | |
| "German Credit Risk", | |
| "Cleveland Heart Disease", | |
| ], | |
| } | |
| def metrics(): | |
| return { | |
| "version": cfg.ENV_VERSION, | |
| "config": { | |
| "max_budget": cfg.MAX_BUDGET, | |
| "max_concurrent_sessions": cfg.MAX_CONCURRENT_SESSIONS, | |
| "session_ttl_seconds": cfg.SESSION_TTL_SECONDS, | |
| "max_same_action_streak": cfg.MAX_SAME_ACTION_STREAK, | |
| "max_row_deletion_pct": 0.10, | |
| "max_rollbacks_per_episode": 3, | |
| }, | |
| "sessions": session_manager.metrics(), | |
| } | |