File size: 5,156 Bytes
8871df9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 | """SqlExecutor -- vypolnyaet SQL-zapros na podklyuchennoy BD i vozvraschaet rezultat.
Primer:
executor = SqlExecutor("sqlite:///data/demo/sales.sqlite")
result = executor.run("SELECT SUM(amount) FROM orders WHERE status='paid'")
print(result.columns)
print(result.rows)
"""
from __future__ import annotations
import sqlite3
from dataclasses import dataclass, field
from pathlib import Path
from urllib.parse import urlparse
@dataclass
class QueryResult:
"""Rezultat vypolneniya SQL-zaprosa."""
columns: list[str]
rows: list[list]
row_count: int
sql: str
error: str | None = None
@property
def success(self) -> bool:
return self.error is None
def to_dict(self) -> dict:
return {
"columns": self.columns,
"rows": self.rows,
"row_count": self.row_count,
"sql": self.sql,
"error": self.error,
}
def to_markdown_table(self) -> str:
if self.error:
return f"Oshibka: {self.error}"
if not self.rows:
return "(pustoy rezultat)"
header = " | ".join(self.columns)
sep = " | ".join(["---"] * len(self.columns))
rows = "\n".join(" | ".join(str(v) for v in row) for row in self.rows)
return f"{header}\n{sep}\n{rows}"
class SqlExecutor:
"""Vypolnyaet SQL na podklyuchennoy BD."""
MAX_ROWS = 500
def __init__(self, connection_string: str):
self.connection_string = connection_string.strip()
self._db_type = self._detect_type(self.connection_string)
def run(self, sql: str) -> QueryResult:
try:
if self._db_type == "sqlite":
return self._run_sqlite(sql)
elif self._db_type == "postgresql":
return self._run_postgres(sql)
elif self._db_type == "mysql":
return self._run_mysql(sql)
else:
return QueryResult(columns=[], rows=[], row_count=0, sql=sql,
error=f"Neizvestnyy tip BD: {self._db_type}")
except Exception as e:
return QueryResult(columns=[], rows=[], row_count=0, sql=sql, error=str(e))
def _run_sqlite(self, sql: str) -> QueryResult:
path = self._safe_sqlite_path(self._sqlite_path())
conn = sqlite3.connect(str(path))
conn.text_factory = lambda b: b.decode("utf-8", errors="replace")
try:
cur = conn.cursor()
cur.execute(sql)
cols = [d[0] for d in (cur.description or [])]
rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
finally:
conn.close()
def _run_postgres(self, sql: str) -> QueryResult:
try:
import psycopg2 # type: ignore
except ImportError as e:
raise ImportError("Ustanovi psycopg2: pip install psycopg2-binary") from e
conn = psycopg2.connect(self.connection_string)
try:
cur = conn.cursor()
cur.execute(sql)
cols = [d[0] for d in (cur.description or [])]
rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
finally:
conn.close()
def _run_mysql(self, sql: str) -> QueryResult:
try:
import pymysql # type: ignore
except ImportError as e:
raise ImportError("Ustanovi pymysql: pip install pymysql") from e
parsed = urlparse(self.connection_string)
conn = pymysql.connect(
host=parsed.hostname,
port=parsed.port or 3306,
user=parsed.username,
password=parsed.password,
database=parsed.path.lstrip("/"),
)
try:
cur = conn.cursor()
cur.execute(sql)
cols = [d[0] for d in (cur.description or [])]
rows = [list(r) for r in cur.fetchmany(self.MAX_ROWS)]
return QueryResult(columns=cols, rows=rows, row_count=len(rows), sql=sql)
finally:
conn.close()
def _sqlite_path(self) -> Path:
cs = self.connection_string
if cs.startswith("sqlite:///"):
return Path(cs[10:])
return Path(cs)
@staticmethod
def _safe_sqlite_path(path: Path) -> Path:
import shutil
import tempfile
journal = Path(str(path) + "-journal")
wal = Path(str(path) + "-wal")
if journal.exists() or wal.exists():
tmp = Path(tempfile.mktemp(suffix=".sqlite"))
shutil.copy2(path, tmp)
return tmp
return path
@staticmethod
def _detect_type(cs: str) -> str:
if cs.startswith("sqlite") or cs.endswith(".sqlite") or cs.endswith(".db"):
return "sqlite"
if cs.startswith("postgresql") or cs.startswith("postgres"):
return "postgresql"
if cs.startswith("mysql"):
return "mysql"
raise ValueError(f"Ne udalos opredelit tip BD: {cs}")
|