File size: 6,151 Bytes
cd46ce5
 
 
 
a7ed1bd
cd46ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7ed1bd
cd46ce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""
Database utilities for ClipboardHealthAI application.
"""

from datetime import datetime, time, date
from enum import Enum
import re
from typing import Any, Dict, Type, TypeVar

from bson import ObjectId
from pydantic import AnyUrl, BaseModel, Field, GetCoreSchemaHandler
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema

T = TypeVar("T", bound=BaseModel)


class PyObjectId:
    """
    Custom type for handling MongoDB ObjectId in Pydantic models.
    """

    @classmethod
    def __get_pydantic_core_schema__(
        cls, _source: type, _handler: GetCoreSchemaHandler
    ) -> core_schema.CoreSchema:
        """
        Define the core schema for Pydantic validation.
        """
        return core_schema.with_info_after_validator_function(
            cls.validate, core_schema.str_schema()  # type: ignore
        )

    @classmethod
    def __get_pydantic_json_schema__(
        cls, _schema: core_schema.CoreSchema, _handler: GetCoreSchemaHandler
    ) -> JsonSchemaValue:
        """
        Define the JSON schema representation.
        """
        return {"type": "string"}

    @classmethod
    def validate(cls, value: str) -> ObjectId:
        """
        Validate and convert a string to MongoDB ObjectId.
        """
        if not ObjectId.is_valid(value):
            raise ValueError(f"Invalid ObjectId: {value}")
        return ObjectId(value)

    def __getattr__(self, item):
        """
        Delegate attribute access to the wrapped ObjectId.
        """
        return getattr(self.__dict__["value"], item)

    def __init__(self, value: str | None = None):
        """
        Initialize with a string value or create a new ObjectId.
        """
        if value is None:
            self.value = ObjectId()
        else:
            self.value = self.validate(value)

    def __str__(self):
        """
        Convert to string representation.
        """
        return str(self.value)


class MongoBaseModel(BaseModel):
    """
    Base model for MongoDB documents with serialization support.
    """

    id: str = Field(default_factory=lambda: str(PyObjectId()))

    class Config:  # pylint: disable=R0903
        """
        Configuration for the model.
        """

        arbitrary_types_allowed = True
        extra = "ignore"
        populate_by_name = True

    @staticmethod
    def serialize_s3_url(value: Any) -> Any:
        """
        Serialize an S3 URL.
        """
        if (
            value
            and isinstance(value, str)
            and "AWSAccessKeyId" in value
            and "Expires" in value
        ):
            match = re.search(r"s3\.amazonaws\.com/([^?]+)", value)
            if match:
                return match.group(1)
        return value

    def to_mongo(self) -> Dict[str, Any]:
        """
        Convert the model instance to a MongoDB-compatible dictionary.
        """

        def model_to_dict(model: BaseModel) -> Dict[str, Any]:
            doc = {}
            for name in model.__fields__.keys():
                value = getattr(model, name)
                key = model.__fields__[name].alias or name

                if isinstance(value, BaseModel):
                    doc[key] = model_to_dict(value)
                elif isinstance(value, list) and all(
                    isinstance(i, BaseModel) for i in value
                ):
                    doc[key] = [model_to_dict(item) for item in value]  # type: ignore
                elif value and isinstance(value, Enum):
                    doc[key] = value.value
                elif isinstance(value, (datetime, time, date)):
                    doc[key] = value.isoformat()  # type: ignore
                elif value and isinstance(value, AnyUrl):
                    doc[key] = str(value)  # type: ignore
                else:
                    doc[key] = self.serialize_s3_url(value)

            return doc

        result = model_to_dict(self)
        return result

    @classmethod
    def from_mongo(cls, data: Dict[str, Any]):
        """
        Create a model instance from MongoDB document data.
        """

        def restore_enums(inst: Any, model_cls: Type[BaseModel]) -> None:
            for name, field in model_cls.__fields__.items():  # type: ignore
                value = getattr(inst, name)
                if (
                    field
                    and isinstance(field.annotation, type)
                    and issubclass(field.annotation, Enum)
                ):
                    setattr(inst, name, field.annotation(value))
                elif isinstance(value, BaseModel):
                    restore_enums(value, value.__class__)
                elif isinstance(value, list):
                    for i, item in enumerate(value):
                        if isinstance(item, BaseModel):
                            restore_enums(item, item.__class__)
                        elif isinstance(field.annotation, type) and issubclass(
                            field.annotation, Enum
                        ):
                            value[i] = field.annotation(item)
                elif isinstance(value, dict):
                    for k, v in value.items():
                        if isinstance(v, BaseModel):
                            restore_enums(v, v.__class__)
                        elif isinstance(field.annotation, type) and issubclass(
                            field.annotation, Enum
                        ):
                            value[k] = field.annotation(v)

        if data is None:
            return None
        instance = cls(**data)
        restore_enums(instance, instance.__class__)
        return instance


class MongoBaseShortenModel(BaseModel):
    """
    Base model for MongoDB documents with serialization support.
    """

    id: str

    @classmethod
    def to_mongo_fields(self) -> dict:
        result = {field: 1 for field in self.__annotations__ if field != "_id"}
        result["_id"] = 0
        result["id"] = 1
        return result

    @classmethod
    def from_mongo(cls, mongo_obj: Dict[str, Any]) -> T:
        return cls(**mongo_obj)