Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import pathlib | |
| import sqlite3 | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| import os | |
| SPIDER_ROOT = pathlib.Path(os.getenv("SPIDER_ROOT", "data/spider")) | |
| class SpiderItem: | |
| db_id: str | |
| question: str | |
| gold_sql: str | |
| db_path: pathlib.Path | |
| def load_spider_sqlite( | |
| split: str = "dev", limit: Optional[int] = None | |
| ) -> List[SpiderItem]: | |
| fn = {"dev": "dev.json", "train": "train_spider.json"}[split] | |
| json_path = SPIDER_ROOT / fn | |
| try: | |
| items = json.loads(json_path.read_text(encoding="utf-8")) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})") | |
| out: list[SpiderItem] = [] | |
| for ex in items[: (limit or len(items))]: | |
| db_id = ex["db_id"] | |
| db_path = SPIDER_ROOT / "database" / db_id / f"{db_id}.sqlite" | |
| if not db_path.exists(): | |
| raise FileNotFoundError(f"Missing SQLite DB for {db_id}: {db_path}") | |
| out.append( | |
| SpiderItem( | |
| db_id=db_id, | |
| question=ex["question"], | |
| gold_sql=ex["query"], | |
| db_path=db_path, | |
| ) | |
| ) | |
| return out | |
| def open_readonly_connection( | |
| db_path: pathlib.Path, timeout: float = 5.0 | |
| ) -> sqlite3.Connection: | |
| uri = f"file:{db_path}?mode=ro&uri=true" | |
| conn = sqlite3.connect(uri, uri=True, timeout=timeout) | |
| conn.row_factory = sqlite3.Row | |
| return conn | |