Spaces:
Sleeping
Sleeping
| from datetime import datetime | |
| from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union, cast | |
| from loguru import logger | |
| from pydantic import BaseModel, ValidationError, field_validator, model_validator | |
| from open_notebook.database.repository import ( | |
| ensure_record_id, | |
| repo_create, | |
| repo_delete, | |
| repo_query, | |
| repo_relate, | |
| repo_update, | |
| repo_upsert, | |
| ) | |
| from open_notebook.exceptions import ( | |
| DatabaseOperationError, | |
| InvalidInputError, | |
| NotFoundError, | |
| ) | |
| T = TypeVar("T", bound="ObjectModel") | |
| class ObjectModel(BaseModel): | |
| id: Optional[str] = None | |
| table_name: ClassVar[str] = "" | |
| nullable_fields: ClassVar[set[str]] = set() # Fields that can be saved as None | |
| created: Optional[datetime] = None | |
| updated: Optional[datetime] = None | |
| async def get_all(cls: Type[T], order_by=None) -> List[T]: | |
| try: | |
| # If called from a specific subclass, use its table_name | |
| if cls.table_name: | |
| target_class = cls | |
| table_name = cls.table_name | |
| else: | |
| # This path is taken if called directly from ObjectModel | |
| raise InvalidInputError( | |
| "get_all() must be called from a specific model class" | |
| ) | |
| if order_by: | |
| query = f"SELECT * FROM {table_name} ORDER BY {order_by}" | |
| else: | |
| query = f"SELECT * FROM {table_name}" | |
| result = await repo_query(query) | |
| objects = [] | |
| for obj in result: | |
| try: | |
| objects.append(target_class(**obj)) | |
| except Exception as e: | |
| logger.critical(f"Error creating object: {str(e)}") | |
| return objects | |
| except Exception as e: | |
| logger.error(f"Error fetching all {cls.table_name}: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| async def get(cls: Type[T], id: str) -> T: | |
| if not id: | |
| raise InvalidInputError("ID cannot be empty") | |
| try: | |
| # Get the table name from the ID (everything before the first colon) | |
| table_name = id.split(":")[0] if ":" in id else id | |
| # If we're calling from a specific subclass and IDs match, use that class | |
| if cls.table_name and cls.table_name == table_name: | |
| target_class: Type[T] = cls | |
| else: | |
| # Otherwise, find the appropriate subclass based on table_name | |
| found_class = cls._get_class_by_table_name(table_name) | |
| if not found_class: | |
| raise InvalidInputError(f"No class found for table {table_name}") | |
| target_class = cast(Type[T], found_class) | |
| result = await repo_query("SELECT * FROM $id", {"id": ensure_record_id(id)}) | |
| if result: | |
| return target_class(**result[0]) | |
| else: | |
| raise NotFoundError(f"{table_name} with id {id} not found") | |
| except Exception as e: | |
| logger.error(f"Error fetching object with id {id}: {str(e)}") | |
| logger.exception(e) | |
| raise NotFoundError(f"Object with id {id} not found - {str(e)}") | |
| def _get_class_by_table_name(cls, table_name: str) -> Optional[Type["ObjectModel"]]: | |
| """Find the appropriate subclass based on table_name.""" | |
| def get_all_subclasses(c: Type["ObjectModel"]) -> List[Type["ObjectModel"]]: | |
| all_subclasses: List[Type["ObjectModel"]] = [] | |
| for subclass in c.__subclasses__(): | |
| all_subclasses.append(subclass) | |
| all_subclasses.extend(get_all_subclasses(subclass)) | |
| return all_subclasses | |
| for subclass in get_all_subclasses(ObjectModel): | |
| if hasattr(subclass, "table_name") and subclass.table_name == table_name: | |
| return subclass | |
| return None | |
| def needs_embedding(self) -> bool: | |
| return False | |
| def get_embedding_content(self) -> Optional[str]: | |
| return None | |
| async def save(self) -> None: | |
| from open_notebook.domain.models import model_manager | |
| try: | |
| self.model_validate(self.model_dump(), strict=True) | |
| data = self._prepare_save_data() | |
| data["updated"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| if self.needs_embedding(): | |
| embedding_content = self.get_embedding_content() | |
| if embedding_content: | |
| EMBEDDING_MODEL = await model_manager.get_embedding_model() | |
| if not EMBEDDING_MODEL: | |
| logger.warning( | |
| "No embedding model found. Content will not be searchable." | |
| ) | |
| data["embedding"] = ( | |
| (await EMBEDDING_MODEL.aembed([embedding_content]))[0] | |
| if EMBEDDING_MODEL | |
| else [] | |
| ) | |
| repo_result: Union[List[Dict[str, Any]], Dict[str, Any]] | |
| if self.id is None: | |
| data["created"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
| repo_result = await repo_create(self.__class__.table_name, data) | |
| else: | |
| data["created"] = ( | |
| self.created.strftime("%Y-%m-%d %H:%M:%S") | |
| if isinstance(self.created, datetime) | |
| else self.created | |
| ) | |
| logger.debug(f"Updating record with id {self.id}") | |
| repo_result = await repo_update( | |
| self.__class__.table_name, self.id, data | |
| ) | |
| # Update the current instance with the result | |
| # repo_result is a list of dictionaries | |
| result_list: List[Dict[str, Any]] = repo_result if isinstance(repo_result, list) else [repo_result] | |
| if result_list and len(result_list) > 0: | |
| first_result = result_list[0] | |
| # Only iterate if it's actually a dict | |
| if isinstance(first_result, dict): | |
| for key, value in first_result.items(): | |
| if hasattr(self, key): | |
| if isinstance(getattr(self, key), BaseModel): | |
| setattr(self, key, type(getattr(self, key))(**value)) | |
| else: | |
| setattr(self, key, value) | |
| elif isinstance(first_result, str): | |
| # Database returned an error string | |
| logger.error(f"Database returned string instead of dict: {first_result}") | |
| raise DatabaseOperationError(f"Database error: {first_result}") | |
| except ValidationError as e: | |
| logger.error(f"Validation failed: {e}") | |
| raise | |
| except RuntimeError: | |
| # Transaction conflicts should propagate for retry | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error saving record: {e}") | |
| raise DatabaseOperationError(e) | |
| def _prepare_save_data(self) -> Dict[str, Any]: | |
| data = self.model_dump() | |
| return { | |
| key: value | |
| for key, value in data.items() | |
| if value is not None or key in self.__class__.nullable_fields | |
| } | |
| async def delete(self) -> bool: | |
| if self.id is None: | |
| raise InvalidInputError("Cannot delete object without an ID") | |
| try: | |
| logger.debug(f"Deleting record with id {self.id}") | |
| return await repo_delete(self.id) | |
| except Exception as e: | |
| logger.error( | |
| f"Error deleting {self.__class__.table_name} with id {self.id}: {str(e)}" | |
| ) | |
| raise DatabaseOperationError( | |
| f"Failed to delete {self.__class__.table_name}" | |
| ) | |
| async def relate( | |
| self, relationship: str, target_id: str, data: Optional[Dict] = {} | |
| ) -> Any: | |
| if not relationship or not target_id or not self.id: | |
| raise InvalidInputError("Relationship and target ID must be provided") | |
| try: | |
| return await repo_relate( | |
| source=self.id, relationship=relationship, target=target_id, data=data | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error creating relationship: {str(e)}") | |
| logger.exception(e) | |
| raise DatabaseOperationError(e) | |
| def parse_datetime(cls, value): | |
| if isinstance(value, str): | |
| return datetime.fromisoformat(value.replace("Z", "+00:00")) | |
| return value | |
| class RecordModel(BaseModel): | |
| record_id: ClassVar[str] | |
| auto_save: ClassVar[bool] = ( | |
| False # Default to False, can be overridden in subclasses | |
| ) | |
| _instances: ClassVar[Dict[str, "RecordModel"]] = {} # Store instances by record_id | |
| class Config: | |
| validate_assignment = True | |
| arbitrary_types_allowed = True | |
| extra = "allow" | |
| from_attributes = True | |
| defer_build = True | |
| def __new__(cls, **kwargs): | |
| # If an instance already exists for this record_id, return it | |
| if cls.record_id in cls._instances: | |
| instance = cls._instances[cls.record_id] | |
| # Update instance with any new kwargs if provided | |
| if kwargs: | |
| for key, value in kwargs.items(): | |
| setattr(instance, key, value) | |
| return instance | |
| # If no instance exists, create a new one | |
| instance = super().__new__(cls) | |
| cls._instances[cls.record_id] = instance | |
| return instance | |
| def __init__(self, **kwargs): | |
| # Only initialize if this is a new instance | |
| if not hasattr(self, "_initialized"): | |
| object.__setattr__(self, "__dict__", {}) | |
| # For RecordModel, we need to handle async initialization differently | |
| # Initialize with provided kwargs only for now | |
| super().__init__(**kwargs) | |
| # Mark as initialized but not loaded from DB yet | |
| object.__setattr__(self, "_initialized", True) | |
| object.__setattr__(self, "_db_loaded", False) | |
| async def _load_from_db(self): | |
| """Load data from database if not already loaded""" | |
| if not getattr(self, "_db_loaded", False): | |
| result = await repo_query( | |
| "SELECT * FROM ONLY $record_id", | |
| {"record_id": ensure_record_id(self.record_id)}, | |
| ) | |
| # Handle case where record doesn't exist yet | |
| if result: | |
| if isinstance(result, list) and len(result) > 0: | |
| # Standard list response | |
| row = result[0] | |
| if isinstance(row, dict): | |
| for key, value in row.items(): | |
| if hasattr(self, key): | |
| object.__setattr__(self, key, value) | |
| elif isinstance(result, dict): | |
| # Direct dict response | |
| for key, value in result.items(): | |
| if hasattr(self, key): | |
| object.__setattr__(self, key, value) | |
| object.__setattr__(self, "_db_loaded", True) | |
| async def get_instance(cls) -> "RecordModel": | |
| """Get or create the singleton instance and load from DB""" | |
| instance = cls() | |
| await instance._load_from_db() | |
| return instance | |
| def auto_save_validator(self): | |
| if self.__class__.auto_save: | |
| # Auto-save can't work with async - log warning | |
| logger.warning( | |
| f"Auto-save is enabled for {self.__class__.__name__} but update() is now async. Call await instance.update() manually." | |
| ) | |
| return self | |
| async def update(self): | |
| # Get all non-ClassVar fields and their values | |
| data = { | |
| field_name: getattr(self, field_name) | |
| for field_name, field_info in self.model_fields.items() | |
| if not str(field_info.annotation).startswith("typing.ClassVar") | |
| } | |
| await repo_upsert( | |
| self.__class__.table_name | |
| if hasattr(self.__class__, "table_name") | |
| else "record", | |
| self.record_id, | |
| data, | |
| ) | |
| result = await repo_query( | |
| "SELECT * FROM $record_id", {"record_id": ensure_record_id(self.record_id)} | |
| ) | |
| if result: | |
| for key, value in result[0].items(): | |
| if hasattr(self, key): | |
| object.__setattr__( | |
| self, key, value | |
| ) # Use object.__setattr__ to avoid triggering validation again | |
| return self | |
| def clear_instance(cls): | |
| """Clear the singleton instance (useful for testing)""" | |
| if cls.record_id in cls._instances: | |
| del cls._instances[cls.record_id] | |
| async def patch(self, model_dict: dict): | |
| """Update model attributes from dictionary and save""" | |
| for key, value in model_dict.items(): | |
| setattr(self, key, value) | |
| await self.update() | |