Spaces:
Running
Running
| from __future__ import annotations | |
| import time | |
| from typing import Any | |
| from app.core.database import ConnectionConfig, pool_manager | |
| from app.core.database.base import StatementResult | |
| from app.core.logger import get_logger | |
| from app.models.schemas import ( | |
| DatabaseQueryError, | |
| DatabaseQueryRequest, | |
| DatabaseQueryResponse, | |
| DatabaseValidateRequest, | |
| DatabaseValidateResponse, | |
| StatementResultSchema, | |
| TableValidationResult, | |
| ) | |
| _logger = get_logger(__name__) | |
| def _root_cause(exc: Exception) -> str: | |
| cause = exc.__cause__ or exc.__context__ | |
| if cause: | |
| return f"{exc} [{cause}]" | |
| return str(exc) | |
| class DatabaseService: | |
| async def validate_connection(self, request: DatabaseValidateRequest) -> DatabaseValidateResponse: | |
| start_time = time.monotonic() | |
| config = ConnectionConfig(**request.to_connection_config()) | |
| tables = list(request.table_or_collection_names) | |
| results: list[TableValidationResult] = [] | |
| try: | |
| executor = await pool_manager.get_executor(config) | |
| except Exception as exc: | |
| elapsed = round((time.monotonic() - start_time) * 1000, 2) | |
| _logger.error("Connection failed for %s: %s", config.safe_repr, exc) | |
| return DatabaseValidateResponse( | |
| success=False, | |
| time_ms=elapsed, | |
| message="Connection failed", | |
| connection_details=config.safe_repr, | |
| error_message=_root_cause(exc), | |
| tables=[], | |
| ) | |
| try: | |
| if request.db_type == "mongodb": | |
| pool = await executor._get_or_create_pool() | |
| db = pool[config.database] | |
| await db.command("ping") | |
| message = "Connected successfully" | |
| if tables: | |
| existing = await db.list_collection_names() | |
| existing_set = set(existing) | |
| for name in tables: | |
| results.append(TableValidationResult( | |
| name=name, | |
| exists=name in existing_set, | |
| )) | |
| else: | |
| test = await executor.execute(["SELECT 1 AS test"]) | |
| if not test or not test[0].success: | |
| raise RuntimeError(test[0].error if test else "No response") | |
| message = "Connected successfully" | |
| if tables: | |
| for name in tables: | |
| if request.db_type == "postgresql": | |
| q = ( | |
| f"SELECT 1 FROM information_schema.tables " | |
| f"WHERE table_schema = '{config.selected_schema}' AND table_name = '{name}'" | |
| ) | |
| else: | |
| q = ( | |
| f"SELECT 1 FROM information_schema.tables " | |
| f"WHERE TABLE_SCHEMA = DATABASE() AND table_name = '{name}'" | |
| ) | |
| r = await executor.execute([q]) | |
| exists = bool(r and r[0].success and r[0].rows > 0) | |
| results.append(TableValidationResult(name=name, exists=exists)) | |
| except Exception as exc: | |
| elapsed = round((time.monotonic() - start_time) * 1000, 2) | |
| _logger.error("Validation failed for %s: %s", config.safe_repr, exc) | |
| return DatabaseValidateResponse( | |
| success=False, | |
| time_ms=elapsed, | |
| message="Validation failed", | |
| connection_details=config.safe_repr, | |
| error_message=_root_cause(exc), | |
| tables=results, | |
| ) | |
| elapsed = round((time.monotonic() - start_time) * 1000, 2) | |
| return DatabaseValidateResponse( | |
| success=True, | |
| time_ms=elapsed, | |
| message=message, | |
| connection_details=config.safe_repr, | |
| tables=results, | |
| ) | |
| async def execute_query(self, request: DatabaseQueryRequest) -> DatabaseQueryResponse: | |
| start_time = time.monotonic() | |
| config = ConnectionConfig(**request.to_connection_config()) | |
| try: | |
| executor = await pool_manager.get_executor(config) | |
| except Exception as exc: | |
| elapsed = (time.monotonic() - start_time) * 1000 | |
| _logger.error( | |
| "Failed to acquire executor for %s: %s", | |
| config.safe_repr, exc, | |
| ) | |
| return DatabaseQueryResponse( | |
| success=False, | |
| execution_time_ms=round(elapsed, 2), | |
| error=DatabaseQueryError( | |
| message=f"Connection failed: {_root_cause(exc)}", | |
| code=type(exc).__name__, | |
| ), | |
| ) | |
| try: | |
| results = await executor.execute( | |
| request.query, | |
| use_transaction=request.use_transaction, | |
| ) | |
| except Exception as exc: | |
| elapsed = (time.monotonic() - start_time) * 1000 | |
| _logger.error( | |
| "Query execution failed for %s: %s", | |
| config.safe_repr, exc, | |
| ) | |
| return DatabaseQueryResponse( | |
| success=False, | |
| execution_time_ms=round(elapsed, 2), | |
| error=DatabaseQueryError( | |
| message=f"Execution failed: {_root_cause(exc)}", | |
| code=type(exc).__name__, | |
| ), | |
| ) | |
| elapsed = (time.monotonic() - start_time) * 1000 | |
| statement_results = [ | |
| StatementResultSchema( | |
| success=r.success, | |
| rows=r.rows, | |
| data=r.data, | |
| error=r.error, | |
| error_code=r.error_code, | |
| ) | |
| for r in results | |
| ] | |
| overall_success = all(r.success for r in results) | |
| if overall_success: | |
| _logger.info( | |
| "Query success for %s (%d stmts, %.2fms)", | |
| config.safe_repr, len(results), elapsed, | |
| ) | |
| return DatabaseQueryResponse( | |
| success=True, | |
| execution_time_ms=round(elapsed, 2), | |
| results=statement_results, | |
| ) | |
| _logger.warning( | |
| "Query partial/full failure for %s (%d stmts, %.2fms)", | |
| config.safe_repr, len(results), elapsed, | |
| ) | |
| return DatabaseQueryResponse( | |
| success=False, | |
| execution_time_ms=round(elapsed, 2), | |
| results=statement_results, | |
| ) | |