File size: 4,868 Bytes
f871fed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""
SurrealDB connection module with retry logic for containerized deployments.
This ensures the FastAPI app waits for SurrealDB to be ready before attempting connections.
"""
import asyncio
import os
from contextlib import asynccontextmanager
from typing import Optional

from loguru import logger
from surrealdb import AsyncSurreal, RecordID


class SurrealDBConnection:
    """Manages SurrealDB connections with retry logic."""
    
    def __init__(
        self,
        url: Optional[str] = None,
        username: Optional[str] = None,
        password: Optional[str] = None,
        namespace: Optional[str] = None,
        database: Optional[str] = None,
        max_retries: int = 5,
        retry_delay: int = 2
    ):
        self.url = url or os.getenv("SURREAL_URL", "ws://localhost:8000/rpc")
        self.username = username or os.getenv("SURREAL_USER", "root")
        self.password = password or os.getenv("SURREAL_PASS") or os.getenv("SURREAL_PASSWORD", "root")
        self.namespace = namespace or os.getenv("SURREAL_NAMESPACE", "open_notebook")
        self.database = database or os.getenv("SURREAL_DATABASE", "main")
        self.max_retries = max_retries
        self.retry_delay = retry_delay
        self._connection: Optional[AsyncSurreal] = None
    
    async def connect(self) -> AsyncSurreal:
        """
        Connect to SurrealDB with retry logic.
        Retries up to max_retries times with exponential backoff.
        """
        for attempt in range(1, self.max_retries + 1):
            try:
                logger.info(f"Attempting to connect to SurrealDB at {self.url} (attempt {attempt}/{self.max_retries})")
                
                db = AsyncSurreal(self.url)
                
                # Sign in with credentials
                await db.signin({
                    "username": self.username,
                    "password": self.password,
                })
                
                # Select namespace and database
                await db.use(self.namespace, self.database)
                
                logger.success(f"Successfully connected to SurrealDB: {self.namespace}/{self.database}")
                self._connection = db
                return db
                
            except Exception as e:
                logger.warning(f"Connection attempt {attempt}/{self.max_retries} failed: {str(e)}")
                
                if attempt < self.max_retries:
                    wait_time = self.retry_delay * attempt  # Exponential backoff
                    logger.info(f"Retrying in {wait_time} seconds...")
                    await asyncio.sleep(wait_time)
                else:
                    logger.error(f"Failed to connect to SurrealDB after {self.max_retries} attempts")
                    raise ConnectionError(
                        f"Could not connect to SurrealDB at {self.url} after {self.max_retries} attempts. "
                        "Please ensure SurrealDB is running and accessible."
                    ) from e
        
        raise ConnectionError("Unexpected error in connection retry loop")
    
    async def close(self):
        """Close the database connection."""
        if self._connection:
            try:
                await self._connection.close()
                logger.info("SurrealDB connection closed")
            except Exception as e:
                logger.error(f"Error closing connection: {e}")
            finally:
                self._connection = None
    
    @asynccontextmanager
    async def get_connection(self):
        """
        Context manager for database connections.
        Creates a new connection for each context.
        """
        db = await self.connect()
        try:
            yield db
        finally:
            await db.close()


# Global connection instance
_db_connection = SurrealDBConnection()


@asynccontextmanager
async def db_connection():
    """
    Get a database connection with automatic retry logic.
    This is the main function used throughout the application.
    """
    async with _db_connection.get_connection() as db:
        yield db


async def initialize_database():
    """
    Initialize database connection at application startup.
    This ensures SurrealDB is ready before accepting requests.
    """
    logger.info("Initializing database connection...")
    try:
        async with db_connection() as db:
            # Test the connection with a simple query (SurrealDB 2.x compatible)
            result = await db.query("INFO FOR DB;")
            logger.success("Database connection test successful")
            return True
    except Exception as e:
        logger.error(f"Database initialization failed: {e}")
        raise


async def close_database():
    """Close database connections at application shutdown."""
    await _db_connection.close()