File size: 10,932 Bytes
de07414
 
7d84930
 
 
 
de07414
7d84930
 
 
a2db70a
de07414
 
 
 
 
 
7d84930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de07414
7d84930
 
a2db70a
 
 
 
 
7d84930
 
 
 
 
 
 
 
 
 
 
 
de07414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d84930
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de07414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d84930
 
 
 
 
 
 
de07414
 
 
7d84930
de07414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1314b5a
 
de07414
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a2db70a
 
 
 
 
 
 
 
 
7d84930
 
a2db70a
 
7d84930
 
 
de07414
 
7d84930
 
 
 
 
 
 
 
 
 
 
 
de07414
 
7d84930
 
 
de07414
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import sqlite3
import os
import logging
import re
import time
import asyncio
from typing import Dict, Any, Optional
from fastapi import FastAPI, HTTPException, Request, Depends
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from starlette.exceptions import HTTPException as StarletteHTTPException
from pydantic import BaseModel
import uvicorn

from .models import Action, Observation, Reward, StepResult, ResetResult
from .tasks import TASKS

# -- Security & Telemetry Setup --
class SecretMaskingFormatter(logging.Formatter):
    def format(self, record):
        msg = super().format(record)
        msg = re.sub(r'Bearer\s+[A-Za-z0-9_\-]+', 'Bearer ***', msg)
        hf_token = os.environ.get("HF_TOKEN")
        if hf_token and hf_token in msg:
            msg = msg.replace(hf_token, "***")
        return msg

