from __future__ import annotations import time from typing import Any from app.core.database import ConnectionConfig, pool_manager from app.core.database.base import StatementResult from app.core.logger import get_logger from app.models.schemas import ( DatabaseQueryError, DatabaseQueryRequest, DatabaseQueryResponse, DatabaseValidateRequest, DatabaseValidateResponse, StatementResultSchema, TableValidationResult, ) _logger = get_logger(__name__) def _root_cause(exc: Exception) -> str: cause = exc.__cause__ or exc.__context__ if cause: return f"{exc} [{cause}]" return str(exc) class DatabaseService: async def validate_connection(self, request: DatabaseValidateRequest) -> DatabaseValidateResponse: start_time = time.monotonic() config = ConnectionConfig(**request.to_connection_config()) tables = list(request.table_or_collection_names) results: list[TableValidationResult] = [] try: executor = await pool_manager.get_executor(config) except Exception as exc: elapsed = round((time.monotonic() - start_time) * 1000, 2) _logger.error("Connection failed for %s: %s", config.safe_repr, exc) return DatabaseValidateResponse( success=False, time_ms=elapsed, message="Connection failed", connection_details=config.safe_repr, error_message=_root_cause(exc), tables=[], ) try: if request.db_type == "mongodb": pool = await executor._get_or_create_pool() db = pool[config.database] await db.command("ping") message = "Connected successfully" if tables: existing = await db.list_collection_names() existing_set = set(existing) for name in tables: results.append(TableValidationResult( name=name, exists=name in existing_set, )) else: test = await executor.execute(["SELECT 1 AS test"]) if not test or not test[0].success: raise RuntimeError(test[0].error if test else "No response") message = "Connected successfully" if tables: for name in tables: if request.db_type == "postgresql": q = ( f"SELECT 1 FROM information_schema.tables " f"WHERE table_schema = '{config.selected_schema}' AND table_name = '{name}'" ) else: q = ( f"SELECT 1 FROM information_schema.tables " f"WHERE TABLE_SCHEMA = DATABASE() AND table_name = '{name}'" ) r = await executor.execute([q]) exists = bool(r and r[0].success and r[0].rows > 0) results.append(TableValidationResult(name=name, exists=exists)) except Exception as exc: elapsed = round((time.monotonic() - start_time) * 1000, 2) _logger.error("Validation failed for %s: %s", config.safe_repr, exc) return DatabaseValidateResponse( success=False, time_ms=elapsed, message="Validation failed", connection_details=config.safe_repr, error_message=_root_cause(exc), tables=results, ) elapsed = round((time.monotonic() - start_time) * 1000, 2) return DatabaseValidateResponse( success=True, time_ms=elapsed, message=message, connection_details=config.safe_repr, tables=results, ) async def execute_query(self, request: DatabaseQueryRequest) -> DatabaseQueryResponse: start_time = time.monotonic() config = ConnectionConfig(**request.to_connection_config()) try: executor = await pool_manager.get_executor(config) except Exception as exc: elapsed = (time.monotonic() - start_time) * 1000 _logger.error( "Failed to acquire executor for %s: %s", config.safe_repr, exc, ) return DatabaseQueryResponse( success=False, execution_time_ms=round(elapsed, 2), error=DatabaseQueryError( message=f"Connection failed: {_root_cause(exc)}", code=type(exc).__name__, ), ) try: results = await executor.execute( request.query, use_transaction=request.use_transaction, ) except Exception as exc: elapsed = (time.monotonic() - start_time) * 1000 _logger.error( "Query execution failed for %s: %s", config.safe_repr, exc, ) return DatabaseQueryResponse( success=False, execution_time_ms=round(elapsed, 2), error=DatabaseQueryError( message=f"Execution failed: {_root_cause(exc)}", code=type(exc).__name__, ), ) elapsed = (time.monotonic() - start_time) * 1000 statement_results = [ StatementResultSchema( success=r.success, rows=r.rows, data=r.data, error=r.error, error_code=r.error_code, ) for r in results ] overall_success = all(r.success for r in results) if overall_success: _logger.info( "Query success for %s (%d stmts, %.2fms)", config.safe_repr, len(results), elapsed, ) return DatabaseQueryResponse( success=True, execution_time_ms=round(elapsed, 2), results=statement_results, ) _logger.warning( "Query partial/full failure for %s (%d stmts, %.2fms)", config.safe_repr, len(results), elapsed, ) return DatabaseQueryResponse( success=False, execution_time_ms=round(elapsed, 2), results=statement_results, )