Spaces:
Sleeping
Sleeping
File size: 3,064 Bytes
64d7fdf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 | from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct, Filter, FieldCondition, MatchValue
from app.config import config, settings
from app.utils.logger import logger
from typing import List
import uuid
class VectorStore:
def __init__(self):
self.client = None
self.collection_name = config["database"]["qdrant"]["collection_name"]
def connect(self):
if self.client is None:
qdrant_url = config["database"]["qdrant"]["url"]
api_key = settings.qdrant_api_key or None
self.client = QdrantClient(
url=qdrant_url,
api_key=api_key
)
logger.info("Qdrant connected")
return self.client
def create_collection(self, vector_size: int = None):
if vector_size is None:
vector_size = config["database"]["qdrant"]["vector_size"]
client = self.get_client()
if not client.collection_exists(self.collection_name):
client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=vector_size,
distance=Distance.COSINE
)
)
logger.info(f"Created Qdrant collection: {self.collection_name}")
else:
logger.info(f"Qdrant collection already exists: {self.collection_name}")
def get_client(self):
if self.client is None:
self.connect()
return self.client
async def add_documents(self, collection_name: str, documents: List, embeddings: List[List[float]]):
client = self.get_client()
points = []
for i, (doc, embedding) in enumerate(zip(documents, embeddings)):
point_id = str(uuid.uuid4())
points.append(
PointStruct(
id=point_id,
vector=embedding,
payload={
"text": doc.page_content,
**doc.metadata
}
)
)
client.upsert(
collection_name=collection_name,
points=points
)
logger.info(f"Added {len(points)} documents to Qdrant")
return [p.id for p in points]
async def delete_by_metadata(self, collection_name: str, metadata_key: str, metadata_value: str):
client = self.get_client()
client.delete(
collection_name=collection_name,
points_selector=Filter(
must=[
FieldCondition(
key=metadata_key,
match=MatchValue(value=metadata_value)
)
]
)
)
logger.info(f"Deleted documents with {metadata_key}={metadata_value} from Qdrant")
vector_store = VectorStore()
|