""" environment.py (Task 1 – Targeted Vulnerability Detection) ------------------------------------------------------------ Full OpenEnv-compliant environment. Episode flow: 1. reset() selects a random (contract, vulnerable_function) pair. 2. The agent receives an Observation with the contract description. 3. The agent uses actions to explore the contract (each costs a small penalty). 4. When the agent submits, the Grader scores the answer and the episode ends. """ from __future__ import annotations from math import floor, log2 import random from typing import Any, Dict, List, Optional, Set from data.data_loader import load_contracts, sample_episode from env.base_env import BaseEnv from env.schemas import ( Action, ActionType, Observation, Reward, ResetResult, StateResult, StepResult, ) from server.tasks.task1 import actions from .grader import Task1Grader TASK_ID = "task1_vuln_detection" AVAILABLE_ACTIONS = [ ActionType.LIST_FUNCTIONS, ActionType.GET_FUNCTION_CODE, ActionType.GET_FUNCTION_SUMMARY, ActionType.GET_FILE_METADATA, ActionType.GET_STATE_VARIABLE, ActionType.GET_CALL_GRAPH, ActionType.SUBMIT, ] class Task1Environment(BaseEnv): """Task 1: Targeted Vulnerability Detection.""" def __init__(self, contracts_path: Optional[str] = None) -> None: self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts() self._rng = random.Random() self._max_steps: int = 40 # Episode state (initialised by reset) self._contract: Dict[str, Any] = {} self._target_fn: Dict[str, Any] = {} self._grader: Optional[Task1Grader] = None self._step_count: int = 0 self._cummulative_cost: float = 0.0 self._done: bool = False self._query_history: List[str] = [] self._seen_queries: Set[str] = set() # ------------------------------------------------------------------ # OpenEnv interface # ------------------------------------------------------------------ def reset(self, seed: Optional[int] = None) -> ResetResult: """Start a new episode by sampling a random vulnerable function.""" if seed is not None: self._rng.seed(seed) self._contract, self._target_fn = sample_episode(self._contracts, self._rng) self._grader = Task1Grader( target_function=self._target_fn["name"], vulnerability_issue=self._target_fn["vulnerability_details"]["issue"], n = floor(log2(len(self._contract["functions"]))) ) self._step_count = 0 self._cummulative_cost = 0.0 self._done = False self._query_history = [] self._seen_queries = set() obs = self._build_observation( last_action=None, last_result=( f"New episode started. Contract: {self._contract['contract_name']}. " f"Use 'list_functions' to explore the contract." ), ) return ResetResult(observation=obs, info={"task_id": TASK_ID}) def step(self, action: Action) -> StepResult: """Execute one agent action.""" if self._done: raise RuntimeError("Episode is done. Call reset() to start a new episode.") if self._step_count > self._max_steps: raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.") self._step_count += 1 result_text, reward = self._dispatch(action) self._cummulative_cost += reward.value self._query_history.append(f"[{action.action_type}] → {result_text[:200]}") obs = self._build_observation( last_action=action.action_type, last_result=result_text, ) return StepResult( observation=obs, reward=reward, done=self._done, info={ "step": self._step_count, "cumulative_reward": self._cummulative_cost, }, ) def state(self) -> StateResult: return StateResult( task_id=TASK_ID, contract_name=self._contract.get("contract_name", ""), target_function=self._target_fn.get("name", ""), step_count=self._step_count, cumulative_reward=self._cummulative_cost, done=self._done, query_history=list(self._query_history), ) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _build_observation( self, last_action: Optional[str], last_result: str, ) -> Observation: return Observation( task_id=TASK_ID, contract_name=self._contract.get("contract_name", ""), last_action=last_action, last_action_result=last_result, done=self._done, extra={ "solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""), "hint": ( "Identify the vulnerable function and its issue. " "Submit with action_type='submit', params={'function_name': '...', " "'vulnerability_type': '...'}" ), }, ) def _query_key(self, action_type: str, params: Dict[str, Any]) -> str: """Build a hashable key for repeated-query detection.""" return f"{action_type}:{sorted(params.items())}" def _is_repeated(self, key: str) -> bool: if key in self._seen_queries: return True self._seen_queries.add(key) return False def _dispatch(self, action: Action) -> tuple[str, Reward]: at = action.action_type params = action.params qkey = self._query_key(at, params) # Mapping from ActionType to handler function handlers = { ActionType.LIST_FUNCTIONS: actions.list_functions, ActionType.GET_FUNCTION_CODE: actions.get_function_code, ActionType.GET_FUNCTION_SUMMARY: actions.get_function_summary, ActionType.GET_FILE_METADATA: actions.get_file_metadata, ActionType.GET_STATE_VARIABLE: actions.get_state_variable, ActionType.GET_CALL_GRAPH: actions.get_call_graph, ActionType.SUBMIT: actions.submit, } handler = handlers.get(at) if handler is None: return actions.unknown_action(self, qkey, params, at) return handler(self, qkey, params)