Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| import glob | |
| from typing import List, Optional | |
| from dotenv import load_dotenv | |
| import logging | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| from langchain_core.documents import Document | |
| from langchain_core.output_parsers import BaseOutputParser | |
| from langchain_core.prompts import PromptTemplate | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain.retrievers import EnsembleRetriever | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| from langchain_google_genai import GoogleGenerativeAI | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class LineListOutputParser(BaseOutputParser[List[str]]): | |
| """Custom output parser for a list of lines with better error handling.""" | |
| def parse(self, text: str) -> List[str]: | |
| """Parse the LLM output into a list of queries.""" | |
| try: | |
| lines = text.strip().split("\n") | |
| # Remove empty lines and clean up | |
| cleaned_lines = [] | |
| for line in lines: | |
| cleaned = line.strip() | |
| if cleaned and not cleaned.startswith("#") and len(cleaned) > 5: | |
| # Remove numbering if present (e.g., "1. ", "- ", etc.) | |
| if cleaned[0].isdigit() and ". " in cleaned: | |
| cleaned = cleaned.split(". ", 1)[1] | |
| elif cleaned.startswith("- "): | |
| cleaned = cleaned[2:] | |
| cleaned_lines.append(cleaned) | |
| # Ensure we have at least one query | |
| if not cleaned_lines: | |
| cleaned_lines = [text.strip()] | |
| return cleaned_lines | |
| except Exception as e: | |
| logger.warning(f"Error parsing output: {e}. Returning original text.") | |
| return [text.strip()] if text.strip() else [""] | |
| def create_custom_multi_query_retriever( | |
| base_retriever, | |
| llm, | |
| num_queries: int = 5, | |
| include_original: bool = True | |
| ): | |
| """Create a custom MultiQueryRetriever with improved prompt.""" | |
| # Custom prompt template for better query generation | |
| # query_prompt = PromptTemplate( | |
| # input_variables=["question"], | |
| # template="""You are an AI assistant specialized in generating diverse search queries. | |
| # Your task is to generate {num_queries} different versions of the given user question to retrieve relevant documents from a knowledge base. | |
| # Guidelines: | |
| # - Create variations that capture different aspects and perspectives of the question | |
| # - Use synonyms and alternative phrasings | |
| # - Consider different levels of specificity (broader and narrower) | |
| # - Focus on the core intent while varying the expression | |
| # - Each query should be a complete, well-formed question or statement | |
| # Original question: {question} | |
| # Generate {num_queries} alternative queries (one per line):""".replace("{num_queries}", str(num_queries)) | |
| # ) | |
| # Create the MultiQueryRetriever with custom components | |
| multi_query_retriever = MultiQueryRetriever.from_llm( | |
| retriever=base_retriever, | |
| llm=llm, | |
| include_original=include_original | |
| ) | |
| # # Override the output parser | |
| # multi_query_retriever.output_parser = LineListOutputParser() | |
| return multi_query_retriever | |
| def validate_environment(): | |
| """Validate that required environment variables are set.""" | |
| required_vars = ["GOOGLE_API_KEY"] | |
| missing_vars = [var for var in required_vars if not os.getenv(var)] | |
| if missing_vars: | |
| raise ValueError(f"Missing required environment variables: {missing_vars}") | |
| logger.info("β Environment variables validated.") | |
| def load_documents_from_json(chunks_directory: str) -> List[Document]: | |
| """Load documents from JSON files with better error handling.""" | |
| json_files = glob.glob(os.path.join(chunks_directory, "*.json")) | |
| if not json_files: | |
| raise ValueError(f"No JSON files found in directory: {chunks_directory}") | |
| logger.info(f"Found {len(json_files)} JSON files: {[os.path.basename(f) for f in json_files]}") | |
| documents = [] | |
| total_processed = 0 | |
| for json_file in json_files: | |
| try: | |
| logger.info(f"Processing: {os.path.basename(json_file)}") | |
| with open(json_file, "r", encoding="utf-8") as f: | |
| chunks_data = json.load(f) | |
| file_doc_count = 0 | |
| for element in chunks_data: | |
| try: | |
| text = element.get("text", "").strip() | |
| if not text: # Skip empty text | |
| continue | |
| metadata = { | |
| "source": element.get("filename", "unknown"), | |
| "filetype": element.get("filetype", "unknown"), | |
| "element_id": element.get("element_id", "unknown"), | |
| "json_source": os.path.basename(json_file) | |
| } | |
| # Add table-specific metadata if present | |
| if element.get("type") == "TableElement" and element.get("table_text_as_html"): | |
| metadata["table_text_as_html"] = element["table_text_as_html"] | |
| # metadata["element_type"] = "table" | |
| else: | |
| metadata["element_type"] = element.get("type", "text") | |
| doc = Document(page_content=text, metadata=metadata) | |
| documents.append(doc) | |
| file_doc_count += 1 | |
| except Exception as e: | |
| logger.warning(f"Error processing element in {json_file}: {e}") | |
| continue | |
| logger.info(f" β Loaded {file_doc_count} documents from {os.path.basename(json_file)}") | |
| total_processed += file_doc_count | |
| except Exception as e: | |
| logger.error(f"Error processing file {json_file}: {e}") | |
| continue | |
| if not documents: | |
| raise ValueError("No valid documents were loaded from any JSON files.") | |
| logger.info(f"β Total loaded: {len(documents)} documents from {len(json_files)} JSON files.") | |
| return documents | |
| def prepare_environment_and_retriever( | |
| chunks_directory: str = "./data/", | |
| model_name: str = "intfloat/multilingual-e5-base", | |
| collection_name: str = "Guide_2023_e5_multilingual", | |
| persist_directory: str = "chroma_db_multilingual", | |
| k_vector: int = 6, | |
| k_sparse: int = 2, | |
| ensemble_weights: List[float] = [0.5, 0.5], | |
| llm_model_name: str = "gemini-2.0-flash-exp", | |
| num_query_variations: int = 5, | |
| include_original_query: bool = True, | |
| temperature: float = 0.1 | |
| ): | |
| """ | |
| Prepare the complete retrieval environment with MultiQueryRetriever. | |
| Args: | |
| chunks_directory: Directory containing JSON files with document chunks | |
| model_name: HuggingFace embedding model name | |
| collection_name: Chroma collection name | |
| persist_directory: Directory to persist Chroma database | |
| k_vector: Number of documents to retrieve from vector search | |
| k_sparse: Number of documents to retrieve from BM25 search | |
| ensemble_weights: Weights for ensemble retriever [vector, sparse] | |
| llm_model_name: Google Gemini model name for query expansion | |
| num_query_variations: Number of query variations to generate | |
| include_original_query: Whether to include original query in search | |
| temperature: LLM temperature for query generation | |
| Returns: | |
| MultiQueryRetriever: Configured retriever ready for use | |
| """ | |
| # Validate environment | |
| validate_environment() | |
| # Load documents | |
| documents = load_documents_from_json(chunks_directory) | |
| # Create embedding function | |
| logger.info(f"Creating embeddings with model: {model_name}") | |
| embedding_function = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| ) | |
| # Create or load vector store | |
| logger.info("Creating/loading vector store...") | |
| try: | |
| # Try to load existing vectorstore first | |
| if os.path.exists(persist_directory): | |
| vectorstore = Chroma( | |
| collection_name=collection_name, | |
| embedding_function=embedding_function, | |
| persist_directory=persist_directory | |
| ) | |
| logger.info("β Loaded existing vector store.") | |
| else: | |
| # Create new vectorstore | |
| vectorstore = Chroma.from_documents( | |
| documents=documents, | |
| embedding=embedding_function, | |
| collection_name=collection_name, | |
| persist_directory=persist_directory | |
| ) | |
| logger.info("β Created new vector store with multilingual embeddings.") | |
| except Exception as e: | |
| logger.warning(f"Error with persistent storage: {e}. Creating in-memory store.") | |
| vectorstore = Chroma.from_documents( | |
| documents=documents, | |
| embedding=embedding_function, | |
| collection_name=collection_name | |
| ) | |
| # Create base retrievers | |
| logger.info("Setting up retrievers...") | |
| # Vector retriever | |
| vector_retriever = vectorstore.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": k_vector} | |
| ) | |
| # BM25 (sparse) retriever | |
| bm25_retriever = BM25Retriever.from_documents(documents) | |
| bm25_retriever.k = k_sparse | |
| # Ensemble retriever (combining vector + sparse search) | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[vector_retriever, bm25_retriever], | |
| weights=ensemble_weights | |
| ) | |
| logger.info(f"β Ensemble retriever created with weights: {ensemble_weights}") | |
| # Language model for multi-query expansion | |
| logger.info(f"Initializing LLM: {llm_model_name}") | |
| try: | |
| llm = GoogleGenerativeAI( | |
| model=llm_model_name, | |
| google_api_key=os.getenv("GOOGLE_API_KEY"), | |
| temperature=temperature, | |
| max_output_tokens=1000 # Reasonable limit for query generation | |
| ) | |
| # Test the LLM with a simple call | |
| test_response = llm.invoke("Generate a simple test query about artificial intelligence.") | |
| logger.info("β LLM connection verified.") | |
| except Exception as e: | |
| logger.error(f"Error initializing LLM: {e}") | |
| raise | |
| # Create MultiQueryRetriever with custom configuration | |
| logger.info("Creating MultiQueryRetriever...") | |
| try: | |
| multi_query_retriever = create_custom_multi_query_retriever( | |
| base_retriever=ensemble_retriever, | |
| llm=llm, | |
| num_queries=num_query_variations, | |
| include_original=include_original_query | |
| ) | |
| logger.info(f"β MultiQueryRetriever ready:") | |
| logger.info(f" - Vector search: top-{k_vector}") | |
| logger.info(f" - Sparse search: top-{k_sparse}") | |
| logger.info(f" - Ensemble weights: {ensemble_weights}") | |
| logger.info(f" - Query variations: {num_query_variations}") | |
| logger.info(f" - Include original: {include_original_query}") | |
| return multi_query_retriever | |
| except Exception as e: | |
| logger.error(f"Error creating MultiQueryRetriever: {e}") | |
| logger.info("Falling back to ensemble retriever without query expansion.") | |
| return ensemble_retriever |