Spaces:
Sleeping
Sleeping
File size: 11,577 Bytes
7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 7f51074 c23c6b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
import json
import os
import glob
from typing import List, Optional
from dotenv import load_dotenv
import logging
# Load environment variables from .env file
load_dotenv()
from langchain_core.documents import Document
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_google_genai import GoogleGenerativeAI
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class LineListOutputParser(BaseOutputParser[List[str]]):
"""Custom output parser for a list of lines with better error handling."""
def parse(self, text: str) -> List[str]:
"""Parse the LLM output into a list of queries."""
try:
lines = text.strip().split("\n")
# Remove empty lines and clean up
cleaned_lines = []
for line in lines:
cleaned = line.strip()
if cleaned and not cleaned.startswith("#") and len(cleaned) > 5:
# Remove numbering if present (e.g., "1. ", "- ", etc.)
if cleaned[0].isdigit() and ". " in cleaned:
cleaned = cleaned.split(". ", 1)[1]
elif cleaned.startswith("- "):
cleaned = cleaned[2:]
cleaned_lines.append(cleaned)
# Ensure we have at least one query
if not cleaned_lines:
cleaned_lines = [text.strip()]
return cleaned_lines
except Exception as e:
logger.warning(f"Error parsing output: {e}. Returning original text.")
return [text.strip()] if text.strip() else [""]
def create_custom_multi_query_retriever(
base_retriever,
llm,
num_queries: int = 5,
include_original: bool = True
):
"""Create a custom MultiQueryRetriever with improved prompt."""
# Custom prompt template for better query generation
# query_prompt = PromptTemplate(
# input_variables=["question"],
# template="""You are an AI assistant specialized in generating diverse search queries.
# Your task is to generate {num_queries} different versions of the given user question to retrieve relevant documents from a knowledge base.
# Guidelines:
# - Create variations that capture different aspects and perspectives of the question
# - Use synonyms and alternative phrasings
# - Consider different levels of specificity (broader and narrower)
# - Focus on the core intent while varying the expression
# - Each query should be a complete, well-formed question or statement
# Original question: {question}
# Generate {num_queries} alternative queries (one per line):""".replace("{num_queries}", str(num_queries))
# )
# Create the MultiQueryRetriever with custom components
multi_query_retriever = MultiQueryRetriever.from_llm(
retriever=base_retriever,
llm=llm,
include_original=include_original
)
# # Override the output parser
# multi_query_retriever.output_parser = LineListOutputParser()
return multi_query_retriever
def validate_environment():
"""Validate that required environment variables are set."""
required_vars = ["GOOGLE_API_KEY"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
raise ValueError(f"Missing required environment variables: {missing_vars}")
logger.info("β
Environment variables validated.")
def load_documents_from_json(chunks_directory: str) -> List[Document]:
"""Load documents from JSON files with better error handling."""
json_files = glob.glob(os.path.join(chunks_directory, "*.json"))
if not json_files:
raise ValueError(f"No JSON files found in directory: {chunks_directory}")
logger.info(f"Found {len(json_files)} JSON files: {[os.path.basename(f) for f in json_files]}")
documents = []
total_processed = 0
for json_file in json_files:
try:
logger.info(f"Processing: {os.path.basename(json_file)}")
with open(json_file, "r", encoding="utf-8") as f:
chunks_data = json.load(f)
file_doc_count = 0
for element in chunks_data:
try:
text = element.get("text", "").strip()
if not text: # Skip empty text
continue
metadata = {
"source": element.get("filename", "unknown"),
"filetype": element.get("filetype", "unknown"),
"element_id": element.get("element_id", "unknown"),
"json_source": os.path.basename(json_file)
}
# Add table-specific metadata if present
if element.get("type") == "TableElement" and element.get("table_text_as_html"):
metadata["table_text_as_html"] = element["table_text_as_html"]
# metadata["element_type"] = "table"
else:
metadata["element_type"] = element.get("type", "text")
doc = Document(page_content=text, metadata=metadata)
documents.append(doc)
file_doc_count += 1
except Exception as e:
logger.warning(f"Error processing element in {json_file}: {e}")
continue
logger.info(f" β Loaded {file_doc_count} documents from {os.path.basename(json_file)}")
total_processed += file_doc_count
except Exception as e:
logger.error(f"Error processing file {json_file}: {e}")
continue
if not documents:
raise ValueError("No valid documents were loaded from any JSON files.")
logger.info(f"β
Total loaded: {len(documents)} documents from {len(json_files)} JSON files.")
return documents
def prepare_environment_and_retriever(
chunks_directory: str = "./data/",
model_name: str = "intfloat/multilingual-e5-base",
collection_name: str = "Guide_2023_e5_multilingual",
persist_directory: str = "chroma_db_multilingual",
k_vector: int = 6,
k_sparse: int = 2,
ensemble_weights: List[float] = [0.5, 0.5],
llm_model_name: str = "gemini-2.0-flash-exp",
num_query_variations: int = 5,
include_original_query: bool = True,
temperature: float = 0.1
):
"""
Prepare the complete retrieval environment with MultiQueryRetriever.
Args:
chunks_directory: Directory containing JSON files with document chunks
model_name: HuggingFace embedding model name
collection_name: Chroma collection name
persist_directory: Directory to persist Chroma database
k_vector: Number of documents to retrieve from vector search
k_sparse: Number of documents to retrieve from BM25 search
ensemble_weights: Weights for ensemble retriever [vector, sparse]
llm_model_name: Google Gemini model name for query expansion
num_query_variations: Number of query variations to generate
include_original_query: Whether to include original query in search
temperature: LLM temperature for query generation
Returns:
MultiQueryRetriever: Configured retriever ready for use
"""
# Validate environment
validate_environment()
# Load documents
documents = load_documents_from_json(chunks_directory)
# Create embedding function
logger.info(f"Creating embeddings with model: {model_name}")
embedding_function = HuggingFaceEmbeddings(
model_name=model_name,
)
# Create or load vector store
logger.info("Creating/loading vector store...")
try:
# Try to load existing vectorstore first
if os.path.exists(persist_directory):
vectorstore = Chroma(
collection_name=collection_name,
embedding_function=embedding_function,
persist_directory=persist_directory
)
logger.info("β
Loaded existing vector store.")
else:
# Create new vectorstore
vectorstore = Chroma.from_documents(
documents=documents,
embedding=embedding_function,
collection_name=collection_name,
persist_directory=persist_directory
)
logger.info("β
Created new vector store with multilingual embeddings.")
except Exception as e:
logger.warning(f"Error with persistent storage: {e}. Creating in-memory store.")
vectorstore = Chroma.from_documents(
documents=documents,
embedding=embedding_function,
collection_name=collection_name
)
# Create base retrievers
logger.info("Setting up retrievers...")
# Vector retriever
vector_retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": k_vector}
)
# BM25 (sparse) retriever
bm25_retriever = BM25Retriever.from_documents(documents)
bm25_retriever.k = k_sparse
# Ensemble retriever (combining vector + sparse search)
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=ensemble_weights
)
logger.info(f"β
Ensemble retriever created with weights: {ensemble_weights}")
# Language model for multi-query expansion
logger.info(f"Initializing LLM: {llm_model_name}")
try:
llm = GoogleGenerativeAI(
model=llm_model_name,
google_api_key=os.getenv("GOOGLE_API_KEY"),
temperature=temperature,
max_output_tokens=1000 # Reasonable limit for query generation
)
# Test the LLM with a simple call
test_response = llm.invoke("Generate a simple test query about artificial intelligence.")
logger.info("β
LLM connection verified.")
except Exception as e:
logger.error(f"Error initializing LLM: {e}")
raise
# Create MultiQueryRetriever with custom configuration
logger.info("Creating MultiQueryRetriever...")
try:
multi_query_retriever = create_custom_multi_query_retriever(
base_retriever=ensemble_retriever,
llm=llm,
num_queries=num_query_variations,
include_original=include_original_query
)
logger.info(f"β
MultiQueryRetriever ready:")
logger.info(f" - Vector search: top-{k_vector}")
logger.info(f" - Sparse search: top-{k_sparse}")
logger.info(f" - Ensemble weights: {ensemble_weights}")
logger.info(f" - Query variations: {num_query_variations}")
logger.info(f" - Include original: {include_original_query}")
return multi_query_retriever
except Exception as e:
logger.error(f"Error creating MultiQueryRetriever: {e}")
logger.info("Falling back to ensemble retriever without query expansion.")
return ensemble_retriever |