Spaces:
Paused
Paused
| """ | |
| Simplified Database Service | |
| Removes complexity and improves maintainability | |
| """ | |
| from typing import Any, Optional, TypeVar | |
| from fastapi import HTTPException | |
| from sqlalchemy import text | |
| from core.database import get_db | |
| from core.models.base import Base | |
| GenericModel = TypeVar("GenericModel", bound=Base) | |
| class DatabaseService: | |
| """Simplified database service with common patterns""" | |
| def __init__(self): | |
| self.db_session = get_db | |
| self._connection = None | |
| self.is_connected = False | |
| async def initialize(self): | |
| """Initialize database connection""" | |
| try: | |
| # For this simplified version, we'll simulate connection | |
| self.is_connected = True | |
| print("Database service initialized (simplified)") | |
| except Exception as e: | |
| print(f"Database initialization error: {e}") | |
| self.is_connected = False | |
| async def create(self, model_data: dict[str, Any], model_class: type[GenericModel]) -> GenericModel: | |
| """Create database record""" | |
| try: | |
| with self.db_session() as session: | |
| db_model = model_class(**model_data) | |
| session.add(db_model) | |
| session.commit() | |
| session.refresh(db_model) | |
| return db_model | |
| except Exception as e: | |
| session.rollback() | |
| raise HTTPException(status_code=500, detail=f"Failed to create {model_class.__name__}: {str(e)}") | |
| async def get_by_id(self, model_class: type[GenericModel], record_id: str) -> Optional[GenericModel]: | |
| """Get record by ID""" | |
| try: | |
| with self.db_session() as session: | |
| return session.query(model_class).filter(model_class.id == record_id).first() | |
| except Exception as e: | |
| print(f"Database error: {e}") | |
| return None | |
| async def get_all( | |
| self, model_class: type[GenericModel], filters: dict[str, Any] = None, pagination: dict[str, Any] = None | |
| ) -> list[GenericModel]: | |
| """Get all records with optional filters and pagination""" | |
| try: | |
| with self.db_session() as session: | |
| query = session.query(model_class) | |
| # Apply filters | |
| if filters: | |
| for key, value in filters.items(): | |
| if hasattr(model_class, key): | |
| query = query.filter(getattr(model_class, key) == value) | |
| # Apply pagination | |
| if pagination: | |
| page = pagination.get("page", 1) | |
| per_page = pagination.get("per_page", 100) | |
| offset = (page - 1) * per_page | |
| query = query.offset(offset).limit(per_page) | |
| return query.all() | |
| except Exception as e: | |
| print(f"Database error: {e}") | |
| return [] | |
| async def update( | |
| self, model_class: type[GenericModel], record_id: str, update_data: dict[str, Any] | |
| ) -> Optional[GenericModel]: | |
| """Update record by ID""" | |
| try: | |
| with self.db_session() as session: | |
| db_model = session.query(model_class).filter(model_class.id == record_id).first() | |
| if db_model: | |
| for key, value in update_data.items(): | |
| if hasattr(db_model, key): | |
| setattr(db_model, key, value) | |
| session.commit() | |
| session.refresh(db_model) | |
| return db_model | |
| except Exception as e: | |
| session.rollback() | |
| print(f"Database update error: {e}") | |
| return None | |
| async def delete(self, model_class: type[GenericModel], record_id: str) -> bool: | |
| """Delete record by ID""" | |
| try: | |
| with self.db_session() as session: | |
| db_model = session.query(model_class).filter(model_class.id == record_id).first() | |
| if db_model: | |
| session.delete(db_model) | |
| session.commit() | |
| return True | |
| return False | |
| except Exception as e: | |
| session.rollback() | |
| print(f"Database delete error: {e}") | |
| return False | |
| async def execute_query(self, query: str, params: tuple = None) -> list[dict[str, Any]]: | |
| """Execute raw SQL query""" | |
| try: | |
| with self.db_session() as session: | |
| result = session.execute(text(query), params or ()) | |
| # Convert to dict format | |
| columns = result.keys() | |
| return [dict(zip(columns, row)) for row in result.fetchall()] | |
| except Exception as e: | |
| print(f"Query execution error: {e}") | |
| return [] | |
| async def execute_insert(self, query: str, params: tuple = None) -> Optional[str]: | |
| """Execute insert query and return ID""" | |
| try: | |
| with self.db_session() as session: | |
| result = session.execute(text(query), params or ()) | |
| session.commit() | |
| return str(result.lastrowid) if hasattr(result, "lastrowid") else None | |
| except Exception as e: | |
| session.rollback() | |
| print(f"Insert execution error: {e}") | |
| return None | |
| async def count(self, model_class: type[GenericModel], filters: dict[str, Any] = None) -> int: | |
| """Count records with optional filters""" | |
| try: | |
| with self.db_session() as session: | |
| query = session.query(model_class) | |
| # Apply filters | |
| if filters: | |
| for key, value in filters.items(): | |
| if hasattr(model_class, key): | |
| query = query.filter(getattr(model_class, key) == value) | |
| return query.count() | |
| except Exception as e: | |
| print(f"Database count error: {e}") | |
| return 0 | |
| # Helper functions for common database patterns | |
| def apply_pagination(query, page: int, per_page: int): | |
| """Apply pagination to SQLAlchemy query""" | |
| offset = (page - 1) * per_page | |
| return query.offset(offset).limit(per_page) | |
| def build_safe_where_clause(filters: dict[str, Any], allowed_columns: list[str]) -> str: | |
| """Build safe WHERE clause from filters""" | |
| if not filters: | |
| return "1=1" | |
| conditions = [] | |
| for key, value in filters.items(): | |
| if key not in allowed_columns: | |
| raise ValueError(f"Unsafe column name: {key}") | |
| if isinstance(value, str): | |
| conditions.append(f"{key} = '{value}'") | |
| elif isinstance(value, list): | |
| placeholders = ", ".join(["%s"] * len(value)) | |
| conditions.append(f"{key} IN ({placeholders})") | |
| else: | |
| conditions.append(f"{key} = {value}") | |
| return " AND ".join(conditions) | |