Spaces:
Sleeping
Sleeping
| import os | |
| from contextlib import asynccontextmanager | |
| from datetime import datetime, timezone | |
| from typing import Any, Dict, List, Optional, TypeVar, Union | |
| from loguru import logger | |
| from surrealdb import AsyncSurreal, RecordID # type: ignore | |
| # Import the new connection module with retry logic | |
| from open_notebook.database.connection import db_connection as db_connection_with_retry | |
| T = TypeVar("T", Dict[str, Any], List[Dict[str, Any]]) | |
| class Repository: | |
| """ | |
| Repository object that provides a clean interface for database operations. | |
| This wraps the module-level functions for easier use in domain models. | |
| """ | |
| async def query(self, query_str: str, vars: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: | |
| """Execute a SurrealQL query""" | |
| return await repo_query(query_str, vars) | |
| async def create(self, table: str, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Create a new record""" | |
| return await repo_create(table, data) | |
| async def update(self, id: str, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| """Update an existing record""" | |
| # Extract table from id (format: table:id) | |
| table = id.split(":")[0] if ":" in id else id | |
| return await repo_update(table, id, data) | |
| async def delete(self, record_id: Union[str, RecordID]) -> bool: | |
| """Delete a record""" | |
| await repo_delete(record_id) | |
| return True | |
| async def get(self, record_id: str) -> Optional[Dict[str, Any]]: | |
| """Get a record by ID""" | |
| results = await repo_query(f"SELECT * FROM {record_id}") | |
| if results and len(results) > 0: | |
| return results[0] if isinstance(results[0], dict) else None | |
| return None | |
| async def upsert(self, table: str, id: Optional[str], data: Dict[str, Any], add_timestamp: bool = False) -> List[Dict[str, Any]]: | |
| """Create or update a record""" | |
| return await repo_upsert(table, id, data, add_timestamp) | |
| async def relate(self, source: str, relationship: str, target: str, data: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: | |
| """Create a relationship between two records""" | |
| return await repo_relate(source, relationship, target, data) | |
| async def insert(self, table: str, data: List[Dict[str, Any]], ignore_duplicates: bool = False) -> List[Dict[str, Any]]: | |
| """Insert records into a table""" | |
| return await repo_insert(table, data, ignore_duplicates) | |
| # Global repository instance for easy import | |
| repo = Repository() | |
| def get_database_url(): | |
| """Get database URL with backward compatibility""" | |
| surreal_url = os.getenv("SURREAL_URL") | |
| if surreal_url: | |
| return surreal_url | |
| # Fallback to old format - WebSocket URL format | |
| address = os.getenv("SURREAL_ADDRESS", "localhost") | |
| port = os.getenv("SURREAL_PORT", "8000") | |
| return f"ws://{address}/rpc:{port}" | |
| def get_database_password(): | |
| """Get password with backward compatibility""" | |
| return os.getenv("SURREAL_PASSWORD") or os.getenv("SURREAL_PASS") | |
| def parse_record_ids(obj: Any) -> Any: | |
| """Recursively parse and convert RecordIDs into strings.""" | |
| if isinstance(obj, dict): | |
| return {k: parse_record_ids(v) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [parse_record_ids(item) for item in obj] | |
| elif isinstance(obj, RecordID): | |
| return str(obj) | |
| return obj | |
| def ensure_record_id(value: Union[str, RecordID]) -> RecordID: | |
| """Ensure a value is a RecordID.""" | |
| if isinstance(value, RecordID): | |
| return value | |
| return RecordID.parse(value) | |
| async def db_connection(): | |
| """ | |
| Get database connection with retry logic. | |
| Uses the connection module that handles SurrealDB startup delays. | |
| """ | |
| async with db_connection_with_retry() as db: | |
| yield db | |
| async def repo_query( | |
| query_str: str, vars: Optional[Dict[str, Any]] = None | |
| ) -> List[Dict[str, Any]]: | |
| """Execute a SurrealQL query and return the results""" | |
| async with db_connection() as connection: | |
| try: | |
| result = parse_record_ids(await connection.query(query_str, vars)) | |
| if isinstance(result, str): | |
| raise RuntimeError(result) | |
| return result | |
| except RuntimeError as e: | |
| # RuntimeError is raised for retriable transaction conflicts - log without stack trace | |
| logger.error(str(e)) | |
| raise | |
| except Exception as e: | |
| logger.exception(e) | |
| raise | |
| async def repo_create(table: str, data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Create a new record in the specified table""" | |
| # Remove 'id' attribute if it exists in data | |
| data.pop("id", None) | |
| data["created"] = datetime.now(timezone.utc) | |
| data["updated"] = datetime.now(timezone.utc) | |
| try: | |
| async with db_connection() as connection: | |
| return parse_record_ids(await connection.insert(table, data)) | |
| except RuntimeError as e: | |
| logger.error(str(e)) | |
| raise | |
| except Exception as e: | |
| logger.exception(e) | |
| raise RuntimeError("Failed to create record") | |
| async def repo_relate( | |
| source: str, relationship: str, target: str, data: Optional[Dict[str, Any]] = None | |
| ) -> List[Dict[str, Any]]: | |
| """Create a relationship between two records with optional data""" | |
| if data is None: | |
| data = {} | |
| query = f"RELATE {source}->{relationship}->{target} CONTENT $data;" | |
| # logger.debug(f"Relate query: {query}") | |
| return await repo_query( | |
| query, | |
| { | |
| "data": data, | |
| }, | |
| ) | |
| async def repo_upsert( | |
| table: str, id: Optional[str], data: Dict[str, Any], add_timestamp: bool = False | |
| ) -> List[Dict[str, Any]]: | |
| """Create or update a record in the specified table""" | |
| data.pop("id", None) | |
| if add_timestamp: | |
| data["updated"] = datetime.now(timezone.utc) | |
| query = f"UPSERT {id if id else table} MERGE $data;" | |
| return await repo_query(query, {"data": data}) | |
| async def repo_update( | |
| table: str, id: str, data: Dict[str, Any] | |
| ) -> List[Dict[str, Any]]: | |
| """Update an existing record by table and id""" | |
| # If id already contains the table name, use it as is | |
| try: | |
| if isinstance(id, RecordID) or (":" in id and id.startswith(f"{table}:")): | |
| record_id = id | |
| else: | |
| record_id = f"{table}:{id}" | |
| data.pop("id", None) | |
| if "created" in data and isinstance(data["created"], str): | |
| data["created"] = datetime.fromisoformat(data["created"]) | |
| data["updated"] = datetime.now(timezone.utc) | |
| query = f"UPDATE {record_id} MERGE $data;" | |
| # logger.debug(f"Update query: {query}") | |
| result = await repo_query(query, {"data": data}) | |
| # if isinstance(result, list): | |
| # return [_return_data(item) for item in result] | |
| return parse_record_ids(result) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to update record: {str(e)}") | |
| async def repo_delete(record_id: Union[str, RecordID]): | |
| """Delete a record by record id""" | |
| try: | |
| async with db_connection() as connection: | |
| return await connection.delete(ensure_record_id(record_id)) | |
| except Exception as e: | |
| logger.exception(e) | |
| raise RuntimeError(f"Failed to delete record: {str(e)}") | |
| async def repo_insert( | |
| table: str, data: List[Dict[str, Any]], ignore_duplicates: bool = False | |
| ) -> List[Dict[str, Any]]: | |
| """Create a new record in the specified table""" | |
| try: | |
| async with db_connection() as connection: | |
| return parse_record_ids(await connection.insert(table, data)) | |
| except Exception as e: | |
| if ignore_duplicates and "already contains" in str(e): | |
| return [] | |
| logger.exception(e) | |
| raise RuntimeError("Failed to create record") | |