Ru2SQL / src /data /schema.py
Tyycha's picture
initial commit
8871df9
raw
history blame
2.91 kB
"""SchemaRetriever — извлекает DDL и примеры строк из SQLite-файлов PAUQ/Spider."""
from __future__ import annotations
import sqlite3
from dataclasses import dataclass
from pathlib import Path
@dataclass
class TableInfo:
name: str
create_sql: str
sample_rows: list[tuple]
class SchemaRetriever:
"""Читает структуру SQLite-БД для подачи в prompt модели."""
def __init__(self, databases_dir: Path | str):
self.databases_dir = Path(databases_dir)
def db_path(self, db_id: str) -> Path:
"""В Spider/PAUQ каждая БД лежит в databases_dir/{db_id}/{db_id}.sqlite."""
path = self.databases_dir / db_id / f"{db_id}.sqlite"
if not path.exists():
raise FileNotFoundError(f"Database file not found: {path}")
return path
def get_tables(self, db_id: str, n_sample_rows: int = 3) -> list[TableInfo]:
"""Возвращает список таблиц с CREATE-SQL и примером строк."""
path = self.db_path(db_id)
conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
try:
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
cur = conn.cursor()
cur.execute(
"SELECT name, sql FROM sqlite_master "
"WHERE type='table' AND name NOT LIKE 'sqlite_%'"
)
rows = cur.fetchall()
tables: list[TableInfo] = []
for table_name, create_sql in rows:
if not create_sql:
continue
try:
cur.execute(f'SELECT * FROM "{table_name}" LIMIT {n_sample_rows}')
samples = cur.fetchall()
except sqlite3.Error:
samples = []
tables.append(
TableInfo(name=table_name, create_sql=create_sql.strip(), sample_rows=samples)
)
return tables
finally:
conn.close()
def render_schema(self, db_id: str, include_samples: bool = True) -> str:
"""Текстовое представление схемы для prompt'а."""
tables = self.get_tables(db_id)
parts: list[str] = []
for t in tables:
parts.append(t.create_sql + ";")
if include_samples and t.sample_rows:
parts.append(f"-- Примеры строк из {t.name}:")
for row in t.sample_rows:
parts.append(f"-- {row}")
parts.append("")
return "\n".join(parts).strip()
def list_databases(self) -> list[str]:
"""Список доступных db_id."""
if not self.databases_dir.exists():
return []
return sorted(p.name for p in self.databases_dir.iterdir() if p.is_dir())