| """SchemaRetriever — извлекает DDL и примеры строк из SQLite-файлов PAUQ/Spider.""" |
|
|
| from __future__ import annotations |
|
|
| import sqlite3 |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
|
|
| @dataclass |
| class TableInfo: |
| name: str |
| create_sql: str |
| sample_rows: list[tuple] |
|
|
|
|
| class SchemaRetriever: |
| """Читает структуру SQLite-БД для подачи в prompt модели.""" |
|
|
| def __init__(self, databases_dir: Path | str): |
| self.databases_dir = Path(databases_dir) |
|
|
| def db_path(self, db_id: str) -> Path: |
| """В Spider/PAUQ каждая БД лежит в databases_dir/{db_id}/{db_id}.sqlite.""" |
| path = self.databases_dir / db_id / f"{db_id}.sqlite" |
| if not path.exists(): |
| raise FileNotFoundError(f"Database file not found: {path}") |
| return path |
|
|
| def get_tables(self, db_id: str, n_sample_rows: int = 3) -> list[TableInfo]: |
| """Возвращает список таблиц с CREATE-SQL и примером строк.""" |
| path = self.db_path(db_id) |
| conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) |
| try: |
| conn.text_factory = lambda b: b.decode("utf-8", errors="replace") |
| cur = conn.cursor() |
| cur.execute( |
| "SELECT name, sql FROM sqlite_master " |
| "WHERE type='table' AND name NOT LIKE 'sqlite_%'" |
| ) |
| rows = cur.fetchall() |
|
|
| tables: list[TableInfo] = [] |
| for table_name, create_sql in rows: |
| if not create_sql: |
| continue |
| try: |
| cur.execute(f'SELECT * FROM "{table_name}" LIMIT {n_sample_rows}') |
| samples = cur.fetchall() |
| except sqlite3.Error: |
| samples = [] |
| tables.append( |
| TableInfo(name=table_name, create_sql=create_sql.strip(), sample_rows=samples) |
| ) |
| return tables |
| finally: |
| conn.close() |
|
|
| def render_schema(self, db_id: str, include_samples: bool = True) -> str: |
| """Текстовое представление схемы для prompt'а.""" |
| tables = self.get_tables(db_id) |
| parts: list[str] = [] |
| for t in tables: |
| parts.append(t.create_sql + ";") |
| if include_samples and t.sample_rows: |
| parts.append(f"-- Примеры строк из {t.name}:") |
| for row in t.sample_rows: |
| parts.append(f"-- {row}") |
| parts.append("") |
| return "\n".join(parts).strip() |
|
|
| def list_databases(self) -> list[str]: |
| """Список доступных db_id.""" |
| if not self.databases_dir.exists(): |
| return [] |
| return sorted(p.name for p in self.databases_dir.iterdir() if p.is_dir()) |
|
|