File size: 9,054 Bytes
d57737f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7738e45
 
 
 
d57737f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""
Data loading and caching for Agent World Model environments.

Downloads data from HuggingFace (Snowflake/AgentWorldModel-1K) on first use
and caches it locally. Provides typed accessors for scenarios, tasks, env code,
DB schemas, sample data, and verifiers.
"""

import json
import logging
import os
import re
import threading
from typing import Any

from huggingface_hub import hf_hub_download

logger = logging.getLogger(__name__)

HF_REPO_ID = "Snowflake/AgentWorldModel-1K"
HF_REPO_TYPE = "dataset"

DATA_FILES = [
    "gen_scenario.jsonl",
    "gen_tasks.jsonl",
    "gen_db.jsonl",
    "gen_sample.jsonl",
    "gen_envs.jsonl",
    "gen_verifier.jsonl",
    "gen_verifier.pure_code.jsonl",
]


def _default_cache_dir() -> str:
    return os.environ.get("AWM_DATA_DIR", os.path.expanduser("~/.cache/openenv/awm"))


def normalize_scenario_name(scenario: str) -> str:
    s = scenario.lower()
    s = re.sub(r"[^a-z0-9_]", "_", s)
    s = re.sub(r"_+", "_", s).strip("_").strip()
    return s


def _load_jsonl(path: str) -> list[dict]:
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f.readlines() if line.strip()]


def _ensure_downloaded(cache_dir: str) -> None:
    """Download data files from HuggingFace if not already cached."""
    os.makedirs(cache_dir, exist_ok=True)

    missing = [f for f in DATA_FILES if not os.path.exists(f"{cache_dir}/{f}")]
    if not missing:
        return

    logger.info(f"Downloading AWM data from {HF_REPO_ID} to {cache_dir} ...")

    for filename in missing:
        logger.info(f"  Downloading {filename} ...")
        hf_hub_download(
            repo_id=HF_REPO_ID,
            repo_type=HF_REPO_TYPE,
            filename=filename,
            local_dir=cache_dir,
        )

    logger.info("AWM data download complete.")


class AWMDataLoader:
    """
    Lazy-loading accessor for Agent World Model data files.
    """

    def __init__(self, cache_dir: str | None = None):
        self._cache_dir = cache_dir or _default_cache_dir()
        self._downloaded = False
        self._lock = threading.Lock()

        self._scenarios: dict[str, dict] | None = None
        self._tasks: dict[str, dict] | None = None
        self._envs: dict[str, dict] | None = None
        self._db_schemas: dict[str, dict] | None = None
        self._samples: dict[str, dict] | None = None
        self._verifiers: dict[str, dict[str, list[dict]]] | None = None

    def _ensure_data(self) -> None:
        if not self._downloaded:
            _ensure_downloaded(self._cache_dir)
            self._downloaded = True

    def _build_scenarios(self) -> dict[str, dict]:
        if self._scenarios is None:
            with self._lock:
                if self._scenarios is None:
                    self._ensure_data()
                    raw = _load_jsonl(f"{self._cache_dir}/gen_scenario.jsonl")
                    result: dict[str, dict] = {}
                    for item in raw:
                        key = normalize_scenario_name(item["name"])
                        result[key] = item
                    self._scenarios = result
        return self._scenarios

    def _build_tasks(self) -> dict[str, dict]:
        if self._tasks is None:
            with self._lock:
                if self._tasks is None:
                    self._ensure_data()
                    raw = _load_jsonl(f"{self._cache_dir}/gen_tasks.jsonl")
                    result: dict[str, dict] = {}
                    for item in raw:
                        key = normalize_scenario_name(item["scenario"])
                        result[key] = item
                    self._tasks = result
        return self._tasks

    def _build_envs(self) -> dict[str, dict]:
        if self._envs is None:
            with self._lock:
                if self._envs is None:
                    self._ensure_data()
                    raw = _load_jsonl(f"{self._cache_dir}/gen_envs.jsonl")
                    result: dict[str, dict] = {}
                    for item in raw:
                        key = normalize_scenario_name(item["scenario"])
                        result[key] = item
                    self._envs = result
        return self._envs

    def _build_db_schemas(self) -> dict[str, dict]:
        if self._db_schemas is None:
            with self._lock:
                if self._db_schemas is None:
                    self._ensure_data()
                    raw = _load_jsonl(f"{self._cache_dir}/gen_db.jsonl")
                    result: dict[str, dict] = {}
                    for item in raw:
                        key = normalize_scenario_name(item["scenario"])
                        result[key] = item
                    self._db_schemas = result
        return self._db_schemas

    def _build_samples(self) -> dict[str, dict]:
        if self._samples is None:
            with self._lock:
                if self._samples is None:
                    self._ensure_data()
                    raw = _load_jsonl(f"{self._cache_dir}/gen_sample.jsonl")
                    result: dict[str, dict] = {}
                    for item in raw:
                        key = normalize_scenario_name(item["scenario"])
                        result[key] = item
                    self._samples = result
        return self._samples

    def _build_verifiers(self, verifier_mode: str = "sql") -> dict[str, list[dict]]:
        if verifier_mode not in {"sql", "code"}:
            raise ValueError(
                f"Invalid verifier mode: {verifier_mode!r}, must be 'sql' or 'code'"
            )

        with self._lock:
            if self._verifiers is None:
                self._verifiers = {}

            if verifier_mode not in self._verifiers:
                self._ensure_data()
                if verifier_mode == "sql":
                    raw = _load_jsonl(f"{self._cache_dir}/gen_verifier.jsonl")
                elif verifier_mode == "code":
                    raw = _load_jsonl(f"{self._cache_dir}/gen_verifier.pure_code.jsonl")

                result: dict[str, list[dict]] = {}
                for item in raw:
                    key = normalize_scenario_name(item["scenario"])
                    if key not in result:
                        result[key] = []
                    result[key].append(item)
                self._verifiers[verifier_mode] = result

        return self._verifiers[verifier_mode]

    def list_scenarios(self) -> list[dict[str, Any]]:
        """Return all scenario names, descriptions, and tasks."""
        scenarios = self._build_scenarios()
        tasks = self._build_tasks()
        result = []
        for key, scenario in scenarios.items():
            task_item = tasks.get(key, {})
            task_list = task_item.get("tasks", [])
            result.append(
                {
                    "name": key,
                    "description": scenario.get("description", ""),
                    "num_tasks": len(task_list),
                    "tasks": task_list,
                }
            )
        return result

    def get_env_code(self, scenario: str) -> str:
        """Return the full_code for a scenario."""
        key = normalize_scenario_name(scenario)
        envs = self._build_envs()
        if key not in envs:
            raise ValueError(
                f"Scenario '{scenario}' (normalized: '{key}') not found in gen_envs.jsonl"
            )
        return envs[key]["full_code"]

    def get_db_schema(self, scenario: str) -> dict:
        """Return the db_schema dict for a scenario."""
        key = normalize_scenario_name(scenario)
        schemas = self._build_db_schemas()
        if key not in schemas:
            raise ValueError(f"Scenario '{scenario}' not found in gen_db.jsonl")
        return schemas[key]["db_schema"]

    def get_sample_data(self, scenario: str) -> Any:
        """Return the sample_data for a scenario."""
        key = normalize_scenario_name(scenario)
        samples = self._build_samples()
        if key not in samples:
            raise ValueError(f"Scenario '{scenario}' not found in gen_sample.jsonl")
        return samples[key]["sample_data"]

    def get_tasks(self, scenario: str) -> list[str]:
        """Return the task list for a scenario."""
        key = normalize_scenario_name(scenario)
        tasks = self._build_tasks()
        if key not in tasks:
            return []
        return tasks[key].get("tasks", [])

    def get_verifier(
        self, scenario: str, task_idx: int, verifier_mode: str = "sql"
    ) -> dict | None:
        """Return the verifier entry for a specific scenario + task_idx."""
        key = normalize_scenario_name(scenario)
        verifiers = self._build_verifiers(verifier_mode)
        entries = verifiers.get(key, [])
        for entry in entries:
            if entry.get("task_idx") == task_idx:
                return entry
        return None

    def scenario_exists(self, scenario: str) -> bool:
        key = normalize_scenario_name(scenario)
        return key in self._build_envs()