File size: 11,577 Bytes
7f51074
 
c23c6b4
 
7f51074
c23c6b4
7f51074
 
 
c23c6b4
7f51074
c23c6b4
 
7f51074
 
 
 
 
 
 
c23c6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f51074
c23c6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f51074
c23c6b4
 
 
 
 
 
 
 
 
7f51074
c23c6b4
7f51074
c23c6b4
 
 
 
 
 
7f51074
c23c6b4
 
 
 
 
 
 
 
 
 
 
 
 
7f51074
c23c6b4
 
 
 
 
 
7f51074
c23c6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f51074
c23c6b4
7f51074
 
c23c6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f51074
c23c6b4
 
 
 
 
7f51074
 
 
 
c23c6b4
7f51074
 
 
 
 
c23c6b4
 
7f51074
c23c6b4
7f51074
 
c23c6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f51074
c23c6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
import json
import os
import glob
from typing import List, Optional
from dotenv import load_dotenv
import logging

# Load environment variables from .env file
load_dotenv()

from langchain_core.documents import Document
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_community.vectorstores import Chroma
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.retrievers import BM25Retriever
from langchain.retrievers import EnsembleRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_google_genai import GoogleGenerativeAI

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class LineListOutputParser(BaseOutputParser[List[str]]):
    """Custom output parser for a list of lines with better error handling."""
    
    def parse(self, text: str) -> List[str]:
        """Parse the LLM output into a list of queries."""
        try:
            lines = text.strip().split("\n")
            # Remove empty lines and clean up
            cleaned_lines = []
            for line in lines:
                cleaned = line.strip()
                if cleaned and not cleaned.startswith("#") and len(cleaned) > 5:
                    # Remove numbering if present (e.g., "1. ", "- ", etc.)
                    if cleaned[0].isdigit() and ". " in cleaned:
                        cleaned = cleaned.split(". ", 1)[1]
                    elif cleaned.startswith("- "):
                        cleaned = cleaned[2:]
                    cleaned_lines.append(cleaned)
            
            # Ensure we have at least one query
            if not cleaned_lines:
                cleaned_lines = [text.strip()]
                
            return cleaned_lines
        except Exception as e:
            logger.warning(f"Error parsing output: {e}. Returning original text.")
            return [text.strip()] if text.strip() else [""]

def create_custom_multi_query_retriever(
    base_retriever,
    llm,
    num_queries: int = 5,
    include_original: bool = True
):
    """Create a custom MultiQueryRetriever with improved prompt."""
    
    # Custom prompt template for better query generation
#     query_prompt = PromptTemplate(
#         input_variables=["question"],
#         template="""You are an AI assistant specialized in generating diverse search queries. 
# Your task is to generate {num_queries} different versions of the given user question to retrieve relevant documents from a knowledge base.

# Guidelines:
# - Create variations that capture different aspects and perspectives of the question
# - Use synonyms and alternative phrasings
# - Consider different levels of specificity (broader and narrower)
# - Focus on the core intent while varying the expression
# - Each query should be a complete, well-formed question or statement

# Original question: {question}

# Generate {num_queries} alternative queries (one per line):""".replace("{num_queries}", str(num_queries))
#     )
    
    # Create the MultiQueryRetriever with custom components
    multi_query_retriever = MultiQueryRetriever.from_llm(
        retriever=base_retriever,
        llm=llm,
        include_original=include_original
    )
    
    # # Override the output parser
    # multi_query_retriever.output_parser = LineListOutputParser()
    
    return multi_query_retriever

def validate_environment():
    """Validate that required environment variables are set."""
    required_vars = ["GOOGLE_API_KEY"]
    missing_vars = [var for var in required_vars if not os.getenv(var)]
    
    if missing_vars:
        raise ValueError(f"Missing required environment variables: {missing_vars}")
    
    logger.info("βœ… Environment variables validated.")

def load_documents_from_json(chunks_directory: str) -> List[Document]:
    """Load documents from JSON files with better error handling."""
    json_files = glob.glob(os.path.join(chunks_directory, "*.json"))
    
    if not json_files:
        raise ValueError(f"No JSON files found in directory: {chunks_directory}")
    
    logger.info(f"Found {len(json_files)} JSON files: {[os.path.basename(f) for f in json_files]}")
    
    documents = []
    total_processed = 0

    for json_file in json_files:
        try:
            logger.info(f"Processing: {os.path.basename(json_file)}")
            
            with open(json_file, "r", encoding="utf-8") as f:
                chunks_data = json.load(f)

            file_doc_count = 0
            for element in chunks_data:
                try:
                    text = element.get("text", "").strip()
                    if not text:  # Skip empty text
                        continue
                        
                    metadata = {
                        "source": element.get("filename", "unknown"),
                        "filetype": element.get("filetype", "unknown"),
                        "element_id": element.get("element_id", "unknown"),
                        "json_source": os.path.basename(json_file)
                    }

                    # Add table-specific metadata if present
                    if element.get("type") == "TableElement" and element.get("table_text_as_html"):
                        metadata["table_text_as_html"] = element["table_text_as_html"]
                        # metadata["element_type"] = "table"
                    else:
                        metadata["element_type"] = element.get("type", "text")

                    doc = Document(page_content=text, metadata=metadata)
                    documents.append(doc)
                    file_doc_count += 1
                    
                except Exception as e:
                    logger.warning(f"Error processing element in {json_file}: {e}")
                    continue
            
            logger.info(f"  β†’ Loaded {file_doc_count} documents from {os.path.basename(json_file)}")
            total_processed += file_doc_count
            
        except Exception as e:
            logger.error(f"Error processing file {json_file}: {e}")
            continue

    if not documents:
        raise ValueError("No valid documents were loaded from any JSON files.")
    
    logger.info(f"βœ… Total loaded: {len(documents)} documents from {len(json_files)} JSON files.")
    return documents

