| import os |
| import numpy as np |
| import google.generativeai as genai |
| from dotenv import load_dotenv |
| from typing import List, Dict, Optional, Union |
| import json |
| import pickle |
| import uuid |
| from qdrant_client import QdrantClient |
| from qdrant_client.http import models |
| from qdrant_client.http.models import Distance, VectorParams, PointStruct |
|
|
| |
| load_dotenv() |
|
|
| class EmbeddingManager: |
| def __init__(self, api_key: Optional[str] = None): |
| """Initialize the embedding manager with Gemini API.""" |
| self.api_key = api_key or os.getenv('GEMINI_API_KEY') |
| if not self.api_key: |
| raise ValueError("GEMINI_API_KEY not found in environment variables") |
| |
| genai.configure(api_key=self.api_key) |
| self.model_name = "models/text-embedding-004" |
| |
| def generate_embedding(self, text: str) -> np.ndarray: |
| """Generate embedding for a single text.""" |
| try: |
| result = genai.embed_content( |
| model=self.model_name, |
| content=text, |
| task_type="retrieval_document" |
| ) |
| return np.array(result['embedding'], dtype=np.float32) |
| except Exception as e: |
| print(f"Error generating embedding: {e}") |
| return np.array([]) |
| |
| def generate_embeddings_batch(self, texts: List[str]) -> List[np.ndarray]: |
| """Generate embeddings for multiple texts.""" |
| embeddings = [] |
| for i, text in enumerate(texts): |
| print(f"Generating embedding {i+1}/{len(texts)}") |
| embedding = self.generate_embedding(text) |
| if embedding.size > 0: |
| embeddings.append(embedding) |
| else: |
| print(f"Failed to generate embedding for text {i+1}") |
| return embeddings |
| |
| def generate_query_embedding(self, query: str) -> np.ndarray: |
| """Generate embedding for a query (search).""" |
| try: |
| result = genai.embed_content( |
| model=self.model_name, |
| content=query, |
| task_type="retrieval_query" |
| ) |
| return np.array(result['embedding'], dtype=np.float32) |
| except Exception as e: |
| print(f"Error generating query embedding: {e}") |
| return np.array([]) |
|
|
|
|
| class QdrantVectorStore: |
| def __init__(self, collection_name: Optional[str] = None, url: Optional[str] = None, api_key: Optional[str] = None): |
| """Initialize Qdrant vector store.""" |
| self.collection_name = collection_name or os.getenv('QDRANT_COLLECTION_NAME', 'rag_documents') |
| |
| |
| qdrant_url = url or os.getenv('QDRANT_URL') |
| qdrant_api_key = api_key or os.getenv('QDRANT_API_KEY') |
| |
| |
| if qdrant_url and qdrant_api_key: |
| |
| print(f"Connecting to Qdrant Cloud at {qdrant_url}") |
| self.client = QdrantClient( |
| url=qdrant_url, |
| api_key=qdrant_api_key, |
| ) |
| else: |
| |
| print("Using local Qdrant instance at http://localhost:6333") |
| self.client = QdrantClient("localhost", port=6333) |
| |
| self.embedding_dim = 768 |
| |
| def create_collection(self, force_recreate: bool = False): |
| """Create or recreate the collection.""" |
| try: |
| |
| collections = self.client.get_collections().collections |
| collection_exists = any(col.name == self.collection_name for col in collections) |
| |
| if collection_exists and force_recreate: |
| print(f"Deleting existing collection: {self.collection_name}") |
| self.client.delete_collection(collection_name=self.collection_name) |
| collection_exists = False |
| |
| if not collection_exists: |
| print(f"Creating collection: {self.collection_name}") |
| self.client.create_collection( |
| collection_name=self.collection_name, |
| vectors_config=VectorParams(size=self.embedding_dim, distance=Distance.COSINE), |
| ) |
| print(f"โ Collection '{self.collection_name}' created successfully") |
| else: |
| print(f"โ Collection '{self.collection_name}' already exists") |
| |
| except Exception as e: |
| print(f"Error creating collection: {e}") |
| raise |
| |
| def add_documents(self, chunks: List[str], embeddings: List[np.ndarray], metadata: List[Dict] = None, session_id: Optional[str] = None): |
| """Add documents with their embeddings to Qdrant. |
| |
| Args: |
| chunks: list of text chunks |
| embeddings: list of numpy embeddings corresponding to chunks |
| metadata: optional list of dicts with metadata per chunk |
| session_id: optional session identifier to attach to each point payload |
| """ |
| if metadata is None: |
| metadata = [{"index": i} for i in range(len(chunks))] |
| |
| if len(chunks) != len(embeddings) or len(chunks) != len(metadata): |
| raise ValueError("chunks, embeddings, and metadata must have the same length") |
| |
| |
| self.create_collection() |
| |
| |
| points = [] |
| for i, (chunk, embedding, meta) in enumerate(zip(chunks, embeddings, metadata)): |
| point_id = str(uuid.uuid4()) |
| |
| |
| payload = { |
| "text": chunk, |
| "metadata": meta |
| } |
|
|
| |
| if session_id is not None: |
| payload["session_id"] = session_id |
| |
| point = PointStruct( |
| id=point_id, |
| vector=embedding.tolist(), |
| payload=payload |
| ) |
| points.append(point) |
| |
| |
| try: |
| print(f"Uploading {len(points)} documents to Qdrant...") |
| self.client.upsert( |
| collection_name=self.collection_name, |
| points=points |
| ) |
| print(f"โ Successfully uploaded {len(points)} documents") |
| except Exception as e: |
| print(f"Error uploading documents: {e}") |
| raise |
| |
| def similarity_search(self, query_embedding: np.ndarray, top_k: int = 5, score_threshold: float = 0.0, |
| include_context: bool = False) -> List[Dict]: |
| """ |
| Search for similar documents in Qdrant. |
| |
| Args: |
| query_embedding: The query vector |
| top_k: Number of results to return |
| score_threshold: Minimum similarity score |
| include_context: If True, try to include surrounding chunks for context |
| """ |
| try: |
| search_results = self.client.search( |
| collection_name=self.collection_name, |
| query_vector=query_embedding.tolist(), |
| limit=top_k, |
| score_threshold=score_threshold |
| ) |
| |
| results = [] |
| for hit in search_results: |
| metadata = hit.payload['metadata'] |
| |
| |
| result = { |
| 'id': hit.id, |
| 'similarity': hit.score, |
| 'chunk': hit.payload['text'], |
| 'metadata': metadata, |
| 'source': { |
| 'file_name': metadata.get('file_name', 'Unknown'), |
| 'file_path': metadata.get('file_path', 'Unknown'), |
| 'chunk_index': metadata.get('chunk_index', 0) |
| } |
| } |
| |
| |
| if include_context: |
| result['context'] = self._get_surrounding_context(metadata) |
| |
| |
| result['citation'] = f"{metadata.get('file_name', 'Unknown')} (chunk {metadata.get('chunk_index', 0)})" |
| |
| results.append(result) |
| |
| return results |
| |
| except Exception as e: |
| print(f"Error searching documents: {e}") |
| return [] |
| |
| def _get_surrounding_context(self, metadata: Dict) -> Dict: |
| """Get surrounding chunks for context (if available).""" |
| try: |
| file_path = metadata.get('file_path') |
| chunk_index = metadata.get('chunk_index', 0) |
| |
| |
| context_filter = { |
| "must": [ |
| {"key": "metadata.file_path", "match": {"value": file_path}} |
| ] |
| } |
| |
| |
| context_results = self.client.search( |
| collection_name=self.collection_name, |
| query_vector=[0.0] * self.embedding_dim, |
| query_filter=context_filter, |
| limit=10, |
| score_threshold=0.0 |
| ) |
| |
| |
| file_chunks = [] |
| for hit in context_results: |
| hit_metadata = hit.payload['metadata'] |
| if hit_metadata.get('chunk_index') is not None: |
| file_chunks.append({ |
| 'index': hit_metadata['chunk_index'], |
| 'text': hit.payload['text'] |
| }) |
| |
| file_chunks.sort(key=lambda x: x['index']) |
| |
| |
| current_idx = None |
| for i, chunk in enumerate(file_chunks): |
| if chunk['index'] == chunk_index: |
| current_idx = i |
| break |
| |
| context = { |
| 'previous_chunk': file_chunks[current_idx - 1]['text'] if current_idx and current_idx > 0 else None, |
| 'next_chunk': file_chunks[current_idx + 1]['text'] if current_idx is not None and current_idx < len(file_chunks) - 1 else None, |
| 'total_chunks_in_file': len(file_chunks) |
| } |
| |
| return context |
| |
| except Exception as e: |
| print(f"Error getting context: {e}") |
| return {'error': 'Could not retrieve context'} |
| |
| def get_relevant_passages(self, query_embedding: np.ndarray, top_k: int = 5) -> List[str]: |
| """Return just the text passages for RAG prompt creation.""" |
| results = self.similarity_search(query_embedding, top_k) |
| return [result['chunk'] for result in results if result['chunk']] |
| |
| def enhanced_search(self, query_embedding: np.ndarray, top_k: int = 5) -> str: |
| """Return a formatted string with search results ready for RAG.""" |
| results = self.similarity_search(query_embedding, top_k, include_context=True) |
| |
| if not results: |
| return "No relevant documents found." |
| |
| formatted_results = [] |
| for i, result in enumerate(results, 1): |
| formatted_result = f""" |
| **Result {i}** (Similarity: {result['similarity']:.3f}) |
| **Source**: {result['citation']} |
| **Content**: {result['chunk']} |
| """ |
| |
| |
| if 'context' in result and not result['context'].get('error'): |
| context = result['context'] |
| if context.get('previous_chunk'): |
| formatted_result += f"\n**Previous Context**: ...{context['previous_chunk'][-100:]}" |
| if context.get('next_chunk'): |
| formatted_result += f"\n**Following Context**: {context['next_chunk'][:100]}..." |
| |
| formatted_results.append(formatted_result) |
| |
| return "\n" + "="*50 + "\n".join(formatted_results) |
| |
| def get_collection_info(self) -> Dict: |
| """Get information about the collection.""" |
| try: |
| info = self.client.get_collection(collection_name=self.collection_name) |
| return { |
| 'name': self.collection_name, |
| 'points_count': info.points_count, |
| 'vectors_count': info.vectors_count, |
| 'status': info.status |
| } |
| except Exception as e: |
| print(f"Error getting collection info: {e}") |
| return {} |
| |
| def delete_collection(self): |
| """Delete the collection.""" |
| try: |
| self.client.delete_collection(collection_name=self.collection_name) |
| print(f"โ Collection '{self.collection_name}' deleted") |
| except Exception as e: |
| print(f"Error deleting collection: {e}") |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("Testing Qdrant Vector Store...") |
| |
| try: |
| embedding_manager = EmbeddingManager() |
| qdrant_store = QdrantVectorStore() |
| |
| |
| sample_texts = [ |
| "This is a sample document about machine learning and artificial intelligence.", |
| "Python is a great programming language for data science and AI development.", |
| "Qdrant is a vector database that enables similarity search at scale." |
| ] |
| |
| print("Generating embeddings...") |
| embeddings = embedding_manager.generate_embeddings_batch(sample_texts) |
| |
| if embeddings: |
| |
| metadata = [ |
| {"source": "sample_doc", "topic": "machine_learning", "index": 0}, |
| {"source": "sample_doc", "topic": "programming", "index": 1}, |
| {"source": "sample_doc", "topic": "database", "index": 2} |
| ] |
| |
| |
| qdrant_store.add_documents(sample_texts, embeddings, metadata) |
| |
| |
| query = "What is vector database?" |
| query_embedding = embedding_manager.generate_query_embedding(query) |
| |
| if query_embedding.size > 0: |
| print(f"\n๐ BASIC SEARCH: {query}") |
| results = qdrant_store.similarity_search(query_embedding, top_k=2) |
| for result in results: |
| print(f"Similarity: {result['similarity']:.4f}") |
| print(f"Source: {result['citation']}") |
| print(f"Text: {result['chunk']}") |
| print(f"Topic: {result['metadata']['topic']}") |
| print("---") |
| |
| |
| print(f"\n๐ ENHANCED SEARCH (RAG-ready format):") |
| enhanced_results = qdrant_store.enhanced_search(query_embedding, top_k=2) |
| print(enhanced_results) |
| |
| |
| info = qdrant_store.get_collection_info() |
| print(f"\nCollection Info: {info}") |
| |
| except Exception as e: |
| print(f"Error in test: {e}") |
| print("Make sure:") |
| print("1. Your GEMINI_API_KEY is valid in .env file") |
| print("2. Qdrant is running (docker run -p 6333:6333 qdrant/qdrant) or configure Qdrant Cloud") |