Spaces:
Paused
Paused
| """ | |
| SQL Database data source. | |
| This module provides data loading from SQL databases using SQLAlchemy, | |
| supporting PostgreSQL, MySQL, SQLite, and other databases. | |
| """ | |
| import logging | |
| import re | |
| from typing import Any, Dict, Iterator, List, Optional | |
| from urllib.parse import quote_plus | |
| from potato.data_sources.base import DataSource, SourceConfig | |
| logger = logging.getLogger(__name__) | |
| class DatabaseSource(DataSource): | |
| """ | |
| Data source for SQL databases. | |
| Loads data from SQL databases using SQLAlchemy, supporting: | |
| - PostgreSQL, MySQL, SQLite | |
| - Custom SQL queries or simple table select | |
| - Connection via connection string or individual parameters | |
| - Incremental loading via OFFSET/LIMIT | |
| Configuration with connection string: | |
| type: database | |
| connection_string: "${DATABASE_URL}" | |
| query: "SELECT id, text, metadata FROM items WHERE status = 'pending'" | |
| Configuration with individual parameters: | |
| type: database | |
| dialect: postgresql # postgresql, mysql, sqlite | |
| host: "localhost" | |
| port: 5432 | |
| database: "annotations" | |
| username: "${DB_USER}" | |
| password: "${DB_PASSWORD}" | |
| table: "items" # Simple table select | |
| id_column: "id" | |
| text_column: "text" | |
| Note: Requires SQLAlchemy and appropriate database driver: | |
| pip install sqlalchemy psycopg2-binary # PostgreSQL | |
| pip install sqlalchemy pymysql # MySQL | |
| """ | |
| # Check for optional dependencies | |
| _HAS_SQLALCHEMY = None | |
| def _check_dependencies(cls) -> bool: | |
| """Check if SQLAlchemy is available.""" | |
| if cls._HAS_SQLALCHEMY is None: | |
| try: | |
| import sqlalchemy | |
| cls._HAS_SQLALCHEMY = True | |
| except ImportError: | |
| cls._HAS_SQLALCHEMY = False | |
| return cls._HAS_SQLALCHEMY | |
| # Pattern for safe SQL identifiers (table/column names) | |
| # Allows: word chars, dots for schema.table, backticks/brackets for quoted identifiers | |
| _SAFE_IDENTIFIER_RE = re.compile(r'\A[\w][\w.$]*\Z', re.ASCII) | |
| def _validate_identifier(name: str) -> str: | |
| """ | |
| Validate a SQL identifier (table or column name) against injection. | |
| Only allows alphanumeric characters, underscores, dots (for schema.table), | |
| and dollar signs. Rejects anything else to prevent SQL injection. | |
| Raises: | |
| ValueError: If the identifier contains unsafe characters | |
| """ | |
| if not name or not DatabaseSource._SAFE_IDENTIFIER_RE.match(name): | |
| raise ValueError( | |
| f"Invalid SQL identifier: '{name}'. " | |
| f"Only alphanumeric characters, underscores, dots, and " | |
| f"dollar signs are allowed." | |
| ) | |
| return name | |
| # Dialect to driver mapping | |
| DIALECT_DRIVERS = { | |
| 'postgresql': 'postgresql+psycopg2', | |
| 'postgres': 'postgresql+psycopg2', | |
| 'mysql': 'mysql+pymysql', | |
| 'sqlite': 'sqlite', | |
| 'mssql': 'mssql+pyodbc', | |
| } | |
| def __init__(self, config: SourceConfig): | |
| """Initialize the database source.""" | |
| super().__init__(config) | |
| # Connection options | |
| self._connection_string = config.config.get("connection_string", "") | |
| self._dialect = config.config.get("dialect", "") | |
| self._host = config.config.get("host", "localhost") | |
| self._port = config.config.get("port") | |
| self._database = config.config.get("database", "") | |
| self._username = config.config.get("username", "") | |
| self._password = config.config.get("password", "") | |
| # Query options | |
| self._query = config.config.get("query", "") | |
| self._table = config.config.get("table", "") | |
| self._id_column = config.config.get("id_column", "id") | |
| self._text_column = config.config.get("text_column", "text") | |
| # Connection pooling options | |
| self._pool_size = config.config.get("pool_size", 5) | |
| self._pool_timeout = config.config.get("pool_timeout", 30) | |
| self._engine = None | |
| self._total_count: Optional[int] = None | |
| def get_source_id(self) -> str: | |
| """Get unique identifier.""" | |
| return self._source_id | |
| def validate_config(self) -> List[str]: | |
| """Validate source configuration.""" | |
| errors = [] | |
| # Must have connection string OR individual parameters | |
| if not self._connection_string: | |
| if not self._dialect: | |
| errors.append( | |
| "Either 'connection_string' or 'dialect' is required" | |
| ) | |
| elif self._dialect not in self.DIALECT_DRIVERS: | |
| errors.append( | |
| f"Unknown dialect '{self._dialect}'. " | |
| f"Supported: {', '.join(self.DIALECT_DRIVERS.keys())}" | |
| ) | |
| if not self._database and self._dialect != 'sqlite': | |
| errors.append("'database' is required") | |
| # Must have query OR table | |
| if not self._query and not self._table: | |
| errors.append("Either 'query' or 'table' is required") | |
| # Validate table name if provided (prevent SQL injection) | |
| if self._table: | |
| try: | |
| self._validate_identifier(self._table) | |
| except ValueError as e: | |
| errors.append(str(e)) | |
| return errors | |
| def is_available(self) -> bool: | |
| """Check if the source is available.""" | |
| if not self._check_dependencies(): | |
| logger.warning( | |
| "SQLAlchemy not installed. " | |
| "Install with: pip install sqlalchemy" | |
| ) | |
| return False | |
| return True | |
| def _build_connection_string(self) -> str: | |
| """Build connection string from individual parameters.""" | |
| if self._connection_string: | |
| return self._connection_string | |
| driver = self.DIALECT_DRIVERS.get(self._dialect, self._dialect) | |
| if self._dialect == 'sqlite': | |
| return f"sqlite:///{self._database}" | |
| # Build URL with credentials | |
| if self._username: | |
| userpass = self._username | |
| if self._password: | |
| userpass += f":{quote_plus(self._password)}" | |
| userpass += "@" | |
| else: | |
| userpass = "" | |
| host_port = self._host | |
| if self._port: | |
| host_port += f":{self._port}" | |
| return f"{driver}://{userpass}{host_port}/{self._database}" | |
| def _get_engine(self): | |
| """Get or create the SQLAlchemy engine.""" | |
| if self._engine: | |
| return self._engine | |
| from sqlalchemy import create_engine | |
| connection_string = self._build_connection_string() | |
| # Create engine with connection pooling | |
| engine_kwargs = {} | |
| if self._dialect != 'sqlite': | |
| engine_kwargs = { | |
| 'pool_size': self._pool_size, | |
| 'pool_timeout': self._pool_timeout, | |
| 'pool_pre_ping': True, # Enable connection health checks | |
| } | |
| self._engine = create_engine(connection_string, **engine_kwargs) | |
| return self._engine | |
| def _build_query(self, offset: int = 0, limit: Optional[int] = None) -> str: | |
| """Build the SQL query with optional pagination.""" | |
| if self._query: | |
| base_query = self._query.rstrip(';') | |
| else: | |
| # Validate table name to prevent SQL injection | |
| safe_table = self._validate_identifier(self._table) | |
| base_query = f"SELECT * FROM {safe_table}" | |
| # Add pagination using validated integer values | |
| if limit is not None or offset > 0: | |
| if limit is not None: | |
| base_query += f" LIMIT {int(limit)}" | |
| if offset > 0: | |
| base_query += f" OFFSET {int(offset)}" | |
| return base_query | |
| def _row_to_dict(self, row, columns: List[str]) -> Dict[str, Any]: | |
| """Convert a database row to a dictionary.""" | |
| item = {} | |
| for i, col in enumerate(columns): | |
| value = row[i] | |
| # Handle special types | |
| if hasattr(value, 'isoformat'): # datetime | |
| value = value.isoformat() | |
| elif hasattr(value, 'tobytes'): # memoryview/bytes | |
| value = value.tobytes().decode('utf-8', errors='replace') | |
| item[col] = value | |
| return item | |
| def read_items( | |
| self, | |
| start: int = 0, | |
| count: Optional[int] = None | |
| ) -> Iterator[Dict[str, Any]]: | |
| """Read items from the database.""" | |
| from sqlalchemy import text | |
| engine = self._get_engine() | |
| query = self._build_query(offset=start, limit=count) | |
| with engine.connect() as connection: | |
| result = connection.execute(text(query)) | |
| # Get column names | |
| columns = list(result.keys()) | |
| for row in result: | |
| item = self._row_to_dict(row, columns) | |
| yield item | |
| def get_total_count(self) -> Optional[int]: | |
| """Get total number of items.""" | |
| if self._total_count is not None: | |
| return self._total_count | |
| from sqlalchemy import text | |
| try: | |
| engine = self._get_engine() | |
| if self._query: | |
| # Wrap query in count (query is admin-provided from YAML config) | |
| count_query = f"SELECT COUNT(*) FROM ({self._query.rstrip(';')}) AS subquery" | |
| else: | |
| # Validate table name to prevent SQL injection | |
| safe_table = self._validate_identifier(self._table) | |
| count_query = f"SELECT COUNT(*) FROM {safe_table}" | |
| with engine.connect() as connection: | |
| result = connection.execute(text(count_query)) | |
| self._total_count = result.scalar() | |
| return self._total_count | |
| except Exception as e: | |
| logger.error(f"Error getting count: {e}") | |
| return None | |
| def supports_partial_reading(self) -> bool: | |
| """Database sources support efficient partial reading via OFFSET/LIMIT.""" | |
| return True | |
| def refresh(self) -> bool: | |
| """Refresh by clearing cached count.""" | |
| self._total_count = None | |
| return True | |
| def get_status(self) -> Dict[str, Any]: | |
| """Get source status.""" | |
| status = super().get_status() | |
| status["dialect"] = self._dialect | |
| status["database"] = self._database | |
| status["table"] = self._table | |
| status["has_custom_query"] = bool(self._query) | |
| return status | |
| def close(self) -> None: | |
| """Close the database connection.""" | |
| if self._engine: | |
| self._engine.dispose() | |
| self._engine = None | |
| self._total_count = None | |