import os from contextlib import asynccontextmanager from datetime import datetime, timezone from typing import Any, Dict, List, Optional, TypeVar, Union from loguru import logger from surrealdb import AsyncSurreal, RecordID # type: ignore # Import the new connection module with retry logic from open_notebook.database.connection import db_connection as db_connection_with_retry T = TypeVar("T", Dict[str, Any], List[Dict[str, Any]]) class Repository: """ Repository object that provides a clean interface for database operations. This wraps the module-level functions for easier use in domain models. """ async def query(self, query_str: str, vars: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: """Execute a SurrealQL query""" return await repo_query(query_str, vars) async def create(self, table: str, data: Dict[str, Any]) -> Dict[str, Any]: """Create a new record""" return await repo_create(table, data) async def update(self, id: str, data: Dict[str, Any]) -> List[Dict[str, Any]]: """Update an existing record""" # Extract table from id (format: table:id) table = id.split(":")[0] if ":" in id else id return await repo_update(table, id, data) async def delete(self, record_id: Union[str, RecordID]) -> bool: """Delete a record""" await repo_delete(record_id) return True async def get(self, record_id: str) -> Optional[Dict[str, Any]]: """Get a record by ID""" results = await repo_query(f"SELECT * FROM {record_id}") if results and len(results) > 0: return results[0] if isinstance(results[0], dict) else None return None async def upsert(self, table: str, id: Optional[str], data: Dict[str, Any], add_timestamp: bool = False) -> List[Dict[str, Any]]: """Create or update a record""" return await repo_upsert(table, id, data, add_timestamp) async def relate(self, source: str, relationship: str, target: str, data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: """Create a relationship between two records""" return await repo_relate(source, relationship, target, data) async def insert(self, table: str, data: List[Dict[str, Any]], ignore_duplicates: bool = False) -> List[Dict[str, Any]]: """Insert records into a table""" return await repo_insert(table, data, ignore_duplicates) # Global repository instance for easy import repo = Repository() def get_database_url(): """Get database URL with backward compatibility""" surreal_url = os.getenv("SURREAL_URL") if surreal_url: return surreal_url # Fallback to old format - WebSocket URL format address = os.getenv("SURREAL_ADDRESS", "localhost") port = os.getenv("SURREAL_PORT", "8000") return f"ws://{address}/rpc:{port}" def get_database_password(): """Get password with backward compatibility""" return os.getenv("SURREAL_PASSWORD") or os.getenv("SURREAL_PASS") def parse_record_ids(obj: Any) -> Any: """Recursively parse and convert RecordIDs into strings.""" if isinstance(obj, dict): return {k: parse_record_ids(v) for k, v in obj.items()} elif isinstance(obj, list): return [parse_record_ids(item) for item in obj] elif isinstance(obj, RecordID): return str(obj) return obj def ensure_record_id(value: Union[str, RecordID]) -> RecordID: """Ensure a value is a RecordID.""" if isinstance(value, RecordID): return value return RecordID.parse(value) @asynccontextmanager async def db_connection(): """ Get database connection with retry logic. Uses the connection module that handles SurrealDB startup delays. """ async with db_connection_with_retry() as db: yield db async def repo_query( query_str: str, vars: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """Execute a SurrealQL query and return the results""" async with db_connection() as connection: try: result = parse_record_ids(await connection.query(query_str, vars)) if isinstance(result, str): raise RuntimeError(result) return result except RuntimeError as e: # RuntimeError is raised for retriable transaction conflicts - log without stack trace logger.error(str(e)) raise except Exception as e: logger.exception(e) raise async def repo_create(table: str, data: Dict[str, Any]) -> Dict[str, Any]: """Create a new record in the specified table""" # Remove 'id' attribute if it exists in data data.pop("id", None) data["created"] = datetime.now(timezone.utc) data["updated"] = datetime.now(timezone.utc) try: async with db_connection() as connection: return parse_record_ids(await connection.insert(table, data)) except RuntimeError as e: logger.error(str(e)) raise except Exception as e: logger.exception(e) raise RuntimeError("Failed to create record") async def repo_relate( source: str, relationship: str, target: str, data: Optional[Dict[str, Any]] = None ) -> List[Dict[str, Any]]: """Create a relationship between two records with optional data""" if data is None: data = {} query = f"RELATE {source}->{relationship}->{target} CONTENT $data;" # logger.debug(f"Relate query: {query}") return await repo_query( query, { "data": data, }, ) async def repo_upsert( table: str, id: Optional[str], data: Dict[str, Any], add_timestamp: bool = False ) -> List[Dict[str, Any]]: """Create or update a record in the specified table""" data.pop("id", None) if add_timestamp: data["updated"] = datetime.now(timezone.utc) query = f"UPSERT {id if id else table} MERGE $data;" return await repo_query(query, {"data": data}) async def repo_update( table: str, id: str, data: Dict[str, Any] ) -> List[Dict[str, Any]]: """Update an existing record by table and id""" # If id already contains the table name, use it as is try: if isinstance(id, RecordID) or (":" in id and id.startswith(f"{table}:")): record_id = id else: record_id = f"{table}:{id}" data.pop("id", None) if "created" in data and isinstance(data["created"], str): data["created"] = datetime.fromisoformat(data["created"]) data["updated"] = datetime.now(timezone.utc) query = f"UPDATE {record_id} MERGE $data;" # logger.debug(f"Update query: {query}") result = await repo_query(query, {"data": data}) # if isinstance(result, list): # return [_return_data(item) for item in result] return parse_record_ids(result) except Exception as e: raise RuntimeError(f"Failed to update record: {str(e)}") async def repo_delete(record_id: Union[str, RecordID]): """Delete a record by record id""" try: async with db_connection() as connection: return await connection.delete(ensure_record_id(record_id)) except Exception as e: logger.exception(e) raise RuntimeError(f"Failed to delete record: {str(e)}") async def repo_insert( table: str, data: List[Dict[str, Any]], ignore_duplicates: bool = False ) -> List[Dict[str, Any]]: """Create a new record in the specified table""" try: async with db_connection() as connection: return parse_record_ids(await connection.insert(table, data)) except Exception as e: if ignore_duplicates and "already contains" in str(e): return [] logger.exception(e) raise RuntimeError("Failed to create record")