ai-csv-import / app /db /postgresql.py
Hamza4100's picture
Update app/db/postgresql.py
28dadd2 verified
"""
PostgreSQL connection and operations module using SQLAlchemy
"""
import logging
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import create_engine, text, inspect
from sqlalchemy.orm import sessionmaker, Session
from sqlalchemy.pool import NullPool
from sqlalchemy.exc import SQLAlchemyError
from app.config import DATABASE_URL
from app.models.schema import Base, SchemaMetadata, create_dynamic_table_sql
logger = logging.getLogger(__name__)
class PostgreSQL:
"""PostgreSQL connection and operations handler using SQLAlchemy"""
def __init__(self, database_url: str):
"""
Initialize PostgreSQL connection
Args:
database_url: PostgreSQL connection string
"""
self.database_url = database_url
self.engine = None
self.SessionLocal = None
def connect(self) -> bool:
"""
Establish PostgreSQL connection and create engine
Returns:
bool: True if successful, False otherwise
"""
try:
self.engine = create_engine(
self.database_url,
poolclass=NullPool, # No connection pooling for Supabase free tier
echo=False,
connect_args={"connect_timeout": 10},
)
# Test connection
with self.engine.connect() as conn:
conn.execute(text("SELECT 1"))
conn.commit()
# Create all tables from models
Base.metadata.create_all(self.engine)
self.SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=self.engine)
logger.info("Successfully connected to PostgreSQL")
return True
except SQLAlchemyError as e:
logger.error(f"Failed to connect to PostgreSQL: {e}")
return False
except Exception as e:
logger.error(f"Unexpected error connecting to PostgreSQL: {e}")
return False
def disconnect(self) -> None:
"""Close PostgreSQL connection"""
if self.engine:
self.engine.dispose()
logger.info("PostgreSQL connection closed")
def is_connected(self) -> bool:
"""
Check if database is truly connected and ready
Returns:
bool: True if ready, False otherwise
"""
if not self.engine or not self.SessionLocal:
return False
try:
# Test the connection
with self.engine.connect() as conn:
conn.execute(text("SELECT 1"))
return True
except Exception as e:
logger.warning(f"Database connection check failed: {e}")
return False
def get_session(self) -> Session:
"""
Get a new database session
Returns:
SQLAlchemy Session
Raises:
RuntimeError: If database not connected
"""
if not self.SessionLocal or not self.engine:
# Try to reconnect
logger.warning("Session not available, attempting to reconnect...")
if not self.connect():
raise RuntimeError("Database not connected. Connection attempt failed.")
return self.SessionLocal()
def insert_documents(
self, table_name: str, documents: List[Dict[str, Any]]
) -> Tuple[bool, str]:
"""
Insert documents/rows into a table (creates table if not exists)
Args:
table_name: Name of the table
documents: List of dictionaries to insert (column_name: value)
Returns:
Tuple of (success: bool, error_message: str)
"""
if not documents:
return True, ""
try:
session = self.get_session()
# Create table if it doesn't exist
if not self.table_exists(table_name):
self._create_table_from_documents(table_name, documents)
# Insert documents as rows
with session.begin():
for doc in documents:
# Build INSERT statement dynamically
columns = ", ".join([f'"{k}"' for k in doc.keys()])
values = ", ".join(["?" for _ in doc.values()])
placeholders = ", ".join([f":{k}" for k in doc.keys()])
insert_sql = f"INSERT INTO \"{table_name}\" ({columns}) VALUES ({placeholders})"
session.execute(text(insert_sql), doc)
logger.info(f"Inserted {len(documents)} rows into table '{table_name}'")
session.close()
return True, ""
except Exception as e:
error_msg = f"Error inserting documents into table '{table_name}': {e}"
logger.error(error_msg)
return False, error_msg
def find_documents(
self,
table_name: str,
limit: int = 1000,
offset: int = 0,
where_clause: Optional[str] = None,
) -> Tuple[bool, List[Dict[str, Any]], str]:
"""
Retrieve rows from a table
Args:
table_name: Name of the table
limit: Maximum number of rows to return
offset: Number of rows to skip (for pagination)
where_clause: Optional WHERE clause for filtering
Returns:
Tuple of (success: bool, data: list, error_message: str)
"""
try:
session = self.get_session()
if not self.table_exists(table_name):
return True, [], ""
# Build query
query_sql = f'SELECT * FROM "{table_name}"'
if where_clause:
query_sql += f" WHERE {where_clause}"
query_sql += f" LIMIT {limit} OFFSET {offset}"
result = session.execute(text(query_sql))
rows = result.fetchall()
# Convert rows to dictionaries
documents = [dict(row._mapping) for row in rows]
logger.info(
f"Retrieved {len(documents)} rows from table '{table_name}' with limit={limit}, offset={offset}"
)
session.close()
return True, documents, ""
except Exception as e:
error_msg = f"Error querying table '{table_name}': {e}"
logger.error(error_msg)
return False, [], error_msg
def save_schema(
self, schema_name: str, schema_definition: Dict[str, Any], table_name: Optional[str] = None
) -> Tuple[bool, str]:
"""
Save a schema mapping for reuse
Args:
schema_name: Name/identifier for the schema
schema_definition: Dictionary containing the schema mapping
table_name: Optional associated table name
Returns:
Tuple of (success: bool, error_message: str)
"""
try:
session = self.get_session()
schema = SchemaMetadata(
name=schema_name, mapping=schema_definition, table_name=table_name
)
session.add(schema)
session.commit()
logger.info(f"Schema '{schema_name}' saved successfully")
session.close()
return True, ""
except Exception as e:
error_msg = f"Error saving schema: {e}"
logger.error(error_msg)
return False, error_msg
def get_schemas(self) -> Tuple[bool, List[Dict[str, Any]], str]:
"""
Retrieve all stored schemas
Returns:
Tuple of (success: bool, schemas: list, error_message: str)
"""
try:
session = self.get_session()
schemas = session.query(SchemaMetadata).all()
schemas_list = [schema.to_dict() for schema in schemas]
logger.info(f"Retrieved {len(schemas_list)} schemas")
session.close()
return True, schemas_list, ""
except Exception as e:
error_msg = f"Error retrieving schemas: {e}"
logger.error(error_msg)
return False, [], error_msg
def table_exists(self, table_name: str) -> bool:
"""
Check if a table exists
Args:
table_name: Name of the table
Returns:
bool: True if table exists, False otherwise
"""
try:
inspector = inspect(self.engine)
tables = inspector.get_table_names()
return table_name in tables
except Exception as e:
logger.error(f"Error checking table existence: {e}")
return False
def get_table_count(self, table_name: str) -> Tuple[bool, int, str]:
"""
Get the number of rows in a table
Args:
table_name: Name of the table
Returns:
Tuple of (success: bool, count: int, error_message: str)
"""
try:
if not self.table_exists(table_name):
return True, 0, ""
session = self.get_session()
result = session.execute(text(f'SELECT COUNT(*) as count FROM "{table_name}"'))
count = result.scalar()
session.close()
return True, count, ""
except Exception as e:
error_msg = f"Error counting rows in table '{table_name}': {e}"
logger.error(error_msg)
return False, 0, error_msg
def get_all_tables(self) -> Tuple[bool, List[str], str]:
"""
Get list of all user-created tables (excluding system tables)
Returns:
Tuple of (success: bool, tables: list, error_message: str)
"""
try:
inspector = inspect(self.engine)
all_tables = inspector.get_table_names()
# Filter out system tables
user_tables = [t for t in all_tables if not t.startswith("pg_")]
return True, user_tables, ""
except Exception as e:
error_msg = f"Error listing tables: {e}"
logger.error(error_msg)
return False, [], error_msg
def _create_table_from_documents(
self, table_name: str, documents: List[Dict[str, Any]]
) -> bool:
"""
Create a table dynamically based on document structure
Args:
table_name: Name of the table to create
documents: Sample documents to infer schema from
Returns:
bool: True if successful, False otherwise
"""
try:
if not documents:
return False
session = self.get_session()
# Get all columns from all documents
all_columns = set()
for doc in documents:
all_columns.update(doc.keys())
# Create table with all columns as TEXT
columns = []
for col in sorted(all_columns):
columns.append(f'"{col}" TEXT')
columns_sql = ", ".join(columns)
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS "{table_name}" (
id SERIAL PRIMARY KEY,
{columns_sql},
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
session.execute(text(create_table_sql))
session.commit()
session.close()
logger.info(f"Created table '{table_name}' with columns: {list(all_columns)}")
return True
except Exception as e:
logger.error(f"Error creating table '{table_name}': {e}")
return False
# Global database instance
_db_instance: Optional[PostgreSQL] = None
def get_db() -> PostgreSQL:
"""
Get or create the global PostgreSQL instance
With auto-retry capability for connection failures
Returns:
PostgreSQL instance
Raises:
RuntimeError: If unable to connect
"""
global _db_instance
# If instance doesn't exist, create and try to connect
if _db_instance is None:
_db_instance = PostgreSQL(DATABASE_URL)
if not _db_instance.connect():
logger.warning("Initial connection failed, but instance created. Will retry on next request.")
# If instance exists but engine is None, try to reconnect
elif _db_instance.engine is None:
logger.info("Attempting to reconnect to PostgreSQL...")
if not _db_instance.connect():
logger.warning("Reconnection attempt failed")
return _db_instance