File size: 2,908 Bytes
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""SchemaRetriever — извлекает DDL и примеры строк из SQLite-файлов PAUQ/Spider."""

from __future__ import annotations

import sqlite3
from dataclasses import dataclass
from pathlib import Path


@dataclass
class TableInfo:
    name: str
    create_sql: str
    sample_rows: list[tuple]


class SchemaRetriever:
    """Читает структуру SQLite-БД для подачи в prompt модели."""

    def __init__(self, databases_dir: Path | str):
        self.databases_dir = Path(databases_dir)

    def db_path(self, db_id: str) -> Path:
        """В Spider/PAUQ каждая БД лежит в databases_dir/{db_id}/{db_id}.sqlite."""
        path = self.databases_dir / db_id / f"{db_id}.sqlite"
        if not path.exists():
            raise FileNotFoundError(f"Database file not found: {path}")
        return path

    def get_tables(self, db_id: str, n_sample_rows: int = 3) -> list[TableInfo]:
        """Возвращает список таблиц с CREATE-SQL и примером строк."""
        path = self.db_path(db_id)
        conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True)
        try:
            conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
            cur = conn.cursor()
            cur.execute(
                "SELECT name, sql FROM sqlite_master "
                "WHERE type='table' AND name NOT LIKE 'sqlite_%'"
            )
            rows = cur.fetchall()

            tables: list[TableInfo] = []
            for table_name, create_sql in rows:
                if not create_sql:
                    continue
                try:
                    cur.execute(f'SELECT * FROM "{table_name}" LIMIT {n_sample_rows}')
                    samples = cur.fetchall()
                except sqlite3.Error:
                    samples = []
                tables.append(
                    TableInfo(name=table_name, create_sql=create_sql.strip(), sample_rows=samples)
                )
            return tables
        finally:
            conn.close()

    def render_schema(self, db_id: str, include_samples: bool = True) -> str:
        """Текстовое представление схемы для prompt'а."""
        tables = self.get_tables(db_id)
        parts: list[str] = []
        for t in tables:
            parts.append(t.create_sql + ";")
            if include_samples and t.sample_rows:
                parts.append(f"-- Примеры строк из {t.name}:")
                for row in t.sample_rows:
                    parts.append(f"-- {row}")
            parts.append("")
        return "\n".join(parts).strip()

    def list_databases(self) -> list[str]:
        """Список доступных db_id."""
        if not self.databases_dir.exists():
            return []
        return sorted(p.name for p in self.databases_dir.iterdir() if p.is_dir())