Ragcore / app /core /vectorstore.py
NinjainPJs's picture
Fix document deletion: use point IDs instead of filter-based delete, add dropdown selector
addfa4f
import logging
from qdrant_client import QdrantClient
from qdrant_client.http.models import (
Distance,
FieldCondition,
Filter,
MatchAny,
MatchValue,
PayloadSchemaType,
PointIdsList,
PointStruct,
Range,
VectorParams,
)
from app.config import get_settings
from app.models.document import Chunk
from app.models.schemas import SearchFilters
logger = logging.getLogger(__name__)
class VectorStoreService:
def __init__(self, url: str, api_key: str, collection_name: str):
self.client = QdrantClient(url=url, api_key=api_key)
self.collection_name = collection_name
logger.info(f"Connected to Qdrant at {url}")
def ensure_collection(self, vector_size: int = 384) -> None:
collections = [c.name for c in self.client.get_collections().collections]
if self.collection_name not in collections:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
)
logger.info(f"Created collection '{self.collection_name}' (dim={vector_size})")
else:
logger.info(f"Collection '{self.collection_name}' already exists")
# Ensure payload indexes exist for filterable fields
self._ensure_payload_indexes()
def _ensure_payload_indexes(self) -> None:
"""Create payload indexes for fields used in filtering."""
index_fields = {
"document_id": PayloadSchemaType.KEYWORD,
"source": PayloadSchemaType.KEYWORD,
"doc_type": PayloadSchemaType.KEYWORD,
"tags": PayloadSchemaType.KEYWORD,
"created_date": PayloadSchemaType.KEYWORD,
}
try:
collection_info = self.client.get_collection(self.collection_name)
existing_indexes = set(collection_info.payload_schema.keys()) if collection_info.payload_schema else set()
except Exception:
existing_indexes = set()
for field_name, field_type in index_fields.items():
if field_name not in existing_indexes:
try:
self.client.create_payload_index(
collection_name=self.collection_name,
field_name=field_name,
field_schema=field_type,
)
logger.info(f"Created payload index: {field_name} ({field_type})")
except Exception as e:
logger.warning(f"Could not create index for '{field_name}': {e}")
def upsert_chunks(self, chunks: list[Chunk], embeddings: list[list[float]]) -> None:
batch_size = 100
for i in range(0, len(chunks), batch_size):
batch_chunks = chunks[i : i + batch_size]
batch_embeddings = embeddings[i : i + batch_size]
points = [
PointStruct(
id=chunk.chunk_id,
vector=embedding,
payload={
"text": chunk.text,
"document_id": chunk.document_id,
"chunk_index": chunk.chunk_index,
"source": chunk.metadata.source,
"doc_type": chunk.metadata.doc_type,
"title": chunk.metadata.title,
"created_date": chunk.metadata.created_date.isoformat()
if chunk.metadata.created_date
else None,
"tags": chunk.metadata.tags,
"page_count": chunk.metadata.page_count,
},
)
for chunk, embedding in zip(batch_chunks, batch_embeddings)
]
self.client.upsert(collection_name=self.collection_name, points=points)
logger.info(f"Upserted {len(chunks)} chunks to '{self.collection_name}'")
def search(
self,
query_vector: list[float],
limit: int = 10,
filters: SearchFilters | None = None,
) -> list[dict]:
qdrant_filter = self._build_filter(filters) if filters and filters.has_filters() else None
results = self.client.query_points(
collection_name=self.collection_name,
query=query_vector,
limit=limit,
query_filter=qdrant_filter,
).points
return [
{
"chunk_id": str(r.id),
"text": r.payload.get("text", ""),
"score": r.score,
"document_id": r.payload.get("document_id", ""),
"metadata": {
"source": r.payload.get("source", ""),
"doc_type": r.payload.get("doc_type", ""),
"title": r.payload.get("title"),
"created_date": r.payload.get("created_date"),
"tags": r.payload.get("tags", []),
"page_count": r.payload.get("page_count"),
},
}
for r in results
]
def delete_document(self, document_id: str) -> int:
# First, find all point IDs belonging to this document
doc_filter = Filter(
must=[FieldCondition(key="document_id", match=MatchValue(value=document_id))]
)
point_ids = []
offset = None
while True:
results, next_offset = self.client.scroll(
collection_name=self.collection_name,
scroll_filter=doc_filter,
limit=100,
offset=offset,
with_payload=False,
with_vectors=False,
)
point_ids.extend([r.id for r in results])
if next_offset is None:
break
offset = next_offset
if not point_ids:
logger.warning(f"No points found for document '{document_id}'")
return 0
# Delete by point IDs (requires only write permission, not manage)
self.client.delete(
collection_name=self.collection_name,
points_selector=PointIdsList(points=point_ids),
)
logger.info(f"Deleted {len(point_ids)} points for document '{document_id}'")
return len(point_ids)
def scroll_all(self, batch_size: int = 100) -> list[dict]:
all_points = []
offset = None
while True:
results, next_offset = self.client.scroll(
collection_name=self.collection_name,
limit=batch_size,
offset=offset,
with_payload=True,
with_vectors=False,
)
for r in results:
all_points.append({
"chunk_id": str(r.id),
"text": r.payload.get("text", ""),
"document_id": r.payload.get("document_id", ""),
"metadata": {
"source": r.payload.get("source", ""),
"doc_type": r.payload.get("doc_type", ""),
"title": r.payload.get("title"),
"tags": r.payload.get("tags", []),
},
})
if next_offset is None:
break
offset = next_offset
return all_points
def get_document_ids(self) -> list[dict]:
all_points = self.scroll_all()
docs: dict[str, dict] = {}
for p in all_points:
doc_id = p["document_id"]
if doc_id not in docs:
docs[doc_id] = {
"document_id": doc_id,
"source": p["metadata"]["source"],
"title": p["metadata"].get("title"),
"doc_type": p["metadata"]["doc_type"],
"num_chunks": 0,
}
docs[doc_id]["num_chunks"] += 1
return list(docs.values())
def count(self) -> int:
info = self.client.get_collection(self.collection_name)
return info.points_count
@staticmethod
def _build_filter(filters: SearchFilters) -> Filter | None:
conditions = []
if filters.source:
conditions.append(FieldCondition(key="source", match=MatchValue(value=filters.source)))
if filters.doc_type:
conditions.append(FieldCondition(key="doc_type", match=MatchValue(value=filters.doc_type)))
if filters.tags:
conditions.append(FieldCondition(key="tags", match=MatchAny(any=filters.tags)))
if filters.date_from or filters.date_to:
range_params = {}
if filters.date_from:
range_params["gte"] = filters.date_from.isoformat()
if filters.date_to:
range_params["lte"] = filters.date_to.isoformat()
conditions.append(FieldCondition(key="created_date", range=Range(**range_params)))
return Filter(must=conditions) if conditions else None
_vectorstore: VectorStoreService | None = None
def get_vectorstore() -> VectorStoreService:
global _vectorstore
if _vectorstore is None:
settings = get_settings()
_vectorstore = VectorStoreService(
url=settings.qdrant_url,
api_key=settings.qdrant_api_key,
collection_name=settings.qdrant_collection,
)
_vectorstore.ensure_collection(vector_size=settings.embedding_dim)
return _vectorstore