rag / embeddings_qdrant.py
jessica45's picture
updated rag
5f04d6e verified
import os
import numpy as np
import google.generativeai as genai
from dotenv import load_dotenv
from typing import List, Dict, Optional, Union
import json
import pickle
import uuid
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.models import Distance, VectorParams, PointStruct
# Load environment variables
load_dotenv()
class EmbeddingManager:
def __init__(self, api_key: Optional[str] = None):
"""Initialize the embedding manager with Gemini API."""
self.api_key = api_key or os.getenv('GEMINI_API_KEY')
if not self.api_key:
raise ValueError("GEMINI_API_KEY not found in environment variables")
genai.configure(api_key=self.api_key)
self.model_name = "models/text-embedding-004"
def generate_embedding(self, text: str) -> np.ndarray:
"""Generate embedding for a single text."""
try:
result = genai.embed_content(
model=self.model_name,
content=text,
task_type="retrieval_document"
)
return np.array(result['embedding'], dtype=np.float32)
except Exception as e:
print(f"Error generating embedding: {e}")
return np.array([])
def generate_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]:
"""Generate embeddings for multiple texts."""
embeddings = []
for i, text in enumerate(texts):
print(f"Generating embedding {i+1}/{len(texts)}")
embedding = self.generate_embedding(text)
if embedding.size > 0:
embeddings.append(embedding)
else:
print(f"Failed to generate embedding for text {i+1}")
return embeddings
def generate_query_embedding(self, query: str) -> np.ndarray:
"""Generate embedding for a query (search)."""
try:
result = genai.embed_content(
model=self.model_name,
content=query,
task_type="retrieval_query"
)
return np.array(result['embedding'], dtype=np.float32)
except Exception as e:
print(f"Error generating query embedding: {e}")
return np.array([])
class QdrantVectorStore:
def __init__(self, collection_name: Optional[str] = None, url: Optional[str] = None, api_key: Optional[str] = None):
"""Initialize Qdrant vector store."""
self.collection_name = collection_name or os.getenv('QDRANT_COLLECTION_NAME', 'rag_documents')
# Get Qdrant configuration from environment
qdrant_url = url or os.getenv('QDRANT_URL')
qdrant_api_key = api_key or os.getenv('QDRANT_API_KEY')
# Initialize Qdrant client
if qdrant_url and qdrant_api_key:
# Qdrant Cloud
print(f"Connecting to Qdrant Cloud at {qdrant_url}")
self.client = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
)
else:
# Local Qdrant (default)
print("Using local Qdrant instance at http://localhost:6333")
self.client = QdrantClient("localhost", port=6333)
self.embedding_dim = 768 # Gemini embedding dimension
def create_collection(self, force_recreate: bool = False):
"""Create or recreate the collection."""
try:
# Check if collection exists
collections = self.client.get_collections().collections
collection_exists = any(col.name == self.collection_name for col in collections)
if collection_exists and force_recreate:
print(f"Deleting existing collection: {self.collection_name}")
self.client.delete_collection(collection_name=self.collection_name)
collection_exists = False
if not collection_exists:
print(f"Creating collection: {self.collection_name}")
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=self.embedding_dim, distance=Distance.COSINE),
)
print(f"โœ“ Collection '{self.collection_name}' created successfully")
else:
print(f"โœ“ Collection '{self.collection_name}' already exists")
except Exception as e:
print(f"Error creating collection: {e}")
raise
def add_documents(self, chunks: List[str], embeddings: List[np.ndarray], metadata: List[Dict] = None, session_id: Optional[str] = None):
"""Add documents with their embeddings to Qdrant.
Args:
chunks: list of text chunks
embeddings: list of numpy embeddings corresponding to chunks
metadata: optional list of dicts with metadata per chunk
session_id: optional session identifier to attach to each point payload
"""
if metadata is None:
metadata = [{"index": i} for i in range(len(chunks))]
if len(chunks) != len(embeddings) or len(chunks) != len(metadata):
raise ValueError("chunks, embeddings, and metadata must have the same length")
# Ensure collection exists
self.create_collection()
# Prepare points for Qdrant
points = []
for i, (chunk, embedding, meta) in enumerate(zip(chunks, embeddings, metadata)):
point_id = str(uuid.uuid4())
# Combine text and metadata for payload
payload = {
"text": chunk,
"metadata": meta
}
# Attach session info if provided
if session_id is not None:
payload["session_id"] = session_id
point = PointStruct(
id=point_id,
vector=embedding.tolist(),
payload=payload
)
points.append(point)
# Upload points to Qdrant
try:
print(f"Uploading {len(points)} documents to Qdrant...")
self.client.upsert(
collection_name=self.collection_name,
points=points
)
print(f"โœ“ Successfully uploaded {len(points)} documents")
except Exception as e:
print(f"Error uploading documents: {e}")
raise
def similarity_search(self, query_embedding: np.ndarray, top_k: int = 5, score_threshold: float = 0.0,
include_context: bool = False) -> List[Dict]:
"""
Search for similar documents in Qdrant.
Args:
query_embedding: The query vector
top_k: Number of results to return
score_threshold: Minimum similarity score
include_context: If True, try to include surrounding chunks for context
"""
try:
search_results = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding.tolist(),
limit=top_k,
score_threshold=score_threshold
)
results = []
for hit in search_results:
metadata = hit.payload['metadata']
# Basic result structure
result = {
'id': hit.id,
'similarity': hit.score,
'chunk': hit.payload['text'],
'metadata': metadata,
'source': {
'file_name': metadata.get('file_name', 'Unknown'),
'file_path': metadata.get('file_path', 'Unknown'),
'chunk_index': metadata.get('chunk_index', 0)
}
}
# Add context if requested
if include_context:
result['context'] = self._get_surrounding_context(metadata)
# Add citation format
result['citation'] = f"{metadata.get('file_name', 'Unknown')} (chunk {metadata.get('chunk_index', 0)})"
results.append(result)
return results
except Exception as e:
print(f"Error searching documents: {e}")
return []
def _get_surrounding_context(self, metadata: Dict) -> Dict:
"""Get surrounding chunks for context (if available)."""
try:
file_path = metadata.get('file_path')
chunk_index = metadata.get('chunk_index', 0)
# Try to find adjacent chunks from the same file
context_filter = {
"must": [
{"key": "metadata.file_path", "match": {"value": file_path}}
]
}
# Search for chunks from same file
context_results = self.client.search(
collection_name=self.collection_name,
query_vector=[0.0] * self.embedding_dim, # Dummy vector
query_filter=context_filter,
limit=10,
score_threshold=0.0
)
# Sort by chunk index and get surrounding chunks
file_chunks = []
for hit in context_results:
hit_metadata = hit.payload['metadata']
if hit_metadata.get('chunk_index') is not None:
file_chunks.append({
'index': hit_metadata['chunk_index'],
'text': hit.payload['text']
})
file_chunks.sort(key=lambda x: x['index'])
# Find current chunk and get neighbors
current_idx = None
for i, chunk in enumerate(file_chunks):
if chunk['index'] == chunk_index:
current_idx = i
break
context = {
'previous_chunk': file_chunks[current_idx - 1]['text'] if current_idx and current_idx > 0 else None,
'next_chunk': file_chunks[current_idx + 1]['text'] if current_idx is not None and current_idx < len(file_chunks) - 1 else None,
'total_chunks_in_file': len(file_chunks)
}
return context
except Exception as e:
print(f"Error getting context: {e}")
return {'error': 'Could not retrieve context'}
def get_relevant_passages(self, query_embedding: np.ndarray, top_k: int = 5) -> List[str]:
"""Return just the text passages for RAG prompt creation."""
results = self.similarity_search(query_embedding, top_k)
return [result['chunk'] for result in results if result['chunk']]
def enhanced_search(self, query_embedding: np.ndarray, top_k: int = 5) -> str:
"""Return a formatted string with search results ready for RAG."""
results = self.similarity_search(query_embedding, top_k, include_context=True)
if not results:
return "No relevant documents found."
formatted_results = []
for i, result in enumerate(results, 1):
formatted_result = f"""
**Result {i}** (Similarity: {result['similarity']:.3f})
**Source**: {result['citation']}
**Content**: {result['chunk']}
"""
# Add context if available
if 'context' in result and not result['context'].get('error'):
context = result['context']
if context.get('previous_chunk'):
formatted_result += f"\n**Previous Context**: ...{context['previous_chunk'][-100:]}"
if context.get('next_chunk'):
formatted_result += f"\n**Following Context**: {context['next_chunk'][:100]}..."
formatted_results.append(formatted_result)
return "\n" + "="*50 + "\n".join(formatted_results)
def get_collection_info(self) -> Dict:
"""Get information about the collection."""
try:
info = self.client.get_collection(collection_name=self.collection_name)
return {
'name': self.collection_name,
'points_count': info.points_count,
'vectors_count': info.vectors_count,
'status': info.status
}
except Exception as e:
print(f"Error getting collection info: {e}")
return {}
def delete_collection(self):
"""Delete the collection."""
try:
self.client.delete_collection(collection_name=self.collection_name)
print(f"โœ“ Collection '{self.collection_name}' deleted")
except Exception as e:
print(f"Error deleting collection: {e}")
if __name__ == "__main__":
# Example usage with Qdrant
print("Testing Qdrant Vector Store...")
try:
embedding_manager = EmbeddingManager()
qdrant_store = QdrantVectorStore()
# Test with sample texts
sample_texts = [
"This is a sample document about machine learning and artificial intelligence.",
"Python is a great programming language for data science and AI development.",
"Qdrant is a vector database that enables similarity search at scale."
]
print("Generating embeddings...")
embeddings = embedding_manager.generate_embeddings_batch(sample_texts)
if embeddings:
# Create metadata
metadata = [
{"source": "sample_doc", "topic": "machine_learning", "index": 0},
{"source": "sample_doc", "topic": "programming", "index": 1},
{"source": "sample_doc", "topic": "database", "index": 2}
]
# Add to Qdrant
qdrant_store.add_documents(sample_texts, embeddings, metadata)
# Test search - Basic
query = "What is vector database?"
query_embedding = embedding_manager.generate_query_embedding(query)
if query_embedding.size > 0:
print(f"\n๐Ÿ” BASIC SEARCH: {query}")
results = qdrant_store.similarity_search(query_embedding, top_k=2)
for result in results:
print(f"Similarity: {result['similarity']:.4f}")
print(f"Source: {result['citation']}")
print(f"Text: {result['chunk']}")
print(f"Topic: {result['metadata']['topic']}")
print("---")
# Test enhanced search
print(f"\n๐Ÿš€ ENHANCED SEARCH (RAG-ready format):")
enhanced_results = qdrant_store.enhanced_search(query_embedding, top_k=2)
print(enhanced_results)
# Show collection info
info = qdrant_store.get_collection_info()
print(f"\nCollection Info: {info}")
except Exception as e:
print(f"Error in test: {e}")
print("Make sure:")
print("1. Your GEMINI_API_KEY is valid in .env file")
print("2. Qdrant is running (docker run -p 6333:6333 qdrant/qdrant) or configure Qdrant Cloud")