Spaces:
Sleeping
Sleeping
File size: 6,746 Bytes
08c19c7 5235476 08c19c7 671787b 08c19c7 88875f7 08c19c7 7f7bcc6 08c19c7 5235476 08c19c7 5235476 08c19c7 5235476 08c19c7 5235476 08c19c7 5235476 08c19c7 5235476 08c19c7 5235476 08c19c7 5235476 08c19c7 671787b 08c19c7 5235476 08c19c7 5235476 c719864 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """
environment.py (Task 1 – Targeted Vulnerability Detection)
------------------------------------------------------------
Full OpenEnv-compliant environment.
Episode flow:
1. reset() selects a random (contract, vulnerable_function) pair.
2. The agent receives an Observation with the contract description.
3. The agent uses actions to explore the contract (each costs a small penalty).
4. When the agent submits, the Grader scores the answer and the episode ends.
"""
from __future__ import annotations
from math import floor, log2
import random
from typing import Any, Dict, List, Optional, Set
from data.data_loader import load_contracts, sample_episode
from env.base_env import BaseEnv
from env.schemas import (
Action,
ActionType,
Observation,
Reward,
ResetResult,
StateResult,
StepResult,
)
from server.tasks.task1 import actions
from .grader import Task1Grader
TASK_ID = "task1_vuln_detection"
AVAILABLE_ACTIONS = [
ActionType.LIST_FUNCTIONS,
ActionType.GET_FUNCTION_CODE,
ActionType.GET_FUNCTION_SUMMARY,
ActionType.GET_FILE_METADATA,
ActionType.GET_STATE_VARIABLE,
ActionType.GET_CALL_GRAPH,
ActionType.SUBMIT,
]
class Task1Environment(BaseEnv):
"""Task 1: Targeted Vulnerability Detection."""
def __init__(self, contracts_path: Optional[str] = None) -> None:
self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
self._rng = random.Random()
self._max_steps: int = 40
# Episode state (initialised by reset)
self._contract: Dict[str, Any] = {}
self._target_fn: Dict[str, Any] = {}
self._grader: Optional[Task1Grader] = None
self._step_count: int = 0
self._cummulative_cost: float = 0.0
self._done: bool = False
self._query_history: List[str] = []
self._seen_queries: Set[str] = set()
# ------------------------------------------------------------------
# OpenEnv interface
# ------------------------------------------------------------------
def reset(self, seed: Optional[int] = None) -> ResetResult:
"""Start a new episode by sampling a random vulnerable function."""
if seed is not None:
self._rng.seed(seed)
self._contract, self._target_fn = sample_episode(self._contracts, self._rng)
self._grader = Task1Grader(
target_function=self._target_fn["name"],
vulnerability_issue=self._target_fn["vulnerability_details"]["issue"],
n = floor(log2(len(self._contract["functions"])))
)
self._step_count = 0
self._cummulative_cost = 0.0
self._done = False
self._query_history = []
self._seen_queries = set()
obs = self._build_observation(
last_action=None,
last_result=(
f"New episode started. Contract: {self._contract['contract_name']}. "
f"Use 'list_functions' to explore the contract."
),
)
return ResetResult(observation=obs, info={"task_id": TASK_ID})
def step(self, action: Action) -> StepResult:
"""Execute one agent action."""
if self._done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
if self._step_count > self._max_steps:
raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.")
self._step_count += 1
result_text, reward = self._dispatch(action)
self._cummulative_cost += reward.value
self._query_history.append(f"[{action.action_type}] → {result_text[:200]}")
obs = self._build_observation(
last_action=action.action_type,
last_result=result_text,
)
return StepResult(
observation=obs,
reward=reward,
done=self._done,
info={
"step": self._step_count,
"cumulative_reward": self._cummulative_cost,
},
)
def state(self) -> StateResult:
return StateResult(
task_id=TASK_ID,
contract_name=self._contract.get("contract_name", ""),
target_function=self._target_fn.get("name", ""),
step_count=self._step_count,
cumulative_reward=self._cummulative_cost,
done=self._done,
query_history=list(self._query_history),
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _build_observation(
self,
last_action: Optional[str],
last_result: str,
) -> Observation:
return Observation(
task_id=TASK_ID,
contract_name=self._contract.get("contract_name", ""),
last_action=last_action,
last_action_result=last_result,
done=self._done,
extra={
"solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
"hint": (
"Identify the vulnerable function and its issue. "
"Submit with action_type='submit', params={'function_name': '...', "
"'vulnerability_type': '...'}"
),
},
)
def _query_key(self, action_type: str, params: Dict[str, Any]) -> str:
"""Build a hashable key for repeated-query detection."""
return f"{action_type}:{sorted(params.items())}"
def _is_repeated(self, key: str) -> bool:
if key in self._seen_queries:
return True
self._seen_queries.add(key)
return False
def _dispatch(self, action: Action) -> tuple[str, Reward]:
at = action.action_type
params = action.params
qkey = self._query_key(at, params)
# Mapping from ActionType to handler function
handlers = {
ActionType.LIST_FUNCTIONS: actions.list_functions,
ActionType.GET_FUNCTION_CODE: actions.get_function_code,
ActionType.GET_FUNCTION_SUMMARY: actions.get_function_summary,
ActionType.GET_FILE_METADATA: actions.get_file_metadata,
ActionType.GET_STATE_VARIABLE: actions.get_state_variable,
ActionType.GET_CALL_GRAPH: actions.get_call_graph,
ActionType.SUBMIT: actions.submit,
}
handler = handlers.get(at)
if handler is None:
return actions.unknown_action(self, qkey, params, at)
return handler(self, qkey, params) |