File size: 12,472 Bytes
30cf758
 
 
 
 
 
 
 
6518b31
30cf758
029f9cf
6518b31
30cf758
6518b31
029f9cf
30cf758
029f9cf
30cf758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
029f9cf
 
 
 
30cf758
 
029f9cf
f4ae3f3
 
029f9cf
 
 
 
 
30cf758
 
 
 
029f9cf
 
 
 
30cf758
 
 
 
 
 
 
 
029f9cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30cf758
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6518b31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d061422
 
 
 
 
 
6518b31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d061422
 
 
6518b31
 
d061422
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6518b31
 
 
 
 
 
 
 
 
 
30cf758
 
 
 
 
 
 
 
 
 
 
 
 
 
029f9cf
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
"""
FastAPI server exposing the OpenEnv HTTP API.
Endpoints: POST /reset, POST /step, GET /state
Also includes: GET /tasks (list available tasks), GET /health
"""
import asyncio
import time
import statistics
from typing import Dict, Optional, List, Any
from contextlib import asynccontextmanager
from pathlib import Path
import sqlite3

from fastapi import FastAPI, HTTPException, Header, Body
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel

from .models import SQLDebugAction, SQLDebugObservation, EpisodeState
from .env import SQLDebugEnv, TASKS


# Session management: one env instance per session
# For HF Space: allow up to 64 concurrent sessions
MAX_SESSIONS = 64
_sessions: Dict[str, SQLDebugEnv] = {}
_session_lock = asyncio.Lock()


@asynccontextmanager
async def lifespan(app: FastAPI):
    yield
    # Cleanup all sessions on shutdown
    for env in _sessions.values():
        env.close()


app = FastAPI(
    title="SQL Debug Environment",
    description="OpenEnv-compliant SQL query debugging environment for RL agent training.",
    version="0.1.0",
    lifespan=lifespan
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)

_static_dir = Path(__file__).resolve().parent / "static"
if _static_dir.is_dir():
    app.mount("/static", StaticFiles(directory=str(_static_dir)), name="static")


@app.get("/")
async def space_home():
    """Hugging Face Space opens here — HTML demo first; Gradio lives at /gradio/."""
    return RedirectResponse(url="/demo", status_code=302)


@app.get("/api/info")
async def api_info():
    """Machine-readable index (JSON clients that used to hit `/`)."""
    return {
        "name": "sql-debug-env",
        "status": "ok",
        "message": "Use /health, /tasks, /reset, /step, /state, /benchmark",
        "demo": "/demo",
        "demo_page": "/server/demo_page.html",
        "gradio": "/gradio",
        "info": "/api/info",
    }


@app.get("/favicon.ico", status_code=204)
async def favicon():
    return None


_DEMO_PAGE_PATH = Path(__file__).resolve().parent / "demo_page.html"


def _read_demo_page_html() -> str:
    """Load the Space demo HTML from disk (next to this module)."""
    if not _DEMO_PAGE_PATH.is_file():
        return (
            "<!doctype html><html><body style='font-family:sans-serif;padding:2rem'>"
            "<p><strong>demo_page.html</strong> is missing next to <code>main.py</code>.</p></body></html>"
        )
    return _DEMO_PAGE_PATH.read_text(encoding="utf-8")


@app.get("/demo", response_class=HTMLResponse)
async def demo_page():
    """Submission-ready demo + proof page."""
    return _read_demo_page_html()


@app.get("/server/demo_page.html", response_class=HTMLResponse)
async def demo_page_repo_path():
    """Same page as /demo — URL matches the repo path for HF Space links and bookmarks."""
    return _read_demo_page_html()


class ResetRequest(BaseModel):
    task_id: Optional[str] = "easy_syntax_fix"


class StepRequest(BaseModel):
    action: SQLDebugAction


async def get_or_create_session(session_id: str, task_id: str = "easy_syntax_fix") -> SQLDebugEnv:
    async with _session_lock:
        if session_id not in _sessions:
            if len(_sessions) >= MAX_SESSIONS:
                # Evict oldest session
                oldest = next(iter(_sessions))
                _sessions[oldest].close()
                del _sessions[oldest]
            _sessions[session_id] = SQLDebugEnv(task_id=task_id)
        return _sessions[session_id]


