Spaces:
Paused
Paused
| """ | |
| PostgreSQL connection and operations module using SQLAlchemy | |
| """ | |
| import logging | |
| from typing import Any, Dict, List, Optional, Tuple | |
| from sqlalchemy import create_engine, text, inspect | |
| from sqlalchemy.orm import sessionmaker, Session | |
| from sqlalchemy.pool import NullPool | |
| from sqlalchemy.exc import SQLAlchemyError | |
| from app.config import DATABASE_URL | |
| from app.models.schema import Base, SchemaMetadata, create_dynamic_table_sql | |
| logger = logging.getLogger(__name__) | |
| class PostgreSQL: | |
| """PostgreSQL connection and operations handler using SQLAlchemy""" | |
| def __init__(self, database_url: str): | |
| """ | |
| Initialize PostgreSQL connection | |
| Args: | |
| database_url: PostgreSQL connection string | |
| """ | |
| self.database_url = database_url | |
| self.engine = None | |
| self.SessionLocal = None | |
| def connect(self) -> bool: | |
| """ | |
| Establish PostgreSQL connection and create engine | |
| Returns: | |
| bool: True if successful, False otherwise | |
| """ | |
| try: | |
| self.engine = create_engine( | |
| self.database_url, | |
| poolclass=NullPool, # No connection pooling for Supabase free tier | |
| echo=False, | |
| connect_args={"connect_timeout": 10}, | |
| ) | |
| # Test connection | |
| with self.engine.connect() as conn: | |
| conn.execute(text("SELECT 1")) | |
| conn.commit() | |
| # Create all tables from models | |
| Base.metadata.create_all(self.engine) | |
| self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine) | |
| logger.info("Successfully connected to PostgreSQL") | |
| return True | |
| except SQLAlchemyError as e: | |
| logger.error(f"Failed to connect to PostgreSQL: {e}") | |
| return False | |
| except Exception as e: | |
| logger.error(f"Unexpected error connecting to PostgreSQL: {e}") | |
| return False | |
| def disconnect(self) -> None: | |
| """Close PostgreSQL connection""" | |
| if self.engine: | |
| self.engine.dispose() | |
| logger.info("PostgreSQL connection closed") | |
| def is_connected(self) -> bool: | |
| """ | |
| Check if database is truly connected and ready | |
| Returns: | |
| bool: True if ready, False otherwise | |
| """ | |
| if not self.engine or not self.SessionLocal: | |
| return False | |
| try: | |
| # Test the connection | |
| with self.engine.connect() as conn: | |
| conn.execute(text("SELECT 1")) | |
| return True | |
| except Exception as e: | |
| logger.warning(f"Database connection check failed: {e}") | |
| return False | |
| def get_session(self) -> Session: | |
| """ | |
| Get a new database session | |
| Returns: | |
| SQLAlchemy Session | |
| Raises: | |
| RuntimeError: If database not connected | |
| """ | |
| if not self.SessionLocal or not self.engine: | |
| # Try to reconnect | |
| logger.warning("Session not available, attempting to reconnect...") | |
| if not self.connect(): | |
| raise RuntimeError("Database not connected. Connection attempt failed.") | |
| return self.SessionLocal() | |
| def insert_documents( | |
| self, table_name: str, documents: List[Dict[str, Any]] | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Insert documents/rows into a table (creates table if not exists) | |
| Args: | |
| table_name: Name of the table | |
| documents: List of dictionaries to insert (column_name: value) | |
| Returns: | |
| Tuple of (success: bool, error_message: str) | |
| """ | |
| if not documents: | |
| return True, "" | |
| try: | |
| session = self.get_session() | |
| # Create table if it doesn't exist | |
| if not self.table_exists(table_name): | |
| self._create_table_from_documents(table_name, documents) | |
| # Insert documents as rows | |
| with session.begin(): | |
| for doc in documents: | |
| # Build INSERT statement dynamically | |
| columns = ", ".join([f'"{k}"' for k in doc.keys()]) | |
| values = ", ".join(["?" for _ in doc.values()]) | |
| placeholders = ", ".join([f":{k}" for k in doc.keys()]) | |
| insert_sql = f"INSERT INTO \"{table_name}\" ({columns}) VALUES ({placeholders})" | |
| session.execute(text(insert_sql), doc) | |
| logger.info(f"Inserted {len(documents)} rows into table '{table_name}'") | |
| session.close() | |
| return True, "" | |
| except Exception as e: | |
| error_msg = f"Error inserting documents into table '{table_name}': {e}" | |
| logger.error(error_msg) | |
| return False, error_msg | |
| def find_documents( | |
| self, | |
| table_name: str, | |
| limit: int = 1000, | |
| offset: int = 0, | |
| where_clause: Optional[str] = None, | |
| ) -> Tuple[bool, List[Dict[str, Any]], str]: | |
| """ | |
| Retrieve rows from a table | |
| Args: | |
| table_name: Name of the table | |
| limit: Maximum number of rows to return | |
| offset: Number of rows to skip (for pagination) | |
| where_clause: Optional WHERE clause for filtering | |
| Returns: | |
| Tuple of (success: bool, data: list, error_message: str) | |
| """ | |
| try: | |
| session = self.get_session() | |
| if not self.table_exists(table_name): | |
| return True, [], "" | |
| # Build query | |
| query_sql = f'SELECT * FROM "{table_name}"' | |
| if where_clause: | |
| query_sql += f" WHERE {where_clause}" | |
| query_sql += f" LIMIT {limit} OFFSET {offset}" | |
| result = session.execute(text(query_sql)) | |
| rows = result.fetchall() | |
| # Convert rows to dictionaries | |
| documents = [dict(row._mapping) for row in rows] | |
| logger.info( | |
| f"Retrieved {len(documents)} rows from table '{table_name}' with limit={limit}, offset={offset}" | |
| ) | |
| session.close() | |
| return True, documents, "" | |
| except Exception as e: | |
| error_msg = f"Error querying table '{table_name}': {e}" | |
| logger.error(error_msg) | |
| return False, [], error_msg | |
| def save_schema( | |
| self, schema_name: str, schema_definition: Dict[str, Any], table_name: Optional[str] = None | |
| ) -> Tuple[bool, str]: | |
| """ | |
| Save a schema mapping for reuse | |
| Args: | |
| schema_name: Name/identifier for the schema | |
| schema_definition: Dictionary containing the schema mapping | |
| table_name: Optional associated table name | |
| Returns: | |
| Tuple of (success: bool, error_message: str) | |
| """ | |
| try: | |
| session = self.get_session() | |
| schema = SchemaMetadata( | |
| name=schema_name, mapping=schema_definition, table_name=table_name | |
| ) | |
| session.add(schema) | |
| session.commit() | |
| logger.info(f"Schema '{schema_name}' saved successfully") | |
| session.close() | |
| return True, "" | |
| except Exception as e: | |
| error_msg = f"Error saving schema: {e}" | |
| logger.error(error_msg) | |
| return False, error_msg | |
| def get_schemas(self) -> Tuple[bool, List[Dict[str, Any]], str]: | |
| """ | |
| Retrieve all stored schemas | |
| Returns: | |
| Tuple of (success: bool, schemas: list, error_message: str) | |
| """ | |
| try: | |
| session = self.get_session() | |
| schemas = session.query(SchemaMetadata).all() | |
| schemas_list = [schema.to_dict() for schema in schemas] | |
| logger.info(f"Retrieved {len(schemas_list)} schemas") | |
| session.close() | |
| return True, schemas_list, "" | |
| except Exception as e: | |
| error_msg = f"Error retrieving schemas: {e}" | |
| logger.error(error_msg) | |
| return False, [], error_msg | |
| def table_exists(self, table_name: str) -> bool: | |
| """ | |
| Check if a table exists | |
| Args: | |
| table_name: Name of the table | |
| Returns: | |
| bool: True if table exists, False otherwise | |
| """ | |
| try: | |
| inspector = inspect(self.engine) | |
| tables = inspector.get_table_names() | |
| return table_name in tables | |
| except Exception as e: | |
| logger.error(f"Error checking table existence: {e}") | |
| return False | |
| def get_table_count(self, table_name: str) -> Tuple[bool, int, str]: | |
| """ | |
| Get the number of rows in a table | |
| Args: | |
| table_name: Name of the table | |
| Returns: | |
| Tuple of (success: bool, count: int, error_message: str) | |
| """ | |
| try: | |
| if not self.table_exists(table_name): | |
| return True, 0, "" | |
| session = self.get_session() | |
| result = session.execute(text(f'SELECT COUNT(*) as count FROM "{table_name}"')) | |
| count = result.scalar() | |
| session.close() | |
| return True, count, "" | |
| except Exception as e: | |
| error_msg = f"Error counting rows in table '{table_name}': {e}" | |
| logger.error(error_msg) | |
| return False, 0, error_msg | |
| def get_all_tables(self) -> Tuple[bool, List[str], str]: | |
| """ | |
| Get list of all user-created tables (excluding system tables) | |
| Returns: | |
| Tuple of (success: bool, tables: list, error_message: str) | |
| """ | |
| try: | |
| inspector = inspect(self.engine) | |
| all_tables = inspector.get_table_names() | |
| # Filter out system tables | |
| user_tables = [t for t in all_tables if not t.startswith("pg_")] | |
| return True, user_tables, "" | |
| except Exception as e: | |
| error_msg = f"Error listing tables: {e}" | |
| logger.error(error_msg) | |
| return False, [], error_msg | |
| def _create_table_from_documents( | |
| self, table_name: str, documents: List[Dict[str, Any]] | |
| ) -> bool: | |
| """ | |
| Create a table dynamically based on document structure | |
| Args: | |
| table_name: Name of the table to create | |
| documents: Sample documents to infer schema from | |
| Returns: | |
| bool: True if successful, False otherwise | |
| """ | |
| try: | |
| if not documents: | |
| return False | |
| session = self.get_session() | |
| # Get all columns from all documents | |
| all_columns = set() | |
| for doc in documents: | |
| all_columns.update(doc.keys()) | |
| # Create table with all columns as TEXT | |
| columns = [] | |
| for col in sorted(all_columns): | |
| columns.append(f'"{col}" TEXT') | |
| columns_sql = ", ".join(columns) | |
| create_table_sql = f""" | |
| CREATE TABLE IF NOT EXISTS "{table_name}" ( | |
| id SERIAL PRIMARY KEY, | |
| {columns_sql}, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """ | |
| session.execute(text(create_table_sql)) | |
| session.commit() | |
| session.close() | |
| logger.info(f"Created table '{table_name}' with columns: {list(all_columns)}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error creating table '{table_name}': {e}") | |
| return False | |
| # Global database instance | |
| _db_instance: Optional[PostgreSQL] = None | |
| def get_db() -> PostgreSQL: | |
| """ | |
| Get or create the global PostgreSQL instance | |
| With auto-retry capability for connection failures | |
| Returns: | |
| PostgreSQL instance | |
| Raises: | |
| RuntimeError: If unable to connect | |
| """ | |
| global _db_instance | |
| # If instance doesn't exist, create and try to connect | |
| if _db_instance is None: | |
| _db_instance = PostgreSQL(DATABASE_URL) | |
| if not _db_instance.connect(): | |
| logger.warning("Initial connection failed, but instance created. Will retry on next request.") | |
| # If instance exists but engine is None, try to reconnect | |
| elif _db_instance.engine is None: | |
| logger.info("Attempting to reconnect to PostgreSQL...") | |
| if not _db_instance.connect(): | |
| logger.warning("Reconnection attempt failed") | |
| return _db_instance | |