| import sqlite3 |
| import os |
| import logging |
| import re |
| import time |
| import asyncio |
| from typing import Dict, Any, Optional |
| from fastapi import FastAPI, HTTPException, Request, Depends |
| from fastapi.responses import JSONResponse |
| from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials |
| from starlette.exceptions import HTTPException as StarletteHTTPException |
| from pydantic import BaseModel |
| import uvicorn |
|
|
| from .models import Action, Observation, Reward, StepResult, ResetResult |
| from .tasks import TASKS |
|
|
| |
| class SecretMaskingFormatter(logging.Formatter): |
| def format(self, record): |
| msg = super().format(record) |
| msg = re.sub(r'Bearer\s+[A-Za-z0-9_\-]+', 'Bearer ***', msg) |
| hf_token = os.environ.get("HF_TOKEN") |
| if hf_token and hf_token in msg: |
| msg = msg.replace(hf_token, "***") |
| return msg |
|
|
| logger = logging.getLogger("agent_audit") |
| logger.setLevel(logging.INFO) |
| ch = logging.StreamHandler() |
| ch.setFormatter(SecretMaskingFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) |
| logger.addHandler(ch) |
|
|
| security = HTTPBearer() |
| AGENT_API_KEY = os.environ.get("AGENT_API_KEY", "test-agent-key") |
|
|
| class TokenBucket: |
| def __init__(self, capacity: int, fill_rate: float): |
| self.capacity = capacity |
| self.fill_rate = fill_rate |
| self.tokens = capacity |
| self.last_fill = time.time() |
|
|
| def consume(self, tokens: int = 1) -> bool: |
| now = time.time() |
| self.tokens = min(self.capacity, self.tokens + (now - self.last_fill) * self.fill_rate) |
| self.last_fill = now |
| if self.tokens >= tokens: |
| self.tokens -= tokens |
| return True |
| return False |
|
|
| rate_limiter = TokenBucket(capacity=50, fill_rate=50.0/60.0) |
|
|
| async def verify_auth_and_rate_limit(credentials: HTTPAuthorizationCredentials = Depends(security)): |
| if credentials.credentials != AGENT_API_KEY: |
| raise HTTPException(status_code=401, detail="Invalid API Key") |
| if not rate_limiter.consume(1): |
| raise HTTPException(status_code=429, detail="Rate limit exceeded") |
| return credentials.credentials |
|
|
| app = FastAPI(title="OpenEnv SQL Data Engineer") |
| db_semaphore = asyncio.Semaphore(5) |
|
|
| @app.exception_handler(Exception) |
| async def global_exception_handler(request: Request, exc: Exception): |
| logger.error(f"Unhandled exception: {str(exc)}") |
| return JSONResponse(status_code=500, content={"error": "Internal Server Error"}) |
|
|
| @app.middleware("http") |
| async def security_middleware(request: Request, call_next): |
| |
| content_length = request.headers.get("content-length") |
| if content_length and int(content_length) > 1048576: |
| return JSONResponse(status_code=413, content={"error": "Payload Too Large"}) |
| |
| try: |
| response = await asyncio.wait_for(call_next(request), timeout=10.0) |
| return response |
| except asyncio.TimeoutError: |
| return JSONResponse(status_code=504, content={"error": "Gateway Timeout"}) |
|
|
| class SQLEnvironment: |
| def __init__(self): |
| self.conn: Optional[sqlite3.Connection] = None |
| self.task_id = 1 |
| self.step_count = 0 |
| self.current_score = 0.0 |
| self.db_path = "tmp_env.db" |
|
|
| def get_schema_dump(self) -> str: |
| if not self.conn: |
| return "" |
| try: |
| c = self.conn.cursor() |
| c.execute("SELECT type, name, sql FROM sqlite_master WHERE type='table' OR type='view'") |
| rows = c.fetchall() |
| dump = [] |
| for t, name, sql in rows: |
| if name.startswith('sqlite_'): continue |
| dump.append(f"[{t.upper()}] {name}:\n {sql}") |
| |
| return "\n".join(dump) if dump else "Database is empty." |
| except Exception as e: |
| return f"Error extracting schema: {e}" |
|
|
| def reset(self, task_id: int = 1) -> ResetResult: |
| if self.conn: |
| self.conn.close() |
| |
| |
| if os.path.exists(self.db_path): |
| os.remove(self.db_path) |
| |
| self.task_id = task_id |
| if self.task_id not in TASKS: |
| raise ValueError(f"Task ID {self.task_id} not found.") |
| |
| self.conn = sqlite3.connect(self.db_path) |
| self.step_count = 0 |
| self.current_score = 0.0 |
| |
| |
| self.query_start_time = 0 |
| def progress_handler(): |
| if time.time() - self.query_start_time > 5.0: |
| return 1 |
| return 0 |
| self.conn.set_progress_handler(progress_handler, 1000) |
|
|
| |
| |
| def authorizer(action_code, arg1, arg2, dbname, source): |
| |
| |
| if action_code in (11, 16, 17): |
| return sqlite3.SQLITE_DENY |
| return sqlite3.SQLITE_OK |
| self.conn.set_authorizer(authorizer) |
| |
| |
| self.conn.execute("PRAGMA foreign_keys = ON") |
| |
| |
| task = TASKS[self.task_id] |
| task.setup_db(self.conn) |
| self.current_score = task.grade(self.conn) |
| |
| goal_text = task.get_goal() |
| |
| instructions = f"Task Goal: {goal_text}\n" |
| |
| obs = Observation( |
| goal=instructions, |
| result="Environment initialized. Schema ready.", |
| step=self.step_count, |
| last_action_error=False, |
| schema_dump=self.get_schema_dump() |
| ) |
| return ResetResult(observation=obs, info={"task_id": self.task_id, "initial_score": self.current_score}) |
|
|
| def step(self, action: Action) -> StepResult: |
| if not self.conn: |
| raise ValueError("Environment not initialized. Call reset() first.") |
| |
| self.step_count += 1 |
| last_action_error = False |
| query_result = "" |
| |
| try: |
| c = self.conn.cursor() |
| query = action.action_str.strip() |
| |
| |
| blocked_patterns = [r"(?i)DROP\s+DATABASE", r"(?i)pg_sleep", r"(?i)randomblob", r"(?i)ATTACH\s+DATABASE"] |
| for p in blocked_patterns: |
| if re.search(p, query): |
| raise Exception(f"Blocked destructive command pattern detected: {p}") |
| |
| if query.upper().startswith("DROP TABLE sqlite_"): |
| raise Exception("Cannot modify system tables.") |
| |
| self.query_start_time = time.time() |
| c.execute(query) |
| |
| if query.upper().startswith("SELECT") or query.upper().startswith("PRAGMA"): |
| rows = c.fetchmany(10) |
| col_names = [description[0] for description in c.description] if c.description else [] |
| |
| result_str = " | ".join(col_names) + "\n" |
| result_str += "-" * len(result_str) + "\n" |
| for r in rows: |
| result_str += " | ".join([str(val) for val in r]) + "\n" |
| if len(rows) == 10: |
| result_str += "... (output truncated)" |
| query_result = result_str |
| else: |
| self.conn.commit() |
| query_result = f"Command executed successfully. Rowcount: {c.rowcount}" |
| |
| except Exception as e: |
| last_action_error = True |
| query_result = f"SQL Error: {str(e)}" |
| |
| |
| task = TASKS[self.task_id] |
| new_score = task.grade(self.conn) |
| |
| |
| reward_value = new_score - self.current_score |
| |
| if last_action_error: |
| |
| reward_value -= 0.05 |
| |
| self.current_score = new_score |
| |
| |
| done = (self.current_score >= 0.99) or (self.step_count > 30) |
| |
| obs = Observation( |
| goal=task.get_goal(), |
| result=query_result, |
| step=self.step_count, |
| last_action_error=last_action_error, |
| schema_dump=self.get_schema_dump() if not last_action_error else None |
| ) |
| |
| return StepResult( |
| observation=obs, |
| reward=reward_value, |
| done=done, |
| info={"current_score": self.current_score} |
| ) |
|
|
| def state(self) -> Any: |
| |
| return { |
| "task_id": self.task_id, |
| "step": self.step_count, |
| "current_score": self.current_score, |
| "schema_dump": self.get_schema_dump() |
| } |
|
|
| |
| env_instance = SQLEnvironment() |
|
|
| class ResetRequest(BaseModel): |
| task_id: int = 1 |
|
|
| @app.post("/reset", response_model=ResetResult) |
| async def reset(request: Request): |
| task_id = 1 |
| try: |
| data = await request.json() |
| if "task_id" in data: |
| task_id = int(data["task_id"]) |
| except: |
| pass |
|
|
| async with db_semaphore: |
| try: |
| logger.info(f"Resetting environment for task_id: {task_id}") |
| return env_instance.reset(task_id=task_id) |
| except Exception as e: |
| logger.error(f"Reset Error: {str(e)}") |
| raise HTTPException(status_code=400, detail="Error during reset") |
|
|
| @app.post("/step", response_model=StepResult) |
| async def step(action: Action, token: str = Depends(verify_auth_and_rate_limit)): |
| async with db_semaphore: |
| try: |
| start_t = time.time() |
| logger.info(f"Attempting query: {action.action_str}") |
| res = env_instance.step(action) |
| duration = time.time() - start_t |
| logger.info(f"Query completed in {duration:.3f}s. Error state: {res.observation.last_action_error}") |
| return res |
| except Exception as e: |
| logger.error(f"Step Error: {str(e)}") |
| raise HTTPException(status_code=400, detail="Error executing SQL step") |
|
|
| @app.get("/state") |
| async def state(token: str = Depends(verify_auth_and_rate_limit)): |
| async with db_semaphore: |
| return env_instance.state() |
|
|
| def main(): |
| uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False) |
|
|
| if __name__ == "__main__": |
| main() |
|
|