Spaces:
Sleeping
Sleeping
Commit ·
b7e0e53
0
Parent(s):
Initial commit for DocChat
Browse files- .env +15 -0
- agents/__init.py__ +5 -0
- agents/relevance_checker.py +89 -0
- agents/research_agent.py +74 -0
- agents/verification_agent.py +134 -0
- agents/workflow.py +132 -0
- app.py +171 -0
- config/__init.py__ +4 -0
- config/__pycache__/llm_config.cpython-313.pyc +0 -0
- config/constants.py +8 -0
- config/llm_config.py +142 -0
- config/settings.py +38 -0
- config/test.py +49 -0
- document_processor/__init.py__ +3 -0
- document_processor/file_handler.py +92 -0
- requirements.txt +67 -0
- retriever/__init.py__ +3 -0
- retriever/builder.py +55 -0
- utils/__init.py__ +3 -0
- utils/logging.py +8 -0
.env
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLM Configuration
|
| 2 |
+
LLM_PROVIDER=google # google or openai
|
| 3 |
+
|
| 4 |
+
# API Keys
|
| 5 |
+
GOOGLE_API_KEY="AIzaSyCXbE6aDpC20WuQWZVR8ULA7LFOT9y6000"
|
| 6 |
+
OPENAI_API_KEY="your_openai_api_key_here"
|
| 7 |
+
|
| 8 |
+
# Database Settings
|
| 9 |
+
CHROMA_DB_PATH=./chroma_db
|
| 10 |
+
|
| 11 |
+
# Retrieval Settings
|
| 12 |
+
VECTOR_SEARCH_K=10
|
| 13 |
+
|
| 14 |
+
# Cache Settings
|
| 15 |
+
CACHE_EXPIRE_DAYS=7
|
agents/__init.py__
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .research_agent import ResearchAgent
|
| 2 |
+
from .verification_agent import VerificationAgent
|
| 3 |
+
from .workflow import AgentWorkflow
|
| 4 |
+
|
| 5 |
+
__all__ = ["ResearchAgent", "VerificationAgent", "AgentWorkflow"]
|
agents/relevance_checker.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
from langchain.schema import BaseRetriever
|
| 3 |
+
from langchain.prompts import ChatPromptTemplate
|
| 4 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 5 |
+
from config.llm_config import llm_config
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class RelevanceChecker:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
"""Initialize the relevance checker with configurable LLM."""
|
| 13 |
+
logger.info("Initializing RelevanceChecker...")
|
| 14 |
+
|
| 15 |
+
# Get LLM from configuration
|
| 16 |
+
self.llm = llm_config.create_llm("relevance")
|
| 17 |
+
|
| 18 |
+
# Create prompt template
|
| 19 |
+
self.prompt_template = ChatPromptTemplate.from_messages([
|
| 20 |
+
("system", """You are an AI relevance checker between a user's question and provided document content.
|
| 21 |
+
|
| 22 |
+
Instructions:
|
| 23 |
+
- Classify how well the document content addresses the user's question.
|
| 24 |
+
- Respond with ONLY ONE of the following labels: CAN_ANSWER, PARTIAL, NO_MATCH.
|
| 25 |
+
- Do not include any additional text or explanation.
|
| 26 |
+
|
| 27 |
+
Label Definitions:
|
| 28 |
+
1) "CAN_ANSWER": The passages contain enough explicit information to fully answer the question.
|
| 29 |
+
2) "PARTIAL": The passages mention or discuss the question's topic but do not provide all the details needed for a complete answer.
|
| 30 |
+
3) "NO_MATCH": The passages do not discuss or mention the question's topic at all.
|
| 31 |
+
|
| 32 |
+
Important: If the passages mention or reference the topic or timeframe of the question in any way, even if incomplete, respond with "PARTIAL" instead of "NO_MATCH"."""),
|
| 33 |
+
("human", """Question: {question}
|
| 34 |
+
|
| 35 |
+
Passages:
|
| 36 |
+
{passages}
|
| 37 |
+
|
| 38 |
+
Respond ONLY with one of the following labels: CAN_ANSWER, PARTIAL, NO_MATCH""")
|
| 39 |
+
])
|
| 40 |
+
|
| 41 |
+
# Create chain
|
| 42 |
+
self.chain = self.prompt_template | self.llm | StrOutputParser()
|
| 43 |
+
|
| 44 |
+
logger.info("RelevanceChecker initialized successfully.")
|
| 45 |
+
|
| 46 |
+
def check(self, question: str, retriever: BaseRetriever, k: int = 3) -> str:
|
| 47 |
+
"""
|
| 48 |
+
Check relevance between question and retrieved documents.
|
| 49 |
+
|
| 50 |
+
Returns: "CAN_ANSWER", "PARTIAL", or "NO_MATCH".
|
| 51 |
+
"""
|
| 52 |
+
logger.debug(f"RelevanceChecker.check called with question='{question}' and k={k}")
|
| 53 |
+
|
| 54 |
+
# Retrieve document chunks
|
| 55 |
+
try:
|
| 56 |
+
top_docs = retriever.invoke(question)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
logger.error(f"Error retrieving documents: {e}")
|
| 59 |
+
return "NO_MATCH"
|
| 60 |
+
|
| 61 |
+
if not top_docs:
|
| 62 |
+
logger.debug("No documents returned from retriever.")
|
| 63 |
+
return "NO_MATCH"
|
| 64 |
+
|
| 65 |
+
# Combine the top k chunk texts
|
| 66 |
+
document_content = "\n\n".join(doc.page_content for doc in top_docs[:k])
|
| 67 |
+
logger.debug(f"Combined document content length: {len(document_content)} characters")
|
| 68 |
+
|
| 69 |
+
try:
|
| 70 |
+
# Get classification from LLM
|
| 71 |
+
response = self.chain.invoke({
|
| 72 |
+
"question": question,
|
| 73 |
+
"passages": document_content
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
# Clean and validate response
|
| 77 |
+
classification = response.strip().upper()
|
| 78 |
+
valid_labels = {"CAN_ANSWER", "PARTIAL", "NO_MATCH"}
|
| 79 |
+
|
| 80 |
+
if classification not in valid_labels:
|
| 81 |
+
logger.warning(f"Invalid classification received: '{classification}'. Defaulting to NO_MATCH.")
|
| 82 |
+
classification = "NO_MATCH"
|
| 83 |
+
|
| 84 |
+
logger.debug(f"Classification: {classification}")
|
| 85 |
+
return classification
|
| 86 |
+
|
| 87 |
+
except Exception as e:
|
| 88 |
+
logger.error(f"Error during relevance classification: {e}")
|
| 89 |
+
return "NO_MATCH"
|
agents/research_agent.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
from langchain.schema import Document
|
| 3 |
+
from langchain.prompts import ChatPromptTemplate
|
| 4 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 5 |
+
from config.llm_config import llm_config
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class ResearchAgent:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
"""
|
| 13 |
+
Initialize the research agent with configurable LLM.
|
| 14 |
+
"""
|
| 15 |
+
logger.info("Initializing ResearchAgent...")
|
| 16 |
+
|
| 17 |
+
# Get LLM from configuration
|
| 18 |
+
self.llm = llm_config.create_llm("research")
|
| 19 |
+
self.client = llm_config.create_direct_client()
|
| 20 |
+
|
| 21 |
+
# Create prompt template
|
| 22 |
+
self.prompt_template = ChatPromptTemplate.from_messages([
|
| 23 |
+
("system", """You are an AI assistant designed to provide precise and factual answers based on the given context.
|
| 24 |
+
|
| 25 |
+
Instructions:
|
| 26 |
+
- Answer the following question using only the provided context.
|
| 27 |
+
- Be clear, concise, and factual.
|
| 28 |
+
- Return as much information as you can get from the context.
|
| 29 |
+
- If the context doesn't contain enough information, say so explicitly.
|
| 30 |
+
- Do not add any information not present in the context.
|
| 31 |
+
- Format your answer in a clear, readable manner."""),
|
| 32 |
+
("human", """Question: {question}
|
| 33 |
+
|
| 34 |
+
Context:
|
| 35 |
+
{context}
|
| 36 |
+
|
| 37 |
+
Provide your answer below:""")
|
| 38 |
+
])
|
| 39 |
+
|
| 40 |
+
# Create chain
|
| 41 |
+
self.chain = self.prompt_template | self.llm | StrOutputParser()
|
| 42 |
+
|
| 43 |
+
logger.info("ResearchAgent initialized successfully.")
|
| 44 |
+
|
| 45 |
+
def generate(self, question: str, documents: List[Document]) -> Dict:
|
| 46 |
+
"""
|
| 47 |
+
Generate an initial answer using the provided documents.
|
| 48 |
+
"""
|
| 49 |
+
logger.info(f"ResearchAgent.generate called with question='{question}' and {len(documents)} documents.")
|
| 50 |
+
|
| 51 |
+
# Combine the document contents
|
| 52 |
+
context = "\n\n".join([doc.page_content for doc in documents])
|
| 53 |
+
logger.debug(f"Combined context length: {len(context)} characters.")
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
# Generate answer using LangChain chain
|
| 57 |
+
draft_answer = self.chain.invoke({
|
| 58 |
+
"question": question,
|
| 59 |
+
"context": context
|
| 60 |
+
})
|
| 61 |
+
|
| 62 |
+
logger.info(f"Generated answer successfully. Length: {len(draft_answer)} characters.")
|
| 63 |
+
|
| 64 |
+
return {
|
| 65 |
+
"draft_answer": draft_answer.strip(),
|
| 66 |
+
"context_used": context
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"Error during answer generation: {e}")
|
| 71 |
+
return {
|
| 72 |
+
"draft_answer": f"I cannot answer this question based on the provided documents. Error: {str(e)}",
|
| 73 |
+
"context_used": context
|
| 74 |
+
}
|
agents/verification_agent.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, List
|
| 2 |
+
from langchain.schema import Document
|
| 3 |
+
from langchain.prompts import ChatPromptTemplate
|
| 4 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 5 |
+
from config.llm_config import llm_config
|
| 6 |
+
import logging
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
|
| 11 |
+
class VerificationAgent:
|
| 12 |
+
def __init__(self):
|
| 13 |
+
"""
|
| 14 |
+
Initialize the verification agent with configurable LLM.
|
| 15 |
+
"""
|
| 16 |
+
logger.info("Initializing VerificationAgent...")
|
| 17 |
+
|
| 18 |
+
# Get LLM from configuration
|
| 19 |
+
self.llm = llm_config.create_llm("verification")
|
| 20 |
+
|
| 21 |
+
# Create prompt template for verification
|
| 22 |
+
self.prompt_template = ChatPromptTemplate.from_messages([
|
| 23 |
+
("system", """You are an AI assistant designed to verify the accuracy and relevance of answers based on provided context.
|
| 24 |
+
|
| 25 |
+
You MUST respond in the exact JSON format specified below.
|
| 26 |
+
|
| 27 |
+
Instructions:
|
| 28 |
+
- Verify the answer against the provided context.
|
| 29 |
+
- Check for:
|
| 30 |
+
1. Direct/indirect factual support (YES/NO)
|
| 31 |
+
2. Unsupported claims (list any if present)
|
| 32 |
+
3. Contradictions (list any if present)
|
| 33 |
+
4. Relevance to the question (YES/NO)
|
| 34 |
+
- Provide additional details or explanations where relevant.
|
| 35 |
+
- If there are no unsupported claims or contradictions, use empty lists.
|
| 36 |
+
- If there are no additional details, use an empty string.
|
| 37 |
+
|
| 38 |
+
JSON Response Format:
|
| 39 |
+
{
|
| 40 |
+
"supported": "YES" or "NO",
|
| 41 |
+
"unsupported_claims": ["claim1", "claim2", ...],
|
| 42 |
+
"contradictions": ["contradiction1", "contradiction2", ...],
|
| 43 |
+
"relevant": "YES" or "NO",
|
| 44 |
+
"additional_details": "string"
|
| 45 |
+
}"""),
|
| 46 |
+
("human", """Answer to verify: {answer}
|
| 47 |
+
|
| 48 |
+
Context:
|
| 49 |
+
{context}
|
| 50 |
+
|
| 51 |
+
Provide your verification in the specified JSON format:""")
|
| 52 |
+
])
|
| 53 |
+
|
| 54 |
+
# Create chain
|
| 55 |
+
self.chain = self.prompt_template | self.llm | StrOutputParser()
|
| 56 |
+
|
| 57 |
+
logger.info("VerificationAgent initialized successfully.")
|
| 58 |
+
|
| 59 |
+
def format_verification_report(self, verification: Dict) -> str:
|
| 60 |
+
"""
|
| 61 |
+
Format the verification report dictionary into a readable paragraph.
|
| 62 |
+
"""
|
| 63 |
+
supported = verification.get("supported", "NO")
|
| 64 |
+
unsupported_claims = verification.get("unsupported_claims", [])
|
| 65 |
+
contradictions = verification.get("contradictions", [])
|
| 66 |
+
relevant = verification.get("relevant", "NO")
|
| 67 |
+
additional_details = verification.get("additional_details", "")
|
| 68 |
+
|
| 69 |
+
report = f"**Supported:** {supported}\n"
|
| 70 |
+
|
| 71 |
+
if unsupported_claims:
|
| 72 |
+
report += f"**Unsupported Claims:** {', '.join(unsupported_claims)}\n"
|
| 73 |
+
else:
|
| 74 |
+
report += f"**Unsupported Claims:** None\n"
|
| 75 |
+
|
| 76 |
+
if contradictions:
|
| 77 |
+
report += f"**Contradictions:** {', '.join(contradictions)}\n"
|
| 78 |
+
else:
|
| 79 |
+
report += f"**Contradictions:** None\n"
|
| 80 |
+
|
| 81 |
+
report += f"**Relevant:** {relevant}\n"
|
| 82 |
+
|
| 83 |
+
if additional_details:
|
| 84 |
+
report += f"**Additional Details:** {additional_details}\n"
|
| 85 |
+
else:
|
| 86 |
+
report += f"**Additional Details:** None\n"
|
| 87 |
+
|
| 88 |
+
return report
|
| 89 |
+
|
| 90 |
+
def check(self, answer: str, documents: List[Document]) -> Dict:
|
| 91 |
+
"""
|
| 92 |
+
Verify the answer against the provided documents.
|
| 93 |
+
"""
|
| 94 |
+
logger.info(f"VerificationAgent.check called with answer length={len(answer)} and {len(documents)} documents.")
|
| 95 |
+
|
| 96 |
+
# Combine all document contents
|
| 97 |
+
context = "\n\n".join([doc.page_content for doc in documents])
|
| 98 |
+
logger.debug(f"Combined context length: {len(context)} characters.")
|
| 99 |
+
|
| 100 |
+
try:
|
| 101 |
+
# Get verification from LLM
|
| 102 |
+
response = self.chain.invoke({
|
| 103 |
+
"answer": answer,
|
| 104 |
+
"context": context
|
| 105 |
+
})
|
| 106 |
+
|
| 107 |
+
# Parse JSON response
|
| 108 |
+
try:
|
| 109 |
+
verification = json.loads(response)
|
| 110 |
+
except json.JSONDecodeError as e:
|
| 111 |
+
logger.error(f"Failed to parse JSON response: {e}")
|
| 112 |
+
verification = {
|
| 113 |
+
"supported": "NO",
|
| 114 |
+
"unsupported_claims": [],
|
| 115 |
+
"contradictions": [],
|
| 116 |
+
"relevant": "NO",
|
| 117 |
+
"additional_details": "Failed to parse verification response."
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
# Format report
|
| 121 |
+
verification_report = self.format_verification_report(verification)
|
| 122 |
+
logger.info("Verification completed successfully.")
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
"verification_report": verification_report,
|
| 126 |
+
"context_used": context
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Error during verification: {e}")
|
| 131 |
+
return {
|
| 132 |
+
"verification_report": f"**Error during verification:** {str(e)}",
|
| 133 |
+
"context_used": context
|
| 134 |
+
}
|
agents/workflow.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import StateGraph, END
|
| 2 |
+
from typing import TypedDict, List, Dict
|
| 3 |
+
from .research_agent import ResearchAgent
|
| 4 |
+
from .verification_agent import VerificationAgent
|
| 5 |
+
from .relevance_checker import RelevanceChecker
|
| 6 |
+
from langchain.schema import Document
|
| 7 |
+
from langchain.retrievers import EnsembleRetriever
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class AgentState(TypedDict):
|
| 13 |
+
question: str
|
| 14 |
+
documents: List[Document]
|
| 15 |
+
draft_answer: str
|
| 16 |
+
verification_report: str
|
| 17 |
+
is_relevant: bool
|
| 18 |
+
retriever: EnsembleRetriever
|
| 19 |
+
|
| 20 |
+
class AgentWorkflow:
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.researcher = ResearchAgent()
|
| 23 |
+
self.verifier = VerificationAgent()
|
| 24 |
+
self.relevance_checker = RelevanceChecker()
|
| 25 |
+
self.compiled_workflow = self.build_workflow() # Compile once during initialization
|
| 26 |
+
|
| 27 |
+
def build_workflow(self):
|
| 28 |
+
"""Create and compile the multi-agent workflow."""
|
| 29 |
+
workflow = StateGraph(AgentState)
|
| 30 |
+
|
| 31 |
+
# Add nodes
|
| 32 |
+
workflow.add_node("check_relevance", self._check_relevance_step)
|
| 33 |
+
workflow.add_node("research", self._research_step)
|
| 34 |
+
workflow.add_node("verify", self._verification_step)
|
| 35 |
+
|
| 36 |
+
# Define edges
|
| 37 |
+
workflow.set_entry_point("check_relevance")
|
| 38 |
+
workflow.add_conditional_edges(
|
| 39 |
+
"check_relevance",
|
| 40 |
+
self._decide_after_relevance_check,
|
| 41 |
+
{
|
| 42 |
+
"relevant": "research",
|
| 43 |
+
"irrelevant": END
|
| 44 |
+
}
|
| 45 |
+
)
|
| 46 |
+
workflow.add_edge("research", "verify")
|
| 47 |
+
workflow.add_conditional_edges(
|
| 48 |
+
"verify",
|
| 49 |
+
self._decide_next_step,
|
| 50 |
+
{
|
| 51 |
+
"re_research": "research",
|
| 52 |
+
"end": END
|
| 53 |
+
}
|
| 54 |
+
)
|
| 55 |
+
return workflow.compile()
|
| 56 |
+
|
| 57 |
+
def _check_relevance_step(self, state: AgentState) -> Dict:
|
| 58 |
+
retriever = state["retriever"]
|
| 59 |
+
classification = self.relevance_checker.check(
|
| 60 |
+
question=state["question"],
|
| 61 |
+
retriever=retriever,
|
| 62 |
+
k=20
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
if classification == "CAN_ANSWER":
|
| 66 |
+
# We have enough info to proceed
|
| 67 |
+
return {"is_relevant": True}
|
| 68 |
+
|
| 69 |
+
elif classification == "PARTIAL":
|
| 70 |
+
# There's partial coverage, but we can still proceed
|
| 71 |
+
return {
|
| 72 |
+
"is_relevant": True
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
else: # classification == "NO_MATCH"
|
| 76 |
+
return {
|
| 77 |
+
"is_relevant": False,
|
| 78 |
+
"draft_answer": "This question isn't related (or there's no data) for your query. Please ask another question relevant to the uploaded document(s)."
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _decide_after_relevance_check(self, state: AgentState) -> str:
|
| 83 |
+
decision = "relevant" if state["is_relevant"] else "irrelevant"
|
| 84 |
+
print(f"[DEBUG] _decide_after_relevance_check -> {decision}")
|
| 85 |
+
return decision
|
| 86 |
+
|
| 87 |
+
def full_pipeline(self, question: str, retriever: EnsembleRetriever):
|
| 88 |
+
try:
|
| 89 |
+
print(f"[DEBUG] Starting full_pipeline with question='{question}'")
|
| 90 |
+
documents = retriever.invoke(question)
|
| 91 |
+
logger.info(f"Retrieved {len(documents)} relevant documents (from .invoke)")
|
| 92 |
+
|
| 93 |
+
initial_state = AgentState(
|
| 94 |
+
question=question,
|
| 95 |
+
documents=documents,
|
| 96 |
+
draft_answer="",
|
| 97 |
+
verification_report="",
|
| 98 |
+
is_relevant=False,
|
| 99 |
+
retriever=retriever
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
final_state = self.compiled_workflow.invoke(initial_state)
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
"draft_answer": final_state["draft_answer"],
|
| 106 |
+
"verification_report": final_state["verification_report"]
|
| 107 |
+
}
|
| 108 |
+
except Exception as e:
|
| 109 |
+
logger.error(f"Workflow execution failed: {e}")
|
| 110 |
+
raise
|
| 111 |
+
|
| 112 |
+
def _research_step(self, state: AgentState) -> Dict:
|
| 113 |
+
print(f"[DEBUG] Entered _research_step with question='{state['question']}'")
|
| 114 |
+
result = self.researcher.generate(state["question"], state["documents"])
|
| 115 |
+
print("[DEBUG] Researcher returned draft answer.")
|
| 116 |
+
return {"draft_answer": result["draft_answer"]}
|
| 117 |
+
|
| 118 |
+
def _verification_step(self, state: AgentState) -> Dict:
|
| 119 |
+
print("[DEBUG] Entered _verification_step. Verifying the draft answer...")
|
| 120 |
+
result = self.verifier.check(state["draft_answer"], state["documents"])
|
| 121 |
+
print("[DEBUG] VerificationAgent returned a verification report.")
|
| 122 |
+
return {"verification_report": result["verification_report"]}
|
| 123 |
+
|
| 124 |
+
def _decide_next_step(self, state: AgentState) -> str:
|
| 125 |
+
verification_report = state["verification_report"]
|
| 126 |
+
print(f"[DEBUG] _decide_next_step with verification_report='{verification_report}'")
|
| 127 |
+
if "Supported: NO" in verification_report or "Relevant: NO" in verification_report:
|
| 128 |
+
logger.info("[DEBUG] Verification indicates re-research needed.")
|
| 129 |
+
return "re_research"
|
| 130 |
+
else:
|
| 131 |
+
logger.info("[DEBUG] Verification successful, ending workflow.")
|
| 132 |
+
return "end"
|
app.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import hashlib
|
| 3 |
+
from typing import List, Dict
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from document_processor.file_handler import DocumentProcessor
|
| 7 |
+
from retriever.builder import RetrieverBuilder
|
| 8 |
+
from agents.workflow import AgentWorkflow
|
| 9 |
+
from config import constants, settings
|
| 10 |
+
from utils.logging import logger
|
| 11 |
+
|
| 12 |
+
# -------------------------
|
| 13 |
+
# Example Data
|
| 14 |
+
# -------------------------
|
| 15 |
+
EXAMPLES = {
|
| 16 |
+
"Google 2024 Environmental Report": {
|
| 17 |
+
"question": "Retrieve the data center PUE efficiency values in Singapore 2nd facility in 2019 and 2022. Also retrieve regional average CFE in Asia pacific in 2023",
|
| 18 |
+
"file_paths": ["examples/google-2024-environmental-report.pdf"]
|
| 19 |
+
},
|
| 20 |
+
"DeepSeek-R1 Technical Report": {
|
| 21 |
+
"question": "Summarize DeepSeek-R1 model's performance evaluation on all coding tasks against OpenAI o1-mini model",
|
| 22 |
+
"file_paths": ["examples/DeepSeek Technical Report.pdf"]
|
| 23 |
+
}
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
# -------------------------
|
| 27 |
+
# Utils
|
| 28 |
+
# -------------------------
|
| 29 |
+
def _get_file_hashes(uploaded_files: List) -> frozenset:
|
| 30 |
+
"""Generate SHA-256 hashes for uploaded files."""
|
| 31 |
+
hashes = set()
|
| 32 |
+
for file in uploaded_files:
|
| 33 |
+
with open(file.name, "rb") as f:
|
| 34 |
+
hashes.add(hashlib.sha256(f.read()).hexdigest())
|
| 35 |
+
return frozenset(hashes)
|
| 36 |
+
|
| 37 |
+
# -------------------------
|
| 38 |
+
# Main App
|
| 39 |
+
# -------------------------
|
| 40 |
+
def main():
|
| 41 |
+
processor = DocumentProcessor()
|
| 42 |
+
retriever_builder = RetrieverBuilder()
|
| 43 |
+
workflow = AgentWorkflow()
|
| 44 |
+
|
| 45 |
+
# -------------------------
|
| 46 |
+
# Custom CSS
|
| 47 |
+
# -------------------------
|
| 48 |
+
css = """
|
| 49 |
+
.title {
|
| 50 |
+
font-size: 1.5em !important;
|
| 51 |
+
text-align: center !important;
|
| 52 |
+
color: #FFD700;
|
| 53 |
+
}
|
| 54 |
+
.subtitle {
|
| 55 |
+
font-size: 1em !important;
|
| 56 |
+
text-align: center !important;
|
| 57 |
+
color: #FFD700;
|
| 58 |
+
}
|
| 59 |
+
.text {
|
| 60 |
+
text-align: center;
|
| 61 |
+
}
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
# -------------------------
|
| 65 |
+
# Gradio UI
|
| 66 |
+
# -------------------------
|
| 67 |
+
with gr.Blocks(theme=gr.themes.Citrus(), title="DocChat 🐥", css=css) as demo:
|
| 68 |
+
gr.Markdown("## DocChat: powered by Docling 🐥 and LangGraph", elem_classes="subtitle")
|
| 69 |
+
gr.Markdown("# How it works ✨:", elem_classes="title")
|
| 70 |
+
gr.Markdown("📤 Upload your document(s), enter your query then hit Submit 📝", elem_classes="text")
|
| 71 |
+
gr.Markdown("Or you can select one of the examples from the drop-down menu, select Load Example then hit Submit 📝", elem_classes="text")
|
| 72 |
+
gr.Markdown("⚠️ **Note:** DocChat only accepts documents in these formats: '.pdf', '.docx', '.txt', '.md'", elem_classes="text")
|
| 73 |
+
|
| 74 |
+
# Session state
|
| 75 |
+
session_state = gr.State({
|
| 76 |
+
"file_hashes": frozenset(),
|
| 77 |
+
"retriever": None
|
| 78 |
+
})
|
| 79 |
+
|
| 80 |
+
# -------------------------
|
| 81 |
+
# Layout
|
| 82 |
+
# -------------------------
|
| 83 |
+
with gr.Row():
|
| 84 |
+
with gr.Column():
|
| 85 |
+
gr.Markdown("### Example 📂")
|
| 86 |
+
example_dropdown = gr.Dropdown(
|
| 87 |
+
label="Select an Example 🐥",
|
| 88 |
+
choices=list(EXAMPLES.keys()),
|
| 89 |
+
value=None
|
| 90 |
+
)
|
| 91 |
+
load_example_btn = gr.Button("Load Example 🛠️")
|
| 92 |
+
files = gr.Files(label="📄 Upload Documents", file_types=constants.ALLOWED_TYPES)
|
| 93 |
+
question = gr.Textbox(label="❓ Question", lines=3)
|
| 94 |
+
submit_btn = gr.Button("Submit 🚀")
|
| 95 |
+
|
| 96 |
+
with gr.Column():
|
| 97 |
+
answer_output = gr.Textbox(label="🐥 Answer", interactive=False)
|
| 98 |
+
verification_output = gr.Textbox(label="✅ Verification Report")
|
| 99 |
+
|
| 100 |
+
# -------------------------
|
| 101 |
+
# Load Example Function
|
| 102 |
+
# -------------------------
|
| 103 |
+
def load_example(example_key: str):
|
| 104 |
+
if not example_key or example_key not in EXAMPLES:
|
| 105 |
+
return [], ""
|
| 106 |
+
ex_data = EXAMPLES[example_key]
|
| 107 |
+
file_paths = ex_data["file_paths"]
|
| 108 |
+
question_text = ex_data["question"]
|
| 109 |
+
|
| 110 |
+
loaded_files = []
|
| 111 |
+
for path in file_paths:
|
| 112 |
+
if os.path.exists(path):
|
| 113 |
+
loaded_files.append(path)
|
| 114 |
+
else:
|
| 115 |
+
logger.warning(f"File not found: {path}")
|
| 116 |
+
|
| 117 |
+
return loaded_files, question_text
|
| 118 |
+
|
| 119 |
+
load_example_btn.click(
|
| 120 |
+
fn=load_example,
|
| 121 |
+
inputs=[example_dropdown],
|
| 122 |
+
outputs=[files, question]
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# -------------------------
|
| 126 |
+
# Process Question
|
| 127 |
+
# -------------------------
|
| 128 |
+
def process_question(question_text: str, uploaded_files: List, state: Dict):
|
| 129 |
+
try:
|
| 130 |
+
if not question_text.strip():
|
| 131 |
+
raise ValueError("❌ Question cannot be empty")
|
| 132 |
+
if not uploaded_files:
|
| 133 |
+
raise ValueError("❌ No documents uploaded")
|
| 134 |
+
|
| 135 |
+
current_hashes = _get_file_hashes(uploaded_files)
|
| 136 |
+
|
| 137 |
+
if state["retriever"] is None or current_hashes != state["file_hashes"]:
|
| 138 |
+
logger.info("Processing new/changed documents...")
|
| 139 |
+
chunks = processor.process(uploaded_files)
|
| 140 |
+
retriever = retriever_builder.build_hybrid_retriever(chunks)
|
| 141 |
+
state.update({
|
| 142 |
+
"file_hashes": current_hashes,
|
| 143 |
+
"retriever": retriever
|
| 144 |
+
})
|
| 145 |
+
|
| 146 |
+
result = workflow.full_pipeline(
|
| 147 |
+
question=question_text,
|
| 148 |
+
retriever=state["retriever"]
|
| 149 |
+
)
|
| 150 |
+
return result["draft_answer"], result["verification_report"], state
|
| 151 |
+
|
| 152 |
+
except Exception as e:
|
| 153 |
+
logger.error(f"Processing error: {str(e)}")
|
| 154 |
+
return f"❌ Error: {str(e)}", "", state
|
| 155 |
+
|
| 156 |
+
submit_btn.click(
|
| 157 |
+
fn=process_question,
|
| 158 |
+
inputs=[question, files, session_state],
|
| 159 |
+
outputs=[answer_output, verification_output, session_state]
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
# -------------------------
|
| 163 |
+
# Hugging Face launch (no local args)
|
| 164 |
+
# -------------------------
|
| 165 |
+
demo.launch()
|
| 166 |
+
|
| 167 |
+
# -------------------------
|
| 168 |
+
# Run App
|
| 169 |
+
# -------------------------
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
main()
|
config/__init.py__
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .settings import settings
|
| 2 |
+
from .constants import MAX_FILE_SIZE, MAX_TOTAL_SIZE, ALLOWED_TYPES
|
| 3 |
+
|
| 4 |
+
__all__ = ["settings", "MAX_FILE_SIZE", "MAX_TOTAL_SIZE", "ALLOWED_TYPES"]
|
config/__pycache__/llm_config.cpython-313.pyc
ADDED
|
Binary file (6.21 kB). View file
|
|
|
config/constants.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Maximum allowed size for a single file (50 MB)
|
| 2 |
+
MAX_FILE_SIZE: int = 50 * 1024 * 1024
|
| 3 |
+
|
| 4 |
+
# Maximum allowed total size for all uploaded files (200 MB)
|
| 5 |
+
MAX_TOTAL_SIZE: int = 200 * 1024 * 1024
|
| 6 |
+
|
| 7 |
+
# Allowed file types for upload
|
| 8 |
+
ALLOWED_TYPES: list = [".txt", ".pdf", ".docx", ".md"]
|
config/llm_config.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM Configuration Manager
|
| 3 |
+
Centralizes all LLM model configurations for easy switching
|
| 4 |
+
"""
|
| 5 |
+
from typing import Dict, Any, Optional
|
| 6 |
+
from enum import Enum
|
| 7 |
+
import os
|
| 8 |
+
from google import genai
|
| 9 |
+
from google.genai import types
|
| 10 |
+
from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
|
| 11 |
+
import logging
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
class ModelProvider(Enum):
|
| 18 |
+
"""Supported LLM providers"""
|
| 19 |
+
GOOGLE = "google"
|
| 20 |
+
OPENAI = "openai"
|
| 21 |
+
ANTHROPIC = "anthropic"
|
| 22 |
+
|
| 23 |
+
class LLMConfig:
|
| 24 |
+
"""Configuration manager for LLM models"""
|
| 25 |
+
|
| 26 |
+
# Model configurations for different tasks
|
| 27 |
+
MODELS = {
|
| 28 |
+
ModelProvider.GOOGLE: {
|
| 29 |
+
"research": "gemini-1.5-pro",
|
| 30 |
+
"verification": "gemini-1.5-flash",
|
| 31 |
+
"relevance": "gemini-1.5-flash",
|
| 32 |
+
"embedding": "text-embedding-004",
|
| 33 |
+
},
|
| 34 |
+
ModelProvider.OPENAI: {
|
| 35 |
+
"research": "gpt-4-turbo",
|
| 36 |
+
"verification": "gpt-4-turbo",
|
| 37 |
+
"relevance": "gpt-4-turbo",
|
| 38 |
+
"embedding": "text-embedding-3-large",
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Default parameters for each task
|
| 43 |
+
DEFAULT_PARAMS = {
|
| 44 |
+
"research": {
|
| 45 |
+
"temperature": 0.3,
|
| 46 |
+
"max_tokens": 300,
|
| 47 |
+
"top_p": 0.95,
|
| 48 |
+
},
|
| 49 |
+
"verification": {
|
| 50 |
+
"temperature": 0.0,
|
| 51 |
+
"max_tokens": 200,
|
| 52 |
+
"top_p": 0.9,
|
| 53 |
+
},
|
| 54 |
+
"relevance": {
|
| 55 |
+
"temperature": 0.0,
|
| 56 |
+
"max_tokens": 10,
|
| 57 |
+
"top_p": 0.9,
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
def __init__(self, provider: ModelProvider = ModelProvider.GOOGLE):
|
| 62 |
+
"""
|
| 63 |
+
Initialize LLM configuration with specified provider
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
provider: Model provider to use (default: Google)
|
| 67 |
+
"""
|
| 68 |
+
self.provider = provider
|
| 69 |
+
self.api_key = self._get_api_key()
|
| 70 |
+
self._validate_config()
|
| 71 |
+
|
| 72 |
+
def _get_api_key(self) -> str:
|
| 73 |
+
"""Get API key for the configured provider"""
|
| 74 |
+
if self.provider == ModelProvider.GOOGLE:
|
| 75 |
+
key = os.getenv("GOOGLE_API_KEY")
|
| 76 |
+
if not key:
|
| 77 |
+
raise ValueError("GOOGLE_API_KEY environment variable is required")
|
| 78 |
+
return key
|
| 79 |
+
elif self.provider == ModelProvider.OPENAI:
|
| 80 |
+
key = os.getenv("OPENAI_API_KEY")
|
| 81 |
+
if not key:
|
| 82 |
+
raise ValueError("OPENAI_API_KEY environment variable is required")
|
| 83 |
+
return key
|
| 84 |
+
else:
|
| 85 |
+
raise ValueError(f"Unsupported provider: {self.provider}")
|
| 86 |
+
|
| 87 |
+
def _validate_config(self):
|
| 88 |
+
"""Validate configuration"""
|
| 89 |
+
if self.provider not in self.MODELS:
|
| 90 |
+
raise ValueError(f"Provider {self.provider} not configured")
|
| 91 |
+
|
| 92 |
+
def get_model_name(self, task: str) -> str:
|
| 93 |
+
"""Get model name for specific task"""
|
| 94 |
+
if task not in self.MODELS[self.provider]:
|
| 95 |
+
raise ValueError(f"Task {task} not configured for provider {self.provider}")
|
| 96 |
+
return self.MODELS[self.provider][task]
|
| 97 |
+
|
| 98 |
+
def get_model_params(self, task: str) -> Dict[str, Any]:
|
| 99 |
+
"""Get model parameters for specific task"""
|
| 100 |
+
return self.DEFAULT_PARAMS.get(task, {}).copy()
|
| 101 |
+
|
| 102 |
+
def create_llm(self, task: str):
|
| 103 |
+
"""Create LLM instance for specific task"""
|
| 104 |
+
model_name = self.get_model_name(task)
|
| 105 |
+
params = self.get_model_params(task)
|
| 106 |
+
|
| 107 |
+
if self.provider == ModelProvider.GOOGLE:
|
| 108 |
+
return ChatGoogleGenerativeAI(
|
| 109 |
+
model=model_name,
|
| 110 |
+
google_api_key=self.api_key,
|
| 111 |
+
temperature=params.get("temperature", 0.3),
|
| 112 |
+
max_tokens=params.get("max_tokens", None),
|
| 113 |
+
top_p=params.get("top_p", 0.95),
|
| 114 |
+
)
|
| 115 |
+
elif self.provider == ModelProvider.OPENAI:
|
| 116 |
+
# Would use ChatOpenAI here
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
raise ValueError(f"Provider {self.provider} not implemented")
|
| 120 |
+
|
| 121 |
+
def create_embedding(self):
|
| 122 |
+
"""Create embedding instance"""
|
| 123 |
+
if self.provider == ModelProvider.GOOGLE:
|
| 124 |
+
return GoogleGenerativeAIEmbeddings(
|
| 125 |
+
model="models/text-embedding-004",
|
| 126 |
+
google_api_key=self.api_key
|
| 127 |
+
)
|
| 128 |
+
elif self.provider == ModelProvider.OPENAI:
|
| 129 |
+
# Would use OpenAIEmbeddings here
|
| 130 |
+
pass
|
| 131 |
+
|
| 132 |
+
raise ValueError(f"Provider {self.provider} not implemented")
|
| 133 |
+
|
| 134 |
+
def create_direct_client(self):
|
| 135 |
+
"""Create direct client for providers that need it"""
|
| 136 |
+
if self.provider == ModelProvider.GOOGLE:
|
| 137 |
+
client = genai.Client(api_key=self.api_key)
|
| 138 |
+
return client
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
# Global configuration instance
|
| 142 |
+
llm_config = LLMConfig()
|
config/settings.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from pydantic_settings import BaseSettings
|
| 3 |
+
from .constants import MAX_FILE_SIZE, MAX_TOTAL_SIZE, ALLOWED_TYPES
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
class Settings(BaseSettings):
|
| 7 |
+
# LLM Provider settings
|
| 8 |
+
LLM_PROVIDER: str = "google" # "google" or "openai"
|
| 9 |
+
|
| 10 |
+
# API Keys
|
| 11 |
+
GOOGLE_API_KEY: str
|
| 12 |
+
OPENAI_API_KEY: str = ""
|
| 13 |
+
|
| 14 |
+
# Optional settings with defaults
|
| 15 |
+
MAX_FILE_SIZE: int = MAX_FILE_SIZE
|
| 16 |
+
MAX_TOTAL_SIZE: int = MAX_TOTAL_SIZE
|
| 17 |
+
ALLOWED_TYPES: list = ALLOWED_TYPES
|
| 18 |
+
|
| 19 |
+
# Database settings
|
| 20 |
+
CHROMA_DB_PATH: str = "./chroma_db"
|
| 21 |
+
CHROMA_COLLECTION_NAME: str = "documents"
|
| 22 |
+
|
| 23 |
+
# Retrieval settings
|
| 24 |
+
VECTOR_SEARCH_K: int = 10
|
| 25 |
+
HYBRID_RETRIEVER_WEIGHTS: list = [0.4, 0.6]
|
| 26 |
+
|
| 27 |
+
# Logging settings
|
| 28 |
+
LOG_LEVEL: str = "INFO"
|
| 29 |
+
|
| 30 |
+
# Cache settings
|
| 31 |
+
CACHE_DIR: str = "document_cache"
|
| 32 |
+
CACHE_EXPIRE_DAYS: int = 7
|
| 33 |
+
|
| 34 |
+
class Config:
|
| 35 |
+
env_file = ".env"
|
| 36 |
+
env_file_encoding = "utf-8"
|
| 37 |
+
|
| 38 |
+
settings = Settings()
|
config/test.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test file for LLMConfig
|
| 3 |
+
Run: python test_llm_config.py
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from llm_config import LLMConfig, ModelProvider
|
| 7 |
+
|
| 8 |
+
def test_basic_config():
|
| 9 |
+
print("🔹 Testing basic configuration...")
|
| 10 |
+
config = LLMConfig(provider=ModelProvider.GOOGLE)
|
| 11 |
+
print("Provider:", config.provider.value)
|
| 12 |
+
print("API Key loaded: ✅")
|
| 13 |
+
|
| 14 |
+
def test_model_names():
|
| 15 |
+
print("\n🔹 Testing model name resolution...")
|
| 16 |
+
config = LLMConfig()
|
| 17 |
+
print("Research model:", config.get_model_name("research"))
|
| 18 |
+
print("Verification model:", config.get_model_name("verification"))
|
| 19 |
+
print("Relevance model:", config.get_model_name("relevance"))
|
| 20 |
+
|
| 21 |
+
def test_llm_creation():
|
| 22 |
+
print("\n🔹 Testing LLM creation...")
|
| 23 |
+
config = LLMConfig()
|
| 24 |
+
llm = config.create_llm("research")
|
| 25 |
+
print("LLM instance created:", type(llm))
|
| 26 |
+
|
| 27 |
+
def test_embedding_creation():
|
| 28 |
+
print("\n🔹 Testing embedding creation...")
|
| 29 |
+
config = LLMConfig()
|
| 30 |
+
embedding = config.create_embedding()
|
| 31 |
+
print("Embedding instance created:", type(embedding))
|
| 32 |
+
|
| 33 |
+
def test_direct_client():
|
| 34 |
+
print("\n🔹 Testing direct Gemini client...")
|
| 35 |
+
config = LLMConfig()
|
| 36 |
+
client = config.create_direct_client()
|
| 37 |
+
print("Direct client created:", type(client))
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
try:
|
| 41 |
+
test_basic_config()
|
| 42 |
+
test_model_names()
|
| 43 |
+
test_llm_creation()
|
| 44 |
+
test_embedding_creation()
|
| 45 |
+
test_direct_client()
|
| 46 |
+
print("\n✅ ALL TESTS PASSED")
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print("\n❌ TEST FAILED")
|
| 49 |
+
print("Error:", e)
|
document_processor/__init.py__
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .file_handler import DocumentProcessor
|
| 2 |
+
|
| 3 |
+
__all__ = ["DocumentProcessor"]
|
document_processor/file_handler.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import hashlib
|
| 3 |
+
import pickle
|
| 4 |
+
from datetime import datetime, timedelta
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List
|
| 7 |
+
from docling.document_converter import DocumentConverter
|
| 8 |
+
from langchain_text_splitters import MarkdownHeaderTextSplitter
|
| 9 |
+
from config import constants
|
| 10 |
+
from config.settings import settings
|
| 11 |
+
from utils.logging import logger
|
| 12 |
+
|
| 13 |
+
class DocumentProcessor:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.headers = [("#", "Header 1"), ("##", "Header 2")]
|
| 16 |
+
self.cache_dir = Path(settings.CACHE_DIR)
|
| 17 |
+
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
|
| 19 |
+
def validate_files(self, files: List) -> None:
|
| 20 |
+
"""Validate the total size of the uploaded files."""
|
| 21 |
+
total_size = sum(os.path.getsize(f.name) for f in files)
|
| 22 |
+
if total_size > constants.MAX_TOTAL_SIZE:
|
| 23 |
+
raise ValueError(f"Total size exceeds {constants.MAX_TOTAL_SIZE//1024//1024}MB limit")
|
| 24 |
+
|
| 25 |
+
def process(self, files: List) -> List:
|
| 26 |
+
"""Process files with caching for subsequent queries"""
|
| 27 |
+
self.validate_files(files)
|
| 28 |
+
all_chunks = []
|
| 29 |
+
seen_hashes = set()
|
| 30 |
+
|
| 31 |
+
for file in files:
|
| 32 |
+
try:
|
| 33 |
+
# Generate content-based hash for caching
|
| 34 |
+
with open(file.name, "rb") as f:
|
| 35 |
+
file_hash = self._generate_hash(f.read())
|
| 36 |
+
|
| 37 |
+
cache_path = self.cache_dir / f"{file_hash}.pkl"
|
| 38 |
+
|
| 39 |
+
if self._is_cache_valid(cache_path):
|
| 40 |
+
logger.info(f"Loading from cache: {file.name}")
|
| 41 |
+
chunks = self._load_from_cache(cache_path)
|
| 42 |
+
else:
|
| 43 |
+
logger.info(f"Processing and caching: {file.name}")
|
| 44 |
+
chunks = self._process_file(file)
|
| 45 |
+
self._save_to_cache(chunks, cache_path)
|
| 46 |
+
|
| 47 |
+
# Deduplicate chunks across files
|
| 48 |
+
for chunk in chunks:
|
| 49 |
+
chunk_hash = self._generate_hash(chunk.page_content.encode())
|
| 50 |
+
if chunk_hash not in seen_hashes:
|
| 51 |
+
all_chunks.append(chunk)
|
| 52 |
+
seen_hashes.add(chunk_hash)
|
| 53 |
+
|
| 54 |
+
except Exception as e:
|
| 55 |
+
logger.error(f"Failed to process {file.name}: {str(e)}")
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
logger.info(f"Total unique chunks: {len(all_chunks)}")
|
| 59 |
+
return all_chunks
|
| 60 |
+
|
| 61 |
+
def _process_file(self, file) -> List:
|
| 62 |
+
"""Original processing logic with Docling"""
|
| 63 |
+
if not file.name.endswith(('.pdf', '.docx', '.txt', '.md')):
|
| 64 |
+
logger.warning(f"Skipping unsupported file type: {file.name}")
|
| 65 |
+
return []
|
| 66 |
+
|
| 67 |
+
converter = DocumentConverter()
|
| 68 |
+
markdown = converter.convert(file.name).document.export_to_markdown()
|
| 69 |
+
splitter = MarkdownHeaderTextSplitter(self.headers)
|
| 70 |
+
return splitter.split_text(markdown)
|
| 71 |
+
|
| 72 |
+
def _generate_hash(self, content: bytes) -> str:
|
| 73 |
+
return hashlib.sha256(content).hexdigest()
|
| 74 |
+
|
| 75 |
+
def _save_to_cache(self, chunks: List, cache_path: Path):
|
| 76 |
+
with open(cache_path, "wb") as f:
|
| 77 |
+
pickle.dump({
|
| 78 |
+
"timestamp": datetime.now().timestamp(),
|
| 79 |
+
"chunks": chunks
|
| 80 |
+
}, f)
|
| 81 |
+
|
| 82 |
+
def _load_from_cache(self, cache_path: Path) -> List:
|
| 83 |
+
with open(cache_path, "rb") as f:
|
| 84 |
+
data = pickle.load(f)
|
| 85 |
+
return data["chunks"]
|
| 86 |
+
|
| 87 |
+
def _is_cache_valid(self, cache_path: Path) -> bool:
|
| 88 |
+
if not cache_path.exists():
|
| 89 |
+
return False
|
| 90 |
+
|
| 91 |
+
cache_age = datetime.now() - datetime.fromtimestamp(cache_path.stat().st_mtime)
|
| 92 |
+
return cache_age < timedelta(days=settings.CACHE_EXPIRE_DAYS)
|
requirements.txt
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core Python
|
| 2 |
+
python-dotenv==1.0.1
|
| 3 |
+
pydantic==2.10.6
|
| 4 |
+
pydantic-settings==2.7.1
|
| 5 |
+
typing-extensions==4.12.2
|
| 6 |
+
|
| 7 |
+
# Web Framework
|
| 8 |
+
fastapi==0.115.7
|
| 9 |
+
uvicorn[standard]==0.34.0
|
| 10 |
+
gradio==5.13.2
|
| 11 |
+
|
| 12 |
+
# LangChain Core
|
| 13 |
+
langchain==0.3.16
|
| 14 |
+
langchain-core==0.3.32
|
| 15 |
+
langchain-community==0.3.16
|
| 16 |
+
langchain-text-splitters==0.3.5
|
| 17 |
+
langgraph==0.2.68
|
| 18 |
+
|
| 19 |
+
# LLM Providers
|
| 20 |
+
langchain-google-genai==2.1.2
|
| 21 |
+
google-generativeai==0.8.4
|
| 22 |
+
langchain-openai==0.3.2
|
| 23 |
+
openai==1.60.2
|
| 24 |
+
|
| 25 |
+
# Embeddings & Vector Stores
|
| 26 |
+
chromadb==0.6.3
|
| 27 |
+
langchain-chroma==0.2.4
|
| 28 |
+
sentence-transformers==3.0.1
|
| 29 |
+
|
| 30 |
+
# Document Processing
|
| 31 |
+
docling==2.15.0
|
| 32 |
+
pypdf==5.2.0
|
| 33 |
+
python-docx==1.1.2
|
| 34 |
+
markdown==3.6
|
| 35 |
+
beautifulsoup4==4.12.3
|
| 36 |
+
lxml==5.3.0
|
| 37 |
+
|
| 38 |
+
# Text Processing & Retrieval
|
| 39 |
+
rank-bm25==0.2.2
|
| 40 |
+
nltk==3.9.1
|
| 41 |
+
scikit-learn==1.6.0
|
| 42 |
+
numpy==1.26.4
|
| 43 |
+
|
| 44 |
+
# Caching & Hashing
|
| 45 |
+
cachetools==5.5.1
|
| 46 |
+
|
| 47 |
+
# Logging
|
| 48 |
+
loguru==0.7.3
|
| 49 |
+
|
| 50 |
+
# Utilities
|
| 51 |
+
python-multipart==0.0.20
|
| 52 |
+
aiofiles==23.2.1
|
| 53 |
+
pillow==10.4.0
|
| 54 |
+
tqdm==4.67.1
|
| 55 |
+
tenacity==9.0.0
|
| 56 |
+
backoff==2.2.1
|
| 57 |
+
httpx==0.28.1
|
| 58 |
+
requests==2.32.3
|
| 59 |
+
orjson==3.10.15
|
| 60 |
+
|
| 61 |
+
# Development & Testing
|
| 62 |
+
pytest==8.3.4
|
| 63 |
+
pytest-asyncio==0.23.7
|
| 64 |
+
black==24.10.0
|
| 65 |
+
isort==5.13.2
|
| 66 |
+
mypy==1.13.0
|
| 67 |
+
ruff==0.9.3
|
retriever/__init.py__
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .builder import RetrieverBuilder
|
| 2 |
+
|
| 3 |
+
__all__ = ["RetrieverBuilder"]
|
retriever/builder.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.vectorstores import Chroma
|
| 2 |
+
from langchain_community.retrievers import BM25Retriever
|
| 3 |
+
from langchain.retrievers import EnsembleRetriever
|
| 4 |
+
from config.settings import settings
|
| 5 |
+
from config.llm_config import llm_config
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
logger = logging.getLogger(__name__)
|
| 9 |
+
|
| 10 |
+
class RetrieverBuilder:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
"""Initialize the retriever builder with embeddings."""
|
| 13 |
+
logger.info("Initializing RetrieverBuilder...")
|
| 14 |
+
|
| 15 |
+
# Get embeddings from configuration
|
| 16 |
+
self.embeddings = llm_config.create_embedding()
|
| 17 |
+
|
| 18 |
+
logger.info("RetrieverBuilder initialized successfully.")
|
| 19 |
+
|
| 20 |
+
def build_hybrid_retriever(self, docs):
|
| 21 |
+
"""Build a hybrid retriever using BM25 and vector-based retrieval."""
|
| 22 |
+
try:
|
| 23 |
+
logger.info(f"Building hybrid retriever with {len(docs)} documents")
|
| 24 |
+
|
| 25 |
+
# Create Chroma vector store
|
| 26 |
+
vector_store = Chroma.from_documents(
|
| 27 |
+
documents=docs,
|
| 28 |
+
embedding=self.embeddings,
|
| 29 |
+
persist_directory=settings.CHROMA_DB_PATH,
|
| 30 |
+
collection_name=settings.CHROMA_COLLECTION_NAME
|
| 31 |
+
)
|
| 32 |
+
logger.info("Vector store created successfully.")
|
| 33 |
+
|
| 34 |
+
# Create BM25 retriever
|
| 35 |
+
bm25 = BM25Retriever.from_documents(docs)
|
| 36 |
+
logger.info("BM25 retriever created successfully.")
|
| 37 |
+
|
| 38 |
+
# Create vector-based retriever
|
| 39 |
+
vector_retriever = vector_store.as_retriever(
|
| 40 |
+
search_kwargs={"k": settings.VECTOR_SEARCH_K}
|
| 41 |
+
)
|
| 42 |
+
logger.info("Vector retriever created successfully.")
|
| 43 |
+
|
| 44 |
+
# Combine retrievers into a hybrid retriever
|
| 45 |
+
hybrid_retriever = EnsembleRetriever(
|
| 46 |
+
retrievers=[bm25, vector_retriever],
|
| 47 |
+
weights=settings.HYBRID_RETRIEVER_WEIGHTS
|
| 48 |
+
)
|
| 49 |
+
logger.info("Hybrid retriever created successfully.")
|
| 50 |
+
|
| 51 |
+
return hybrid_retriever
|
| 52 |
+
|
| 53 |
+
except Exception as e:
|
| 54 |
+
logger.error(f"Failed to build hybrid retriever: {e}")
|
| 55 |
+
raise
|
utils/__init.py__
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .logging import logger
|
| 2 |
+
|
| 3 |
+
__all__ = ["logger"]
|
utils/logging.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from loguru import logger
|
| 2 |
+
|
| 3 |
+
logger.add(
|
| 4 |
+
"app.log",
|
| 5 |
+
rotation="10 MB",
|
| 6 |
+
retention="30 days",
|
| 7 |
+
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}"
|
| 8 |
+
)
|