Aswini-Kumar commited on
Commit
def0858
Β·
verified Β·
1 Parent(s): f1de9f4

Upload server/main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/main.py +92 -20
server/main.py CHANGED
@@ -1,16 +1,20 @@
1
  """
2
- server/main.py β€” Production FastAPI application (v0.3).
3
 
4
  Endpoints:
5
- POST /reset β€” Start new episode, returns session_id
6
- POST /step β€” Take action (requires session_id)
7
- GET /state/{session_id} β€” Get current observation
8
- GET /health β€” Health check
9
- GET /metrics β€” Session + episode metrics
 
 
10
  """
11
  from fastapi import FastAPI, HTTPException
 
12
  from pydantic import BaseModel, field_validator
13
  from typing import Optional
 
14
  from server.environment import DataCentricEnvironment
15
  from server.session_manager import session_manager
16
  from server.config import cfg
@@ -22,24 +26,36 @@ app = FastAPI(
22
  title="DataCentric-Env",
23
  version=cfg.ENV_VERSION,
24
  description=(
25
- "RL environment: LLM acts as data engineer. "
26
- "Query specialist agents for recommendations, apply them to fix a dataset, "
27
- "hit the accuracy target."
 
 
 
 
28
  ),
29
  docs_url="/docs",
30
  redoc_url="/redoc",
31
  )
32
 
 
 
 
 
 
 
 
33
  VALID_ACTIONS = {
34
  "query_cleaner", "query_augmenter", "query_balancer",
35
- "query_validator", "query_analyst", "apply",
36
  }
37
 
38
 
39
- # ── Request models ─────────────────────────────────────────────────────────────
40
 
41
  class ResetRequest(BaseModel):
42
  difficulty: Optional[str] = None
 
43
 
44
  @field_validator("difficulty")
45
  @classmethod
@@ -60,7 +76,7 @@ class ActionRequest(BaseModel):
60
  def validate_action(cls, v):
61
  if v not in VALID_ACTIONS:
62
  raise ValueError(
63
- f"Invalid action '{v}'. Must be one of: {sorted(VALID_ACTIONS)}"
64
  )
65
  return v
66
 
@@ -72,25 +88,50 @@ class ActionRequest(BaseModel):
72
  return v
73
 
74
 
75
- # ── Endpoints ──────────────────────────────────────────────────────────────────
76
 
77
  @app.post("/reset", summary="Start a new episode")
78
  def reset(body: ResetRequest = None):
 
 
 
 
 
 
 
 
 
79
  difficulty = body.difficulty if body else None
 
80
 
81
- # Create new session + environment
82
- session_id = "pending" # placeholder before create_session
83
  env = DataCentricEnvironment(session_id="pending", episode_count=0)
84
  session_id = session_manager.create_session(env)
85
- env.session_id = session_id # patch in the real ID
86
 
87
- obs = env.reset(difficulty=difficulty)
88
  log_event(logger, "api_reset", session_id=session_id, difficulty=obs.get("difficulty"))
89
  return obs
90
 
91
 
