Docker_Deploy / src /python /vector_store.py
Shaheryar Shah
Add backend files for RAG Chatbot Docker deployment
bec06d9
import uuid
import sys
from typing import List, Dict, Any, Optional
from qdrant_client import QdrantClient
from qdrant_client.http import models
# Add the current directory to the path so we can import config
sys.path.insert(0, os.path.dirname(__file__))
from config import QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME
import logging
logger = logging.getLogger(__name__)
class VectorStore:
"""
A class to handle vector storage and retrieval using Qdrant.
"""
def __init__(self):
if QDRANT_API_KEY:
self.client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
prefer_grpc=True
)
else:
self.client = QdrantClient(url=QDRANT_URL)
def create_collection(self, vector_size: int = 1536):
"""Create a collection in Qdrant if it doesn't exist."""
try:
# Check if collection exists
collections = self.client.get_collections().collections
if not any(col.name == COLLECTION_NAME for col in collections):
self.client.create_collection(
collection_name=COLLECTION_NAME,
vectors_config=models.VectorParams(
size=vector_size,
distance=models.Distance.COSINE
),
)
logger.info(f"Created collection: {COLLECTION_NAME}")
else:
logger.info(f"Collection {COLLECTION_NAME} already exists")
except Exception as e:
logger.error(f"Error creating collection: {str(e)}")
raise
def add_documents(self, documents: List[Dict[str, Any]]):
"""Add documents with embeddings to the collection."""
try:
points = []
for doc in documents:
# Generate a unique ID for each document chunk
point_id = str(uuid.uuid4())
# Extract content, embedding, and metadata
content = doc.get('content', '')
embedding = doc.get('embedding', [])
metadata = doc.get('metadata', {})
# Create payload with all available metadata
payload = {
"content": content,
"source": metadata.get('source', ''),
"file_name": metadata.get('file_name', ''),
"file_path": metadata.get('file_path', ''),
}
# Add additional metadata if available
if 'chunk_id' in metadata:
payload['chunk_id'] = metadata['chunk_id']
if 'total_chunks' in metadata:
payload['total_chunks'] = metadata['total_chunks']
points.append(
models.PointStruct(
id=point_id,
vector=embedding,
payload=payload
)
)
# Upload points to the collection
self.client.upload_points(
collection_name=COLLECTION_NAME,
points=points
)
logger.info(f"Added {len(points)} documents to collection {COLLECTION_NAME}")
except Exception as e:
logger.error(f"Error adding documents: {str(e)}")
raise
def delete_collection(self):
"""Delete the collection if it exists."""
try:
self.client.delete_collection(collection_name=COLLECTION_NAME)
logger.info(f"Deleted collection: {COLLECTION_NAME}")
except Exception as e:
logger.error(f"Error deleting collection: {str(e)}")
raise
def delete_documents_by_source(self, source: str):
"""Delete documents that match a specific source."""
try:
# Find points with the matching source
result = self.client.scroll(
collection_name=COLLECTION_NAME,
scroll_filter=models.Filter(
must=[
models.FieldCondition(
key="source",
match=models.MatchValue(value=source)
)
]
),
limit=10000 # Adjust as needed
)
# Extract IDs of matching points
point_ids = [point.id for point in result[0]]
if point_ids:
# Delete the points
self.client.delete(
collection_name=COLLECTION_NAME,
points_selector=models.PointIdsList(
points=point_ids
)
)
logger.info(f"Deleted {len(point_ids)} documents from source: {source}")
else:
logger.info(f"No documents found from source: {source}")
except Exception as e:
logger.error(f"Error deleting documents by source: {str(e)}")
raise
def search_similar(self, query_embedding: List[float], top_k: int = 5) -> List[Dict[str, Any]]:
"""Search for similar documents based on embedding."""
try:
results = self.client.search(
collection_name=COLLECTION_NAME,
query_vector=query_embedding,
limit=top_k
)
hits = []
for hit in results:
hits.append({
'content': hit.payload.get('content', ''),
'source': hit.payload.get('source', ''),
'score': hit.score,
'id': hit.id
})
return hits
except Exception as e:
logger.error(f"Error searching for similar documents: {str(e)}")
return []
def get_all_documents_count(self) -> int:
"""Get the total number of documents in the collection."""
try:
info = self.client.get_collection(collection_name=COLLECTION_NAME)
return info.points_count
except Exception as e:
logger.error(f"Error getting document count: {str(e)}")
return 0