def prepare_environment_and_retriever(
    chunks_directory: str = "./data/",
    model_name: str = "intfloat/multilingual-e5-base",
    collection_name: str = "Guide_2023_e5_multilingual",
    persist_directory: str = "chroma_db_multilingual",
    k_vector: int = 6,
    k_sparse: int = 2,
    ensemble_weights: List[float] = [0.5, 0.5],
    llm_model_name: str = "gemini-2.0-flash-exp",
    num_query_variations: int = 5,
    include_original_query: bool = True,
    temperature: float = 0.1
):
    """
    Prepare the complete retrieval environment with MultiQueryRetriever.
    
    Args:
        chunks_directory: Directory containing JSON files with document chunks
        model_name: HuggingFace embedding model name
        collection_name: Chroma collection name
        persist_directory: Directory to persist Chroma database
        k_vector: Number of documents to retrieve from vector search
        k_sparse: Number of documents to retrieve from BM25 search
        ensemble_weights: Weights for ensemble retriever [vector, sparse]
        llm_model_name: Google Gemini model name for query expansion
        num_query_variations: Number of query variations to generate
        include_original_query: Whether to include original query in search
        temperature: LLM temperature for query generation
    
    Returns:
        MultiQueryRetriever: Configured retriever ready for use
    """
    
    # Validate environment
    validate_environment()
    
    # Load documents
    documents = load_documents_from_json(chunks_directory)
    
    # Create embedding function
    logger.info(f"Creating embeddings with model: {model_name}")
    embedding_function = HuggingFaceEmbeddings(
        model_name=model_name,
    )

    # Create or load vector store
    logger.info("Creating/loading vector store...")
    try:
        # Try to load existing vectorstore first
        if os.path.exists(persist_directory):
            vectorstore = Chroma(
                collection_name=collection_name,
                embedding_function=embedding_function,
                persist_directory=persist_directory
            )
            logger.info("βœ… Loaded existing vector store.")
        else:
            # Create new vectorstore
            vectorstore = Chroma.from_documents(
                documents=documents,
                embedding=embedding_function,
                collection_name=collection_name,
                persist_directory=persist_directory
            )
            logger.info("βœ… Created new vector store with multilingual embeddings.")
    except Exception as e:
        logger.warning(f"Error with persistent storage: {e}. Creating in-memory store.")
        vectorstore = Chroma.from_documents(
            documents=documents,
            embedding=embedding_function,
            collection_name=collection_name
        )

    # Create base retrievers
    logger.info("Setting up retrievers...")
    
    # Vector retriever
    vector_retriever = vectorstore.as_retriever(
        search_type="similarity",
        search_kwargs={"k": k_vector}
    )

    # BM25 (sparse) retriever
    bm25_retriever = BM25Retriever.from_documents(documents)
    bm25_retriever.k = k_sparse

    # Ensemble retriever (combining vector + sparse search)
    ensemble_retriever = EnsembleRetriever(
        retrievers=[vector_retriever, bm25_retriever],
        weights=ensemble_weights
    )
    logger.info(f"βœ… Ensemble retriever created with weights: {ensemble_weights}")

    # Language model for multi-query expansion
    logger.info(f"Initializing LLM: {llm_model_name}")
    try:
        llm = GoogleGenerativeAI(
            model=llm_model_name,
            google_api_key=os.getenv("GOOGLE_API_KEY"),
            temperature=temperature,
            max_output_tokens=1000  # Reasonable limit for query generation
        )
        
        # Test the LLM with a simple call
        test_response = llm.invoke("Generate a simple test query about artificial intelligence.")
        logger.info("βœ… LLM connection verified.")
        
    except Exception as e:
        logger.error(f"Error initializing LLM: {e}")
        raise

    # Create MultiQueryRetriever with custom configuration
    logger.info("Creating MultiQueryRetriever...")
    try:
        multi_query_retriever = create_custom_multi_query_retriever(
            base_retriever=ensemble_retriever,
            llm=llm,
            num_queries=num_query_variations,
            include_original=include_original_query
        )
        
        logger.info(f"βœ… MultiQueryRetriever ready:")
        logger.info(f"  - Vector search: top-{k_vector}")
        logger.info(f"  - Sparse search: top-{k_sparse}")
        logger.info(f"  - Ensemble weights: {ensemble_weights}")
        logger.info(f"  - Query variations: {num_query_variations}")
        logger.info(f"  - Include original: {include_original_query}")
        
        return multi_query_retriever
        
    except Exception as e:
        logger.error(f"Error creating MultiQueryRetriever: {e}")
        logger.info("Falling back to ensemble retriever without query expansion.")
        return ensemble_retriever