logger = logging.getLogger("agent_audit")
logger.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setFormatter(SecretMaskingFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(ch)

security = HTTPBearer()
AGENT_API_KEY = os.environ.get("AGENT_API_KEY", "test-agent-key")

class TokenBucket:
    def __init__(self, capacity: int, fill_rate: float):
        self.capacity = capacity
        self.fill_rate = fill_rate
        self.tokens = capacity
        self.last_fill = time.time()

    def consume(self, tokens: int = 1) -> bool:
        now = time.time()
        self.tokens = min(self.capacity, self.tokens + (now - self.last_fill) * self.fill_rate)
        self.last_fill = now
        if self.tokens >= tokens:
            self.tokens -= tokens
            return True
        return False

rate_limiter = TokenBucket(capacity=50, fill_rate=50.0/60.0)

async def verify_auth_and_rate_limit(credentials: HTTPAuthorizationCredentials = Depends(security)):
    if credentials.credentials != AGENT_API_KEY:
        raise HTTPException(status_code=401, detail="Invalid API Key")
    if not rate_limiter.consume(1):
        raise HTTPException(status_code=429, detail="Rate limit exceeded")
    return credentials.credentials

app = FastAPI(title="OpenEnv SQL Data Engineer")
db_semaphore = asyncio.Semaphore(5)

@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
    logger.error(f"Unhandled exception: {str(exc)}")
    return JSONResponse(status_code=500, content={"error": "Internal Server Error"})

@app.middleware("http")
async def security_middleware(request: Request, call_next):
    # Max payload limit ~ 1MB
    content_length = request.headers.get("content-length")
    if content_length and int(content_length) > 1048576:
        return JSONResponse(status_code=413, content={"error": "Payload Too Large"})
    
    try:
        response = await asyncio.wait_for(call_next(request), timeout=10.0)
        return response
    except asyncio.TimeoutError:
        return JSONResponse(status_code=504, content={"error": "Gateway Timeout"})

class SQLEnvironment:
    def __init__(self):
        self.conn: Optional[sqlite3.Connection] = None
        self.task_id = 1
        self.step_count = 0
        self.current_score = 0.0
        self.db_path = "tmp_env.db"

    def get_schema_dump(self) -> str:
        if not self.conn:
            return ""
        try:
            c = self.conn.cursor()
            c.execute("SELECT type, name, sql FROM sqlite_master WHERE type='table' OR type='view'")
            rows = c.fetchall()
            dump = []
            for t, name, sql in rows:
                if name.startswith('sqlite_'): continue
                dump.append(f"[{t.upper()}] {name}:\n  {sql}")
                
            return "\n".join(dump) if dump else "Database is empty."
        except Exception as e:
            return f"Error extracting schema: {e}"

    def reset(self, task_id: int = 1) -> ResetResult:
        if self.conn:
            self.conn.close()
        
        # Clean up existing temp db
        if os.path.exists(self.db_path):
            os.remove(self.db_path)
            
        self.task_id = task_id
        if self.task_id not in TASKS:
            raise ValueError(f"Task ID {self.task_id} not found.")
            
        self.conn = sqlite3.connect(self.db_path)
        self.step_count = 0
        self.current_score = 0.0
        
        # Security: 5 second timeout on queries
        self.query_start_time = 0
        def progress_handler():
            if time.time() - self.query_start_time > 5.0:
                return 1 # Abort query
            return 0
        self.conn.set_progress_handler(progress_handler, 1000)

        # Security: Simple Authorizer to block DROP TABLE and explicit destructive system mods
        # We allow standard DDL since the agent gets asked to CREATE VIEW, but restrict DROP
        def authorizer(action_code, arg1, arg2, dbname, source):
            # 11 = SQLITE_DROP_TABLE, 16 = SQLITE_DROP_VIEW
            # 17 = SQLITE_ATTACH (blocks ATTACH DATABASE)
            if action_code in (11, 16, 17):
                return sqlite3.SQLITE_DENY
            return sqlite3.SQLITE_OK
        self.conn.set_authorizer(authorizer)
        
        # Initialize standard SQLite settings
        self.conn.execute("PRAGMA foreign_keys = ON")
        
        # Setup specific task data
        task = TASKS[self.task_id]
        task.setup_db(self.conn)
        self.current_score = task.grade(self.conn) # Base score
        
        goal_text = task.get_goal()
        # Add basic info about actual task goal
        instructions = f"Task Goal: {goal_text}\n"
        
        obs = Observation(
            goal=instructions,
            result="Environment initialized. Schema ready.",
            step=self.step_count,
            last_action_error=False,
            schema_dump=self.get_schema_dump()
        )
        return ResetResult(observation=obs, info={"task_id": self.task_id, "initial_score": self.current_score})

    def step(self, action: Action) -> StepResult:
        if not self.conn:
            raise ValueError("Environment not initialized. Call reset() first.")
            
        self.step_count += 1
        last_action_error = False
        query_result = ""
        
        try:
            c = self.conn.cursor()
            query = action.action_str.strip()
            
            # Injection Mitigations at parser level
            blocked_patterns = [r"(?i)DROP\s+DATABASE", r"(?i)pg_sleep", r"(?i)randomblob", r"(?i)ATTACH\s+DATABASE"]
            for p in blocked_patterns:
                if re.search(p, query):
                    raise Exception(f"Blocked destructive command pattern detected: {p}")
                    
            if query.upper().startswith("DROP TABLE sqlite_"):
                raise Exception("Cannot modify system tables.")
                
            self.query_start_time = time.time()
            c.execute(query)
            
            if query.upper().startswith("SELECT") or query.upper().startswith("PRAGMA"):
                rows = c.fetchmany(10) # limit output size for LLM observation
                col_names = [description[0] for description in c.description] if c.description else []
                # Format tabular output
                result_str = " | ".join(col_names) + "\n"
                result_str += "-" * len(result_str) + "\n"
                for r in rows:
                    result_str += " | ".join([str(val) for val in r]) + "\n"
                if len(rows) == 10:
                    result_str += "... (output truncated)"
                query_result = result_str
            else:
                self.conn.commit()
                query_result = f"Command executed successfully. Rowcount: {c.rowcount}"
                
        except Exception as e:
            last_action_error = True
            query_result = f"SQL Error: {str(e)}"
            
        # Run grader
        task = TASKS[self.task_id]
        new_score = task.grade(self.conn)
        
        # Reward is dense: change in score + small penalty for errors
        reward_value = new_score - self.current_score
        
        if last_action_error:
            # minor penalty for syntax errors
            reward_value -= 0.05
            
        self.current_score = new_score
        
        # Episode terminates when score is 0.99
        done = (self.current_score >= 0.99) or (self.step_count > 30)
        
        obs = Observation(
            goal=task.get_goal(),
            result=query_result,
            step=self.step_count,
            last_action_error=last_action_error,
            schema_dump=self.get_schema_dump() if not last_action_error else None # only dump if no error to save tokens
        )
        
        return StepResult(
            observation=obs,
            reward=reward_value,
            done=done,
            info={"current_score": self.current_score}
        )

    def state(self) -> Any:
        # Return state as unstructured dict per standard API
        return {
            "task_id": self.task_id,
            "step": self.step_count,
            "current_score": self.current_score,
            "schema_dump": self.get_schema_dump()
        }

# Global instance
env_instance = SQLEnvironment()

class ResetRequest(BaseModel):
    task_id: int = 1

@app.post("/reset", response_model=ResetResult)
async def reset(request: Request):
    task_id = 1
    try:
        data = await request.json()
        if "task_id" in data:
            task_id = int(data["task_id"])
    except:
        pass # gracefully handle missing json body (curl -d '{}')

    async with db_semaphore:
        try:
            logger.info(f"Resetting environment for task_id: {task_id}")
            return env_instance.reset(task_id=task_id)
        except Exception as e:
            logger.error(f"Reset Error: {str(e)}")
            raise HTTPException(status_code=400, detail="Error during reset")

@app.post("/step", response_model=StepResult)
async def step(action: Action, token: str = Depends(verify_auth_and_rate_limit)):
    async with db_semaphore:
        try:
            start_t = time.time()
            logger.info(f"Attempting query: {action.action_str}")
            res = env_instance.step(action)
            duration = time.time() - start_t
            logger.info(f"Query completed in {duration:.3f}s. Error state: {res.observation.last_action_error}")
            return res
        except Exception as e:
            logger.error(f"Step Error: {str(e)}")
            raise HTTPException(status_code=400, detail="Error executing SQL step")

@app.get("/state")
async def state(token: str = Depends(verify_auth_and_rate_limit)):
    async with db_semaphore:
        return env_instance.state()

def main():
    uvicorn.run("server.app:app", host="0.0.0.0", port=7860, reload=False)

if __name__ == "__main__":
    main()