Spaces:
Sleeping
Sleeping
File size: 4,322 Bytes
6ca2339 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 | """Qdrant vector retriever — handles embedding queries and searching."""
from urllib.parse import urlparse
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from sentence_transformers import SentenceTransformer
import uuid as uuid_lib
from app.config import (
QDRANT_URL,
QDRANT_API_KEY,
COLLECTION_NAME,
EMBEDDING_MODEL,
EMBEDDING_DIMENSION,
TOP_K,
)
class Retriever:
"""Wraps Qdrant for vector search operations."""
def __init__(self):
# Parse URL into host/port for qdrant_client (avoids default port 6333 issue)
parsed = urlparse(QDRANT_URL)
host = parsed.hostname or "localhost"
port = parsed.port or (443 if parsed.scheme == "https" else 80)
use_https = parsed.scheme == "https"
self.client = QdrantClient(
host=host,
port=port,
api_key=QDRANT_API_KEY if QDRANT_API_KEY else None,
prefer_grpc=False,
https=use_https,
timeout=30,
)
self.model = SentenceTransformer(EMBEDDING_MODEL)
def ensure_collection(self):
"""Create the collection if it doesn't exist."""
collections = [c.name for c in self.client.get_collections().collections]
if COLLECTION_NAME not in collections:
self.client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=VectorParams(
size=EMBEDDING_DIMENSION,
distance=Distance.COSINE,
),
)
print(f"Created collection: {COLLECTION_NAME}")
else:
print(f"Collection '{COLLECTION_NAME}' already exists.")
def embed_text(self, text: str) -> list[float]:
"""Embed a single text string."""
return self.model.encode(text).tolist()
def embed_texts(self, texts: list[str]) -> list[list[float]]:
"""Embed a batch of text strings."""
return self.model.encode(texts).tolist()
def upsert_chunks(self, chunks: list[dict]):
"""
Upsert document chunks into Qdrant.
Each chunk: {"text": str, "metadata": dict}
"""
if not chunks:
return
texts = [c["text"] for c in chunks]
embeddings = self.embed_texts(texts)
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
point_id = str(uuid_lib.uuid4())
payload = {**chunk["metadata"], "text": chunk["text"]}
points.append(
PointStruct(id=point_id, vector=embedding, payload=payload)
)
# Upsert in batches of 100
batch_size = 100
for i in range(0, len(points), batch_size):
batch = points[i : i + batch_size]
self.client.upsert(
collection_name=COLLECTION_NAME,
points=batch,
)
def search(self, query: str, top_k: int = TOP_K) -> list[dict]:
"""
Search for relevant chunks.
Returns list of {"text": str, "score": float, "metadata": dict}
"""
query_vector = self.embed_text(query)
results = self.client.search(
collection_name=COLLECTION_NAME,
query_vector=query_vector,
limit=top_k,
)
return [
{
"text": hit.payload.get("text", ""),
"score": hit.score,
"metadata": {
k: v for k, v in hit.payload.items() if k != "text"
},
}
for hit in results
]
def get_collection_info(self) -> dict:
"""Get information about the collection."""
try:
info = self.client.get_collection(COLLECTION_NAME)
return {
"name": COLLECTION_NAME,
"vectors_count": info.vectors_count,
"points_count": info.points_count,
"status": info.status.value,
}
except Exception as e:
return {"error": str(e)}
# Global singleton — lazy loaded
_retriever = None
def get_retriever() -> Retriever:
global _retriever
if _retriever is None:
_retriever = Retriever()
return _retriever
|