derm-ai / app /services /vector_database_search.py
muhammadnoman76's picture
update
e02b28a
import os
import uuid
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_qdrant import Qdrant
from qdrant_client import QdrantClient, models
from qdrant_client.http.exceptions import UnexpectedResponse
from dotenv import load_dotenv
import logging
load_dotenv()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
os.environ["GOOGLE_API_KEY"] = os.getenv("GEMINI_API_KEY")
QDRANT_URL = os.getenv("QDRANT_URL")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "dermatology_docs")
class VectorDatabaseSearch:
def __init__(self, collection_name=QDRANT_COLLECTION_NAME):
self.collection_name = collection_name
self.embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
self.client = None
self.vectorstore = None
self.is_initialized = False
# Initialize connection
self._initialize_connection()
def _initialize_connection(self):
"""Initialize Qdrant connection with proper error handling"""
try:
# Check if credentials are available
if not QDRANT_URL or not QDRANT_API_KEY:
logger.warning("Qdrant credentials not found. Vector search will be disabled.")
self.is_initialized = False
return
# Initialize Qdrant client
self.client = QdrantClient(
url=QDRANT_URL,
api_key=QDRANT_API_KEY,
timeout=30 # Add timeout
)
# Test connection
self.client.get_collections()
# Initialize collection
self._initialize_collection()
# Initialize vector store
self.vectorstore = Qdrant(
client=self.client,
collection_name=self.collection_name,
embeddings=self.embeddings
)
self.is_initialized = True
logger.info(f"Successfully connected to Qdrant collection: {self.collection_name}")
except UnexpectedResponse as e:
logger.error(f"Authentication error with Qdrant: {e}")
self.is_initialized = False
except Exception as e:
logger.error(f"Error initializing Qdrant connection: {e}")
self.is_initialized = False
def _initialize_collection(self):
"""Initialize Qdrant collection if it doesn't exist"""
if not self.client:
return
try:
collections = self.client.get_collections()
collection_exists = any(c.name == self.collection_name for c in collections.collections)
if not collection_exists:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=768,
distance=models.Distance.COSINE
)
)
logger.info(f"Created new collection: {self.collection_name}")
else:
# Check if collection has data
collection_info = self.client.get_collection(self.collection_name)
logger.info(f"Collection {self.collection_name} exists with {collection_info.points_count} points")
except Exception as e:
logger.error(f"Error initializing collection: {e}")
self.is_initialized = False
def add_pdf(self, pdf_path):
"""Add PDF to vector database"""
if not self.is_initialized:
logger.error("Vector database not initialized. Cannot add PDF.")
return False
try:
loader = PyPDFLoader(pdf_path)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
split_docs = splitter.split_documents(docs)
book_name = os.path.splitext(os.path.basename(pdf_path))[0]
logger.info(f"Processing {book_name} with {len(split_docs)} chunks")
for doc in split_docs:
doc.metadata = {
"source": book_name,
"page": doc.metadata.get('page', 1),
"id": str(uuid.uuid4())
}
self.vectorstore.add_documents(split_docs)
logger.info(f"Successfully added {len(split_docs)} chunks from {book_name}")
return True
except Exception as e:
logger.error(f"Error adding PDF: {e}")
return False
def search(self, query, top_k=5):
"""Search documents based on query"""
if not self.is_initialized:
logger.warning("Vector database not initialized. Returning empty results.")
return []
try:
# Check if collection has any data
collection_info = self.client.get_collection(self.collection_name)
if collection_info.points_count == 0:
logger.warning(f"Collection {self.collection_name} is empty. No documents to search.")
return []
# Perform similarity search
results = self.vectorstore.similarity_search_with_score(query, k=top_k)
formatted = []
for doc, score in results:
# Convert score to confidence percentage (cosine similarity)
confidence = (1 - score) * 100 # Qdrant returns distance, not similarity
formatted.append({
"source": doc.metadata.get('source', 'Unknown'),
"page": doc.metadata.get('page', 0),
"content": doc.page_content[:500],
"confidence": round(confidence, 2)
})
logger.info(f"Found {len(formatted)} results for query: {query[:50]}...")
return formatted
except Exception as e:
logger.error(f"Search error: {e}")
return []
def get_book_info(self):
"""Retrieve list of unique book sources in the collection"""
if not self.is_initialized:
logger.warning("Vector database not initialized.")
return []
try:
# Check if collection exists
collections = self.client.get_collections()
if not any(c.name == self.collection_name for c in collections.collections):
logger.info(f"Collection {self.collection_name} does not exist yet")
return []
# Get collection info
collection_info = self.client.get_collection(self.collection_name)
if collection_info.points_count == 0:
logger.info("Collection is empty")
return []
# Get sample of points to extract sources
points = self.client.scroll(
collection_name=self.collection_name,
limit=min(1000, collection_info.points_count),
with_payload=True,
with_vectors=False
)[0]
books = set()
for point in points:
if hasattr(point, 'payload') and point.payload:
if 'metadata' in point.payload and 'source' in point.payload['metadata']:
books.add(point.payload['metadata']['source'])
elif 'source' in point.payload:
books.add(point.payload['source'])
logger.info(f"Found {len(books)} unique books in collection")
return list(books)
except Exception as e:
logger.error(f"Error retrieving book info: {e}")
return []
def is_available(self):
"""Check if vector database is available and has data"""
if not self.is_initialized:
return False
try:
collection_info = self.client.get_collection(self.collection_name)
return collection_info.points_count > 0
except:
return False