File size: 10,932 Bytes
de07414 7d84930 de07414 7d84930 a2db70a de07414 7d84930 de07414 7d84930 a2db70a 7d84930 de07414 7d84930 de07414 7d84930 de07414 7d84930 de07414 1314b5a de07414 a2db70a 7d84930 a2db70a 7d84930 de07414 7d84930 de07414 7d84930 de07414 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 | import sqlite3
import os
import logging
import re
import time
import asyncio
from typing import Dict, Any, Optional
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from starlette.exceptions import HTTPException as StarletteHTTPException
from pydantic import BaseModel
import uvicorn
from .models import Action, Observation, Reward, StepResult, ResetResult
from .tasks import TASKS
# -- Security & Telemetry Setup --
class SecretMaskingFormatter(logging.Formatter):
def format(self, record):
msg = super().format(record)
msg = re.sub(r'Bearer\s+[A-Za-z0-9_\-]+', 'Bearer ***', msg)
hf_token = os.environ.get("HF_TOKEN")
if hf_token and hf_token in msg:
msg = msg.replace(hf_token, "***")
return msg
logger = logging.getLogger("agent_audit")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setFormatter(SecretMaskingFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(ch)
security = HTTPBearer()
AGENT_API_KEY = os.environ.get("AGENT_API_KEY", "test-agent-key")
class TokenBucket:
def __init__(self, capacity: int, fill_rate: float):
self.capacity = capacity
self.fill_rate = fill_rate
self.tokens = capacity
self.last_fill = time.time()
def consume(self, tokens: int = 1) -> bool:
now = time.time()
self.tokens = min(self.capacity, self.tokens + (now - self.last_fill) * self.fill_rate)
self.last_fill = now
if self.tokens >= tokens:
self.tokens -= tokens
return True
return False
rate_limiter = TokenBucket(capacity=50, fill_rate=50.0/60.0)
async def verify_auth_and_rate_limit(credentials: HTTPAuthorizationCredentials = Depends(security)):
if credentials.credentials != AGENT_API_KEY:
raise HTTPException(status_code=401, detail="Invalid API Key")
if not rate_limiter.consume(1):
raise HTTPException(status_code=429, detail="Rate limit exceeded")
return credentials.credentials
app = FastAPI(title="OpenEnv SQL Data Engineer")
db_semaphore = asyncio.Semaphore(5)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled exception: {str(exc)}")
return JSONResponse(status_code=500, content={"error": "Internal Server Error"})
@app.middleware("http")
async def security_middleware(request: Request, call_next):
# Max payload limit ~ 1MB
content_length = request.headers.get("content-length")
if content_length and int(content_length) > 1048576:
return JSONResponse(status_code=413, content={"error": "Payload Too Large"})
try:
response = await asyncio.wait_for(call_next(request), timeout=10.0)
return response
except asyncio.TimeoutError:
return JSONResponse(status_code=504, content={"error": "Gateway Timeout"})
class SQLEnvironment:
def __init__(self):
self.conn: Optional[sqlite3.Connection] = None
self.task_id = 1
self.step_count = 0
self.current_score = 0.0
self.db_path = "tmp_env.db"
def get_schema_dump(self) -> str:
if not self.conn:
return ""
try:
c = self.conn.cursor()
c.execute("SELECT type, name, sql FROM sqlite_master WHERE type='table' OR type='view'")
rows = c.fetchall()
dump = []
for t, name, sql in rows:
if name.startswith('sqlite_'): continue
dump.append(f"[{t.upper()}] {name}:\n {sql}")
return "\n".join(dump) if dump else "Database is empty."
except Exception as e:
return f"Error extracting schema: {e}"
def reset(self, task_id: int = 1) -> ResetResult:
if self.conn:
self.conn.close()
# Clean up existing temp db
if os.path.exists(self.db_path):
os.remove(self.db_path)
self.task_id = task_id
if self.task_id not in TASKS:
raise ValueError(f"Task ID {self.task_id} not found.")
self.conn = sqlite3.connect(self.db_path)
self.step_count = 0
self.current_score = 0.0
# Security: 5 second timeout on queries
self.query_start_time = 0
def progress_handler():
if time.time() - self.query_start_time > 5.0:
return 1 # Abort query
return 0
self.conn.set_progress_handler(progress_handler, 1000)
# Security: Simple Authorizer to block DROP TABLE and explicit destructive system mods
# We allow standard DDL since the agent gets asked to CREATE VIEW, but restrict DROP
def authorizer(action_code, arg1, arg2, dbname, source):
# 11 = SQLITE_DROP_TABLE, 16 = SQLITE_DROP_VIEW
# 17 = SQLITE_ATTACH (blocks ATTACH DATABASE)
if action_code in (11, 16, 17):
return sqlite3.SQLITE_DENY
return sqlite3.SQLITE_OK
self.conn.set_authorizer(authorizer)
# Initialize standard SQLite settings
self.conn.execute("PRAGMA foreign_keys = ON")
# Setup specific task data
task = TASKS[self.task_id]
task.setup_db(self.conn)
self.current_score = task.grade(self.conn) # Base score
goal_text = task.get_goal()
# Add basic info about actual task goal
instructions = f"Task Goal: {goal_text}\n"
obs = Observation(
goal=instructions,
result="Environment initialized. Schema ready.",
step=self.step_count,
last_action_error=False,
schema_dump=self.get_schema_dump()
)
return ResetResult(observation=obs, info={"task_id": self.task_id, "initial_score": self.current_score})
def step(self, action: Action) -> StepResult:
if not self.conn:
raise ValueError("Environment not initialized. Call reset() first.")
self.step_count += 1
last_action_error = False
query_result = ""
try:
c = self.conn.cursor()
query = action.action_str.strip()
# Injection Mitigations at parser level
blocked_patterns = [r"(?i)DROP\s+DATABASE", r"(?i)pg_sleep", r"(?i)randomblob", r"(?i)ATTACH\s+DATABASE"]
for p in blocked_patterns:
if re.search(p, query):
raise Exception(f"Blocked destructive command pattern detected: {p}")
if query.upper().startswith("DROP TABLE sqlite_"):
raise Exception("Cannot modify system tables.")
self.query_start_time = time.time()
c.execute(query)
if query.upper().startswith("SELECT") or query.upper().startswith("PRAGMA"):
rows = c.fetchmany(10) # limit output size for LLM observation
col_names = [description[0] for description in c.description] if c.description else []
# Format tabular output
result_str = " | ".join(col_names) + "\n"
result_str += "-" * len(result_str) + "\n"
for r in rows:
result_str += " | ".join([str(val) for val in r]) + "\n"
if len(rows) == 10:
result_str += "... (output truncated)"
query_result = result_str
else:
self.conn.commit()
query_result = f"Command executed successfully. Rowcount: {c.rowcount}"
except Exception as e:
last_action_error = True
query_result = f"SQL Error: {str(e)}"
# Run grader
task = TASKS[self.task_id]
new_score = task.grade(self.conn)
# Reward is dense: change in score + small penalty for errors
reward_value = new_score - self.current_score
if last_action_error:
# minor penalty for syntax errors
reward_value -= 0.05
self.current_score = new_score
# Episode terminates when score is 0.99
done = (self.current_score >= 0.99) or (self.step_count > 30)
obs = Observation(
goal=task.get_goal(),
result=query_result,
step=self.step_count,
last_action_error=last_action_error,
schema_dump=self.get_schema_dump() if not last_action_error else None # only dump if no error to save tokens
)
return StepResult(
observation=obs,
reward=reward_value,
done=done,
info={"current_score": self.current_score}
)
def state(self) -> Any:
# Return state as unstructured dict per standard API
return {
"task_id": self.task_id,
"step": self.step_count,
"current_score": self.current_score,
"schema_dump": self.get_schema_dump()
}
# Global instance
env_instance = SQLEnvironment()
class ResetRequest(BaseModel):
task_id: int = 1
@app.post("/reset", response_model=ResetResult)
async def reset(request: Request):
task_id = 1
try:
data = await request.json()
if "task_id" in data:
task_id = int(data["task_id"])
except:
pass # gracefully handle missing json body (curl -d '{}')
async with db_semaphore:
try:
logger.info(f"Resetting environment for task_id: {task_id}")
return env_instance.reset(task_id=task_id)
except Exception as e:
logger.error(f"Reset Error: {str(e)}")
raise HTTPException(status_code=400, detail="Error during reset")
@app.post("/step", response_model=StepResult)
async def step(action: Action, token: str = Depends(verify_auth_and_rate_limit)):
async with db_semaphore:
try:
start_t = time.time()
logger.info(f"Attempting query: {action.action_str}")
res = env_instance.step(action)
duration = time.time() - start_t
logger.info(f"Query completed in {duration:.3f}s. Error state: {res.observation.last_action_error}")
return res
except Exception as e:
logger.error(f"Step Error: {str(e)}")
raise HTTPException(status_code=400, detail="Error executing SQL step")
@app.get("/state")
async def state(token: str = Depends(verify_auth_and_rate_limit)):
async with db_semaphore:
return env_instance.state()
def main():
uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)
if __name__ == "__main__":
main()
|