nl2sql-copilot / benchmarks /spider_loader.py
Melika Kheirieh
style: format code with ruff
105e019
raw
history blame
1.49 kB
from __future__ import annotations
import json
import pathlib
import sqlite3
from dataclasses import dataclass
from typing import List, Optional
import os
SPIDER_ROOT = pathlib.Path(os.getenv("SPIDER_ROOT", "data/spider"))
@dataclass
class SpiderItem:
db_id: str
question: str
gold_sql: str
db_path: pathlib.Path
def load_spider_sqlite(
split: str = "dev", limit: Optional[int] = None
) -> List[SpiderItem]:
fn = {"dev": "dev.json", "train": "train_spider.json"}[split]
json_path = SPIDER_ROOT / fn
try:
items = json.loads(json_path.read_text(encoding="utf-8"))
except Exception as e:
raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})")
out: list[SpiderItem] = []
for ex in items[: (limit or len(items))]:
db_id = ex["db_id"]
db_path = SPIDER_ROOT / "database" / db_id / f"{db_id}.sqlite"
if not db_path.exists():
raise FileNotFoundError(f"Missing SQLite DB for {db_id}: {db_path}")
out.append(
SpiderItem(
db_id=db_id,
question=ex["question"],
gold_sql=ex["query"],
db_path=db_path,
)
)
return out
def open_readonly_connection(
db_path: pathlib.Path, timeout: float = 5.0
) -> sqlite3.Connection:
uri = f"file:{db_path}?mode=ro&uri=true"
conn = sqlite3.connect(uri, uri=True, timeout=timeout)
conn.row_factory = sqlite3.Row
return conn