File size: 10,316 Bytes
cfb0fa4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import logging
from typing import Optional, Tuple, List, Dict, Any

from open_webui.config import (
    MILVUS_URI,
    MILVUS_TOKEN,
    MILVUS_DB,
    MILVUS_COLLECTION_PREFIX,
    MILVUS_INDEX_TYPE,
    MILVUS_METRIC_TYPE,
    MILVUS_HNSW_M,
    MILVUS_HNSW_EFCONSTRUCTION,
    MILVUS_IVF_FLAT_NLIST,
)
from open_webui.retrieval.vector.main import (
    GetResult,
    SearchResult,
    VectorDBBase,
    VectorItem,
)
from pymilvus import (
    connections,
    utility,
    Collection,
    CollectionSchema,
    FieldSchema,
    DataType,
)

log = logging.getLogger(__name__)

RESOURCE_ID_FIELD = "resource_id"


class MilvusClient(VectorDBBase):
    def __init__(self):
        # Milvus collection names can only contain numbers, letters, and underscores.
        self.collection_prefix = MILVUS_COLLECTION_PREFIX.replace("-", "_")
        connections.connect(
            alias="default",
            uri=MILVUS_URI,
            token=MILVUS_TOKEN,
            db_name=MILVUS_DB,
        )

        # Main collection types for multi-tenancy
        self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories"
        self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge"
        self.FILE_COLLECTION = f"{self.collection_prefix}_files"
        self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web_search"
        self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash_based"
        self.shared_collections = [
            self.MEMORY_COLLECTION,
            self.KNOWLEDGE_COLLECTION,
            self.FILE_COLLECTION,
            self.WEB_SEARCH_COLLECTION,
            self.HASH_BASED_COLLECTION,
        ]

    def _get_collection_and_resource_id(self, collection_name: str) -> Tuple[str, str]:
        """
        Maps the traditional collection name to multi-tenant collection and resource ID.

        WARNING: This mapping relies on current Open WebUI naming conventions for
        collection names. If Open WebUI changes how it generates collection names
        (e.g., "user-memory-" prefix, "file-" prefix, web search patterns, or hash
        formats), this mapping will break and route data to incorrect collections.
        POTENTIALLY CAUSING HUGE DATA CORRUPTION, DATA CONSISTENCY ISSUES AND INCORRECT
        DATA MAPPING INSIDE THE DATABASE.
        """
        resource_id = collection_name

        if collection_name.startswith("user-memory-"):
            return self.MEMORY_COLLECTION, resource_id
        elif collection_name.startswith("file-"):
            return self.FILE_COLLECTION, resource_id
        elif collection_name.startswith("web-search-"):
            return self.WEB_SEARCH_COLLECTION, resource_id
        elif len(collection_name) == 63 and all(
            c in "0123456789abcdef" for c in collection_name
        ):
            return self.HASH_BASED_COLLECTION, resource_id
        else:
            return self.KNOWLEDGE_COLLECTION, resource_id

    def _create_shared_collection(self, mt_collection_name: str, dimension: int):
        fields = [
            FieldSchema(
                name="id",
                dtype=DataType.VARCHAR,
                is_primary=True,
                auto_id=False,
                max_length=36,
            ),
            FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
            FieldSchema(name="text", dtype=DataType.VARCHAR, max_length=65535),
            FieldSchema(name="metadata", dtype=DataType.JSON),
            FieldSchema(name=RESOURCE_ID_FIELD, dtype=DataType.VARCHAR, max_length=255),
        ]
        schema = CollectionSchema(fields, "Shared collection for multi-tenancy")
        collection = Collection(mt_collection_name, schema)

        index_params = {
            "metric_type": MILVUS_METRIC_TYPE,
            "index_type": MILVUS_INDEX_TYPE,
            "params": {},
        }
        if MILVUS_INDEX_TYPE == "HNSW":
            index_params["params"] = {
                "M": MILVUS_HNSW_M,
                "efConstruction": MILVUS_HNSW_EFCONSTRUCTION,
            }
        elif MILVUS_INDEX_TYPE == "IVF_FLAT":
            index_params["params"] = {"nlist": MILVUS_IVF_FLAT_NLIST}

        collection.create_index("vector", index_params)
        collection.create_index(RESOURCE_ID_FIELD)
        log.info(f"Created shared collection: {mt_collection_name}")
        return collection

    def _ensure_collection(self, mt_collection_name: str, dimension: int):
        if not utility.has_collection(mt_collection_name):
            self._create_shared_collection(mt_collection_name, dimension)

    def has_collection(self, collection_name: str) -> bool:
        mt_collection, resource_id = self._get_collection_and_resource_id(
            collection_name
        )
        if not utility.has_collection(mt_collection):
            return False

        collection = Collection(mt_collection)
        collection.load()
        res = collection.query(expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'", limit=1)
        return len(res) > 0

    def upsert(self, collection_name: str, items: List[VectorItem]):
        if not items:
            return
        mt_collection, resource_id = self._get_collection_and_resource_id(
            collection_name
        )
        dimension = len(items[0]["vector"])
        self._ensure_collection(mt_collection, dimension)
        collection = Collection(mt_collection)

        entities = [
            {
                "id": item["id"],
                "vector": item["vector"],
                "text": item["text"],
                "metadata": item["metadata"],
                RESOURCE_ID_FIELD: resource_id,
            }
            for item in items
        ]
        collection.insert(entities)

    def search(
        self,
        collection_name: str,
        vectors: List[List[float]],
        filter: Optional[Dict] = None,
        limit: int = 10,
    ) -> Optional[SearchResult]:
        if not vectors:
            return None

        mt_collection, resource_id = self._get_collection_and_resource_id(
            collection_name
        )
        if not utility.has_collection(mt_collection):
            return None

        collection = Collection(mt_collection)
        collection.load()

        search_params = {"metric_type": MILVUS_METRIC_TYPE, "params": {}}
        results = collection.search(
            data=vectors,
            anns_field="vector",
            param=search_params,
            limit=limit,
            expr=f"{RESOURCE_ID_FIELD} == '{resource_id}'",
            output_fields=["id", "text", "metadata"],
        )

        ids, documents, metadatas, distances = [], [], [], []
        for hits in results:
            batch_ids, batch_docs, batch_metadatas, batch_dists = [], [], [], []
            for hit in hits:
                batch_ids.append(hit.entity.get("id"))
                batch_docs.append(hit.entity.get("text"))
                batch_metadatas.append(hit.entity.get("metadata"))
                batch_dists.append(hit.distance)
            ids.append(batch_ids)
            documents.append(batch_docs)
            metadatas.append(batch_metadatas)
            distances.append(batch_dists)

        return SearchResult(
            ids=ids, documents=documents, metadatas=metadatas, distances=distances
        )

    def delete(
        self,
        collection_name: str,
        ids: Optional[List[str]] = None,
        filter: Optional[Dict[str, Any]] = None,
    ):
        mt_collection, resource_id = self._get_collection_and_resource_id(
            collection_name
        )
        if not utility.has_collection(mt_collection):
            return

        collection = Collection(mt_collection)

        # Build expression
        expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
        if ids:
            # Milvus expects a string list for 'in' operator
            id_list_str = ", ".join([f"'{id_val}'" for id_val in ids])
            expr.append(f"id in [{id_list_str}]")

        if filter:
            for key, value in filter.items():
                expr.append(f"metadata['{key}'] == '{value}'")

        collection.delete(" and ".join(expr))

    def reset(self):
        for collection_name in self.shared_collections:
            if utility.has_collection(collection_name):
                utility.drop_collection(collection_name)

    def delete_collection(self, collection_name: str):
        mt_collection, resource_id = self._get_collection_and_resource_id(
            collection_name
        )
        if not utility.has_collection(mt_collection):
            return

        collection = Collection(mt_collection)
        collection.delete(f"{RESOURCE_ID_FIELD} == '{resource_id}'")

    def query(
        self, collection_name: str, filter: Dict[str, Any], limit: Optional[int] = None
    ) -> Optional[GetResult]:
        mt_collection, resource_id = self._get_collection_and_resource_id(
            collection_name
        )
        if not utility.has_collection(mt_collection):
            return None

        collection = Collection(mt_collection)
        collection.load()

        expr = [f"{RESOURCE_ID_FIELD} == '{resource_id}'"]
        if filter:
            for key, value in filter.items():
                if isinstance(value, str):
                    expr.append(f"metadata['{key}'] == '{value}'")
                else:
                    expr.append(f"metadata['{key}'] == {value}")

        iterator = collection.query_iterator(
            expr=" and ".join(expr),
            output_fields=["id", "text", "metadata"],
            limit=limit if limit else -1,
        )

        all_results = []
        while True:
            batch = iterator.next()
            if not batch:
                iterator.close()
                break
            all_results.extend(batch)

        ids = [res["id"] for res in all_results]
        documents = [res["text"] for res in all_results]
        metadatas = [res["metadata"] for res in all_results]

        return GetResult(ids=[ids], documents=[documents], metadatas=[metadatas])

    def get(self, collection_name: str) -> Optional[GetResult]:
        return self.query(collection_name, filter={}, limit=None)

    def insert(self, collection_name: str, items: List[VectorItem]):
        return self.upsert(collection_name, items)