File size: 8,050 Bytes
e8db72f
def0858
e8db72f
 
def0858
 
 
 
 
 
 
e8db72f
 
def0858
e8db72f
 
def0858
d8b5f3c
e8db72f
 
 
 
 
b7e4429
343deae
 
e8db72f
343deae
def0858
 
 
 
 
 
 
343deae
e8db72f
 
343deae
b7e4429
def0858
 
 
 
 
 
 
d8b5f3c
 
 
 
 
 
 
e8db72f
 
def0858
e8db72f
 
 
def0858
b7e4429
 
 
def0858
b7e4429
e8db72f
 
 
 
 
 
 
b7e4429
343deae
e8db72f
 
 
 
 
 
 
 
 
d8b5f3c
e8db72f
 
 
 
 
 
 
 
343deae
 
def0858
e8db72f
 
b7e4429
def0858
d8b5f3c
def0858
 
d8b5f3c
 
 
 
 
 
def0858
b7e4429
def0858
b7e4429
e8db72f
 
def0858
e8db72f
def0858
e8db72f
 
b7e4429
e8db72f
def0858
343deae
def0858
 
 
 
d8b5f3c
 
 
 
 
def0858
d8b5f3c
 
 
def0858
 
d8b5f3c
def0858
e8db72f
 
 
 
 
 
 
 
 
 
 
 
b7e4429
e8db72f
def0858
 
 
 
e8db72f
 
b7e4429
e8db72f
 
 
d8b5f3c
e8db72f
 
 
b7e4429
 
 
def0858
 
 
d8b5f3c
def0858
 
 
d8b5f3c
 
def0858
 
 
 
 
 
 
e8db72f
b7e4429
e8db72f
 
 
 
def0858
d8b5f3c
def0858
d8b5f3c
def0858
 
 
e8db72f
 
 
def0858
e8db72f
 
 
 
 
 
 
 
def0858
 
e8db72f
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
server/main.py β€” Production FastAPI application (v0.5).

Endpoints:
  POST /reset                   β€” Start new episode (returns session_id + full observation)
  POST /step                    β€” Take action (query | apply | rollback)
  GET  /state/{session_id}      β€” Current observation
  GET  /trajectory/{session_id} β€” Full episode trace with all rewards and effects
  GET  /health                  β€” Health check + version
  GET  /metrics                 β€” Session counts + config
  GET  /docs                    β€” Swagger UI (auto-generated)
