Spaces:
Runtime error
Runtime error
| import uuid | |
| from abc import ABC | |
| from typing import Generic, Type, TypeVar | |
| from loguru import logger | |
| from pydantic import UUID4, BaseModel, Field | |
| from pymongo import errors | |
| from llm_engineering.domain.exceptions import ImproperlyConfigured | |
| from llm_engineering.infrastructure.db.mongo import connection | |
| from llm_engineering.settings import settings | |
| _database = connection.get_database(settings.DATABASE_NAME) | |
| T = TypeVar("T", bound="NoSQLBaseDocument") | |
| class NoSQLBaseDocument(BaseModel, Generic[T], ABC): | |
| id: UUID4 = Field(default_factory=uuid.uuid4) | |
| def __eq__(self, value: object) -> bool: | |
| if not isinstance(value, self.__class__): | |
| return False | |
| return self.id == value.id | |
| def __hash__(self) -> int: | |
| return hash(self.id) | |
| def from_mongo(cls: Type[T], data: dict) -> T: | |
| """Convert "_id" (str object) into "id" (UUID object).""" | |
| if not data: | |
| raise ValueError("Data is empty.") | |
| id = data.pop("_id") | |
| return cls(**dict(data, id=id)) | |
| def to_mongo(self: T, **kwargs) -> dict: | |
| """Convert "id" (UUID object) into "_id" (str object).""" | |
| exclude_unset = kwargs.pop("exclude_unset", False) | |
| by_alias = kwargs.pop("by_alias", True) | |
| parsed = self.model_dump(exclude_unset=exclude_unset, by_alias=by_alias, **kwargs) | |
| if "_id" not in parsed and "id" in parsed: | |
| parsed["_id"] = str(parsed.pop("id")) | |
| for key, value in parsed.items(): | |
| if isinstance(value, uuid.UUID): | |
| parsed[key] = str(value) | |
| return parsed | |
| def model_dump(self: T, **kwargs) -> dict: | |
| dict_ = super().model_dump(**kwargs) | |
| for key, value in dict_.items(): | |
| if isinstance(value, uuid.UUID): | |
| dict_[key] = str(value) | |
| return dict_ | |
| def save(self: T, **kwargs) -> T | None: | |
| collection = _database[self.get_collection_name()] | |
| try: | |
| collection.insert_one(self.to_mongo(**kwargs)) | |
| return self | |
| except errors.WriteError: | |
| logger.exception("Failed to insert document.") | |
| return None | |
| def get_or_create(cls: Type[T], **filter_options) -> T: | |
| collection = _database[cls.get_collection_name()] | |
| try: | |
| instance = collection.find_one(filter_options) | |
| if instance: | |
| return cls.from_mongo(instance) | |
| new_instance = cls(**filter_options) | |
| new_instance = new_instance.save() | |
| return new_instance | |
| except errors.OperationFailure: | |
| logger.exception(f"Failed to retrieve document with filter options: {filter_options}") | |
| raise | |
| def bulk_insert(cls: Type[T], documents: list[T], **kwargs) -> bool: | |
| collection = _database[cls.get_collection_name()] | |
| try: | |
| collection.insert_many(doc.to_mongo(**kwargs) for doc in documents) | |
| return True | |
| except (errors.WriteError, errors.BulkWriteError): | |
| logger.error(f"Failed to insert documents of type {cls.__name__}") | |
| return False | |
| def find(cls: Type[T], **filter_options) -> T | None: | |
| collection = _database[cls.get_collection_name()] | |
| try: | |
| instance = collection.find_one(filter_options) | |
| if instance: | |
| return cls.from_mongo(instance) | |
| return None | |
| except errors.OperationFailure: | |
| logger.error("Failed to retrieve document") | |
| return None | |
| def bulk_find(cls: Type[T], **filter_options) -> list[T]: | |
| collection = _database[cls.get_collection_name()] | |
| try: | |
| instances = collection.find(filter_options) | |
| return [document for instance in instances if (document := cls.from_mongo(instance)) is not None] | |
| except errors.OperationFailure: | |
| logger.error("Failed to retrieve documents") | |
| return [] | |
| def get_collection_name(cls: Type[T]) -> str: | |
| if not hasattr(cls, "Settings") or not hasattr(cls.Settings, "name"): | |
| raise ImproperlyConfigured( | |
| "Document should define an Settings configuration class with the name of the collection." | |
| ) | |
| return cls.Settings.name | |