92
- @app.post("/step", summary="Take an action in the environment")
93
  def step(body: ActionRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  env = session_manager.get_env(body.session_id)
95
  if env is None:
96
  raise HTTPException(
@@ -105,28 +146,58 @@ def step(body: ActionRequest):
105
  action_dict["target_class"] = body.target_class
106
 
107
  result = env.step(action_dict)
 
 
 
 
 
108
  session_manager.increment_steps(body.session_id)
109
  return result
110
 
111
 
112
  @app.get("/state/{session_id}", summary="Get current observation")
113
  def state(session_id: str):
 
114
  env = session_manager.get_env(session_id)
115
  if env is None:
116
  raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.")
117
  return env.state()
118
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  @app.get("/health", summary="Health check")
121
  def health():
122
  return {
123
  "status": "ok",
124
  "version": cfg.ENV_VERSION,
125
  "active_sessions": session_manager.metrics()["active_sessions"],
 
 
 
 
 
 
 
126
  }
127
 
128
 
129
- @app.get("/metrics", summary="Episode and session metrics")
130
  def metrics():
131
  return {
132
  "version": cfg.ENV_VERSION,
@@ -134,8 +205,9 @@ def metrics():
134
  "max_budget": cfg.MAX_BUDGET,
135
  "max_concurrent_sessions": cfg.MAX_CONCURRENT_SESSIONS,
136
  "session_ttl_seconds": cfg.SESSION_TTL_SECONDS,
137
- "golden_row_count": cfg.GOLDEN_ROW_COUNT,
138
  "max_same_action_streak": cfg.MAX_SAME_ACTION_STREAK,
 
 
139
  },
140
  "sessions": session_manager.metrics(),
141
  }
 
1
  """
2
+ server/main.py β€” Production FastAPI application (v0.5).
3
 
4
  Endpoints:
5
+ POST /reset β€” Start new episode (returns session_id + full observation)
6
+ POST /step β€” Take action (query | apply | rollback)
7
+ GET /state/{session_id} β€” Current observation
8
+ GET /trajectory/{session_id} β€” Full episode trace with all rewards and effects
9
+ GET /health β€” Health check + version
10
+ GET /metrics β€” Session counts + config
11
+ GET /docs β€” Swagger UI (auto-generated)
12
  """
13
  from fastapi import FastAPI, HTTPException
14
+ from fastapi.middleware.cors import CORSMiddleware
15
  from pydantic import BaseModel, field_validator
16
  from typing import Optional
17
+
18
  from server.environment import DataCentricEnvironment
19
  from server.session_manager import session_manager
20
  from server.config import cfg
 
26
  title="DataCentric-Env",
27
  version=cfg.ENV_VERSION,
28
  description=(
29
+ "RL environment: an LLM acts as a data engineer. "
30
+ "Given a real, messy tabular dataset (UCI Adult, Pima Diabetes, German Credit, etc.), "
31
+ "the agent queries specialist agents for recommendations and applies them to fix the data "
32
+ "until the frozen classifier hits the accuracy target. "
33
+ "All scores compared against published academic baselines.\n\n"
34
+ "**New in v0.5:** Rollback action, episode reasoning trace, feature importance, "
35
+ "regression explanations, benchmark comparisons."
36
  ),
37
  docs_url="/docs",
38
  redoc_url="/redoc",
39
  )
40
 
41
+ app.add_middleware(
42
+ CORSMiddleware,
43
+ allow_origins=["*"],
44
+ allow_methods=["*"],
45
+ allow_headers=["*"],
46
+ )
47
+
48
  VALID_ACTIONS = {
49
  "query_cleaner", "query_augmenter", "query_balancer",
50
+ "query_validator", "query_analyst", "apply", "rollback",
51
  }
52
 
53
 
54
+ # ── Request models ──────────────────────────────────────────────────────────────
55
 
56
  class ResetRequest(BaseModel):
57
  difficulty: Optional[str] = None
58
+ seed: Optional[int] = None
59
 
60
  @field_validator("difficulty")
61
  @classmethod
 
76
  def validate_action(cls, v):
77
  if v not in VALID_ACTIONS:
78
  raise ValueError(
79
+ f"Invalid action '{v}'. Valid: {sorted(VALID_ACTIONS)}"
80
  )
81
  return v
82
 
 
88
  return v
89
 
90
 
91
+ # ── Endpoints ───────────────────────────────────────────────────────────────────
92
 
93
  @app.post("/reset", summary="Start a new episode")
94
  def reset(body: ResetRequest = None):
95
+ """
96
+ Creates a new episode. Returns a `session_id` + full observation.
97
+
98
+ The observation includes:
99
+ - Real dataset name, domain, and known quality issues
100
+ - Current accuracy vs target vs published baseline vs majority-class baseline
101
+ - Dataset statistics (missing %, balance ratio)
102
+ - Available actions
103
+ """
104
  difficulty = body.difficulty if body else None
105
+ seed = body.seed if body else None
106
 
 
 
107
  env = DataCentricEnvironment(session_id="pending", episode_count=0)
108
  session_id = session_manager.create_session(env)
