Spaces:
Paused
Paused
File size: 10,824 Bytes
aceb1b2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 | """
SQL Database data source.
This module provides data loading from SQL databases using SQLAlchemy,
supporting PostgreSQL, MySQL, SQLite, and other databases.
"""
import logging
import re
from typing import Any, Dict, Iterator, List, Optional
from urllib.parse import quote_plus
from potato.data_sources.base import DataSource, SourceConfig
logger = logging.getLogger(__name__)
class DatabaseSource(DataSource):
"""
Data source for SQL databases.
Loads data from SQL databases using SQLAlchemy, supporting:
- PostgreSQL, MySQL, SQLite
- Custom SQL queries or simple table select
- Connection via connection string or individual parameters
- Incremental loading via OFFSET/LIMIT
Configuration with connection string:
type: database
connection_string: "${DATABASE_URL}"
query: "SELECT id, text, metadata FROM items WHERE status = 'pending'"
Configuration with individual parameters:
type: database
dialect: postgresql # postgresql, mysql, sqlite
host: "localhost"
port: 5432
database: "annotations"
username: "${DB_USER}"
password: "${DB_PASSWORD}"
table: "items" # Simple table select
id_column: "id"
text_column: "text"
Note: Requires SQLAlchemy and appropriate database driver:
pip install sqlalchemy psycopg2-binary # PostgreSQL
pip install sqlalchemy pymysql # MySQL
"""
# Check for optional dependencies
_HAS_SQLALCHEMY = None
@classmethod
def _check_dependencies(cls) -> bool:
"""Check if SQLAlchemy is available."""
if cls._HAS_SQLALCHEMY is None:
try:
import sqlalchemy
cls._HAS_SQLALCHEMY = True
except ImportError:
cls._HAS_SQLALCHEMY = False
return cls._HAS_SQLALCHEMY
# Pattern for safe SQL identifiers (table/column names)
# Allows: word chars, dots for schema.table, backticks/brackets for quoted identifiers
_SAFE_IDENTIFIER_RE = re.compile(r'\A[\w][\w.$]*\Z', re.ASCII)
@staticmethod
def _validate_identifier(name: str) -> str:
"""
Validate a SQL identifier (table or column name) against injection.
Only allows alphanumeric characters, underscores, dots (for schema.table),
and dollar signs. Rejects anything else to prevent SQL injection.
Raises:
ValueError: If the identifier contains unsafe characters
"""
if not name or not DatabaseSource._SAFE_IDENTIFIER_RE.match(name):
raise ValueError(
f"Invalid SQL identifier: '{name}'. "
f"Only alphanumeric characters, underscores, dots, and "
f"dollar signs are allowed."
)
return name
# Dialect to driver mapping
DIALECT_DRIVERS = {
'postgresql': 'postgresql+psycopg2',
'postgres': 'postgresql+psycopg2',
'mysql': 'mysql+pymysql',
'sqlite': 'sqlite',
'mssql': 'mssql+pyodbc',
}
def __init__(self, config: SourceConfig):
"""Initialize the database source."""
super().__init__(config)
# Connection options
self._connection_string = config.config.get("connection_string", "")
self._dialect = config.config.get("dialect", "")
self._host = config.config.get("host", "localhost")
self._port = config.config.get("port")
self._database = config.config.get("database", "")
self._username = config.config.get("username", "")
self._password = config.config.get("password", "")
# Query options
self._query = config.config.get("query", "")
self._table = config.config.get("table", "")
self._id_column = config.config.get("id_column", "id")
self._text_column = config.config.get("text_column", "text")
# Connection pooling options
self._pool_size = config.config.get("pool_size", 5)
self._pool_timeout = config.config.get("pool_timeout", 30)
self._engine = None
self._total_count: Optional[int] = None
def get_source_id(self) -> str:
"""Get unique identifier."""
return self._source_id
def validate_config(self) -> List[str]:
"""Validate source configuration."""
errors = []
# Must have connection string OR individual parameters
if not self._connection_string:
if not self._dialect:
errors.append(
"Either 'connection_string' or 'dialect' is required"
)
elif self._dialect not in self.DIALECT_DRIVERS:
errors.append(
f"Unknown dialect '{self._dialect}'. "
f"Supported: {', '.join(self.DIALECT_DRIVERS.keys())}"
)
if not self._database and self._dialect != 'sqlite':
errors.append("'database' is required")
# Must have query OR table
if not self._query and not self._table:
errors.append("Either 'query' or 'table' is required")
# Validate table name if provided (prevent SQL injection)
if self._table:
try:
self._validate_identifier(self._table)
except ValueError as e:
errors.append(str(e))
return errors
def is_available(self) -> bool:
"""Check if the source is available."""
if not self._check_dependencies():
logger.warning(
"SQLAlchemy not installed. "
"Install with: pip install sqlalchemy"
)
return False
return True
def _build_connection_string(self) -> str:
"""Build connection string from individual parameters."""
if self._connection_string:
return self._connection_string
driver = self.DIALECT_DRIVERS.get(self._dialect, self._dialect)
if self._dialect == 'sqlite':
return f"sqlite:///{self._database}"
# Build URL with credentials
if self._username:
userpass = self._username
if self._password:
userpass += f":{quote_plus(self._password)}"
userpass += "@"
else:
userpass = ""
host_port = self._host
if self._port:
host_port += f":{self._port}"
return f"{driver}://{userpass}{host_port}/{self._database}"
def _get_engine(self):
"""Get or create the SQLAlchemy engine."""
if self._engine:
return self._engine
from sqlalchemy import create_engine
connection_string = self._build_connection_string()
# Create engine with connection pooling
engine_kwargs = {}
if self._dialect != 'sqlite':
engine_kwargs = {
'pool_size': self._pool_size,
'pool_timeout': self._pool_timeout,
'pool_pre_ping': True, # Enable connection health checks
}
self._engine = create_engine(connection_string, **engine_kwargs)
return self._engine
def _build_query(self, offset: int = 0, limit: Optional[int] = None) -> str:
"""Build the SQL query with optional pagination."""
if self._query:
base_query = self._query.rstrip(';')
else:
# Validate table name to prevent SQL injection
safe_table = self._validate_identifier(self._table)
base_query = f"SELECT * FROM {safe_table}"
# Add pagination using validated integer values
if limit is not None or offset > 0:
if limit is not None:
base_query += f" LIMIT {int(limit)}"
if offset > 0:
base_query += f" OFFSET {int(offset)}"
return base_query
def _row_to_dict(self, row, columns: List[str]) -> Dict[str, Any]:
"""Convert a database row to a dictionary."""
item = {}
for i, col in enumerate(columns):
value = row[i]
# Handle special types
if hasattr(value, 'isoformat'): # datetime
value = value.isoformat()
elif hasattr(value, 'tobytes'): # memoryview/bytes
value = value.tobytes().decode('utf-8', errors='replace')
item[col] = value
return item
def read_items(
self,
start: int = 0,
count: Optional[int] = None
) -> Iterator[Dict[str, Any]]:
"""Read items from the database."""
from sqlalchemy import text
engine = self._get_engine()
query = self._build_query(offset=start, limit=count)
with engine.connect() as connection:
result = connection.execute(text(query))
# Get column names
columns = list(result.keys())
for row in result:
item = self._row_to_dict(row, columns)
yield item
def get_total_count(self) -> Optional[int]:
"""Get total number of items."""
if self._total_count is not None:
return self._total_count
from sqlalchemy import text
try:
engine = self._get_engine()
if self._query:
# Wrap query in count (query is admin-provided from YAML config)
count_query = f"SELECT COUNT(*) FROM ({self._query.rstrip(';')}) AS subquery"
else:
# Validate table name to prevent SQL injection
safe_table = self._validate_identifier(self._table)
count_query = f"SELECT COUNT(*) FROM {safe_table}"
with engine.connect() as connection:
result = connection.execute(text(count_query))
self._total_count = result.scalar()
return self._total_count
except Exception as e:
logger.error(f"Error getting count: {e}")
return None
def supports_partial_reading(self) -> bool:
"""Database sources support efficient partial reading via OFFSET/LIMIT."""
return True
def refresh(self) -> bool:
"""Refresh by clearing cached count."""
self._total_count = None
return True
def get_status(self) -> Dict[str, Any]:
"""Get source status."""
status = super().get_status()
status["dialect"] = self._dialect
status["database"] = self._database
status["table"] = self._table
status["has_custom_query"] = bool(self._query)
return status
def close(self) -> None:
"""Close the database connection."""
if self._engine:
self._engine.dispose()
self._engine = None
self._total_count = None
|