import os from pydantic import Field from langchain_community.retrievers import PineconeHybridSearchRetriever from pinecone import Pinecone from pinecone_text.sparse import BM25Encoder from typing import List, Dict, Any, Optional from langchain_core.callbacks.manager import CallbackManagerForRetrieverRun from langchain_core.documents import Document from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain.chains import create_retrieval_chain from langchain_core.messages import BaseMessage from langchain_core.runnables import RunnableLambda from langchain_community.chat_message_histories import ChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.runnables.history import RunnableWithMessageHistory from langchain.chains import create_history_aware_retriever from langchain.chains.combine_documents import create_stuff_documents_chain from dotenv import load_dotenv from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain_community.document_compressors import FlashrankRerank from flashrank import Ranker load_dotenv() GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY") PINECONE_API_KEY = os.getenv("PINECONE_API_KEY") embed_model = os.getenv("EMBEDDING_MODEL") llm_model = os.getenv("LLM_MODEL") rerank_model = os.getenv("RERANK_MODEL") # Dense Vector embedding class FixedDimensionGoogleGenerativeAIEmbeddings(GoogleGenerativeAIEmbeddings): """ A wrapper that fixes the output_dimensionality for embedding methods. """ # Define a Pydantic-compatible field to store the output dimension. # This makes the field visible to external validation checks. output_dimensionality: Optional[int] = Field( None, description="The fixed output dimension for embeddings." ) def __init__(self, **kwargs): super().__init__(**kwargs) def embed_documents(self, texts, **kwargs): if self.output_dimensionality is not None: kwargs['output_dimensionality'] = self.output_dimensionality return super().embed_documents(texts, **kwargs) def embed_query(self, text, **kwargs): if self.output_dimensionality is not None: kwargs['output_dimensionality'] = self.output_dimensionality return super().embed_query(text, **kwargs) embeddings = FixedDimensionGoogleGenerativeAIEmbeddings( google_api_key=GOOGLE_API_KEY, model=embed_model, output_dimensionality=768 ) bm25_encoder = BM25Encoder().default() index_name = "personal-assistant" pc = Pinecone(api_key=PINECONE_API_KEY) index = pc.Index(index_name) class CustomHybridSearchRetriever(PineconeHybridSearchRetriever): def _get_relevant_documents( self, query: str, *, run_manager: Optional[CallbackManagerForRetrieverRun] = None ) -> List[Document]: """Get documents relevant to the query using hybrid search with fallback to dense-only.""" try: # Try hybrid search first return super()._get_relevant_documents(query, run_manager=run_manager) except Exception as e: # If sparse encoding fails, fall back to dense-only search if "Sparse vector must contain at least one value" in str(e): print("Falling back to dense-only search for query:", query) # Generate dense embeddings embedding = self.embeddings.embed_query(query) # Search with only dense vectors results = self.index.query( vector=embedding, top_k=self.top_k, include_metadata=True, namespace=self.namespace, ) # Convert Pinecone results to LangChain documents return self._process_pinecone_results(results) else: # If it's a different error, re-raise it raise e def _process_pinecone_results(self, results): """Process Pinecone results into Document objects.""" docs = [] for result in results.matches: metadata = result.metadata or {} # Create Document with page content and metadata doc = Document( page_content=metadata.pop("text", ""), metadata=metadata, ) docs.append(doc) return docs namespace = 'portfolio' base_retriever = CustomHybridSearchRetriever( embeddings=embeddings, sparse_encoder=bm25_encoder, index=index, top_k=30, namespace=namespace ) reranker_compressor = FlashrankRerank( model=rerank_model, top_n=5 ) retriever = ContextualCompressionRetriever( base_compressor=reranker_compressor, base_retriever=base_retriever ) llm = ChatGoogleGenerativeAI( model=llm_model, google_api_key=GOOGLE_API_KEY, temperature=0.5, ) store = {} def get_full_session_history(session_id: str) -> BaseChatMessageHistory: if session_id not in store: print(f"INFO: Creating new chat history for session: {session_id}") store[session_id] = ChatMessageHistory() return store[session_id] MAX_HISTORY_TURNS = 3 MAX_HISTORY_MESSAGES = MAX_HISTORY_TURNS * 2 def limit_history_for_rag_chain(input_dict: Dict[str, Any]) -> Dict[str, Any]: modified_input = input_dict.copy() if "chat_history" in modified_input: history = modified_input["chat_history"] if isinstance(history, list) and all(isinstance(m, BaseMessage) for m in history): limited_history = history[-MAX_HISTORY_MESSAGES:] modified_input["chat_history"] = limited_history else: print("WARN: 'chat_history' in input_dict is not a list of BaseMessages. Passing as is.") return modified_input retriever_prompt_template = os.getenv("RETRIEVER_PROMPT").format(max_turns=MAX_HISTORY_TURNS) contextualize_q_prompt = ChatPromptTemplate.from_messages( [ ("system", retriever_prompt_template), MessagesPlaceholder(variable_name="chat_history"), # This will receive the limited history ("human", "{input}"), ] ) # The history-aware retriever now uses the new, reranking-enabled retriever history_aware_retriever = create_history_aware_retriever( llm, retriever, contextualize_q_prompt ) SYSTEM_PROMPT = os.getenv("SYSTEM_PROMPT").format(max_turns=MAX_HISTORY_TURNS) qa_prompt = ChatPromptTemplate.from_messages( [ ("system", SYSTEM_PROMPT), # MessagesPlaceholder(variable_name="context"), # Make sure your prompt includes `context` ("human", "{input}"), ] ) qa_chain = create_stuff_documents_chain(llm, qa_prompt) rag_chain = create_retrieval_chain(history_aware_retriever, qa_chain) conversational_rag_chain = RunnableWithMessageHistory( runnable=RunnableLambda(limit_history_for_rag_chain) | rag_chain, get_session_history=get_full_session_history, input_messages_key="input", history_messages_key="chat_history", output_messages_key="answer", ) def chat(query: str, session_id: str): response = conversational_rag_chain.invoke( {"input": query}, config={"configurable": {"session_id": session_id}} ) return response.get("answer", "No answer found.")