| |
|
| |
|
| | import os
|
| | from dotenv import load_dotenv
|
| | from operator import itemgetter
|
| |
|
| |
|
| | from langchain_groq import ChatGroq
|
| |
|
| |
|
| | from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| |
|
| |
|
| | from langchain_core.runnables import RunnableParallel, RunnablePassthrough
|
| | from langchain_core.output_parsers import StrOutputParser
|
| | from langchain_core.runnables.history import RunnableWithMessageHistory
|
| |
|
| |
|
| | from utils.query_expansion import (
|
| | MultiQueryExpander,
|
| | QueryStrategy,
|
| | expand_query_simple
|
| | )
|
| | from langchain_core.documents import Document
|
| | from typing import List
|
| |
|
| | def create_multi_query_retriever(base_retriever, llm=None, strategy: str = "balanced"):
|
| | """
|
| | Wraps a base retriever with query expansion capabilities.
|
| |
|
| | This function creates a retriever that:
|
| | 1. Expands the input query into multiple variations
|
| | 2. Retrieves documents for each query variation
|
| | 3. Combines and deduplicates results
|
| | 4. Returns the merged document set
|
| |
|
| | Args:
|
| | base_retriever: The underlying retriever (e.g., EnsembleRetriever, FAISS)
|
| | llm: Optional LLM for advanced query expansion
|
| | strategy: "quick" (2 queries), "balanced" (3-4), or "comprehensive" (5-6)
|
| |
|
| | Returns:
|
| | A function that performs multi-query retrieval
|
| |
|
| | Example:
|
| | >>> multi_retriever = create_multi_query_retriever(ensemble_retriever, strategy="balanced")
|
| | >>> docs = multi_retriever("How do I debug Python code?")
|
| | """
|
| |
|
| | def multi_query_retrieve(query: str) -> List[Document]:
|
| | """
|
| | Retrieves documents using expanded query variations.
|
| |
|
| | Args:
|
| | query: Original user query
|
| |
|
| | Returns:
|
| | List of unique documents from all query variations
|
| | """
|
| |
|
| | query_variations = expand_query_simple(query, strategy=strategy, llm=llm)
|
| |
|
| | print(f"\n🔍 Query Expansion Active:")
|
| | print(f" Original: {query}")
|
| | print(f" Generated {len(query_variations)} variations:")
|
| | for i, var in enumerate(query_variations, 1):
|
| | print(f" {i}. {var}")
|
| |
|
| |
|
| | all_docs = []
|
| | seen_content = set()
|
| |
|
| | for i, query_var in enumerate(query_variations):
|
| | try:
|
| | docs = base_retriever.invoke(query_var)
|
| | print(f" ✓ Query {i+1}: Retrieved {len(docs)} docs")
|
| |
|
| |
|
| | for doc in docs:
|
| | content_hash = hash(doc.page_content)
|
| | if content_hash not in seen_content:
|
| | seen_content.add(content_hash)
|
| | all_docs.append(doc)
|
| |
|
| | except Exception as e:
|
| | print(f" ✗ Query {i+1}: Error - {str(e)[:50]}")
|
| | continue
|
| |
|
| | print(f" 📊 Total unique documents: {len(all_docs)}")
|
| |
|
| |
|
| | return all_docs[:20]
|
| |
|
| | return multi_query_retrieve
|
| |
|
| |
|
| | def create_rag_chain(retriever, get_session_history_func, enable_query_expansion=True, expansion_strategy="balanced"):
|
| | """
|
| | Creates an advanced Retrieval-Augmented Generation (RAG) chain with hybrid search,
|
| | query expansion, query rewriting, answer refinement, and conversational memory.
|
| |
|
| | Args:
|
| | retriever: A configured LangChain retriever object.
|
| | get_session_history_func: A function to get the chat history for a session.
|
| | enable_query_expansion: Whether to enable multi-query expansion (default: True)
|
| | expansion_strategy: "quick", "balanced", or "comprehensive" (default: "balanced")
|
| |
|
| | Returns:
|
| | A LangChain runnable object representing the RAG chain with memory.
|
| |
|
| | Raises:
|
| | ValueError: If the GROQ_API_KEY is missing.
|
| | """
|
| |
|
| | load_dotenv()
|
| |
|
| |
|
| | api_key = os.getenv("GROQ_API_KEY")
|
| |
|
| | if not api_key or api_key == "your_groq_api_key_here":
|
| | error_msg = "GROQ_API_KEY not found or not configured properly.\n"
|
| |
|
| |
|
| | if os.getenv("SPACE_ID") or os.getenv("SPACES_ZERO_GPU"):
|
| | error_msg += (
|
| | "For Hugging Face Spaces: Set GROQ_API_KEY in your Space's Settings > Repository Secrets.\n"
|
| | "Go to your Space settings and add GROQ_API_KEY as a secret variable."
|
| | )
|
| | else:
|
| | error_msg += (
|
| | "For local development: Set your GROQ API key in the .env file.\n"
|
| | "Copy .env.example to .env and add your actual API key.\n"
|
| | "Get your API key from: https://console.groq.com/keys"
|
| | )
|
| |
|
| | raise ValueError(error_msg)
|
| |
|
| |
|
| |
|
| |
|
| | llm = ChatGroq(model_name="openai/gpt-oss-20b", api_key=api_key, temperature=0.3)
|
| |
|
| |
|
| | if enable_query_expansion:
|
| | print(f"\n✨ Query Expansion ENABLED (strategy: {expansion_strategy})")
|
| | enhanced_retriever = create_multi_query_retriever(
|
| | base_retriever=retriever,
|
| | llm=llm,
|
| | strategy=expansion_strategy
|
| | )
|
| | else:
|
| | print("\n⚠️ Query Expansion DISABLED")
|
| | enhanced_retriever = retriever
|
| |
|
| |
|
| | print("\nSetting up query rewriting chain...")
|
| | rewrite_template = """You are an expert at optimizing search queries for document retrieval.
|
| |
|
| | Your task: Transform the user's question into an optimized search query that will retrieve the most relevant information from the document database.
|
| |
|
| | Guidelines:
|
| | 1. Incorporate context from the chat history to make the query standalone
|
| | 2. Expand abbreviations and clarify ambiguous terms
|
| | 3. Include key technical terms and synonyms that might appear in documents
|
| | 4. For complex questions, preserve all important aspects
|
| | 5. Keep queries specific and focused
|
| |
|
| | IMPORTANT: Output ONLY the optimized search query, nothing else.
|
| |
|
| | Chat History:
|
| | {chat_history}
|
| |
|
| | Follow-up Question: {question}
|
| |
|
| | Optimized Search Query:"""
|
| | rewrite_prompt = ChatPromptTemplate.from_messages([
|
| | ("system", rewrite_template),
|
| | MessagesPlaceholder(variable_name="chat_history"),
|
| | ("human", "Based on our conversation, reformulate this question to be a standalone query: {question}")
|
| | ])
|
| | query_rewriter = rewrite_prompt | llm | StrOutputParser()
|
| |
|
| |
|
| | print("\nSetting up main RAG chain...")
|
| | rag_template = """You are Cognichat, an expert AI assistant developed by Ritesh and Alish.
|
| | Your primary function is to provide accurate, relevant answers based ONLY on the information in the provided context.
|
| |
|
| | IMPORTANT INSTRUCTIONS:
|
| | 1. ONLY use information from the Context below - do not use external knowledge
|
| | 2. If the answer is not in the Context, clearly state: "I don't have enough information in the document to answer that question."
|
| | 3. When answering from the Context, be specific and cite relevant details
|
| | 4. For complex or lengthy documents, synthesize information from multiple parts of the Context if needed
|
| | 5. If the Context has partial information, acknowledge what you know and what's missing
|
| | 6. Provide clear, well-structured answers with examples from the Context when available
|
| | 7. If the question requires information not in the Context, explain what information would be needed
|
| |
|
| | Context (from the uploaded documents):
|
| | {context}
|
| |
|
| | ---
|
| | Based on the context above, provide a clear and accurate answer to the user's question."""
|
| | rag_prompt = ChatPromptTemplate.from_messages([
|
| | ("system", rag_template),
|
| | MessagesPlaceholder(variable_name="chat_history"),
|
| | ("human", "{question}"),
|
| | ])
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | setup_and_retrieval = RunnableParallel({
|
| | "context": query_rewriter | enhanced_retriever,
|
| | "question": itemgetter("question"),
|
| | "chat_history": itemgetter("chat_history"),
|
| | })
|
| |
|
| |
|
| |
|
| | conversational_rag_chain = (
|
| | setup_and_retrieval
|
| | | rag_prompt
|
| | | llm
|
| | | StrOutputParser()
|
| | )
|
| |
|
| |
|
| | chain_with_memory = RunnableWithMessageHistory(
|
| | conversational_rag_chain,
|
| | get_session_history_func,
|
| | input_messages_key="question",
|
| | history_messages_key="chat_history",
|
| | )
|
| |
|
| |
|
| | print("\nSetting up answer refinement chain...")
|
| | refine_template = """You are an expert editor specializing in making technical and complex information clear and accessible.
|
| |
|
| | Your task: Refine the given answer to improve clarity, structure, and readability while maintaining ALL original information.
|
| |
|
| | Guidelines:
|
| | 1. Improve sentence structure and flow
|
| | 2. Use formatting (bullet points, numbered lists, bold text) when it enhances understanding
|
| | 3. Break long paragraphs into digestible sections
|
| | 4. Ensure technical terms are used correctly
|
| | 5. Add logical transitions between ideas
|
| | 6. NEVER add new information not in the original answer
|
| | 7. NEVER remove important details
|
| | 8. If the answer states lack of information, keep that explicit
|
| |
|
| | Original Answer:
|
| | {answer}
|
| |
|
| | Refined Answer:"""
|
| | refine_prompt = ChatPromptTemplate.from_template(refine_template)
|
| | refinement_chain = refine_prompt | llm | StrOutputParser()
|
| |
|
| |
|
| |
|
| |
|
| | final_chain = (
|
| | lambda input_dict: {"answer": chain_with_memory.invoke(input_dict, config=input_dict.get('config'))}
|
| | ) | refinement_chain
|
| |
|
| | print("\nFinalizing the complete chain with memory...")
|
| | return final_chain |