Spaces:
Sleeping
Sleeping
| import os | |
| from pinecone import Pinecone | |
| from dotenv import load_dotenv | |
| import logging | |
| from typing import Optional, List, Dict, Any, Union, Tuple | |
| import time | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| import google.generativeai as genai | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain.callbacks.manager import Callbacks | |
| from langchain_core.documents import Document | |
| from langchain_core.pydantic_v1 import Field | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| # Load environment variables | |
| load_dotenv() | |
| # Pinecone API key and index name | |
| PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") | |
| PINECONE_INDEX_NAME = os.getenv("PINECONE_INDEX_NAME") | |
| GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") | |
| # Pinecone retrieval configuration | |
| DEFAULT_LIMIT_K = int(os.getenv("PINECONE_DEFAULT_LIMIT_K", "10")) | |
| DEFAULT_TOP_K = int(os.getenv("PINECONE_DEFAULT_TOP_K", "6")) | |
| DEFAULT_SIMILARITY_METRIC = os.getenv("PINECONE_DEFAULT_SIMILARITY_METRIC", "cosine") | |
| DEFAULT_SIMILARITY_THRESHOLD = float(os.getenv("PINECONE_DEFAULT_SIMILARITY_THRESHOLD", "0.75")) | |
| ALLOWED_METRICS = os.getenv("PINECONE_ALLOWED_METRICS", "cosine,dotproduct,euclidean").split(",") | |
| # Export constants for importing elsewhere | |
| __all__ = [ | |
| 'get_pinecone_index', | |
| 'check_db_connection', | |
| 'search_vectors', | |
| 'upsert_vectors', | |
| 'delete_vectors', | |
| 'fetch_metadata', | |
| 'get_chain', | |
| 'DEFAULT_TOP_K', | |
| 'DEFAULT_LIMIT_K', | |
| 'DEFAULT_SIMILARITY_METRIC', | |
| 'DEFAULT_SIMILARITY_THRESHOLD', | |
| 'ALLOWED_METRICS', | |
| 'ThresholdRetriever' | |
| ] | |
| # Configure Google API | |
| if GOOGLE_API_KEY: | |
| genai.configure(api_key=GOOGLE_API_KEY) | |
| # Initialize global variables to store instances of Pinecone and index | |
| pc = None | |
| index = None | |
| _retriever_instance = None | |
| # Check environment variables | |
| if not PINECONE_API_KEY: | |
| logger.error("PINECONE_API_KEY is not set in environment variables") | |
| if not PINECONE_INDEX_NAME: | |
| logger.error("PINECONE_INDEX_NAME is not set in environment variables") | |
| # Initialize Pinecone | |
| def init_pinecone(): | |
| """Initialize pinecone connection using new API""" | |
| global pc, index | |
| try: | |
| # Only initialize if not already initialized | |
| if pc is None: | |
| logger.info(f"Initializing Pinecone connection to index {PINECONE_INDEX_NAME}...") | |
| # Check if API key and index name are set | |
| if not PINECONE_API_KEY: | |
| logger.error("PINECONE_API_KEY is not set in environment variables") | |
| return None | |
| if not PINECONE_INDEX_NAME: | |
| logger.error("PINECONE_INDEX_NAME is not set in environment variables") | |
| return None | |
| # Initialize Pinecone client using the new API | |
| pc = Pinecone(api_key=PINECONE_API_KEY) | |
| try: | |
| # Check if index exists | |
| index_list = pc.list_indexes() | |
| if not hasattr(index_list, 'names') or PINECONE_INDEX_NAME not in index_list.names(): | |
| logger.error(f"Index {PINECONE_INDEX_NAME} does not exist in Pinecone") | |
| return None | |
| # Get existing index | |
| index = pc.Index(PINECONE_INDEX_NAME) | |
| logger.info(f"Pinecone connection established to index {PINECONE_INDEX_NAME}") | |
| except Exception as connection_error: | |
| logger.error(f"Error connecting to Pinecone index: {connection_error}") | |
| return None | |
| return index | |
| except ImportError as e: | |
| logger.error(f"Required package for Pinecone is missing: {e}") | |
| return None | |
| except Exception as e: | |
| logger.error(f"Unexpected error initializing Pinecone: {e}") | |
| return None | |
| # Get Pinecone index singleton | |
| def get_pinecone_index(): | |
| """Get Pinecone index""" | |
| global index | |
| if index is None: | |
| index = init_pinecone() | |
| return index | |
| # Check Pinecone connection | |
| def check_db_connection(): | |
| """Check Pinecone connection""" | |
| try: | |
| pinecone_index = get_pinecone_index() | |
| if pinecone_index is None: | |
| return False | |
| # Check index information to confirm connection is working | |
| stats = pinecone_index.describe_index_stats() | |
| # Get total vector count from the new result structure | |
| total_vectors = stats.get('total_vector_count', 0) | |
| if hasattr(stats, 'namespaces'): | |
| # If there are namespaces, calculate total vector count from namespaces | |
| total_vectors = sum(ns.get('vector_count', 0) for ns in stats.namespaces.values()) | |
| logger.info(f"Pinecone connection is working. Total vectors: {total_vectors}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error in Pinecone connection: {e}") | |
| return False | |
| # Convert similarity score based on the metric | |
| def convert_score(score: float, metric: str) -> float: | |
| """ | |
| Convert similarity score to a 0-1 scale based on the metric used. | |
| For metrics like euclidean distance where lower is better, we invert the score. | |
| Args: | |
| score: The raw similarity score | |
| metric: The similarity metric used | |
| Returns: | |
| A normalized score between 0-1 where higher means more similar | |
| """ | |
| if metric.lower() in ["euclidean", "l2"]: | |
| # For distance metrics (lower is better), we inverse and normalize | |
| # Assuming max reasonable distance is 2.0 for normalized vectors | |
| return max(0, 1 - (score / 2.0)) | |
| else: | |
| # For cosine and dot product (higher is better), return as is | |
| return score | |
| # Filter results based on similarity threshold | |
| def filter_by_threshold(results, threshold: float, metric: str) -> List[Dict]: | |
| """ | |
| Filter query results based on similarity threshold. | |
| Args: | |
| results: The query results from Pinecone | |
| threshold: The similarity threshold (0-1) | |
| metric: The similarity metric used | |
| Returns: | |
| Filtered list of matches | |
| """ | |
| filtered_matches = [] | |
| if not hasattr(results, 'matches'): | |
| return filtered_matches | |
| for match in results.matches: | |
| # Get the score | |
| score = getattr(match, 'score', 0) | |
| # Convert score based on metric | |
| normalized_score = convert_score(score, metric) | |
| # Filter based on threshold | |
| if normalized_score >= threshold: | |
| # Add normalized score as an additional attribute | |
| match.normalized_score = normalized_score | |
| filtered_matches.append(match) | |
| return filtered_matches | |
| # Search vectors in Pinecone with advanced options | |
| async def search_vectors( | |
| query_vector, | |
| top_k: int = DEFAULT_TOP_K, | |
| limit_k: int = DEFAULT_LIMIT_K, | |
| similarity_metric: str = DEFAULT_SIMILARITY_METRIC, | |
| similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, | |
| namespace: str = "Default", | |
| filter: Optional[Dict] = None | |
| ) -> Dict: | |
| """ | |
| Search for most similar vectors in Pinecone with advanced filtering options. | |
| Args: | |
| query_vector: The query vector | |
| top_k: Number of results to return (after threshold filtering) | |
| limit_k: Maximum number of results to retrieve from Pinecone | |
| similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean) | |
| similarity_threshold: Threshold for similarity (0-1) | |
| namespace: Namespace to search in | |
| filter: Filter query | |
| Returns: | |
| Search results with matches filtered by threshold | |
| """ | |
| try: | |
| # Validate parameters | |
| if similarity_metric not in ALLOWED_METRICS: | |
| logger.warning(f"Invalid similarity metric: {similarity_metric}. Using default: {DEFAULT_SIMILARITY_METRIC}") | |
| similarity_metric = DEFAULT_SIMILARITY_METRIC | |
| if limit_k < top_k: | |
| logger.warning(f"limit_k ({limit_k}) must be greater than or equal to top_k ({top_k}). Setting limit_k to {top_k}") | |
| limit_k = top_k | |
| # Perform search directly without cache | |
| pinecone_index = get_pinecone_index() | |
| if pinecone_index is None: | |
| logger.error("Failed to get Pinecone index for search") | |
| return None | |
| # Query Pinecone with the provided metric and higher limit_k to allow for threshold filtering | |
| results = pinecone_index.query( | |
| vector=query_vector, | |
| top_k=limit_k, # Retrieve more results than needed to allow for threshold filtering | |
| namespace=namespace, | |
| filter=filter, | |
| include_metadata=True, | |
| include_values=False, # No need to return vector values to save bandwidth | |
| metric=similarity_metric # Specify similarity metric | |
| ) | |
| # Filter results by threshold | |
| filtered_matches = filter_by_threshold(results, similarity_threshold, similarity_metric) | |
| # Limit to top_k after filtering | |
| filtered_matches = filtered_matches[:top_k] | |
| # Create a new results object with filtered matches | |
| results.matches = filtered_matches | |
| # Log search result metrics | |
| match_count = len(filtered_matches) | |
| logger.info(f"Pinecone search returned {match_count} matches after threshold filtering (metric: {similarity_metric}, threshold: {similarity_threshold}, namespace: {namespace})") | |
| return results | |
| except Exception as e: | |
| logger.error(f"Error searching vectors: {e}") | |
| return None | |
| # Upsert vectors to Pinecone | |
| async def upsert_vectors(vectors, namespace="Default"): | |
| """Upsert vectors to Pinecone index""" | |
| try: | |
| pinecone_index = get_pinecone_index() | |
| if pinecone_index is None: | |
| logger.error("Failed to get Pinecone index for upsert") | |
| return None | |
| response = pinecone_index.upsert( | |
| vectors=vectors, | |
| namespace=namespace | |
| ) | |
| # Log upsert metrics | |
| upserted_count = response.get('upserted_count', 0) | |
| logger.info(f"Upserted {upserted_count} vectors to Pinecone") | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error upserting vectors: {e}") | |
| return None | |
| # Delete vectors from Pinecone | |
| async def delete_vectors(ids, namespace="Default"): | |
| """Delete vectors from Pinecone index""" | |
| try: | |
| pinecone_index = get_pinecone_index() | |
| if pinecone_index is None: | |
| logger.error("Failed to get Pinecone index for delete") | |
| return False | |
| response = pinecone_index.delete( | |
| ids=ids, | |
| namespace=namespace | |
| ) | |
| logger.info(f"Deleted vectors with IDs {ids} from Pinecone") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Error deleting vectors: {e}") | |
| return False | |
| # Fetch vector metadata from Pinecone | |
| async def fetch_metadata(ids, namespace="Default"): | |
| """Fetch metadata for specific vector IDs""" | |
| try: | |
| pinecone_index = get_pinecone_index() | |
| if pinecone_index is None: | |
| logger.error("Failed to get Pinecone index for fetch") | |
| return None | |
| response = pinecone_index.fetch( | |
| ids=ids, | |
| namespace=namespace | |
| ) | |
| return response | |
| except Exception as e: | |
| logger.error(f"Error fetching vector metadata: {e}") | |
| return None | |
| # Create a custom retriever class for Langchain integration | |
| class ThresholdRetriever(BaseRetriever): | |
| """ | |
| Custom retriever that supports threshold-based filtering and multiple similarity metrics. | |
| This integrates with the Langchain ecosystem while using our advanced retrieval logic. | |
| """ | |
| vectorstore: Any = Field(description="Vector store to use for retrieval") | |
| embeddings: Any = Field(description="Embeddings model to use for retrieval") | |
| search_kwargs: Dict[str, Any] = Field(default_factory=dict, description="Search kwargs for the vectorstore") | |
| top_k: int = Field(default=DEFAULT_TOP_K, description="Number of results to return after filtering") | |
| limit_k: int = Field(default=DEFAULT_LIMIT_K, description="Maximum number of results to retrieve from Pinecone") | |
| similarity_metric: str = Field(default=DEFAULT_SIMILARITY_METRIC, description="Similarity metric to use") | |
| similarity_threshold: float = Field(default=DEFAULT_SIMILARITY_THRESHOLD, description="Threshold for similarity") | |
| namespace: str = "Default" | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| arbitrary_types_allowed = True | |
| async def search_vectors_sync( | |
| self, query_vector, | |
| top_k: int = DEFAULT_TOP_K, | |
| limit_k: int = DEFAULT_LIMIT_K, | |
| similarity_metric: str = DEFAULT_SIMILARITY_METRIC, | |
| similarity_threshold: float = DEFAULT_SIMILARITY_THRESHOLD, | |
| namespace: str = "Default", | |
| filter: Optional[Dict] = None | |
| ) -> Dict: | |
| """Synchronous wrapper for search_vectors""" | |
| import asyncio | |
| try: | |
| # Get current event loop or create a new one | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| # Use event loop to run async function | |
| if loop.is_running(): | |
| # If we're in an event loop, use asyncio.create_task | |
| task = asyncio.create_task(search_vectors( | |
| query_vector=query_vector, | |
| top_k=top_k, | |
| limit_k=limit_k, | |
| similarity_metric=similarity_metric, | |
| similarity_threshold=similarity_threshold, | |
| namespace=namespace, | |
| filter=filter | |
| )) | |
| return await task | |
| else: | |
| # If not in an event loop, just await directly | |
| return await search_vectors( | |
| query_vector=query_vector, | |
| top_k=top_k, | |
| limit_k=limit_k, | |
| similarity_metric=similarity_metric, | |
| similarity_threshold=similarity_threshold, | |
| namespace=namespace, | |
| filter=filter | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error in search_vectors_sync: {e}") | |
| return None | |
| def _get_relevant_documents( | |
| self, query: str, *, run_manager: Callbacks = None | |
| ) -> List[Document]: | |
| """ | |
| Get documents relevant to the query using threshold-based retrieval. | |
| Args: | |
| query: The query string | |
| run_manager: The callbacks manager | |
| Returns: | |
| List of relevant documents | |
| """ | |
| # Generate embedding for query using the embeddings model | |
| try: | |
| # Use the embeddings model we stored in the class | |
| embedding = self.embeddings.embed_query(query) | |
| except Exception as e: | |
| logger.error(f"Error generating embedding: {e}") | |
| # Fallback to creating a new embedding model if needed | |
| embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001") | |
| embedding = embedding_model.embed_query(query) | |
| # Perform search with advanced options - avoid asyncio.run() | |
| import asyncio | |
| # Get or create event loop | |
| try: | |
| loop = asyncio.get_event_loop() | |
| except RuntimeError: | |
| loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(loop) | |
| # Run asynchronous search in a safe way | |
| if loop.is_running(): | |
| # We're inside an existing event loop (like in FastAPI) | |
| # Use a different approach - convert it to a synchronous call | |
| from concurrent.futures import ThreadPoolExecutor | |
| import functools | |
| # Define a wrapper function to run in a thread | |
| def run_async_in_thread(): | |
| # Create a new event loop for this thread | |
| thread_loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(thread_loop) | |
| # Run the coroutine and return the result | |
| return thread_loop.run_until_complete(search_vectors( | |
| query_vector=embedding, | |
| top_k=self.top_k, | |
| limit_k=self.limit_k, | |
| similarity_metric=self.similarity_metric, | |
| similarity_threshold=self.similarity_threshold, | |
| namespace=self.namespace, | |
| # filter=self.search_kwargs.get("filter", None) | |
| )) | |
| # Run the async function in a thread | |
| with ThreadPoolExecutor() as executor: | |
| search_result = executor.submit(run_async_in_thread).result() | |
| else: | |
| # No event loop running, we can use run_until_complete | |
| search_result = loop.run_until_complete(search_vectors( | |
| query_vector=embedding, | |
| top_k=self.top_k, | |
| limit_k=self.limit_k, | |
| similarity_metric=self.similarity_metric, | |
| similarity_threshold=self.similarity_threshold, | |
| namespace=self.namespace, | |
| # filter=self.search_kwargs.get("filter", None) | |
| )) | |
| # Convert to documents | |
| documents = [] | |
| if search_result and hasattr(search_result, 'matches'): | |
| for match in search_result.matches: | |
| # Extract metadata | |
| metadata = {} | |
| if hasattr(match, 'metadata'): | |
| metadata = match.metadata | |
| # Add score to metadata | |
| score = getattr(match, 'score', 0) | |
| normalized_score = getattr(match, 'normalized_score', score) | |
| metadata['score'] = score | |
| metadata['normalized_score'] = normalized_score | |
| # Extract text | |
| text = metadata.get('text', '') | |
| if 'text' in metadata: | |
| del metadata['text'] # Remove from metadata since it's the content | |
| # Create Document | |
| doc = Document( | |
| page_content=text, | |
| metadata=metadata | |
| ) | |
| documents.append(doc) | |
| return documents | |
| # Get the retrieval chain with Pinecone vector store | |
| def get_chain( | |
| index_name=PINECONE_INDEX_NAME, | |
| namespace="Default", | |
| top_k=DEFAULT_TOP_K, | |
| limit_k=DEFAULT_LIMIT_K, | |
| similarity_metric=DEFAULT_SIMILARITY_METRIC, | |
| similarity_threshold=DEFAULT_SIMILARITY_THRESHOLD | |
| ): | |
| """ | |
| Get the retrieval chain with Pinecone vector store using threshold-based retrieval. | |
| Args: | |
| index_name: Pinecone index name | |
| namespace: Pinecone namespace | |
| top_k: Number of results to return after filtering | |
| limit_k: Maximum number of results to retrieve from Pinecone | |
| similarity_metric: Similarity metric to use (cosine, dotproduct, euclidean) | |
| similarity_threshold: Threshold for similarity (0-1) | |
| Returns: | |
| ThresholdRetriever instance | |
| """ | |
| global _retriever_instance | |
| try: | |
| # If already initialized with same parameters, return cached instance | |
| if _retriever_instance is not None: | |
| return _retriever_instance | |
| start_time = time.time() | |
| logger.info("Initializing new retriever chain with threshold-based filtering") | |
| # Initialize embeddings model | |
| embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") | |
| # Get index | |
| pinecone_index = get_pinecone_index() | |
| if not pinecone_index: | |
| logger.error("Failed to get Pinecone index for retriever chain") | |
| return None | |
| # Get statistics for logging | |
| try: | |
| stats = pinecone_index.describe_index_stats() | |
| total_vectors = stats.get('total_vector_count', 0) | |
| logger.info(f"Pinecone index stats - Total vectors: {total_vectors}") | |
| except Exception as e: | |
| logger.error(f"Error getting index stats: {e}") | |
| # Use Pinecone from langchain_community.vectorstores | |
| from langchain_community.vectorstores import Pinecone as LangchainPinecone | |
| logger.info(f"Creating Pinecone vectorstore with index: {index_name}, namespace: {namespace}") | |
| vectorstore = LangchainPinecone.from_existing_index( | |
| embedding=embeddings, | |
| index_name=index_name, | |
| namespace=namespace, | |
| text_key="text" | |
| ) | |
| # Create threshold-based retriever | |
| logger.info(f"Creating ThresholdRetriever with top_k={top_k}, limit_k={limit_k}, " + | |
| f"metric={similarity_metric}, threshold={similarity_threshold}") | |
| # Create ThresholdRetriever with both vectorstore and embeddings | |
| _retriever_instance = ThresholdRetriever( | |
| vectorstore=vectorstore, | |
| embeddings=embeddings, # Pass embeddings separately | |
| top_k=top_k, | |
| limit_k=limit_k, | |
| similarity_metric=similarity_metric, | |
| similarity_threshold=similarity_threshold | |
| ) | |
| logger.info(f"Pinecone retriever initialized in {time.time() - start_time:.2f} seconds") | |
| return _retriever_instance | |
| except Exception as e: | |
| logger.error(f"Error creating retrieval chain: {e}") | |
| return None |