KnowledgeMesh / app /services /vector_store.py
pkheria's picture
psuhing to git
b5e0c74
Raw
History Blame Contribute Delete
3.1 kB
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, PointStruct, VectorParams
from app.core.config import settings
from app.core.models import Chunk, SearchResult
class QdrantVectorStore:
def __init__(self, collection_name: str | None = None):
self.collection_name = collection_name or settings.QDRANT_COLLECTION_NAME
self.client = QdrantClient(
url=settings.get_qdrant_url(),
api_key=settings.QDRANT_API_KEY or None,
timeout=60,
)
def ensure_collection(self, vector_size: int) -> None:
collections = self.client.get_collections().collections
exists = any(collection.name == self.collection_name for collection in collections)
if not exists:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
def upsert_chunks(self, chunks: list[Chunk], embeddings: list[list[float]]) -> None:
if len(chunks) != len(embeddings):
raise ValueError("Chunks and embeddings must have the same length.")
if not chunks:
return
self.ensure_collection(vector_size=len(embeddings[0]))
points = [
PointStruct(
id=chunk.id,
vector=embedding,
payload={
"text": chunk.text,
"chunk_index": chunk.index,
"source_type": chunk.source_type.value,
"source": chunk.source,
"title": chunk.title,
"metadata": chunk.metadata,
},
)
for chunk, embedding in zip(chunks, embeddings, strict=True)
]
self.client.upsert(collection_name=self.collection_name, points=points)
def search(self, query_embedding: list[float], limit: int = 5) -> list[SearchResult]:
if hasattr(self.client, "query_points"):
response = self.client.query_points(
collection_name=self.collection_name,
query=query_embedding,
limit=limit,
with_payload=True,
)
hits = response.points
else:
hits = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding,
limit=limit,
with_payload=True,
)
results: list[SearchResult] = []
for hit in hits:
payload = hit.payload or {}
results.append(
SearchResult(
score=float(hit.score),
text=str(payload.get("text", "")),
title=str(payload.get("title", "")),
source=str(payload.get("source", "")),
source_type=str(payload.get("source_type", "")),
metadata=dict(payload.get("metadata", {})),
)
)
return results