@app.get("/health")
async def health():
    return {"status": "ok", "sessions_active": len(_sessions)}


@app.get("/tasks")
async def list_tasks():
    """List all available tasks with metadata."""
    return {
        "tasks": [task.to_dict() for task in TASKS.values()]
    }


def _stats(values: list[float]) -> Dict[str, float]:
    ordered = sorted(values)
    n = len(ordered)
    p95_idx = max(0, int(n * 0.95) - 1)
    return {
        "avg_ms": round(statistics.mean(ordered), 3),
        "p50_ms": round(statistics.median(ordered), 3),
        "p95_ms": round(ordered[p95_idx], 3),
        "n": n,
    }


@app.get("/benchmark")
async def benchmark(runs: int = 20):
    """
    Real-time benchmark endpoint (fresh measurements on every call).
    Safe to call from dashboards/web pages for live verification.
    """
    runs = max(1, min(runs, 100))

    health_times: list[float] = []
    tasks_times: list[float] = []
    reset_times: list[float] = []
    step_times: list[float] = []

    bench_env = SQLDebugEnv(task_id="easy_syntax_fix")
    try:
        for _ in range(runs):
            t0 = time.perf_counter()
            _ = {"status": "ok", "sessions_active": len(_sessions)}
            health_times.append((time.perf_counter() - t0) * 1000)

            t0 = time.perf_counter()
            _ = [task.to_dict() for task in TASKS.values()]
            tasks_times.append((time.perf_counter() - t0) * 1000)

            t0 = time.perf_counter()
            await bench_env.reset()
            reset_times.append((time.perf_counter() - t0) * 1000)

            t0 = time.perf_counter()
            await bench_env.step(SQLDebugAction(action_type="inspect_schema"))
            step_times.append((time.perf_counter() - t0) * 1000)
    finally:
        bench_env.close()

    return {
        "benchmark": {
            "runs": runs,
            "task_id": "easy_syntax_fix",
            "timestamp_epoch_ms": int(time.time() * 1000),
            "results": {
                "health": _stats(health_times),
                "tasks": _stats(tasks_times),
                "reset": _stats(reset_times),
                "step_inspect_schema": _stats(step_times),
            },
        }
    }


@app.post("/reset")
async def reset(
    request: ResetRequest = ResetRequest(),
    x_session_id: Optional[str] = Header(default=None)
):
    """
    Reset the environment for a new episode.

    Returns initial observation with task description and broken query.
    """
    session_id = x_session_id or "default"
    task_id = request.task_id or "easy_syntax_fix"

    if task_id not in TASKS:
        raise HTTPException(status_code=400, detail=f"Unknown task_id: {task_id}. Valid: {list(TASKS.keys())}")

    # Always create fresh env on reset
    async with _session_lock:
        if session_id in _sessions:
            _sessions[session_id].close()
        _sessions[session_id] = SQLDebugEnv(task_id=task_id)

    env = _sessions[session_id]
    observation, info = await env.reset()

    return {
        "observation": observation.model_dump(),
        "info": info,
        "reward": None,
        "done": False
    }


@app.post("/step")
async def step(
    request: StepRequest,
    x_session_id: Optional[str] = Header(default=None)
):
    """
    Execute one action in the environment.

    Action types:
    - submit_query: Submit SQL for evaluation (requires 'query' field)
    - inspect_schema: Get table schema (free action)
    - inspect_error: Get last error message (free action)
    - inspect_sample: Get sample rows from table (requires 'table_name')
    - reset_query: Reset to original broken query (small penalty)
    """
    session_id = x_session_id or "default"

    if session_id not in _sessions:
        raise HTTPException(status_code=400, detail="Session not found. Call /reset first.")

    env = _sessions[session_id]

    try:
        observation, reward, done, info = await env.step(request.action)
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))
    except ValueError as e:
        raise HTTPException(status_code=422, detail=str(e))

    return {
        "observation": observation.model_dump(),
        "reward": reward,
        "done": done,
        "info": info
    }


