stacklogix / app /retriever.py
Deploy Bot
Deployment commit
6ca2339
"""Qdrant vector retriever — handles embedding queries and searching."""
from urllib.parse import urlparse
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from sentence_transformers import SentenceTransformer
import uuid as uuid_lib
from app.config import (
QDRANT_URL,
QDRANT_API_KEY,
COLLECTION_NAME,
EMBEDDING_MODEL,
EMBEDDING_DIMENSION,
TOP_K,
)
class Retriever:
"""Wraps Qdrant for vector search operations."""
def __init__(self):
# Parse URL into host/port for qdrant_client (avoids default port 6333 issue)
parsed = urlparse(QDRANT_URL)
host = parsed.hostname or "localhost"
port = parsed.port or (443 if parsed.scheme == "https" else 80)
use_https = parsed.scheme == "https"
self.client = QdrantClient(
host=host,
port=port,
api_key=QDRANT_API_KEY if QDRANT_API_KEY else None,
prefer_grpc=False,
https=use_https,
timeout=30,
)
self.model = SentenceTransformer(EMBEDDING_MODEL)
def ensure_collection(self):
"""Create the collection if it doesn't exist."""
collections = [c.name for c in self.client.get_collections().collections]
if COLLECTION_NAME not in collections:
self.client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(
size=EMBEDDING_DIMENSION,
distance=Distance.COSINE,
),
)
print(f"Created collection: {COLLECTION_NAME}")
else:
print(f"Collection '{COLLECTION_NAME}' already exists.")
def embed_text(self, text: str) -> list[float]:
"""Embed a single text string."""
return self.model.encode(text).tolist()
def embed_texts(self, texts: list[str]) -> list[list[float]]:
"""Embed a batch of text strings."""
return self.model.encode(texts).tolist()
def upsert_chunks(self, chunks: list[dict]):
"""
Upsert document chunks into Qdrant.
Each chunk: {"text": str, "metadata": dict}
"""
if not chunks:
return
texts = [c["text"] for c in chunks]
embeddings = self.embed_texts(texts)
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
point_id = str(uuid_lib.uuid4())
payload = {**chunk["metadata"], "text": chunk["text"]}
points.append(
PointStruct(id=point_id, vector=embedding, payload=payload)
)
# Upsert in batches of 100
batch_size = 100
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
self.client.upsert(
collection_name=COLLECTION_NAME,
points=batch,
)
def search(self, query: str, top_k: int = TOP_K) -> list[dict]:
"""
Search for relevant chunks.
Returns list of {"text": str, "score": float, "metadata": dict}
"""
query_vector = self.embed_text(query)
results = self.client.search(
collection_name=COLLECTION_NAME,
query_vector=query_vector,
limit=top_k,
)
return [
{
"text": hit.payload.get("text", ""),
"score": hit.score,
"metadata": {
k: v for k, v in hit.payload.items() if k != "text"
},
}
for hit in results
]
def get_collection_info(self) -> dict:
"""Get information about the collection."""
try:
info = self.client.get_collection(COLLECTION_NAME)
return {
"name": COLLECTION_NAME,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value,
}
except Exception as e:
return {"error": str(e)}
# Global singleton — lazy loaded
_retriever = None
def get_retriever() -> Retriever:
global _retriever
if _retriever is None:
_retriever = Retriever()
return _retriever