GradioApps / vector_db.py
nrigheriu's picture
added app files
869f31e verified
from qdrant_client import QdrantClient
from qdrant_client.models import VectorParams, Distance, PointStruct
class QdrantStorage:
def __init__(self, path="./qdrant_storage", collection="docs", dim=3072):
# Use local mode - this will use your existing data
self.client = QdrantClient(path=path)
self.collection = collection
if not self.client.collection_exists(self.collection):
self.client.create_collection(
collection_name=self.collection,
vectors_config=VectorParams(size=dim, distance=Distance.COSINE),
)
def upsert(self, ids, vectors, payloads):
points = [PointStruct(id=ids[i], vector=vectors[i], payload=payloads[i]) for i in range(len(ids))]
self.client.upsert(self.collection, points=points)
def search(self, query_vector, top_k: int = 5, source_filter: str = None):
from qdrant_client.models import Filter, FieldCondition, MatchValue
# If source_filter is provided, only search within that source
if source_filter:
results = self.client.search(
collection_name=self.collection,
query_vector=query_vector,
query_filter=Filter(
must=[
FieldCondition(
key="source",
match=MatchValue(value=source_filter)
)
]
),
with_payload=True,
limit=top_k
)
else:
# Search across all sources
results = self.client.search(
collection_name=self.collection,
query_vector=query_vector,
with_payload=True,
limit=top_k
)
contexts = []
sources = set()
for r in results:
payload = getattr(r, "payload", None) or {}
text = payload.get("text", "")
source = payload.get("source", "")
if text:
contexts.append(text)
sources.add(source)
return {"contexts": contexts, "sources": list(sources)}