File size: 5,156 Bytes
8871df9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""SqlExecutor -- vypolnyaet SQL-zapros na podklyuchennoy BD i vozvraschaet rezultat.

Primer:
    executor = SqlExecutor("sqlite:///data/demo/sales.sqlite")
    result = executor.run("SELECT SUM(amount) FROM orders WHERE status='paid'")
    print(result.columns)
    print(result.rows)
"""

from __future__ import annotations

import sqlite3
from dataclasses import dataclass, field
from pathlib import Path
from urllib.parse import urlparse


@dataclass
class QueryResult:
    """Rezultat vypolneniya SQL-zaprosa."""
    columns: list[str]
    rows: list[list]
    row_count: int
    sql: str
    error: str | None = None

    @property
    def success(self) -> bool:
        return self.error is None

    def to_dict(self) -> dict:
        return {
            "columns": self.columns,
            "rows": self.rows,
            "row_count": self.row_count,
            "sql": self.sql,
            "error": self.error,
        }

    def to_markdown_table(self) -> str:
        if self.error:
            return f"Oshibka: {self.error}"
        if not self.rows:
            return "(pustoy rezultat)"
        header = " | ".join(self.columns)
        sep = " | ".join(["---"] * len(self.columns))
        rows = "\n".join(" | ".join(str(v) for v in row) for row in self.rows)
        return f"{header}\n{sep}\n{rows}"


class SqlExecutor:
    """Vypolnyaet SQL na podklyuchennoy BD."""

    MAX_ROWS = 500

    def __init__(self, connection_string: str):
        self.connection_string = connection_string.strip()
        self._db_type = self._detect_type(self.connection_string)

    def run(self, sql: str) -> QueryResult:
        try:
            if self._db_type == "sqlite":
                return self._run_sqlite(sql)
            elif self._db_type == "postgresql":
                return self._run_postgres(sql)
            elif self._db_type == "mysql":
                return self._run_mysql(sql)
            else:
                return QueryResult(columns=[], rows=[], row_count=0, sql=sql,
                                   error=f"Neizvestnyy tip BD: {self._db_type}")
        except Exception as e:
            return QueryResult(columns=[], rows=[], row_count=0, sql=sql, error=str(e))

    def _run_sqlite(self, sql: str) -> QueryResult:
        path = self._safe_sqlite_path(self._sqlite_path())
        conn = sqlite3.connect(str(path))
        conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
        try:
            cur = conn.cursor()
            cur.execute(sql)
            cols = [d[0] for d in (cur.description or [])]
            rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
            return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
        finally:
            conn.close()

    def _run_postgres(self, sql: str) -> QueryResult:
        try:
            import psycopg2  # type: ignore
        except ImportError as e:
            raise ImportError("Ustanovi psycopg2: pip install psycopg2-binary") from e

        conn = psycopg2.connect(self.connection_string)
        try:
            cur = conn.cursor()
            cur.execute(sql)
            cols = [d[0] for d in (cur.description or [])]
            rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
            return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
        finally:
            conn.close()

    def _run_mysql(self, sql: str) -> QueryResult:
        try:
            import pymysql  # type: ignore
        except ImportError as e:
            raise ImportError("Ustanovi pymysql: pip install pymysql") from e

        parsed = urlparse(self.connection_string)
        conn = pymysql.connect(
            host=parsed.hostname,
            port=parsed.port or 3306,
            user=parsed.username,
            password=parsed.password,
            database=parsed.path.lstrip("/"),
        )
        try:
            cur = conn.cursor()
            cur.execute(sql)
            cols = [d[0] for d in (cur.description or [])]
            rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
            return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
        finally:
            conn.close()

    def _sqlite_path(self) -> Path:
        cs = self.connection_string
        if cs.startswith("sqlite:///"):
            return Path(cs[10:])
        return Path(cs)

    @staticmethod
    def _safe_sqlite_path(path: Path) -> Path:
        import shutil
        import tempfile
        journal = Path(str(path) + "-journal")
        wal = Path(str(path) + "-wal")
        if journal.exists() or wal.exists():
            tmp = Path(tempfile.mktemp(suffix=".sqlite"))
            shutil.copy2(path, tmp)
            return tmp
        return path

    @staticmethod
    def _detect_type(cs: str) -> str:
        if cs.startswith("sqlite") or cs.endswith(".sqlite") or cs.endswith(".db"):
            return "sqlite"
        if cs.startswith("postgresql") or cs.startswith("postgres"):
            return "postgresql"
        if cs.startswith("mysql"):
            return "mysql"
        raise ValueError(f"Ne udalos opredelit tip BD: {cs}")