Spaces:
Sleeping
Sleeping
File size: 4,919 Bytes
e207f41 454d146 105e019 454d146 105e019 e207f41 454d146 e207f41 c1bc4eb e207f41 454d146 e207f41 c1bc4eb 454d146 c1bc4eb 454d146 e207f41 454d146 e207f41 454d146 e207f41 454d146 e207f41 454d146 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from __future__ import annotations
import json
import os
import sqlite3
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional
@dataclass
class SpiderItem:
db_id: str
question: str
gold_sql: str
db_path: str # absolute path to the sqlite file
# ---------- helpers ----------
def _candidate_roots(env_root: Optional[str]) -> List[Path]:
"""
Build a small list of candidate Spider roots to tolerate common layouts:
- $SPIDER_ROOT
- data/spider
- data/spider/spider (when the repo was cloned into data/spider/spider)
- <env>/spider (when SPIDER_ROOT points to the parent directory)
"""
cands: List[Path] = []
if env_root:
p = Path(env_root).expanduser().resolve()
cands.append(p)
cands.append((p / "spider").resolve())
# project-local defaults
here = Path.cwd().resolve()
cands.append((here / "data" / "spider").resolve())
cands.append((here / "data" / "spider" / "spider").resolve())
# de-dup
seen, uniq = set(), []
for x in cands:
if str(x) not in seen:
uniq.append(x)
seen.add(str(x))
return uniq
def _resolve_split_json(root: Path, split: str) -> Path:
"""
Map split name to file name and return full path under `root`.
Spider uses:
- dev.json
- train_spider.json
"""
fname = "dev.json" if split == "dev" else "train_spider.json"
return (root / fname).resolve()
def _resolve_database_dir(root: Path) -> Path:
return (root / "database").resolve()
def _ensure_exists(path: Path, kind: str) -> None:
if not path.exists():
raise FileNotFoundError(f"{kind} not found: {path}")
# ---------- public API ----------
def load_spider_sqlite(
*, split: str = "dev", limit: Optional[int] = None
) -> List[SpiderItem]:
"""
Load a subset of Spider (dev/train) and attach absolute sqlite db paths.
Looks under:
- $SPIDER_ROOT (if set)
- ./data/spider
- ./data/spider/spider
- $SPIDER_ROOT/spider
"""
env_root = os.getenv("SPIDER_ROOT")
roots = _candidate_roots(env_root)
# find a root that actually contains the split file & database/
json_path: Optional[Path] = None
database_dir: Optional[Path] = None
chosen_root: Optional[Path] = None
for r in roots:
jp = _resolve_split_json(r, split)
dbd = _resolve_database_dir(r)
if jp.exists() and dbd.exists():
json_path, database_dir, chosen_root = jp, dbd, r
break
if json_path is None or database_dir is None:
debug = "\n".join(
f"- {str(_resolve_split_json(r, split))} | {str(_resolve_database_dir(r))}"
for r in roots
)
raise RuntimeError(
"Failed to locate Spider dataset.\n"
f"Checked candidates for split='{split}':\n{debug}\n"
"Tip: export SPIDER_ROOT=/absolute/path/to/spider "
"(the folder that directly contains dev.json/train_spider.json and database/)"
)
# read split
try:
items = json.loads(json_path.read_text(encoding="utf-8"))
except Exception as e:
raise RuntimeError(f"Failed to read Spider split file: {json_path} ({e})")
# build rows with absolute sqlite path
out: List[SpiderItem] = []
for obj in items:
db_id: str = obj.get("db_id", "")
q: str = obj.get("question", "").strip()
gold: str = obj.get("query", obj.get("sql", "")).strip() # Spider uses 'query'
if not (db_id and q and gold):
continue
# <root>/database/<db_id>/<db_id>.sqlite
db_file = (database_dir / db_id / f"{db_id}.sqlite").resolve()
if not db_file.exists():
# some mirrors use .db ; try a fallback
alt = (database_dir / db_id / f"{db_id}.db").resolve()
if alt.exists():
db_file = alt
else:
# skip if DB file missing
# (you could also raise here if you prefer strict behavior)
continue
out.append(
SpiderItem(db_id=db_id, question=q, gold_sql=gold, db_path=str(db_file))
)
if limit is not None and len(out) >= limit:
break
if not out:
raise RuntimeError(
f"No usable items from {json_path} (limit={limit}). "
"Check db files under database/<db_id>/<db_id>.sqlite"
)
# small info for sanity
print(
f"✔ Spider root: {chosen_root}\n"
f"✔ Split file: {json_path.name} ({len(out)} items)"
)
return out
def open_readonly_connection(db_path: str) -> sqlite3.Connection:
"""
Open SQLite in read-only mode (URI).
"""
uri = f"file:{Path(db_path).resolve()}?mode=ro"
return sqlite3.connect(uri, uri=True, check_same_thread=False)
|