codebook / potato /data_sources /sources /database_source.py
davidjurgens's picture
Deploy: Potato — Codebook Annotation
aceb1b2 verified
Raw
History Blame Contribute Delete
10.8 kB
"""
SQL Database data source.
This module provides data loading from SQL databases using SQLAlchemy,
supporting PostgreSQL, MySQL, SQLite, and other databases.
"""
import logging
import re
from typing import Any, Dict, Iterator, List, Optional
from urllib.parse import quote_plus
from potato.data_sources.base import DataSource, SourceConfig
logger = logging.getLogger(__name__)
class DatabaseSource(DataSource):
"""
Data source for SQL databases.
Loads data from SQL databases using SQLAlchemy, supporting:
- PostgreSQL, MySQL, SQLite
- Custom SQL queries or simple table select
- Connection via connection string or individual parameters
- Incremental loading via OFFSET/LIMIT
Configuration with connection string:
type: database
connection_string: "${DATABASE_URL}"
query: "SELECT id, text, metadata FROM items WHERE status = 'pending'"
Configuration with individual parameters:
type: database
dialect: postgresql # postgresql, mysql, sqlite
host: "localhost"
port: 5432
database: "annotations"
username: "${DB_USER}"
password: "${DB_PASSWORD}"
table: "items" # Simple table select
id_column: "id"
text_column: "text"
Note: Requires SQLAlchemy and appropriate database driver:
pip install sqlalchemy psycopg2-binary # PostgreSQL
pip install sqlalchemy pymysql # MySQL
"""
# Check for optional dependencies
_HAS_SQLALCHEMY = None
@classmethod
def _check_dependencies(cls) -> bool:
"""Check if SQLAlchemy is available."""
if cls._HAS_SQLALCHEMY is None:
try:
import sqlalchemy
cls._HAS_SQLALCHEMY = True
except ImportError:
cls._HAS_SQLALCHEMY = False
return cls._HAS_SQLALCHEMY
# Pattern for safe SQL identifiers (table/column names)
# Allows: word chars, dots for schema.table, backticks/brackets for quoted identifiers
_SAFE_IDENTIFIER_RE = re.compile(r'\A[\w][\w.$]*\Z', re.ASCII)
@staticmethod
def _validate_identifier(name: str) -> str:
"""
Validate a SQL identifier (table or column name) against injection.
Only allows alphanumeric characters, underscores, dots (for schema.table),
and dollar signs. Rejects anything else to prevent SQL injection.
Raises:
ValueError: If the identifier contains unsafe characters
"""
if not name or not DatabaseSource._SAFE_IDENTIFIER_RE.match(name):
raise ValueError(
f"Invalid SQL identifier: '{name}'. "
f"Only alphanumeric characters, underscores, dots, and "
f"dollar signs are allowed."
)
return name
# Dialect to driver mapping
DIALECT_DRIVERS = {
'postgresql': 'postgresql+psycopg2',
'postgres': 'postgresql+psycopg2',
'mysql': 'mysql+pymysql',
'sqlite': 'sqlite',
'mssql': 'mssql+pyodbc',
}
def __init__(self, config: SourceConfig):
"""Initialize the database source."""
super().__init__(config)
# Connection options
self._connection_string = config.config.get("connection_string", "")
self._dialect = config.config.get("dialect", "")
self._host = config.config.get("host", "localhost")
self._port = config.config.get("port")
self._database = config.config.get("database", "")
self._username = config.config.get("username", "")
self._password = config.config.get("password", "")
# Query options
self._query = config.config.get("query", "")
self._table = config.config.get("table", "")
self._id_column = config.config.get("id_column", "id")
self._text_column = config.config.get("text_column", "text")
# Connection pooling options
self._pool_size = config.config.get("pool_size", 5)
self._pool_timeout = config.config.get("pool_timeout", 30)
self._engine = None
self._total_count: Optional[int] = None
def get_source_id(self) -> str:
"""Get unique identifier."""
return self._source_id
def validate_config(self) -> List[str]:
"""Validate source configuration."""
errors = []
# Must have connection string OR individual parameters
if not self._connection_string:
if not self._dialect:
errors.append(
"Either 'connection_string' or 'dialect' is required"
)
elif self._dialect not in self.DIALECT_DRIVERS:
errors.append(
f"Unknown dialect '{self._dialect}'. "
f"Supported: {', '.join(self.DIALECT_DRIVERS.keys())}"
)
if not self._database and self._dialect != 'sqlite':
errors.append("'database' is required")
# Must have query OR table
if not self._query and not self._table:
errors.append("Either 'query' or 'table' is required")
# Validate table name if provided (prevent SQL injection)
if self._table:
try:
self._validate_identifier(self._table)
except ValueError as e:
errors.append(str(e))
return errors
def is_available(self) -> bool:
"""Check if the source is available."""
if not self._check_dependencies():
logger.warning(
"SQLAlchemy not installed. "
"Install with: pip install sqlalchemy"
)
return False
return True
def _build_connection_string(self) -> str:
"""Build connection string from individual parameters."""
if self._connection_string:
return self._connection_string
driver = self.DIALECT_DRIVERS.get(self._dialect, self._dialect)
if self._dialect == 'sqlite':
return f"sqlite:///{self._database}"
# Build URL with credentials
if self._username:
userpass = self._username
if self._password:
userpass += f":{quote_plus(self._password)}"
userpass += "@"
else:
userpass = ""
host_port = self._host
if self._port:
host_port += f":{self._port}"
return f"{driver}://{userpass}{host_port}/{self._database}"
def _get_engine(self):
"""Get or create the SQLAlchemy engine."""
if self._engine:
return self._engine
from sqlalchemy import create_engine
connection_string = self._build_connection_string()
# Create engine with connection pooling
engine_kwargs = {}
if self._dialect != 'sqlite':
engine_kwargs = {
'pool_size': self._pool_size,
'pool_timeout': self._pool_timeout,
'pool_pre_ping': True, # Enable connection health checks
}
self._engine = create_engine(connection_string, **engine_kwargs)
return self._engine
def _build_query(self, offset: int = 0, limit: Optional[int] = None) -> str:
"""Build the SQL query with optional pagination."""
if self._query:
base_query = self._query.rstrip(';')
else:
# Validate table name to prevent SQL injection
safe_table = self._validate_identifier(self._table)
base_query = f"SELECT * FROM {safe_table}"
# Add pagination using validated integer values
if limit is not None or offset > 0:
if limit is not None:
base_query += f" LIMIT {int(limit)}"
if offset > 0:
base_query += f" OFFSET {int(offset)}"
return base_query
def _row_to_dict(self, row, columns: List[str]) -> Dict[str, Any]:
"""Convert a database row to a dictionary."""
item = {}
for i, col in enumerate(columns):
value = row[i]
# Handle special types
if hasattr(value, 'isoformat'): # datetime
value = value.isoformat()
elif hasattr(value, 'tobytes'): # memoryview/bytes
value = value.tobytes().decode('utf-8', errors='replace')
item[col] = value
return item
def read_items(
self,
start: int = 0,
count: Optional[int] = None
) -> Iterator[Dict[str, Any]]:
"""Read items from the database."""
from sqlalchemy import text
engine = self._get_engine()
query = self._build_query(offset=start, limit=count)
with engine.connect() as connection:
result = connection.execute(text(query))
# Get column names
columns = list(result.keys())
for row in result:
item = self._row_to_dict(row, columns)
yield item
def get_total_count(self) -> Optional[int]:
"""Get total number of items."""
if self._total_count is not None:
return self._total_count
from sqlalchemy import text
try:
engine = self._get_engine()
if self._query:
# Wrap query in count (query is admin-provided from YAML config)
count_query = f"SELECT COUNT(*) FROM ({self._query.rstrip(';')}) AS subquery"
else:
# Validate table name to prevent SQL injection
safe_table = self._validate_identifier(self._table)
count_query = f"SELECT COUNT(*) FROM {safe_table}"
with engine.connect() as connection:
result = connection.execute(text(count_query))
self._total_count = result.scalar()
return self._total_count
except Exception as e:
logger.error(f"Error getting count: {e}")
return None
def supports_partial_reading(self) -> bool:
"""Database sources support efficient partial reading via OFFSET/LIMIT."""
return True
def refresh(self) -> bool:
"""Refresh by clearing cached count."""
self._total_count = None
return True
def get_status(self) -> Dict[str, Any]:
"""Get source status."""
status = super().get_status()
status["dialect"] = self._dialect
status["database"] = self._database
status["table"] = self._table
status["has_custom_query"] = bool(self._query)
return status
def close(self) -> None:
"""Close the database connection."""
if self._engine:
self._engine.dispose()
self._engine = None
self._total_count = None