Spaces:
Running
Running
| """ | |
| SurrealDB connection module with retry logic for containerized deployments. | |
| This ensures the FastAPI app waits for SurrealDB to be ready before attempting connections. | |
| """ | |
| import asyncio | |
| import os | |
| from contextlib import asynccontextmanager | |
| from typing import Optional | |
| from loguru import logger | |
| from surrealdb import AsyncSurreal, RecordID | |
| class SurrealDBConnection: | |
| """Manages SurrealDB connections with retry logic.""" | |
| def __init__( | |
| self, | |
| url: Optional[str] = None, | |
| username: Optional[str] = None, | |
| password: Optional[str] = None, | |
| namespace: Optional[str] = None, | |
| database: Optional[str] = None, | |
| max_retries: int = 5, | |
| retry_delay: int = 2 | |
| ): | |
| self.url = url or os.getenv("SURREAL_URL", "ws://localhost:8000/rpc") | |
| self.username = username or os.getenv("SURREAL_USER", "root") | |
| self.password = password or os.getenv("SURREAL_PASS") or os.getenv("SURREAL_PASSWORD", "root") | |
| self.namespace = namespace or os.getenv("SURREAL_NAMESPACE", "open_notebook") | |
| self.database = database or os.getenv("SURREAL_DATABASE", "main") | |
| self.max_retries = max_retries | |
| self.retry_delay = retry_delay | |
| self._connection: Optional[AsyncSurreal] = None | |
| async def connect(self) -> AsyncSurreal: | |
| """ | |
| Connect to SurrealDB with retry logic. | |
| Retries up to max_retries times with exponential backoff. | |
| """ | |
| for attempt in range(1, self.max_retries + 1): | |
| try: | |
| logger.info(f"Attempting to connect to SurrealDB at {self.url} (attempt {attempt}/{self.max_retries})") | |
| db = AsyncSurreal(self.url) | |
| # Sign in with credentials | |
| await db.signin({ | |
| "username": self.username, | |
| "password": self.password, | |
| }) | |
| # Select namespace and database | |
| await db.use(self.namespace, self.database) | |
| logger.success(f"Successfully connected to SurrealDB: {self.namespace}/{self.database}") | |
| self._connection = db | |
| return db | |
| except Exception as e: | |
| logger.warning(f"Connection attempt {attempt}/{self.max_retries} failed: {str(e)}") | |
| if attempt < self.max_retries: | |
| wait_time = self.retry_delay * attempt # Exponential backoff | |
| logger.info(f"Retrying in {wait_time} seconds...") | |
| await asyncio.sleep(wait_time) | |
| else: | |
| logger.error(f"Failed to connect to SurrealDB after {self.max_retries} attempts") | |
| raise ConnectionError( | |
| f"Could not connect to SurrealDB at {self.url} after {self.max_retries} attempts. " | |
| "Please ensure SurrealDB is running and accessible." | |
| ) from e | |
| raise ConnectionError("Unexpected error in connection retry loop") | |
| async def close(self): | |
| """Close the database connection.""" | |
| if self._connection: | |
| try: | |
| await self._connection.close() | |
| logger.info("SurrealDB connection closed") | |
| except Exception as e: | |
| logger.error(f"Error closing connection: {e}") | |
| finally: | |
| self._connection = None | |
| async def get_connection(self): | |
| """ | |
| Context manager for database connections. | |
| Creates a new connection for each context. | |
| """ | |
| db = await self.connect() | |
| try: | |
| yield db | |
| finally: | |
| await db.close() | |
| # Global connection instance | |
| _db_connection = SurrealDBConnection() | |
| async def db_connection(): | |
| """ | |
| Get a database connection with automatic retry logic. | |
| This is the main function used throughout the application. | |
| """ | |
| async with _db_connection.get_connection() as db: | |
| yield db | |
| async def initialize_database(): | |
| """ | |
| Initialize database connection at application startup. | |
| This ensures SurrealDB is ready before accepting requests. | |
| """ | |
| logger.info("Initializing database connection...") | |
| try: | |
| async with db_connection() as db: | |
| # Test the connection with a simple query (SurrealDB 2.x compatible) | |
| result = await db.query("INFO FOR DB;") | |
| logger.success("Database connection test successful") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Database initialization failed: {e}") | |
| raise | |
| async def close_database(): | |
| """Close database connections at application shutdown.""" | |
| await _db_connection.close() | |