import json import os import glob from typing import List, Optional from dotenv import load_dotenv import logging # Load environment variables from .env file load_dotenv() from langchain_core.documents import Document from langchain_core.output_parsers import BaseOutputParser from langchain_core.prompts import PromptTemplate from langchain_community.vectorstores import Chroma from langchain_huggingface import HuggingFaceEmbeddings from langchain_community.retrievers import BM25Retriever from langchain.retrievers import EnsembleRetriever from langchain.retrievers.multi_query import MultiQueryRetriever from langchain_google_genai import GoogleGenerativeAI # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class LineListOutputParser(BaseOutputParser[List[str]]): """Custom output parser for a list of lines with better error handling.""" def parse(self, text: str) -> List[str]: """Parse the LLM output into a list of queries.""" try: lines = text.strip().split("\n") # Remove empty lines and clean up cleaned_lines = [] for line in lines: cleaned = line.strip() if cleaned and not cleaned.startswith("#") and len(cleaned) > 5: # Remove numbering if present (e.g., "1. ", "- ", etc.) if cleaned[0].isdigit() and ". " in cleaned: cleaned = cleaned.split(". ", 1)[1] elif cleaned.startswith("- "): cleaned = cleaned[2:] cleaned_lines.append(cleaned) # Ensure we have at least one query if not cleaned_lines: cleaned_lines = [text.strip()] return cleaned_lines except Exception as e: logger.warning(f"Error parsing output: {e}. Returning original text.") return [text.strip()] if text.strip() else [""] def create_custom_multi_query_retriever( base_retriever, llm, num_queries: int = 5, include_original: bool = True ): """Create a custom MultiQueryRetriever with improved prompt.""" # Custom prompt template for better query generation # query_prompt = PromptTemplate( # input_variables=["question"], # template="""You are an AI assistant specialized in generating diverse search queries. # Your task is to generate {num_queries} different versions of the given user question to retrieve relevant documents from a knowledge base. # Guidelines: # - Create variations that capture different aspects and perspectives of the question # - Use synonyms and alternative phrasings # - Consider different levels of specificity (broader and narrower) # - Focus on the core intent while varying the expression # - Each query should be a complete, well-formed question or statement # Original question: {question} # Generate {num_queries} alternative queries (one per line):""".replace("{num_queries}", str(num_queries)) # ) # Create the MultiQueryRetriever with custom components multi_query_retriever = MultiQueryRetriever.from_llm( retriever=base_retriever, llm=llm, include_original=include_original ) # # Override the output parser # multi_query_retriever.output_parser = LineListOutputParser() return multi_query_retriever def validate_environment(): """Validate that required environment variables are set.""" required_vars = ["GOOGLE_API_KEY"] missing_vars = [var for var in required_vars if not os.getenv(var)] if missing_vars: raise ValueError(f"Missing required environment variables: {missing_vars}") logger.info("✅ Environment variables validated.") def load_documents_from_json(chunks_directory: str) -> List[Document]: """Load documents from JSON files with better error handling.""" json_files = glob.glob(os.path.join(chunks_directory, "*.json")) if not json_files: raise ValueError(f"No JSON files found in directory: {chunks_directory}") logger.info(f"Found {len(json_files)} JSON files: {[os.path.basename(f) for f in json_files]}") documents = [] total_processed = 0 for json_file in json_files: try: logger.info(f"Processing: {os.path.basename(json_file)}") with open(json_file, "r", encoding="utf-8") as f: chunks_data = json.load(f) file_doc_count = 0 for element in chunks_data: try: text = element.get("text", "").strip() if not text: # Skip empty text continue metadata = { "source": element.get("filename", "unknown"), "filetype": element.get("filetype", "unknown"), "element_id": element.get("element_id", "unknown"), "json_source": os.path.basename(json_file) } # Add table-specific metadata if present if element.get("type") == "TableElement" and element.get("table_text_as_html"): metadata["table_text_as_html"] = element["table_text_as_html"] # metadata["element_type"] = "table" else: metadata["element_type"] = element.get("type", "text") doc = Document(page_content=text, metadata=metadata) documents.append(doc) file_doc_count += 1 except Exception as e: logger.warning(f"Error processing element in {json_file}: {e}") continue logger.info(f" → Loaded {file_doc_count} documents from {os.path.basename(json_file)}") total_processed += file_doc_count except Exception as e: logger.error(f"Error processing file {json_file}: {e}") continue if not documents: raise ValueError("No valid documents were loaded from any JSON files.") logger.info(f"✅ Total loaded: {len(documents)} documents from {len(json_files)} JSON files.") return documents def prepare_environment_and_retriever( chunks_directory: str = "./data/", model_name: str = "intfloat/multilingual-e5-base", collection_name: str = "Guide_2023_e5_multilingual", persist_directory: str = "chroma_db_multilingual", k_vector: int = 6, k_sparse: int = 2, ensemble_weights: List[float] = [0.5, 0.5], llm_model_name: str = "gemini-2.0-flash-exp", num_query_variations: int = 5, include_original_query: bool = True, temperature: float = 0.1 ): """ Prepare the complete retrieval environment with MultiQueryRetriever. Args: chunks_directory: Directory containing JSON files with document chunks model_name: HuggingFace embedding model name collection_name: Chroma collection name persist_directory: Directory to persist Chroma database k_vector: Number of documents to retrieve from vector search k_sparse: Number of documents to retrieve from BM25 search ensemble_weights: Weights for ensemble retriever [vector, sparse] llm_model_name: Google Gemini model name for query expansion num_query_variations: Number of query variations to generate include_original_query: Whether to include original query in search temperature: LLM temperature for query generation Returns: MultiQueryRetriever: Configured retriever ready for use """ # Validate environment validate_environment() # Load documents documents = load_documents_from_json(chunks_directory) # Create embedding function logger.info(f"Creating embeddings with model: {model_name}") embedding_function = HuggingFaceEmbeddings( model_name=model_name, ) # Create or load vector store logger.info("Creating/loading vector store...") try: # Try to load existing vectorstore first if os.path.exists(persist_directory): vectorstore = Chroma( collection_name=collection_name, embedding_function=embedding_function, persist_directory=persist_directory ) logger.info("✅ Loaded existing vector store.") else: # Create new vectorstore vectorstore = Chroma.from_documents( documents=documents, embedding=embedding_function, collection_name=collection_name, persist_directory=persist_directory ) logger.info("✅ Created new vector store with multilingual embeddings.") except Exception as e: logger.warning(f"Error with persistent storage: {e}. Creating in-memory store.") vectorstore = Chroma.from_documents( documents=documents, embedding=embedding_function, collection_name=collection_name ) # Create base retrievers logger.info("Setting up retrievers...") # Vector retriever vector_retriever = vectorstore.as_retriever( search_type="similarity", search_kwargs={"k": k_vector} ) # BM25 (sparse) retriever bm25_retriever = BM25Retriever.from_documents(documents) bm25_retriever.k = k_sparse # Ensemble retriever (combining vector + sparse search) ensemble_retriever = EnsembleRetriever( retrievers=[vector_retriever, bm25_retriever], weights=ensemble_weights ) logger.info(f"✅ Ensemble retriever created with weights: {ensemble_weights}") # Language model for multi-query expansion logger.info(f"Initializing LLM: {llm_model_name}") try: llm = GoogleGenerativeAI( model=llm_model_name, google_api_key=os.getenv("GOOGLE_API_KEY"), temperature=temperature, max_output_tokens=1000 # Reasonable limit for query generation ) # Test the LLM with a simple call test_response = llm.invoke("Generate a simple test query about artificial intelligence.") logger.info("✅ LLM connection verified.") except Exception as e: logger.error(f"Error initializing LLM: {e}") raise # Create MultiQueryRetriever with custom configuration logger.info("Creating MultiQueryRetriever...") try: multi_query_retriever = create_custom_multi_query_retriever( base_retriever=ensemble_retriever, llm=llm, num_queries=num_query_variations, include_original=include_original_query ) logger.info(f"✅ MultiQueryRetriever ready:") logger.info(f" - Vector search: top-{k_vector}") logger.info(f" - Sparse search: top-{k_sparse}") logger.info(f" - Ensemble weights: {ensemble_weights}") logger.info(f" - Query variations: {num_query_variations}") logger.info(f" - Include original: {include_original_query}") return multi_query_retriever except Exception as e: logger.error(f"Error creating MultiQueryRetriever: {e}") logger.info("Falling back to ensemble retriever without query expansion.") return ensemble_retriever