"""
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, field_validator
from typing import Optional

from server.environment import DataCentricEnvironment, _registry
from server.session_manager import session_manager
from server.config import cfg
from server.logger import get_logger, log_event

logger = get_logger("api")

app = FastAPI(
    title="DataCentric-Env",
    version=cfg.ENV_VERSION,
    description=(
        "RL environment: an LLM acts as a data engineer. "
        "Given a real, messy tabular dataset (UCI Adult, Pima Diabetes, German Credit, etc.), "
        "the agent queries specialist agents for recommendations and applies them to fix the data "
        "until the frozen classifier hits the accuracy target. "
        "All scores compared against published academic baselines.\n\n"
        "**New in v0.5:** Rollback action, episode reasoning trace, feature importance, "
        "regression explanations, benchmark comparisons."
    ),
    docs_url="/docs",
    redoc_url="/redoc",
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.on_event("startup")
async def startup_event():
    """Pre-load all 5 real datasets in a background thread so the first /reset is instant."""
    _registry.warmup()


VALID_ACTIONS = {
    "query_cleaner", "query_augmenter", "query_balancer",
    "query_validator", "query_analyst", "apply", "rollback",
}


# ── Request models ──────────────────────────────────────────────────────────────

class ResetRequest(BaseModel):
    difficulty: Optional[str] = None
    seed: Optional[int] = None

    @field_validator("difficulty")
    @classmethod
    def validate_difficulty(cls, v):
        if v is not None and v not in ("easy", "medium", "hard"):
            raise ValueError("difficulty must be 'easy', 'medium', or 'hard'")
        return v


class ActionRequest(BaseModel):
    session_id: str
    action: str
    rec_id: Optional[str] = None
    target_class: Optional[int] = None

    @field_validator("action")
    @classmethod
    def validate_action(cls, v):
        if v not in VALID_ACTIONS:
            raise ValueError(f"Invalid action '{v}'. Valid: {sorted(VALID_ACTIONS)}")
        return v

    @field_validator("target_class")
    @classmethod
    def validate_target_class(cls, v):
        if v is not None and v not in (0, 1):
            raise ValueError("target_class must be 0 or 1")
        return v


# ── Endpoints ───────────────────────────────────────────────────────────────────

@app.post("/reset", summary="Start a new episode")
def reset(body: ResetRequest = None):
    """
    Creates a new episode on a real dataset. Returns `session_id` + full observation.

    The observation includes:
    - Dataset name, domain, and documented known quality issues
    - Current accuracy vs target vs published benchmark vs majority-class baseline
    - Dataset statistics (missing %, class balance ratio)
    - Feature importance (empty until first apply)
    - Episode trace (empty at start)
    - All pending recommendations (empty until first query)
    """
    difficulty = body.difficulty if body else None
    seed = body.seed if body else None

    env = DataCentricEnvironment(session_id="pending", episode_count=0)
    session_id = session_manager.create_session(env)
    env.session_id = session_id

    obs = env.reset(difficulty=difficulty, seed=seed)
    log_event(logger, "api_reset", session_id=session_id, difficulty=obs.get("difficulty"))
    return obs


@app.post("/step", summary="Take an action")
def step(body: ActionRequest):
    """
    Take one action in the environment.

    **Query actions** (cost 1-2 budget, return recommendations):
    - `query_cleaner` (cost 1) β€” missing value + zero-as-missing analysis, domain-aware
    - `query_augmenter` (cost 1) β€” minority class synthesis via SMOTE-like interpolation
    - `query_balancer` (cost 1) β€” class resampling with explicit tradeoff explanation
    - `query_validator` (cost 2) β€” duplicate + outlier detection (conservative IQR for medical)
    - `query_analyst` (cost 2) β€” holistic diagnosis + prioritized plan + published baseline

    **Apply action** (modifies dataset, no budget cost):
    - `apply` with `rec_id` β€” apply a recommendation by its ID from any previous query
    - Response includes: feature importance (LogReg coefs), regression explanation if accuracy drops

    **Rollback action** (cost 1 budget, max 3/episode):
    - `rollback` β€” undo the last apply and restore the previous dataset state
    """
    env = session_manager.get_env(body.session_id)
    if env is None:
        raise HTTPException(
            status_code=404,
            detail=f"Session '{body.session_id}' not found or expired. Call /reset first."
        )

    action_dict = {"action": body.action}
    if body.rec_id:
        action_dict["rec_id"] = body.rec_id
    if body.target_class is not None:
        action_dict["target_class"] = body.target_class

    result = env.step(action_dict)

    if "error" in result and "exploit" not in str(result):
        log_event(logger, "step_error", session_id=body.session_id, error=result["error"])

    session_manager.increment_steps(body.session_id)
    return result


@app.get("/state/{session_id}", summary="Get current observation")
def state(session_id: str):
    """Current full observation including episode trace, benchmarks, and feature importance."""
    env = session_manager.get_env(session_id)
    if env is None:
        raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.")
    return env.state()


@app.get("/trajectory/{session_id}", summary="Full episode trajectory")
def trajectory(session_id: str):
    """
    Complete episode trace β€” every step with reward, accuracy delta, and effect label.

    Useful for:
    - Offline reward model training
    - Debugging agent decisions
    - Comparing strategy effectiveness across episodes
    """
    env = session_manager.get_env(session_id)
    if env is None:
        raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.")
    return env.episode_summary()


@app.get("/health", summary="Health check")
def health():
    return {
        "status": "ok",
        "version": cfg.ENV_VERSION,
        "active_sessions": session_manager.metrics()["active_sessions"],
        "real_datasets": [
            "UCI Adult Census Income",
            "Pima Indians Diabetes",
            "Wisconsin Breast Cancer Diagnostic",
            "German Credit Risk",
            "Cleveland Heart Disease",
        ],
    }


@app.get("/metrics", summary="Server metrics")
def metrics():
    return {
        "version": cfg.ENV_VERSION,
        "config": {
            "max_budget": cfg.MAX_BUDGET,
            "max_concurrent_sessions": cfg.MAX_CONCURRENT_SESSIONS,
            "session_ttl_seconds": cfg.SESSION_TTL_SECONDS,
            "max_same_action_streak": cfg.MAX_SAME_ACTION_STREAK,
            "max_row_deletion_pct": 0.10,
            "max_rollbacks_per_episode": 3,
        },
        "sessions": session_manager.metrics(),
    }