Simple_RAG / prepare_env.py
Zeggai Abdellah
update the system
c23c6b4
import json
import os
import glob
from typing import List, Optional
from dotenv import load_dotenv
import logging
# Load environment variables from .env file
load_dotenv()
from langchain_core.documents import Document
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_google_genai import GoogleGenerativeAI
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LineListOutputParser(BaseOutputParser[List[str]]):
"""Custom output parser for a list of lines with better error handling."""
def parse(self, text: str) -> List[str]:
"""Parse the LLM output into a list of queries."""
try:
lines = text.strip().split("\n")
# Remove empty lines and clean up
cleaned_lines = []
for line in lines:
cleaned = line.strip()
if cleaned and not cleaned.startswith("#") and len(cleaned) > 5:
# Remove numbering if present (e.g., "1. ", "- ", etc.)
if cleaned[0].isdigit() and ". " in cleaned:
cleaned = cleaned.split(". ", 1)[1]
elif cleaned.startswith("- "):
cleaned = cleaned[2:]
cleaned_lines.append(cleaned)
# Ensure we have at least one query
if not cleaned_lines:
cleaned_lines = [text.strip()]
return cleaned_lines
except Exception as e:
logger.warning(f"Error parsing output: {e}. Returning original text.")
return [text.strip()] if text.strip() else [""]
def create_custom_multi_query_retriever(
base_retriever,
llm,
num_queries: int = 5,
include_original: bool = True
):
"""Create a custom MultiQueryRetriever with improved prompt."""
# Custom prompt template for better query generation
# query_prompt = PromptTemplate(
# input_variables=["question"],
# template="""You are an AI assistant specialized in generating diverse search queries.
# Your task is to generate {num_queries} different versions of the given user question to retrieve relevant documents from a knowledge base.
# Guidelines:
# - Create variations that capture different aspects and perspectives of the question
# - Use synonyms and alternative phrasings
# - Consider different levels of specificity (broader and narrower)
# - Focus on the core intent while varying the expression
# - Each query should be a complete, well-formed question or statement
# Original question: {question}
# Generate {num_queries} alternative queries (one per line):""".replace("{num_queries}", str(num_queries))
# )
# Create the MultiQueryRetriever with custom components
multi_query_retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=llm,
include_original=include_original
)
# # Override the output parser
# multi_query_retriever.output_parser = LineListOutputParser()
return multi_query_retriever
def validate_environment():
"""Validate that required environment variables are set."""
required_vars = ["GOOGLE_API_KEY"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
raise ValueError(f"Missing required environment variables: {missing_vars}")
logger.info("βœ… Environment variables validated.")
def load_documents_from_json(chunks_directory: str) -> List[Document]:
"""Load documents from JSON files with better error handling."""
json_files = glob.glob(os.path.join(chunks_directory, "*.json"))
if not json_files:
raise ValueError(f"No JSON files found in directory: {chunks_directory}")
logger.info(f"Found {len(json_files)} JSON files: {[os.path.basename(f) for f in json_files]}")
documents = []
total_processed = 0
for json_file in json_files:
try:
logger.info(f"Processing: {os.path.basename(json_file)}")
with open(json_file, "r", encoding="utf-8") as f:
chunks_data = json.load(f)
file_doc_count = 0
for element in chunks_data:
try:
text = element.get("text", "").strip()
if not text: # Skip empty text
continue
metadata = {
"source": element.get("filename", "unknown"),
"filetype": element.get("filetype", "unknown"),
"element_id": element.get("element_id", "unknown"),
"json_source": os.path.basename(json_file)
}
# Add table-specific metadata if present
if element.get("type") == "TableElement" and element.get("table_text_as_html"):
metadata["table_text_as_html"] = element["table_text_as_html"]
# metadata["element_type"] = "table"
else:
metadata["element_type"] = element.get("type", "text")
doc = Document(page_content=text, metadata=metadata)
documents.append(doc)
file_doc_count += 1
except Exception as e:
logger.warning(f"Error processing element in {json_file}: {e}")
continue
logger.info(f" β†’ Loaded {file_doc_count} documents from {os.path.basename(json_file)}")
total_processed += file_doc_count
except Exception as e:
logger.error(f"Error processing file {json_file}: {e}")
continue
if not documents:
raise ValueError("No valid documents were loaded from any JSON files.")
logger.info(f"βœ… Total loaded: {len(documents)} documents from {len(json_files)} JSON files.")
return documents
def prepare_environment_and_retriever(
chunks_directory: str = "./data/",
model_name: str = "intfloat/multilingual-e5-base",
collection_name: str = "Guide_2023_e5_multilingual",
persist_directory: str = "chroma_db_multilingual",
k_vector: int = 6,
k_sparse: int = 2,
ensemble_weights: List[float] = [0.5, 0.5],
llm_model_name: str = "gemini-2.0-flash-exp",
num_query_variations: int = 5,
include_original_query: bool = True,
temperature: float = 0.1
):
"""
Prepare the complete retrieval environment with MultiQueryRetriever.
Args:
chunks_directory: Directory containing JSON files with document chunks
model_name: HuggingFace embedding model name
collection_name: Chroma collection name
persist_directory: Directory to persist Chroma database
k_vector: Number of documents to retrieve from vector search
k_sparse: Number of documents to retrieve from BM25 search
ensemble_weights: Weights for ensemble retriever [vector, sparse]
llm_model_name: Google Gemini model name for query expansion
num_query_variations: Number of query variations to generate
include_original_query: Whether to include original query in search
temperature: LLM temperature for query generation
Returns:
MultiQueryRetriever: Configured retriever ready for use
"""
# Validate environment
validate_environment()
# Load documents
documents = load_documents_from_json(chunks_directory)
# Create embedding function
logger.info(f"Creating embeddings with model: {model_name}")
embedding_function = HuggingFaceEmbeddings(
model_name=model_name,
)
# Create or load vector store
logger.info("Creating/loading vector store...")
try:
# Try to load existing vectorstore first
if os.path.exists(persist_directory):
vectorstore = Chroma(
collection_name=collection_name,
embedding_function=embedding_function,
persist_directory=persist_directory
)
logger.info("βœ… Loaded existing vector store.")
else:
# Create new vectorstore
vectorstore = Chroma.from_documents(
documents=documents,
embedding=embedding_function,
collection_name=collection_name,
persist_directory=persist_directory
)
logger.info("βœ… Created new vector store with multilingual embeddings.")
except Exception as e:
logger.warning(f"Error with persistent storage: {e}. Creating in-memory store.")
vectorstore = Chroma.from_documents(
documents=documents,
embedding=embedding_function,
collection_name=collection_name
)
# Create base retrievers
logger.info("Setting up retrievers...")
# Vector retriever
vector_retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": k_vector}
)
# BM25 (sparse) retriever
bm25_retriever = BM25Retriever.from_documents(documents)
bm25_retriever.k = k_sparse
# Ensemble retriever (combining vector + sparse search)
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=ensemble_weights
)
logger.info(f"βœ… Ensemble retriever created with weights: {ensemble_weights}")
# Language model for multi-query expansion
logger.info(f"Initializing LLM: {llm_model_name}")
try:
llm = GoogleGenerativeAI(
model=llm_model_name,
google_api_key=os.getenv("GOOGLE_API_KEY"),
temperature=temperature,
max_output_tokens=1000 # Reasonable limit for query generation
)
# Test the LLM with a simple call
test_response = llm.invoke("Generate a simple test query about artificial intelligence.")
logger.info("βœ… LLM connection verified.")
except Exception as e:
logger.error(f"Error initializing LLM: {e}")
raise
# Create MultiQueryRetriever with custom configuration
logger.info("Creating MultiQueryRetriever...")
try:
multi_query_retriever = create_custom_multi_query_retriever(
base_retriever=ensemble_retriever,
llm=llm,
num_queries=num_query_variations,
include_original=include_original_query
)
logger.info(f"βœ… MultiQueryRetriever ready:")
logger.info(f" - Vector search: top-{k_vector}")
logger.info(f" - Sparse search: top-{k_sparse}")
logger.info(f" - Ensemble weights: {ensemble_weights}")
logger.info(f" - Query variations: {num_query_variations}")
logger.info(f" - Include original: {include_original_query}")
return multi_query_retriever
except Exception as e:
logger.error(f"Error creating MultiQueryRetriever: {e}")
logger.info("Falling back to ensemble retriever without query expansion.")
return ensemble_retriever