""" data_loader.py -------------- Loads and indexes smart contract data from JSON files. Each contract is parsed into a structured dict; vulnerable functions are indexed for fast lookup by Task 1. """ import json import os import random from typing import Any, Dict, List, Optional, Tuple DATA_DIR = os.path.join(os.path.dirname(__file__)) DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json") DEFAULT_VUNERABILITIES_FILE = os.path.join(DATA_DIR, "vulnerabilities.json") def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]: """Load and return all contracts from the JSON dataset.""" with open(path, "r") as f: return json.load(f) def load_vulnerabilities(path: str = DEFAULT_VUNERABILITIES_FILE) -> List[Dict[str, Any]]: """Load and return all vulnerability entries from the JSON dataset.""" with open(path, "r") as f: return json.load(f) def get_all_vulnerable_entries( contracts: List[Dict[str, Any]], ) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]: """ Returns a flat list of (contract, function) pairs where function['vulnerable'] is True. Used by Task 1 to populate the episode pool. """ entries = [] for contract in contracts: for fn in contract.get("functions", []): if fn.get("vulnerable", False): entries.append((contract, fn)) return entries def sample_episode( contracts: List[Dict[str, Any]], rng: Optional[random.Random] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ Randomly selects one (contract, vulnerable_function) pair. Returns the contract dict and the target function dict. """ if rng is None: rng = random.Random() entries = get_all_vulnerable_entries(contracts) if not entries: raise ValueError("No vulnerable functions found in dataset.") return rng.choice(entries) def get_function_by_name( contract: Dict[str, Any], name: str ) -> Optional[Dict[str, Any]]: """Case-insensitive function lookup within a contract.""" for fn in contract.get("functions", []): if fn["name"].lower() == name.lower(): return fn return None def get_state_variable_by_name( contract: Dict[str, Any], name: str ) -> Optional[Dict[str, Any]]: """Case-insensitive state variable lookup.""" for sv in contract.get("state_variables", []): if sv["name"].lower() == name.lower(): return sv return None def list_function_names(contract: Dict[str, Any]) -> List[str]: """Return all function names in the contract.""" return [fn["name"] for fn in contract.get("functions", [])] def list_state_variable_names(contract: Dict[str, Any]) -> List[str]: """Return all state variable names.""" return [sv["name"] for sv in contract.get("state_variables", [])]