File size: 10,259 Bytes
cc2ed2f
8871df9
cc2ed2f
 
 
 
8871df9
cc2ed2f
8871df9
 
 
 
 
 
 
cc2ed2f
8871df9
 
 
 
 
cc2ed2f
 
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2ed2f
8871df9
 
 
 
 
 
 
 
 
 
 
 
cc2ed2f
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2ed2f
8871df9
 
 
 
 
 
 
 
 
cc2ed2f
 
8871df9
 
 
 
 
 
 
 
 
 
cc2ed2f
8871df9
 
cc2ed2f
 
 
 
 
 
 
 
 
 
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2ed2f
 
 
8871df9
 
 
 
 
 
 
 
 
cc2ed2f
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2ed2f
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2ed2f
 
 
8871df9
 
 
cc2ed2f
 
 
 
 
 
 
8871df9
cc2ed2f
 
8871df9
 
 
 
 
 
 
 
 
 
 
 
cc2ed2f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
"""DbConnector — подключение к произвольной БД и чтение схемы.

Поддерживаемые типы БД:
    SQLite     — путь к файлу: "sqlite:///path/to/db.sqlite" или просто путь
    PostgreSQL — "postgresql://user:pass@host:port/dbname" (требует psycopg2)
    MySQL      — "mysql://user:pass@host:port/dbname"      (требует pymysql)

Пример:
    conn = DbConnector("sqlite:///data/demo/sales.sqlite")
    print(conn.render_schema())
    tables = conn.list_tables()
"""

from __future__ import annotations

import logging
import sqlite3
from dataclasses import dataclass, field
from pathlib import Path
from urllib.parse import urlparse

logger = logging.getLogger(__name__)


@dataclass
class ColumnInfo:
    name: str
    type: str
    nullable: bool = True
    primary_key: bool = False


@dataclass
class TableInfo:
    name: str
    columns: list[ColumnInfo] = field(default_factory=list)
    sample_rows: list[tuple] = field(default_factory=list)

    def to_ddl(self) -> str:
        """Генерирует CREATE TABLE statement из метаданных."""
        col_parts = []
        for col in self.columns:
            line = f"    {col.name} {col.type}"
            if col.primary_key:
                line += " PRIMARY KEY"
            if not col.nullable:
                line += " NOT NULL"
            col_parts.append(line)
        return f"CREATE TABLE {self.name} (\n" + ",\n".join(col_parts) + "\n);"


