Bromeo777 commited on
Commit
96f5ddf
·
verified ·
1 Parent(s): 8f3856f

Add app\db\milvus.py

Browse files
Files changed (1) hide show
  1. app//db//milvus.py +117 -0
app//db//milvus.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import asyncio
3
+ import re
4
+ from typing import List, Dict, Any, Optional
5
+
6
+ from pymilvus import (
7
+ connections,
8
+ utility,
9
+ FieldSchema,
10
+ CollectionSchema,
11
+ DataType,
12
+ Collection
13
+ )
14
+ from app.core.config import settings
15
+
16
+ logger = logging.getLogger("rm_research.db.milvus")
17
+
18
+ class MilvusVectorDB:
19
+ """
20
+ Institutional Scale Vector Intelligence Layer.
21
+ Optimized for high-recall academic searches with non-blocking I/O
22
+ and strict input sanitization to prevent expression injection.
23
+ """
24
+
25
+ def __init__(self):
26
+ self.collection_name = "academic_knowledge_corpus"
27
+ self.dim = 768 # Tuned for scholarly transformer embeddings
28
+ self.alias = "default"
29
+ # Regex to ensure IDs are alphanumeric or standard UUID/Slug formats
30
+ self._sanitizer = re.compile(r"^[a-zA-Z0-9_\-]+$")
31
+
32
+ async def connect(self):
33
+ """Establishes thread-safe connection to Milvus cluster."""
34
+ loop = asyncio.get_running_loop()
35
+ try:
36
+ if not connections.has_connection(self.alias):
37
+ await loop.run_in_executor(
38
+ None,
39
+ lambda: connections.connect(
40
+ alias=self.alias,
41
+ host=settings.MILVUS_HOST,
42
+ port=settings.MILVUS_PORT,
43
+ user=settings.MILVUS_USER,
44
+ password=settings.MILVUS_PASSWORD,
45
+ secure=True,
46
+ timeout=30
47
+ )
48
+ )
49
+ logger.info(f"Connected to Milvus: {settings.MILVUS_HOST}")
50
+ except Exception as e:
51
+ logger.critical(f"Milvus Auth Failure: {str(e)}")
52
+ raise
53
+
54
+ async def search_ann(
55
+ self,
56
+ query_vector: List[float],
57
+ limit: int = 10,
58
+ institution_id: Optional[str] = None,
59
+ disciplines: Optional[List[str]] = None
60
+ ) -> List[Dict[str, Any]]:
61
+ """
62
+ Executes Secure Approximate Nearest Neighbor (ANN) search.
63
+ Includes a whitelist-based filter builder to prevent injection attacks.
64
+ """
65
+ await self.connect()
66
+ collection = Collection(self.collection_name)
67
+ loop = asyncio.get_running_loop()
68
+
69
+ # 1. Build & Sanitize Expression (Security Fix)
70
+ filters = []
71
+
72
+ if institution_id:
73
+ if self._sanitizer.match(institution_id):
74
+ filters.append(f"attributes['institution_id'] == '{institution_id}'")
75
+ else:
76
+ logger.warning(f"Sanitization block: Invalid institution_id '{institution_id}'")
77
+
78
+ if disciplines:
79
+ valid_dis = [d for d in disciplines if self._sanitizer.match(d)]
80
+ if valid_dis:
81
+ filters.append(f"attributes['discipline'] in {valid_dis}")
82
+
83
+ expr = " and ".join(filters) if filters else None
84
+
85
+ # 2. Execute Search in Executor
86
+ results = await loop.run_in_executor(
87
+ None,
88
+ lambda: collection.search(
89
+ data=[query_vector],
90
+ anns_field="embedding",
91
+ param={"metric_type": "COSINE", "params": {"ef": 128}},
92
+ limit=limit,
93
+ expr=expr,
94
+ output_fields=["paper_id", "attributes"]
95
+ )
96
+ )
97
+
98
+ return [
99
+ {
100
+ "paper_id": hit.entity.get("paper_id"),
101
+ "score": round(1.0 - hit.distance, 4), # Normalized similarity
102
+ "metadata": hit.entity.get("attributes")
103
+ } for hit in results[0]
104
+ ]
105
+
106
+ async def insert_batch(self, vectors: List[List[float]], ids: List[str], metadata: List[Dict]):
107
+ """Ingest batch into Milvus and flush to disk for persistence."""
108
+ await self.connect()
109
+ collection = Collection(self.collection_name)
110
+ loop = asyncio.get_running_loop()
111
+
112
+ await loop.run_in_executor(None, lambda: collection.insert([ids, vectors, metadata]))
113
+ await loop.run_in_executor(None, collection.flush)
114
+ logger.info(f"Ingested {len(ids)} artifacts.")
115
+
116
+ # Singleton instance
117
+ milvus_db = MilvusVectorDB()