""" app.py ------ FastAPI server exposing the OpenEnv HTTP interface. Endpoints: POST /reset – start a new episode POST /step – take one action GET /state – inspect internal state (debugging) GET /tasks – list available tasks GET /health – liveness probe GET /action_space – action space description for a task GET /observation_space – observation space description Sessions are keyed by a UUID in the `session_id` query parameter. If omitted, "default" is used (fine for sequential single-agent runs). """ from typing import Dict, Optional from fastapi import FastAPI, HTTPException, Query from pydantic import BaseModel from env.schemas import Action, ActionType, TaskInfo from tasks.task1 import Task1Environment from tasks.task2 import Task2Environment from tasks.task3 import Task3Environment # ───────────────────────────────────────────────────────────────────────────── # App # ───────────────────────────────────────────────────────────────────────────── app = FastAPI( title="Smart Contract Audit RL Environment", description=( "OpenEnv-compliant reinforcement learning environment for smart contract " "security analysis. Train and evaluate agents on real-world Solidity audit tasks." ), version="1.2.0", ) # ───────────────────────────────────────────────────────────────────────────── # Session management # ───────────────────────────────────────────────────────────────────────────── _sessions: Dict[str, object] = {} DEFAULT_SESSION = "default" TASK_ENV_MAP = { "task1_vuln_detection": Task1Environment, "task2_property_discovery": Task2Environment, "task3_rule_checker": Task3Environment, } def _create_env(task_id: str): cls = TASK_ENV_MAP.get(task_id) if cls is None: raise HTTPException( status_code=400, detail=f"Unknown task_id '{task_id}'. Available: {list(TASK_ENV_MAP)}", ) return cls() # ───────────────────────────────────────────────────────────────────────────── # Request bodies # ───────────────────────────────────────────────────────────────────────────── class ResetRequest(BaseModel): task_id: str = "task1_vuln_detection" seed: Optional[int] = None class StepRequest(BaseModel): action_type: str params: dict = {} # ───────────────────────────────────────────────────────────────────────────── # Routes # ───────────────────────────────────────────────────────────────────────────── @app.get("/health") def health(): """Liveness probe.""" return {"status": "ok", "version": "1.1.0"} @app.get("/tasks") def list_tasks(): """List all tasks with their status.""" tasks = [ TaskInfo( task_id="task1_vuln_detection", name="Targeted Vulnerability Detection", difficulty="medium", description="Given a Solidity contract, identify the vulnerable function and describe the vulnerability type in 2-3 words.", status="active", ), TaskInfo( task_id="task2_property_discovery", name="Property Discovery", difficulty="hard", description="Given a Solidity function, write the natural-language property that describes its correct behaviour.", status="active", ), TaskInfo( task_id="task3_rule_checker", name="Rule Checker", difficulty="easy", description="Given a property in English and a Solidity contract, identify which function violates that property.", status="active", ), ] return {"tasks": [t.model_dump() for t in tasks]} @app.post("/reset") def reset( body: ResetRequest, session_id: str = Query(default=DEFAULT_SESSION), ): """Reset the environment and start a new episode.""" env = _create_env(body.task_id) _sessions[session_id] = env result = env.reset(seed=body.seed) return result.model_dump() @app.post("/step") def step( body: StepRequest, session_id: str = Query(default=DEFAULT_SESSION), ): """Apply one action and advance the episode.""" env = _sessions.get(session_id) if env is None: raise HTTPException( status_code=400, detail=f"No active session '{session_id}'. Call /reset first.", ) try: action_type = ActionType(body.action_type) except ValueError: raise HTTPException( status_code=400, detail=f"Unknown action_type '{body.action_type}'. Valid: {[a.value for a in ActionType]}", ) action = Action(action_type=action_type, params=body.params) try: result = env.step(action) except RuntimeError as e: raise HTTPException(status_code=409, detail=str(e)) return result.model_dump() @app.get("/state") def state(session_id: str = Query(default=DEFAULT_SESSION)): """Return internal state for debugging (not for agents).""" env = _sessions.get(session_id) if env is None: raise HTTPException( status_code=400, detail=f"No active session '{session_id}'. Call /reset first.", ) return env.state().model_dump() @app.get("/action_space") def action_space(task_id: str = "task1_vuln_detection"): """Describe the action space for a task.""" if task_id == "task1_vuln_detection": return { "task_id": task_id, "actions": [ {"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"}, {"type": "get_function_code", "params": {"function_name": "string"}, "reward": "+0.05 (target) / -0.10 (other)", "description": "Get full Solidity source of a function"}, {"type": "get_function_summary", "params": {"function_name": "string"}, "reward": "+0.03 (target) / -0.05 (other)", "description": "Get NatSpec comment of a function"}, {"type": "get_file_metadata", "params": {}, "reward": -0.04, "description": "Get contract-level metadata"}, {"type": "get_state_variable", "params": {"variable_name": "string (optional)"}, "reward": -0.05, "description": "Get a state variable or list all"}, {"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"}, {"type": "submit", "params": {"function_name": "str", "vulnerability_type": "str"},"reward": "+5.0 / +1.0 / -1.5", "description": "Submit answer. Ends episode."}, ], } if task_id == "task2_property_discovery": return { "task_id": task_id, "actions": [ {"type": "get_function_code", "params": {}, "reward": -0.06, "description": "Read full source of the target function"}, {"type": "get_function_natspec", "params": {}, "reward": -0.08, "description": "Read NatSpec + expected behaviour"}, {"type": "get_file_natspec", "params": {}, "reward": -0.03, "description": "Read contract-level NatSpec"}, {"type": "get_related_functions", "params": {}, "reward": -0.06, "description": "List caller/callee functions with summaries"}, {"type": "get_signature", "params": {}, "reward": -0.04, "description": "Get structured I/O + expected behaviour"}, {"type": "get_similar_rule", "params": {}, "reward": -0.20, "description": "Get a similar property from another contract"}, {"type": "submit_property", "params": {"property": "string"}, "reward": "0.0–5.0 (scored)", "description": "Submit property. ONE attempt. Ends episode."}, ], } if task_id == "task3_rule_checker": return { "task_id": task_id, "actions": [ {"type": "list_functions", "params": {}, "reward": -0.05, "description": "List all function names"}, {"type": "get_function_metadata", "params": {"function_name": "string"}, "reward": -0.05, "description": "Get signature, visibility, params of a function"}, {"type": "get_function_code", "params": {"function_name": "string"}, "reward": -0.10, "description": "Read full Solidity source of a function"}, {"type": "get_state_variable", "params": {"variable_name": "string (opt)"}, "reward": -0.05, "description": "Get a state variable or list all"}, {"type": "get_call_graph", "params": {}, "reward": -0.08, "description": "Get function call graph"}, {"type": "get_property_specification", "params": {}, "reward": -0.03, "description": "Get formal pre/post-condition for the property"}, {"type": "submit_function", "params": {"function_name": "string"}, "reward": "+5.0 / +1.5 / -1.5", "description": "Submit answer. ONE attempt. Ends episode."}, ], } return {"error": f"No action space defined for task '{task_id}'"} @app.get("/observation_space") def observation_space(): """Describe the observation space (same for all tasks).""" return { "type": "object", "fields": { "task_id": "string – active task identifier", "contract_name": "string – Solidity contract name", "contract_description": "string – what the contract does", "available_actions": "list[string] – valid action types for this task", "last_action": "string|null – previous action type", "last_action_result": "string|null – human-readable result of last action", "step_count": "int – steps taken in this episode", "cumulative_reward": "float – running reward total", "done": "bool – True when episode is over", "extra": "object – task-specific hints (target_function, hint, etc.)", }, } # ───────────────────────────────────────────────────────────────────────────── # Entry point # ───────────────────────────────────────────────────────────────────────────── if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)