class DbConnector:
    """Универсальный коннектор к БД. Читает схему для подстановки в промпт."""

    def __init__(self, connection_string: str, n_sample_rows: int = 2):
        self.connection_string = self._normalize(connection_string)
        self.n_sample_rows = n_sample_rows
        self._db_type = self._detect_type(self.connection_string)

    def list_tables(self) -> list[str]:
        return [t.name for t in self._get_tables(n_sample_rows=0)]

    def get_schema(self, include_samples: bool = True) -> list[TableInfo]:
        return self._get_tables(n_sample_rows=self.n_sample_rows if include_samples else 0)

    def render_schema(self, include_samples: bool = True) -> str:
        tables = self.get_schema(include_samples=include_samples)
        parts: list[str] = []
        for t in tables:
            parts.append(t.to_ddl())
            if include_samples and t.sample_rows:
                parts.append(f"-- Примеры строк из {t.name}:")
                for row in t.sample_rows:
                    parts.append(f"--   {row}")
            parts.append("")
        return "\n".join(parts).strip()

    def test_connection(self) -> bool:
        try:
            self._get_tables(n_sample_rows=0)
            return True
        except Exception as e:  # noqa: BLE001
            logger.warning("Подключение к БД не удалось: %s", e)
            return False

    def _get_tables(self, n_sample_rows: int) -> list[TableInfo]:
        if self._db_type == "sqlite":
            return self._get_tables_sqlite(n_sample_rows)
        elif self._db_type == "postgresql":
            return self._get_tables_postgres(n_sample_rows)
        elif self._db_type == "mysql":
            return self._get_tables_mysql(n_sample_rows)
        else:
            raise ValueError(f"Неизвестный тип БД: {self._db_type}")

    def _get_tables_sqlite(self, n_sample_rows: int) -> list[TableInfo]:
        """SQLite-подключение в режиме read-only через URI.

        immutable=1 говорит SQLite, что файл не изменяется во время сессии,
        поэтому journal/WAL-файлы можно игнорировать. Это убирает прежнюю
        логику с копированием БД во временную директорию и заодно даёт
        guardrail-уровень безопасности: любая модифицирующая операция
        на таком соединении завершится ошибкой.
        """
        path = self._sqlite_path()
        conn = sqlite3.connect(self._sqlite_uri(path), uri=True)
        conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
        try:
            cur = conn.cursor()
            cur.execute(
                "SELECT name FROM sqlite_master "
                "WHERE type='table' AND name NOT LIKE 'sqlite_%' "
                "ORDER BY name"
            )
            table_names = [r[0] for r in cur.fetchall()]
            tables: list[TableInfo] = []
            for name in table_names:
                cur.execute(f'PRAGMA table_info("{name}")')
                cols = [
                    ColumnInfo(
                        name=row[1],
                        type=row[2] or "TEXT",
                        nullable=not row[3],
                        primary_key=bool(row[5]),
                    )
                    for row in cur.fetchall()
                ]
                samples: list[tuple] = []
                if n_sample_rows > 0:
                    try:
                        cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
                        samples = cur.fetchall()
                    except sqlite3.Error as e:
                        logger.debug("Не удалось получить sample-строки для %s: %s",
                                     name, e)
                tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
            return tables
        finally:
            conn.close()

    def _get_tables_postgres(self, n_sample_rows: int) -> list[TableInfo]:
        try:
            import psycopg2  # type: ignore
        except ImportError as e:
            raise ImportError("Установи psycopg2: pip install psycopg2-binary") from e

        conn = psycopg2.connect(self.connection_string)
        try:
            cur = conn.cursor()
            cur.execute(
                "SELECT table_name FROM information_schema.tables "
                "WHERE table_schema = 'public' AND table_type = 'BASE TABLE' "
                "ORDER BY table_name"
            )
            table_names = [r[0] for r in cur.fetchall()]
            tables: list[TableInfo] = []
            for name in table_names:
                cur.execute(
                    "SELECT column_name, data_type, is_nullable "
                    "FROM information_schema.columns "
                    "WHERE table_name = %s AND table_schema = 'public' "
                    "ORDER BY ordinal_position",
                    (name,),
                )
                cols = [
                    ColumnInfo(name=r[0], type=r[1], nullable=(r[2] == "YES"))
                    for r in cur.fetchall()
                ]
                samples: list[tuple] = []
                if n_sample_rows > 0:
                    cur.execute(f'SELECT * FROM "{name}" LIMIT {n_sample_rows}')
                    samples = cur.fetchall()
                tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
            return tables
        finally:
            conn.close()

    def _get_tables_mysql(self, n_sample_rows: int) -> list[TableInfo]:
        try:
            import pymysql  # type: ignore
        except ImportError as e:
            raise ImportError("Установи pymysql: pip install pymysql") from e

        parsed = urlparse(self.connection_string)
        conn = pymysql.connect(
            host=parsed.hostname,
            port=parsed.port or 3306,
            user=parsed.username,
            password=parsed.password,
            database=parsed.path.lstrip("/"),
        )
        try:
            cur = conn.cursor()
            cur.execute("SHOW TABLES")
            table_names = [r[0] for r in cur.fetchall()]
            tables: list[TableInfo] = []
            for name in table_names:
                cur.execute(f"DESCRIBE `{name}`")
                cols = [
                    ColumnInfo(
                        name=r[0], type=r[1],
                        nullable=(r[2] == "YES"),
                        primary_key=(r[3] == "PRI"),
                    )
                    for r in cur.fetchall()
                ]
                samples: list[tuple] = []
                if n_sample_rows > 0:
                    cur.execute(f"SELECT * FROM `{name}` LIMIT {n_sample_rows}")
                    samples = cur.fetchall()
                tables.append(TableInfo(name=name, columns=cols, sample_rows=samples))
            return tables
        finally:
            conn.close()

    def _sqlite_path(self) -> Path:
        cs = self.connection_string
        if cs.startswith("sqlite:///"):
            return Path(cs[10:])
        return Path(cs)

    @staticmethod
    def _sqlite_uri(path: Path) -> str:
        """Read-only URI для SQLite с игнорированием journal/WAL."""
        return f"file:{path}?mode=ro&immutable=1"

    @staticmethod
    def _normalize(cs: str) -> str:
        """Если передан просто путь к файлу — превращаем в sqlite:// URI.

        Если строка уже выглядит как URI (sqlite/postgres/mysql) —
        возвращаем как есть. Без этой проверки сценарий «передали
        корректный sqlite:///path» приводил к двойной нормализации
        и подключению к несуществующему пути.
        """
        cs = cs.strip()
        if cs.startswith(("sqlite:", "postgres", "mysql")):
            return cs
        if cs.endswith(".sqlite") or cs.endswith(".db"):
            return f"sqlite:///{cs}"
        return cs

    @staticmethod
    def _detect_type(cs: str) -> str:
        if cs.startswith("sqlite"):
            return "sqlite"
        if cs.startswith("postgresql") or cs.startswith("postgres"):
            return "postgresql"
        if cs.startswith("mysql"):
            return "mysql"
        raise ValueError(f"Не удалось определить тип БД: {cs}")