@app.post("/step_with_review")
async def step_with_review(
    request: StepRequest,
    x_session_id: Optional[str] = Header(default=None)
):
    """
    Execute a step with a Reviewer Agent layer.
    If the action is a query submission, the Reviewer validates it first.
    """
    session_id = x_session_id or "default"
    if session_id not in _sessions:
        raise HTTPException(status_code=400, detail="Session not found. Call /reset first.")
    
    env = _sessions[session_id]
    action = request.action

    if action.action_type == "submit_query" and action.query:
        # Reviewer checks the query before execution
        state = env.get_state()
        review = reviewer_check(action.query, state.db_schema or {})
        
        if not review["approved"]:
            # Reviewer rejected — return feedback without executing
            # Keep reward in strict (0, 1) range for OpenEnv compatibility
            reward = 0.001
            obs = env.to_observation(
                last_action_type="review_rejected",
                error_details=f"REVIEWER REJECTION: {review['reason']}",
            )
            
            return {
                "observation": obs.model_dump(),
                "reward": reward,
                "done": False,
                "info": {"review_rejected": True, "reason": review["reason"]}
            }

    # If approved or not a query, proceed to normal step
    try:
        observation, reward, done, info = await env.step(action)
    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

    return {
        "observation": observation.model_dump(),
        "reward": reward,
        "done": done,
        "info": info
    }


def reviewer_check(query: str, schema: Dict[str, Any]) -> Dict[str, Any]:
    """
    Simple rule-based Reviewer Agent.
    Checks:
    1. Table existence
    2. Read-only (SELECT/WITH)
    3. Basic SQLite syntax (EXPLAIN)
    """
    query_upper = query.upper().strip()
    
    # Check 1: Is it a read query?
    if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")):
        return {"approved": False, "reason": "Only SELECT queries or CTEs (WITH) are allowed."}

    # Check 2: Does it reference valid tables?
    tables = list(schema.keys())
    referenced = [t for t in tables if t.upper() in query_upper]
    if not referenced and tables:
        return {"approved": False, "reason": f"Query does not reference any valid tables. Available: {tables}"}

    # Check 3: Syntax check via EXPLAIN on a lightweight schema stub.
    # Build minimal CREATE TABLE statements from the provided schema so EXPLAIN
    # doesn't fail with "no such table" for otherwise-valid queries.
    try:
        conn = sqlite3.connect(":memory:")
        for table_name, columns in (schema or {}).items():
            if not columns:
                continue
            col_defs = []
            for col in columns:
                name = col.get("name", "col")
                col_type = col.get("type", "TEXT")
                nullable = col.get("nullable")
                not_null = " NOT NULL" if str(nullable).upper() == "NO" else ""
                col_defs.append(f"{name} {col_type}{not_null}")
            cols_sql = ", ".join(col_defs) if col_defs else "id INTEGER"
            conn.execute(f"CREATE TABLE IF NOT EXISTS {table_name} ({cols_sql})")

        # We don't have the actual data here, but EXPLAIN is sufficient for
        # catching syntax errors and many semantic issues.
        conn.execute(f"EXPLAIN {query}")
        conn.close()
    except sqlite3.OperationalError as e:
        return {"approved": False, "reason": f"Syntax error caught by Reviewer: {e}"}
    except Exception as e:
        return {"approved": False, "reason": f"Reviewer error: {e}"}

    return {"approved": True, "reason": "Query approved"}


@app.get("/state")
async def state(x_session_id: Optional[str] = Header(default=None)):
    """Return current full episode state."""
    session_id = x_session_id or "default"

    if session_id not in _sessions:
        raise HTTPException(status_code=400, detail="No active session. Call /reset first.")

    env = _sessions[session_id]
    try:
        current_state = env.get_state()
        return current_state.model_dump()
    except RuntimeError as e:
        raise HTTPException(status_code=400, detail=str(e))


# Gradio UI on the same Space (mounted after all API routes)
from .gradio_ui import mount_gradio

app = mount_gradio(app, _static_dir)