stochastic / vector_store.py
Sonu Prasad
fix: update search API to query_points for qdrant-client 1.9+
e252af5
from typing import Optional
from dataclasses import dataclass
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.exceptions import UnexpectedResponse
from sentence_transformers import SentenceTransformer
import numpy as np
from config import config
@dataclass
class DocumentChunk:
chunk_id: str
paper_id: str
paper_name: str
content: str
section_title: str = ""
subsection_title: str = ""
@dataclass
class SearchResult:
chunk: DocumentChunk
score: float
rank: int
class QdrantVectorStore:
VECTOR_SIZE = 384
MAX_VECTORS_FREE_TIER = 1000000
def __init__(self):
self.client: Optional[QdrantClient] = None
self.model: Optional[SentenceTransformer] = None
self._initialize()
def _initialize(self):
if config.QDRANT_URL and config.QDRANT_API_KEY:
self.client = QdrantClient(
url=config.QDRANT_URL,
api_key=config.QDRANT_API_KEY
)
self._ensure_collection()
self.model = SentenceTransformer(config.EMBEDDING_MODEL)
def _ensure_collection(self):
collection_exists = False
try:
self.client.get_collection(config.QDRANT_COLLECTION)
collection_exists = True
except (UnexpectedResponse, Exception):
self.client.create_collection(
collection_name=config.QDRANT_COLLECTION,
vectors_config=models.VectorParams(
size=self.VECTOR_SIZE,
distance=models.Distance.COSINE
)
)
try:
self.client.create_payload_index(
collection_name=config.QDRANT_COLLECTION,
field_name="paper_name",
field_schema=models.PayloadSchemaType.KEYWORD
)
except Exception:
pass
def _check_and_cleanup_if_needed(self):
if not self.client:
return
try:
info = self.client.get_collection(config.QDRANT_COLLECTION)
if info.points_count >= self.MAX_VECTORS_FREE_TIER * 0.9:
self.client.delete_collection(config.QDRANT_COLLECTION)
self._ensure_collection()
print("Qdrant collection reset due to approaching limit")
except Exception as e:
print(f"Error checking collection: {e}")
def add_chunks(self, chunks: list[DocumentChunk]) -> int:
if not chunks or not self.client:
return 0
self._check_and_cleanup_if_needed()
texts = [c.content for c in chunks]
embeddings = self.model.encode(texts, normalize_embeddings=True)
points = []
for i, chunk in enumerate(chunks):
points.append(models.PointStruct(
id=hash(chunk.chunk_id) % (2**63),
vector=embeddings[i].tolist(),
payload={
"chunk_id": chunk.chunk_id,
"paper_id": chunk.paper_id,
"paper_name": chunk.paper_name,
"content": chunk.content,
"section_title": chunk.section_title,
"subsection_title": chunk.subsection_title
}
))
self.client.upsert(
collection_name=config.QDRANT_COLLECTION,
points=points
)
return len(chunks)
def search(self, query: str, top_k: Optional[int] = None, paper_filter: Optional[str] = None) -> list[SearchResult]:
if not self.client:
return []
top_k = top_k or config.TOP_K_CHUNKS
query_embedding = self.model.encode(query, normalize_embeddings=True)
filter_condition = None
if paper_filter:
filter_condition = models.Filter(
must=[models.FieldCondition(
key="paper_name",
match=models.MatchValue(value=paper_filter)
)]
)
results = self.client.query_points(
collection_name=config.QDRANT_COLLECTION,
query=query_embedding.tolist(),
query_filter=filter_condition,
limit=top_k
)
search_results = []
for i, hit in enumerate(results.points):
chunk = DocumentChunk(
chunk_id=hit.payload["chunk_id"],
paper_id=hit.payload["paper_id"],
paper_name=hit.payload["paper_name"],
content=hit.payload["content"],
section_title=hit.payload.get("section_title", ""),
subsection_title=hit.payload.get("subsection_title", "")
)
search_results.append(SearchResult(chunk=chunk, score=hit.score, rank=i+1))
return search_results
def get_papers(self) -> list[dict]:
if not self.client:
return []
try:
result = self.client.scroll(
collection_name=config.QDRANT_COLLECTION,
limit=10000,
with_payload=["paper_name"]
)
papers = {}
for point in result[0]:
name = point.payload.get("paper_name", "")
if name:
papers[name] = papers.get(name, 0) + 1
return [{"paper_name": k, "chunk_count": v} for k, v in papers.items()]
except Exception:
return []
def delete_paper(self, paper_name: str) -> bool:
if not self.client:
return False
try:
self.client.delete(
collection_name=config.QDRANT_COLLECTION,
points_selector=models.FilterSelector(
filter=models.Filter(
must=[models.FieldCondition(
key="paper_name",
match=models.MatchValue(value=paper_name)
)]
)
)
)
return True
except Exception:
return False
def get_stats(self) -> dict:
if not self.client:
return {"papers_indexed": 0, "chunks_indexed": 0}
try:
info = self.client.get_collection(config.QDRANT_COLLECTION)
papers = self.get_papers()
return {
"papers_indexed": len(papers),
"chunks_indexed": info.points_count
}
except Exception:
return {"papers_indexed": 0, "chunks_indexed": 0}
vector_store = QdrantVectorStore()