Add app\db\milvus.py
Browse files- app//db//milvus.py +117 -0
app//db//milvus.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import asyncio
|
| 3 |
+
import re
|
| 4 |
+
from typing import List, Dict, Any, Optional
|
| 5 |
+
|
| 6 |
+
from pymilvus import (
|
| 7 |
+
connections,
|
| 8 |
+
utility,
|
| 9 |
+
FieldSchema,
|
| 10 |
+
CollectionSchema,
|
| 11 |
+
DataType,
|
| 12 |
+
Collection
|
| 13 |
+
)
|
| 14 |
+
from app.core.config import settings
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger("rm_research.db.milvus")
|
| 17 |
+
|
| 18 |
+
class MilvusVectorDB:
|
| 19 |
+
"""
|
| 20 |
+
Institutional Scale Vector Intelligence Layer.
|
| 21 |
+
Optimized for high-recall academic searches with non-blocking I/O
|
| 22 |
+
and strict input sanitization to prevent expression injection.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
self.collection_name = "academic_knowledge_corpus"
|
| 27 |
+
self.dim = 768 # Tuned for scholarly transformer embeddings
|
| 28 |
+
self.alias = "default"
|
| 29 |
+
# Regex to ensure IDs are alphanumeric or standard UUID/Slug formats
|
| 30 |
+
self._sanitizer = re.compile(r"^[a-zA-Z0-9_\-]+$")
|
| 31 |
+
|
| 32 |
+
async def connect(self):
|
| 33 |
+
"""Establishes thread-safe connection to Milvus cluster."""
|
| 34 |
+
loop = asyncio.get_running_loop()
|
| 35 |
+
try:
|
| 36 |
+
if not connections.has_connection(self.alias):
|
| 37 |
+
await loop.run_in_executor(
|
| 38 |
+
None,
|
| 39 |
+
lambda: connections.connect(
|
| 40 |
+
alias=self.alias,
|
| 41 |
+
host=settings.MILVUS_HOST,
|
| 42 |
+
port=settings.MILVUS_PORT,
|
| 43 |
+
user=settings.MILVUS_USER,
|
| 44 |
+
password=settings.MILVUS_PASSWORD,
|
| 45 |
+
secure=True,
|
| 46 |
+
timeout=30
|
| 47 |
+
)
|
| 48 |
+
)
|
| 49 |
+
logger.info(f"Connected to Milvus: {settings.MILVUS_HOST}")
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.critical(f"Milvus Auth Failure: {str(e)}")
|
| 52 |
+
raise
|
| 53 |
+
|
| 54 |
+
async def search_ann(
|
| 55 |
+
self,
|
| 56 |
+
query_vector: List[float],
|
| 57 |
+
limit: int = 10,
|
| 58 |
+
institution_id: Optional[str] = None,
|
| 59 |
+
disciplines: Optional[List[str]] = None
|
| 60 |
+
) -> List[Dict[str, Any]]:
|
| 61 |
+
"""
|
| 62 |
+
Executes Secure Approximate Nearest Neighbor (ANN) search.
|
| 63 |
+
Includes a whitelist-based filter builder to prevent injection attacks.
|
| 64 |
+
"""
|
| 65 |
+
await self.connect()
|
| 66 |
+
collection = Collection(self.collection_name)
|
| 67 |
+
loop = asyncio.get_running_loop()
|
| 68 |
+
|
| 69 |
+
# 1. Build & Sanitize Expression (Security Fix)
|
| 70 |
+
filters = []
|
| 71 |
+
|
| 72 |
+
if institution_id:
|
| 73 |
+
if self._sanitizer.match(institution_id):
|
| 74 |
+
filters.append(f"attributes['institution_id'] == '{institution_id}'")
|
| 75 |
+
else:
|
| 76 |
+
logger.warning(f"Sanitization block: Invalid institution_id '{institution_id}'")
|
| 77 |
+
|
| 78 |
+
if disciplines:
|
| 79 |
+
valid_dis = [d for d in disciplines if self._sanitizer.match(d)]
|
| 80 |
+
if valid_dis:
|
| 81 |
+
filters.append(f"attributes['discipline'] in {valid_dis}")
|
| 82 |
+
|
| 83 |
+
expr = " and ".join(filters) if filters else None
|
| 84 |
+
|
| 85 |
+
# 2. Execute Search in Executor
|
| 86 |
+
results = await loop.run_in_executor(
|
| 87 |
+
None,
|
| 88 |
+
lambda: collection.search(
|
| 89 |
+
data=[query_vector],
|
| 90 |
+
anns_field="embedding",
|
| 91 |
+
param={"metric_type": "COSINE", "params": {"ef": 128}},
|
| 92 |
+
limit=limit,
|
| 93 |
+
expr=expr,
|
| 94 |
+
output_fields=["paper_id", "attributes"]
|
| 95 |
+
)
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return [
|
| 99 |
+
{
|
| 100 |
+
"paper_id": hit.entity.get("paper_id"),
|
| 101 |
+
"score": round(1.0 - hit.distance, 4), # Normalized similarity
|
| 102 |
+
"metadata": hit.entity.get("attributes")
|
| 103 |
+
} for hit in results[0]
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
async def insert_batch(self, vectors: List[List[float]], ids: List[str], metadata: List[Dict]):
|
| 107 |
+
"""Ingest batch into Milvus and flush to disk for persistence."""
|
| 108 |
+
await self.connect()
|
| 109 |
+
collection = Collection(self.collection_name)
|
| 110 |
+
loop = asyncio.get_running_loop()
|
| 111 |
+
|
| 112 |
+
await loop.run_in_executor(None, lambda: collection.insert([ids, vectors, metadata]))
|
| 113 |
+
await loop.run_in_executor(None, collection.flush)
|
| 114 |
+
logger.info(f"Ingested {len(ids)} artifacts.")
|
| 115 |
+
|
| 116 |
+
# Singleton instance
|
| 117 |
+
milvus_db = MilvusVectorDB()
|