File size: 2,860 Bytes
08c19c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fccda7
08c19c7
 
 
 
 
 
 
8fccda7
 
 
 
 
 
08c19c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
data_loader.py
--------------
Loads and indexes smart contract data from JSON files.
Each contract is parsed into a structured dict; vulnerable functions
are indexed for fast lookup by Task 1.
"""

import json
import os
import random
from typing import Any, Dict, List, Optional, Tuple


DATA_DIR = os.path.join(os.path.dirname(__file__))
DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json")
DEFAULT_VUNERABILITIES_FILE = os.path.join(DATA_DIR, "vulnerabilities.json")

def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]:
    """Load and return all contracts from the JSON dataset."""
    with open(path, "r") as f:
        return json.load(f)


def load_vulnerabilities(path: str = DEFAULT_VUNERABILITIES_FILE) -> List[Dict[str, Any]]:
    """Load and return all vulnerability entries from the JSON dataset."""
    with open(path, "r") as f:
        return json.load(f)


def get_all_vulnerable_entries(
    contracts: List[Dict[str, Any]],
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
    """
    Returns a flat list of (contract, function) pairs where
    function['vulnerable'] is True.
    Used by Task 1 to populate the episode pool.
    """
    entries = []
    for contract in contracts:
        for fn in contract.get("functions", []):
            if fn.get("vulnerable", False):
                entries.append((contract, fn))
    return entries


def sample_episode(
    contracts: List[Dict[str, Any]],
    rng: Optional[random.Random] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """
    Randomly selects one (contract, vulnerable_function) pair.
    Returns the contract dict and the target function dict.
    """
    if rng is None:
        rng = random.Random()
    entries = get_all_vulnerable_entries(contracts)
    if not entries:
        raise ValueError("No vulnerable functions found in dataset.")
    return rng.choice(entries)


def get_function_by_name(
    contract: Dict[str, Any], name: str
) -> Optional[Dict[str, Any]]:
    """Case-insensitive function lookup within a contract."""
    for fn in contract.get("functions", []):
        if fn["name"].lower() == name.lower():
            return fn
    return None


def get_state_variable_by_name(
    contract: Dict[str, Any], name: str
) -> Optional[Dict[str, Any]]:
    """Case-insensitive state variable lookup."""
    for sv in contract.get("state_variables", []):
        if sv["name"].lower() == name.lower():
            return sv
    return None


def list_function_names(contract: Dict[str, Any]) -> List[str]:
    """Return all function names in the contract."""
    return [fn["name"] for fn in contract.get("functions", [])]


def list_state_variable_names(contract: Dict[str, Any]) -> List[str]:
    """Return all state variable names."""
    return [sv["name"] for sv in contract.get("state_variables", [])]