109
+ env.session_id = session_id
110
 
111
+ obs = env.reset(difficulty=difficulty, seed=seed)
112
  log_event(logger, "api_reset", session_id=session_id, difficulty=obs.get("difficulty"))
113
  return obs
114
 
115
 
116
+ @app.post("/step", summary="Take an action")
117
  def step(body: ActionRequest):
118
+ """
119
+ Take one action in the environment.
120
+
121
+ **Query actions** (cost 1-2 budget, return recommendations):
122
+ - `query_cleaner` β€” missing value analysis, domain-aware (knows zeros=missing in medical)
123
+ - `query_augmenter` β€” minority class synthesis via SMOTE-like interpolation
124
+ - `query_balancer` β€” class resampling (oversample or undersample, explains tradeoff)
125
+ - `query_validator` (cost 2) β€” duplicate + outlier detection with domain-appropriate thresholds
126
+ - `query_analyst` (cost 2) β€” holistic diagnosis + prioritized action plan + published baseline reference
127
+
128
+ **Apply action** (cost 0 budget, modifies dataset):
129
+ - `apply` with `rec_id` β€” apply a specific recommendation
130
+ - Returns: feature importance, regression explanation (if accuracy dropped), benchmark comparison
131
+
132
+ **Rollback action** (cost 1 budget, max 3/episode):
133
+ - `rollback` β€” undo last apply, restore previous dataset state
134
+ """
135
  env = session_manager.get_env(body.session_id)
136
  if env is None:
137
  raise HTTPException(
 
146
  action_dict["target_class"] = body.target_class
147
 
148
  result = env.step(action_dict)
149
+
150
+ if "error" in result and "exploit" not in str(result):
151
+ # Log non-exploit errors as warnings (not 500s β€” always return JSON)
152
+ log_event(logger, "step_error", session_id=body.session_id, error=result["error"])
153
+
154
  session_manager.increment_steps(body.session_id)
155
  return result
156
 
157
 
158
  @app.get("/state/{session_id}", summary="Get current observation")
159
  def state(session_id: str):
160
+ """Current full observation including episode trace and benchmarks."""
161
  env = session_manager.get_env(session_id)
162
  if env is None:
163
  raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.")
164
  return env.state()
165
 
166
 
167
+ @app.get("/trajectory/{session_id}", summary="Full episode trajectory")
168
+ def trajectory(session_id: str):
169
+ """
170
+ Returns the complete episode trace: every query, apply, rollback, and exploit event,
171
+ with reward decompositions and accuracy deltas.
172
+
173
+ Useful for:
174
+ - Offline reward model training
175
+ - Debugging why the agent made a particular decision
176
+ - Comparing strategies across episodes
177
+ """
178
+ env = session_manager.get_env(session_id)
179
+ if env is None:
180
+ raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found.")
181
+ return env.episode_summary()
182
+
183
+
184
  @app.get("/health", summary="Health check")
185
  def health():
186
  return {
187
  "status": "ok",
188
  "version": cfg.ENV_VERSION,
189
  "active_sessions": session_manager.metrics()["active_sessions"],
190
+ "real_datasets": [
191
+ "UCI Adult Census",
192
+ "Pima Indians Diabetes",
193
+ "Wisconsin Breast Cancer",
194
+ "German Credit Risk",
195
+ "Cleveland Heart Disease",
196
+ ],
197
  }
198
 
199
 
200
+ @app.get("/metrics", summary="Server metrics")
201
  def metrics():
202
  return {
203
  "version": cfg.ENV_VERSION,
 
205
  "max_budget": cfg.MAX_BUDGET,
206
  "max_concurrent_sessions": cfg.MAX_CONCURRENT_SESSIONS,
207
  "session_ttl_seconds": cfg.SESSION_TTL_SECONDS,
 
208
  "max_same_action_streak": cfg.MAX_SAME_ACTION_STREAK,
209
+ "max_row_deletion_pct": 0.10,
210
+ "max_rollbacks_per_episode": 3,
211
  },
212
  "sessions": session_manager.metrics(),
213
  }