MaheshLEO4 commited on
Commit
b7e0e53
·
0 Parent(s):

Initial commit for DocChat

Browse files
.env ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLM Configuration
2
+ LLM_PROVIDER=google # google or openai
3
+
4
+ # API Keys
5
+ GOOGLE_API_KEY="AIzaSyCXbE6aDpC20WuQWZVR8ULA7LFOT9y6000"
6
+ OPENAI_API_KEY="your_openai_api_key_here"
7
+
8
+ # Database Settings
9
+ CHROMA_DB_PATH=./chroma_db
10
+
11
+ # Retrieval Settings
12
+ VECTOR_SEARCH_K=10
13
+
14
+ # Cache Settings
15
+ CACHE_EXPIRE_DAYS=7
agents/__init.py__ ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .research_agent import ResearchAgent
2
+ from .verification_agent import VerificationAgent
3
+ from .workflow import AgentWorkflow
4
+
5
+ __all__ = ["ResearchAgent", "VerificationAgent", "AgentWorkflow"]
agents/relevance_checker.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from langchain.schema import BaseRetriever
3
+ from langchain.prompts import ChatPromptTemplate
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from config.llm_config import llm_config
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class RelevanceChecker:
11
+ def __init__(self):
12
+ """Initialize the relevance checker with configurable LLM."""
13
+ logger.info("Initializing RelevanceChecker...")
14
+
15
+ # Get LLM from configuration
16
+ self.llm = llm_config.create_llm("relevance")
17
+
18
+ # Create prompt template
19
+ self.prompt_template = ChatPromptTemplate.from_messages([
20
+ ("system", """You are an AI relevance checker between a user's question and provided document content.
21
+
22
+ Instructions:
23
+ - Classify how well the document content addresses the user's question.
24
+ - Respond with ONLY ONE of the following labels: CAN_ANSWER, PARTIAL, NO_MATCH.
25
+ - Do not include any additional text or explanation.
26
+
27
+ Label Definitions:
28
+ 1) "CAN_ANSWER": The passages contain enough explicit information to fully answer the question.
29
+ 2) "PARTIAL": The passages mention or discuss the question's topic but do not provide all the details needed for a complete answer.
30
+ 3) "NO_MATCH": The passages do not discuss or mention the question's topic at all.
31
+
32
+ Important: If the passages mention or reference the topic or timeframe of the question in any way, even if incomplete, respond with "PARTIAL" instead of "NO_MATCH"."""),
33
+ ("human", """Question: {question}
34
+
35
+ Passages:
36
+ {passages}
37
+
38
+ Respond ONLY with one of the following labels: CAN_ANSWER, PARTIAL, NO_MATCH""")
39
+ ])
40
+
41
+ # Create chain
42
+ self.chain = self.prompt_template | self.llm | StrOutputParser()
43
+
44
+ logger.info("RelevanceChecker initialized successfully.")
45
+
46
+ def check(self, question: str, retriever: BaseRetriever, k: int = 3) -> str:
47
+ """
48
+ Check relevance between question and retrieved documents.
49
+
50
+ Returns: "CAN_ANSWER", "PARTIAL", or "NO_MATCH".
51
+ """
52
+ logger.debug(f"RelevanceChecker.check called with question='{question}' and k={k}")
53
+
54
+ # Retrieve document chunks
55
+ try:
56
+ top_docs = retriever.invoke(question)
57
+ except Exception as e:
58
+ logger.error(f"Error retrieving documents: {e}")
59
+ return "NO_MATCH"
60
+
61
+ if not top_docs:
62
+ logger.debug("No documents returned from retriever.")
63
+ return "NO_MATCH"
64
+
65
+ # Combine the top k chunk texts
66
+ document_content = "\n\n".join(doc.page_content for doc in top_docs[:k])
67
+ logger.debug(f"Combined document content length: {len(document_content)} characters")
68
+
69
+ try:
70
+ # Get classification from LLM
71
+ response = self.chain.invoke({
72
+ "question": question,
73
+ "passages": document_content
74
+ })
75
+
76
+ # Clean and validate response
77
+ classification = response.strip().upper()
78
+ valid_labels = {"CAN_ANSWER", "PARTIAL", "NO_MATCH"}
79
+
80
+ if classification not in valid_labels:
81
+ logger.warning(f"Invalid classification received: '{classification}'. Defaulting to NO_MATCH.")
82
+ classification = "NO_MATCH"
83
+
84
+ logger.debug(f"Classification: {classification}")
85
+ return classification
86
+
87
+ except Exception as e:
88
+ logger.error(f"Error during relevance classification: {e}")
89
+ return "NO_MATCH"
agents/research_agent.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from langchain.schema import Document
3
+ from langchain.prompts import ChatPromptTemplate
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from config.llm_config import llm_config
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class ResearchAgent:
11
+ def __init__(self):
12
+ """
13
+ Initialize the research agent with configurable LLM.
14
+ """
15
+ logger.info("Initializing ResearchAgent...")
16
+
17
+ # Get LLM from configuration
18
+ self.llm = llm_config.create_llm("research")
19
+ self.client = llm_config.create_direct_client()
20
+
21
+ # Create prompt template
22
+ self.prompt_template = ChatPromptTemplate.from_messages([
23
+ ("system", """You are an AI assistant designed to provide precise and factual answers based on the given context.
24
+
25
+ Instructions:
26
+ - Answer the following question using only the provided context.
27
+ - Be clear, concise, and factual.
28
+ - Return as much information as you can get from the context.
29
+ - If the context doesn't contain enough information, say so explicitly.
30
+ - Do not add any information not present in the context.
31
+ - Format your answer in a clear, readable manner."""),
32
+ ("human", """Question: {question}
33
+
34
+ Context:
35
+ {context}
36
+
37
+ Provide your answer below:""")
38
+ ])
39
+
40
+ # Create chain
41
+ self.chain = self.prompt_template | self.llm | StrOutputParser()
42
+
43
+ logger.info("ResearchAgent initialized successfully.")
44
+
45
+ def generate(self, question: str, documents: List[Document]) -> Dict:
46
+ """
47
+ Generate an initial answer using the provided documents.
48
+ """
49
+ logger.info(f"ResearchAgent.generate called with question='{question}' and {len(documents)} documents.")
50
+
51
+ # Combine the document contents
52
+ context = "\n\n".join([doc.page_content for doc in documents])
53
+ logger.debug(f"Combined context length: {len(context)} characters.")
54
+
55
+ try:
56
+ # Generate answer using LangChain chain
57
+ draft_answer = self.chain.invoke({
58
+ "question": question,
59
+ "context": context
60
+ })
61
+
62
+ logger.info(f"Generated answer successfully. Length: {len(draft_answer)} characters.")
63
+
64
+ return {
65
+ "draft_answer": draft_answer.strip(),
66
+ "context_used": context
67
+ }
68
+
69
+ except Exception as e:
70
+ logger.error(f"Error during answer generation: {e}")
71
+ return {
72
+ "draft_answer": f"I cannot answer this question based on the provided documents. Error: {str(e)}",
73
+ "context_used": context
74
+ }
agents/verification_agent.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List
2
+ from langchain.schema import Document
3
+ from langchain.prompts import ChatPromptTemplate
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from config.llm_config import llm_config
6
+ import logging
7
+ import json
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class VerificationAgent:
12
+ def __init__(self):
13
+ """
14
+ Initialize the verification agent with configurable LLM.
15
+ """
16
+ logger.info("Initializing VerificationAgent...")
17
+
18
+ # Get LLM from configuration
19
+ self.llm = llm_config.create_llm("verification")
20
+
21
+ # Create prompt template for verification
22
+ self.prompt_template = ChatPromptTemplate.from_messages([
23
+ ("system", """You are an AI assistant designed to verify the accuracy and relevance of answers based on provided context.
24
+
25
+ You MUST respond in the exact JSON format specified below.
26
+
27
+ Instructions:
28
+ - Verify the answer against the provided context.
29
+ - Check for:
30
+ 1. Direct/indirect factual support (YES/NO)
31
+ 2. Unsupported claims (list any if present)
32
+ 3. Contradictions (list any if present)
33
+ 4. Relevance to the question (YES/NO)
34
+ - Provide additional details or explanations where relevant.
35
+ - If there are no unsupported claims or contradictions, use empty lists.
36
+ - If there are no additional details, use an empty string.
37
+
38
+ JSON Response Format:
39
+ {
40
+ "supported": "YES" or "NO",
41
+ "unsupported_claims": ["claim1", "claim2", ...],
42
+ "contradictions": ["contradiction1", "contradiction2", ...],
43
+ "relevant": "YES" or "NO",
44
+ "additional_details": "string"
45
+ }"""),
46
+ ("human", """Answer to verify: {answer}
47
+
48
+ Context:
49
+ {context}
50
+
51
+ Provide your verification in the specified JSON format:""")
52
+ ])
53
+
54
+ # Create chain
55
+ self.chain = self.prompt_template | self.llm | StrOutputParser()
56
+
57
+ logger.info("VerificationAgent initialized successfully.")
58
+
59
+ def format_verification_report(self, verification: Dict) -> str:
60
+ """
61
+ Format the verification report dictionary into a readable paragraph.
62
+ """
63
+ supported = verification.get("supported", "NO")
64
+ unsupported_claims = verification.get("unsupported_claims", [])
65
+ contradictions = verification.get("contradictions", [])
66
+ relevant = verification.get("relevant", "NO")
67
+ additional_details = verification.get("additional_details", "")
68
+
69
+ report = f"**Supported:** {supported}\n"
70
+
71
+ if unsupported_claims:
72
+ report += f"**Unsupported Claims:** {', '.join(unsupported_claims)}\n"
73
+ else:
74
+ report += f"**Unsupported Claims:** None\n"
75
+
76
+ if contradictions:
77
+ report += f"**Contradictions:** {', '.join(contradictions)}\n"
78
+ else:
79
+ report += f"**Contradictions:** None\n"
80
+
81
+ report += f"**Relevant:** {relevant}\n"
82
+
83
+ if additional_details:
84
+ report += f"**Additional Details:** {additional_details}\n"
85
+ else:
86
+ report += f"**Additional Details:** None\n"
87
+
88
+ return report
89
+
90
+ def check(self, answer: str, documents: List[Document]) -> Dict:
91
+ """
92
+ Verify the answer against the provided documents.
93
+ """
94
+ logger.info(f"VerificationAgent.check called with answer length={len(answer)} and {len(documents)} documents.")
95
+
96
+ # Combine all document contents
97
+ context = "\n\n".join([doc.page_content for doc in documents])
98
+ logger.debug(f"Combined context length: {len(context)} characters.")
99
+
100
+ try:
101
+ # Get verification from LLM
102
+ response = self.chain.invoke({
103
+ "answer": answer,
104
+ "context": context
105
+ })
106
+
107
+ # Parse JSON response
108
+ try:
109
+ verification = json.loads(response)
110
+ except json.JSONDecodeError as e:
111
+ logger.error(f"Failed to parse JSON response: {e}")
112
+ verification = {
113
+ "supported": "NO",
114
+ "unsupported_claims": [],
115
+ "contradictions": [],
116
+ "relevant": "NO",
117
+ "additional_details": "Failed to parse verification response."
118
+ }
119
+
120
+ # Format report
121
+ verification_report = self.format_verification_report(verification)
122
+ logger.info("Verification completed successfully.")
123
+
124
+ return {
125
+ "verification_report": verification_report,
126
+ "context_used": context
127
+ }
128
+
129
+ except Exception as e:
130
+ logger.error(f"Error during verification: {e}")
131
+ return {
132
+ "verification_report": f"**Error during verification:** {str(e)}",
133
+ "context_used": context
134
+ }
agents/workflow.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.graph import StateGraph, END
2
+ from typing import TypedDict, List, Dict
3
+ from .research_agent import ResearchAgent
4
+ from .verification_agent import VerificationAgent
5
+ from .relevance_checker import RelevanceChecker
6
+ from langchain.schema import Document
7
+ from langchain.retrievers import EnsembleRetriever
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class AgentState(TypedDict):
13
+ question: str
14
+ documents: List[Document]
15
+ draft_answer: str
16
+ verification_report: str
17
+ is_relevant: bool
18
+ retriever: EnsembleRetriever
19
+
20
+ class AgentWorkflow:
21
+ def __init__(self):
22
+ self.researcher = ResearchAgent()
23
+ self.verifier = VerificationAgent()
24
+ self.relevance_checker = RelevanceChecker()
25
+ self.compiled_workflow = self.build_workflow() # Compile once during initialization
26
+
27
+ def build_workflow(self):
28
+ """Create and compile the multi-agent workflow."""
29
+ workflow = StateGraph(AgentState)
30
+
31
+ # Add nodes
32
+ workflow.add_node("check_relevance", self._check_relevance_step)
33
+ workflow.add_node("research", self._research_step)
34
+ workflow.add_node("verify", self._verification_step)
35
+
36
+ # Define edges
37
+ workflow.set_entry_point("check_relevance")
38
+ workflow.add_conditional_edges(
39
+ "check_relevance",
40
+ self._decide_after_relevance_check,
41
+ {
42
+ "relevant": "research",
43
+ "irrelevant": END
44
+ }
45
+ )
46
+ workflow.add_edge("research", "verify")
47
+ workflow.add_conditional_edges(
48
+ "verify",
49
+ self._decide_next_step,
50
+ {
51
+ "re_research": "research",
52
+ "end": END
53
+ }
54
+ )
55
+ return workflow.compile()
56
+
57
+ def _check_relevance_step(self, state: AgentState) -> Dict:
58
+ retriever = state["retriever"]
59
+ classification = self.relevance_checker.check(
60
+ question=state["question"],
61
+ retriever=retriever,
62
+ k=20
63
+ )
64
+
65
+ if classification == "CAN_ANSWER":
66
+ # We have enough info to proceed
67
+ return {"is_relevant": True}
68
+
69
+ elif classification == "PARTIAL":
70
+ # There's partial coverage, but we can still proceed
71
+ return {
72
+ "is_relevant": True
73
+ }
74
+
75
+ else: # classification == "NO_MATCH"
76
+ return {
77
+ "is_relevant": False,
78
+ "draft_answer": "This question isn't related (or there's no data) for your query. Please ask another question relevant to the uploaded document(s)."
79
+ }
80
+
81
+
82
+ def _decide_after_relevance_check(self, state: AgentState) -> str:
83
+ decision = "relevant" if state["is_relevant"] else "irrelevant"
84
+ print(f"[DEBUG] _decide_after_relevance_check -> {decision}")
85
+ return decision
86
+
87
+ def full_pipeline(self, question: str, retriever: EnsembleRetriever):
88
+ try:
89
+ print(f"[DEBUG] Starting full_pipeline with question='{question}'")
90
+ documents = retriever.invoke(question)
91
+ logger.info(f"Retrieved {len(documents)} relevant documents (from .invoke)")
92
+
93
+ initial_state = AgentState(
94
+ question=question,
95
+ documents=documents,
96
+ draft_answer="",
97
+ verification_report="",
98
+ is_relevant=False,
99
+ retriever=retriever
100
+ )
101
+
102
+ final_state = self.compiled_workflow.invoke(initial_state)
103
+
104
+ return {
105
+ "draft_answer": final_state["draft_answer"],
106
+ "verification_report": final_state["verification_report"]
107
+ }
108
+ except Exception as e:
109
+ logger.error(f"Workflow execution failed: {e}")
110
+ raise
111
+
112
+ def _research_step(self, state: AgentState) -> Dict:
113
+ print(f"[DEBUG] Entered _research_step with question='{state['question']}'")
114
+ result = self.researcher.generate(state["question"], state["documents"])
115
+ print("[DEBUG] Researcher returned draft answer.")
116
+ return {"draft_answer": result["draft_answer"]}
117
+
118
+ def _verification_step(self, state: AgentState) -> Dict:
119
+ print("[DEBUG] Entered _verification_step. Verifying the draft answer...")
120
+ result = self.verifier.check(state["draft_answer"], state["documents"])
121
+ print("[DEBUG] VerificationAgent returned a verification report.")
122
+ return {"verification_report": result["verification_report"]}
123
+
124
+ def _decide_next_step(self, state: AgentState) -> str:
125
+ verification_report = state["verification_report"]
126
+ print(f"[DEBUG] _decide_next_step with verification_report='{verification_report}'")
127
+ if "Supported: NO" in verification_report or "Relevant: NO" in verification_report:
128
+ logger.info("[DEBUG] Verification indicates re-research needed.")
129
+ return "re_research"
130
+ else:
131
+ logger.info("[DEBUG] Verification successful, ending workflow.")
132
+ return "end"
app.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import hashlib
3
+ from typing import List, Dict
4
+ import os
5
+
6
+ from document_processor.file_handler import DocumentProcessor
7
+ from retriever.builder import RetrieverBuilder
8
+ from agents.workflow import AgentWorkflow
9
+ from config import constants, settings
10
+ from utils.logging import logger
11
+
12
+ # -------------------------
13
+ # Example Data
14
+ # -------------------------
15
+ EXAMPLES = {
16
+ "Google 2024 Environmental Report": {
17
+ "question": "Retrieve the data center PUE efficiency values in Singapore 2nd facility in 2019 and 2022. Also retrieve regional average CFE in Asia pacific in 2023",
18
+ "file_paths": ["examples/google-2024-environmental-report.pdf"]
19
+ },
20
+ "DeepSeek-R1 Technical Report": {
21
+ "question": "Summarize DeepSeek-R1 model's performance evaluation on all coding tasks against OpenAI o1-mini model",
22
+ "file_paths": ["examples/DeepSeek Technical Report.pdf"]
23
+ }
24
+ }
25
+
26
+ # -------------------------
27
+ # Utils
28
+ # -------------------------
29
+ def _get_file_hashes(uploaded_files: List) -> frozenset:
30
+ """Generate SHA-256 hashes for uploaded files."""
31
+ hashes = set()
32
+ for file in uploaded_files:
33
+ with open(file.name, "rb") as f:
34
+ hashes.add(hashlib.sha256(f.read()).hexdigest())
35
+ return frozenset(hashes)
36
+
37
+ # -------------------------
38
+ # Main App
39
+ # -------------------------
40
+ def main():
41
+ processor = DocumentProcessor()
42
+ retriever_builder = RetrieverBuilder()
43
+ workflow = AgentWorkflow()
44
+
45
+ # -------------------------
46
+ # Custom CSS
47
+ # -------------------------
48
+ css = """
49
+ .title {
50
+ font-size: 1.5em !important;
51
+ text-align: center !important;
52
+ color: #FFD700;
53
+ }
54
+ .subtitle {
55
+ font-size: 1em !important;
56
+ text-align: center !important;
57
+ color: #FFD700;
58
+ }
59
+ .text {
60
+ text-align: center;
61
+ }
62
+ """
63
+
64
+ # -------------------------
65
+ # Gradio UI
66
+ # -------------------------
67
+ with gr.Blocks(theme=gr.themes.Citrus(), title="DocChat 🐥", css=css) as demo:
68
+ gr.Markdown("## DocChat: powered by Docling 🐥 and LangGraph", elem_classes="subtitle")
69
+ gr.Markdown("# How it works ✨:", elem_classes="title")
70
+ gr.Markdown("📤 Upload your document(s), enter your query then hit Submit 📝", elem_classes="text")
71
+ gr.Markdown("Or you can select one of the examples from the drop-down menu, select Load Example then hit Submit 📝", elem_classes="text")
72
+ gr.Markdown("⚠️ **Note:** DocChat only accepts documents in these formats: '.pdf', '.docx', '.txt', '.md'", elem_classes="text")
73
+
74
+ # Session state
75
+ session_state = gr.State({
76
+ "file_hashes": frozenset(),
77
+ "retriever": None
78
+ })
79
+
80
+ # -------------------------
81
+ # Layout
82
+ # -------------------------
83
+ with gr.Row():
84
+ with gr.Column():
85
+ gr.Markdown("### Example 📂")
86
+ example_dropdown = gr.Dropdown(
87
+ label="Select an Example 🐥",
88
+ choices=list(EXAMPLES.keys()),
89
+ value=None
90
+ )
91
+ load_example_btn = gr.Button("Load Example 🛠️")
92
+ files = gr.Files(label="📄 Upload Documents", file_types=constants.ALLOWED_TYPES)
93
+ question = gr.Textbox(label="❓ Question", lines=3)
94
+ submit_btn = gr.Button("Submit 🚀")
95
+
96
+ with gr.Column():
97
+ answer_output = gr.Textbox(label="🐥 Answer", interactive=False)
98
+ verification_output = gr.Textbox(label="✅ Verification Report")
99
+
100
+ # -------------------------
101
+ # Load Example Function
102
+ # -------------------------
103
+ def load_example(example_key: str):
104
+ if not example_key or example_key not in EXAMPLES:
105
+ return [], ""
106
+ ex_data = EXAMPLES[example_key]
107
+ file_paths = ex_data["file_paths"]
108
+ question_text = ex_data["question"]
109
+
110
+ loaded_files = []
111
+ for path in file_paths:
112
+ if os.path.exists(path):
113
+ loaded_files.append(path)
114
+ else:
115
+ logger.warning(f"File not found: {path}")
116
+
117
+ return loaded_files, question_text
118
+
119
+ load_example_btn.click(
120
+ fn=load_example,
121
+ inputs=[example_dropdown],
122
+ outputs=[files, question]
123
+ )
124
+
125
+ # -------------------------
126
+ # Process Question
127
+ # -------------------------
128
+ def process_question(question_text: str, uploaded_files: List, state: Dict):
129
+ try:
130
+ if not question_text.strip():
131
+ raise ValueError("❌ Question cannot be empty")
132
+ if not uploaded_files:
133
+ raise ValueError("❌ No documents uploaded")
134
+
135
+ current_hashes = _get_file_hashes(uploaded_files)
136
+
137
+ if state["retriever"] is None or current_hashes != state["file_hashes"]:
138
+ logger.info("Processing new/changed documents...")
139
+ chunks = processor.process(uploaded_files)
140
+ retriever = retriever_builder.build_hybrid_retriever(chunks)
141
+ state.update({
142
+ "file_hashes": current_hashes,
143
+ "retriever": retriever
144
+ })
145
+
146
+ result = workflow.full_pipeline(
147
+ question=question_text,
148
+ retriever=state["retriever"]
149
+ )
150
+ return result["draft_answer"], result["verification_report"], state
151
+
152
+ except Exception as e:
153
+ logger.error(f"Processing error: {str(e)}")
154
+ return f"❌ Error: {str(e)}", "", state
155
+
156
+ submit_btn.click(
157
+ fn=process_question,
158
+ inputs=[question, files, session_state],
159
+ outputs=[answer_output, verification_output, session_state]
160
+ )
161
+
162
+ # -------------------------
163
+ # Hugging Face launch (no local args)
164
+ # -------------------------
165
+ demo.launch()
166
+
167
+ # -------------------------
168
+ # Run App
169
+ # -------------------------
170
+ if __name__ == "__main__":
171
+ main()
config/__init.py__ ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .settings import settings
2
+ from .constants import MAX_FILE_SIZE, MAX_TOTAL_SIZE, ALLOWED_TYPES
3
+
4
+ __all__ = ["settings", "MAX_FILE_SIZE", "MAX_TOTAL_SIZE", "ALLOWED_TYPES"]
config/__pycache__/llm_config.cpython-313.pyc ADDED
Binary file (6.21 kB). View file
 
config/constants.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Maximum allowed size for a single file (50 MB)
2
+ MAX_FILE_SIZE: int = 50 * 1024 * 1024
3
+
4
+ # Maximum allowed total size for all uploaded files (200 MB)
5
+ MAX_TOTAL_SIZE: int = 200 * 1024 * 1024
6
+
7
+ # Allowed file types for upload
8
+ ALLOWED_TYPES: list = [".txt", ".pdf", ".docx", ".md"]
config/llm_config.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Configuration Manager
3
+ Centralizes all LLM model configurations for easy switching
4
+ """
5
+ from typing import Dict, Any, Optional
6
+ from enum import Enum
7
+ import os
8
+ from google import genai
9
+ from google.genai import types
10
+ from langchain_google_genai import ChatGoogleGenerativeAI, GoogleGenerativeAIEmbeddings
11
+ import logging
12
+ from dotenv import load_dotenv
13
+ load_dotenv()
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+ class ModelProvider(Enum):
18
+ """Supported LLM providers"""
19
+ GOOGLE = "google"
20
+ OPENAI = "openai"
21
+ ANTHROPIC = "anthropic"
22
+
23
+ class LLMConfig:
24
+ """Configuration manager for LLM models"""
25
+
26
+ # Model configurations for different tasks
27
+ MODELS = {
28
+ ModelProvider.GOOGLE: {
29
+ "research": "gemini-1.5-pro",
30
+ "verification": "gemini-1.5-flash",
31
+ "relevance": "gemini-1.5-flash",
32
+ "embedding": "text-embedding-004",
33
+ },
34
+ ModelProvider.OPENAI: {
35
+ "research": "gpt-4-turbo",
36
+ "verification": "gpt-4-turbo",
37
+ "relevance": "gpt-4-turbo",
38
+ "embedding": "text-embedding-3-large",
39
+ }
40
+ }
41
+
42
+ # Default parameters for each task
43
+ DEFAULT_PARAMS = {
44
+ "research": {
45
+ "temperature": 0.3,
46
+ "max_tokens": 300,
47
+ "top_p": 0.95,
48
+ },
49
+ "verification": {
50
+ "temperature": 0.0,
51
+ "max_tokens": 200,
52
+ "top_p": 0.9,
53
+ },
54
+ "relevance": {
55
+ "temperature": 0.0,
56
+ "max_tokens": 10,
57
+ "top_p": 0.9,
58
+ }
59
+ }
60
+
61
+ def __init__(self, provider: ModelProvider = ModelProvider.GOOGLE):
62
+ """
63
+ Initialize LLM configuration with specified provider
64
+
65
+ Args:
66
+ provider: Model provider to use (default: Google)
67
+ """
68
+ self.provider = provider
69
+ self.api_key = self._get_api_key()
70
+ self._validate_config()
71
+
72
+ def _get_api_key(self) -> str:
73
+ """Get API key for the configured provider"""
74
+ if self.provider == ModelProvider.GOOGLE:
75
+ key = os.getenv("GOOGLE_API_KEY")
76
+ if not key:
77
+ raise ValueError("GOOGLE_API_KEY environment variable is required")
78
+ return key
79
+ elif self.provider == ModelProvider.OPENAI:
80
+ key = os.getenv("OPENAI_API_KEY")
81
+ if not key:
82
+ raise ValueError("OPENAI_API_KEY environment variable is required")
83
+ return key
84
+ else:
85
+ raise ValueError(f"Unsupported provider: {self.provider}")
86
+
87
+ def _validate_config(self):
88
+ """Validate configuration"""
89
+ if self.provider not in self.MODELS:
90
+ raise ValueError(f"Provider {self.provider} not configured")
91
+
92
+ def get_model_name(self, task: str) -> str:
93
+ """Get model name for specific task"""
94
+ if task not in self.MODELS[self.provider]:
95
+ raise ValueError(f"Task {task} not configured for provider {self.provider}")
96
+ return self.MODELS[self.provider][task]
97
+
98
+ def get_model_params(self, task: str) -> Dict[str, Any]:
99
+ """Get model parameters for specific task"""
100
+ return self.DEFAULT_PARAMS.get(task, {}).copy()
101
+
102
+ def create_llm(self, task: str):
103
+ """Create LLM instance for specific task"""
104
+ model_name = self.get_model_name(task)
105
+ params = self.get_model_params(task)
106
+
107
+ if self.provider == ModelProvider.GOOGLE:
108
+ return ChatGoogleGenerativeAI(
109
+ model=model_name,
110
+ google_api_key=self.api_key,
111
+ temperature=params.get("temperature", 0.3),
112
+ max_tokens=params.get("max_tokens", None),
113
+ top_p=params.get("top_p", 0.95),
114
+ )
115
+ elif self.provider == ModelProvider.OPENAI:
116
+ # Would use ChatOpenAI here
117
+ pass
118
+
119
+ raise ValueError(f"Provider {self.provider} not implemented")
120
+
121
+ def create_embedding(self):
122
+ """Create embedding instance"""
123
+ if self.provider == ModelProvider.GOOGLE:
124
+ return GoogleGenerativeAIEmbeddings(
125
+ model="models/text-embedding-004",
126
+ google_api_key=self.api_key
127
+ )
128
+ elif self.provider == ModelProvider.OPENAI:
129
+ # Would use OpenAIEmbeddings here
130
+ pass
131
+
132
+ raise ValueError(f"Provider {self.provider} not implemented")
133
+
134
+ def create_direct_client(self):
135
+ """Create direct client for providers that need it"""
136
+ if self.provider == ModelProvider.GOOGLE:
137
+ client = genai.Client(api_key=self.api_key)
138
+ return client
139
+ return None
140
+
141
+ # Global configuration instance
142
+ llm_config = LLMConfig()
config/settings.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from pydantic_settings import BaseSettings
3
+ from .constants import MAX_FILE_SIZE, MAX_TOTAL_SIZE, ALLOWED_TYPES
4
+ import os
5
+
6
+ class Settings(BaseSettings):
7
+ # LLM Provider settings
8
+ LLM_PROVIDER: str = "google" # "google" or "openai"
9
+
10
+ # API Keys
11
+ GOOGLE_API_KEY: str
12
+ OPENAI_API_KEY: str = ""
13
+
14
+ # Optional settings with defaults
15
+ MAX_FILE_SIZE: int = MAX_FILE_SIZE
16
+ MAX_TOTAL_SIZE: int = MAX_TOTAL_SIZE
17
+ ALLOWED_TYPES: list = ALLOWED_TYPES
18
+
19
+ # Database settings
20
+ CHROMA_DB_PATH: str = "./chroma_db"
21
+ CHROMA_COLLECTION_NAME: str = "documents"
22
+
23
+ # Retrieval settings
24
+ VECTOR_SEARCH_K: int = 10
25
+ HYBRID_RETRIEVER_WEIGHTS: list = [0.4, 0.6]
26
+
27
+ # Logging settings
28
+ LOG_LEVEL: str = "INFO"
29
+
30
+ # Cache settings
31
+ CACHE_DIR: str = "document_cache"
32
+ CACHE_EXPIRE_DAYS: int = 7
33
+
34
+ class Config:
35
+ env_file = ".env"
36
+ env_file_encoding = "utf-8"
37
+
38
+ settings = Settings()
config/test.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Test file for LLMConfig
3
+ Run: python test_llm_config.py
4
+ """
5
+
6
+ from llm_config import LLMConfig, ModelProvider
7
+
8
+ def test_basic_config():
9
+ print("🔹 Testing basic configuration...")
10
+ config = LLMConfig(provider=ModelProvider.GOOGLE)
11
+ print("Provider:", config.provider.value)
12
+ print("API Key loaded: ✅")
13
+
14
+ def test_model_names():
15
+ print("\n🔹 Testing model name resolution...")
16
+ config = LLMConfig()
17
+ print("Research model:", config.get_model_name("research"))
18
+ print("Verification model:", config.get_model_name("verification"))
19
+ print("Relevance model:", config.get_model_name("relevance"))
20
+
21
+ def test_llm_creation():
22
+ print("\n🔹 Testing LLM creation...")
23
+ config = LLMConfig()
24
+ llm = config.create_llm("research")
25
+ print("LLM instance created:", type(llm))
26
+
27
+ def test_embedding_creation():
28
+ print("\n🔹 Testing embedding creation...")
29
+ config = LLMConfig()
30
+ embedding = config.create_embedding()
31
+ print("Embedding instance created:", type(embedding))
32
+
33
+ def test_direct_client():
34
+ print("\n🔹 Testing direct Gemini client...")
35
+ config = LLMConfig()
36
+ client = config.create_direct_client()
37
+ print("Direct client created:", type(client))
38
+
39
+ if __name__ == "__main__":
40
+ try:
41
+ test_basic_config()
42
+ test_model_names()
43
+ test_llm_creation()
44
+ test_embedding_creation()
45
+ test_direct_client()
46
+ print("\n✅ ALL TESTS PASSED")
47
+ except Exception as e:
48
+ print("\n❌ TEST FAILED")
49
+ print("Error:", e)
document_processor/__init.py__ ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .file_handler import DocumentProcessor
2
+
3
+ __all__ = ["DocumentProcessor"]
document_processor/file_handler.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hashlib
3
+ import pickle
4
+ from datetime import datetime, timedelta
5
+ from pathlib import Path
6
+ from typing import List
7
+ from docling.document_converter import DocumentConverter
8
+ from langchain_text_splitters import MarkdownHeaderTextSplitter
9
+ from config import constants
10
+ from config.settings import settings
11
+ from utils.logging import logger
12
+
13
+ class DocumentProcessor:
14
+ def __init__(self):
15
+ self.headers = [("#", "Header 1"), ("##", "Header 2")]
16
+ self.cache_dir = Path(settings.CACHE_DIR)
17
+ self.cache_dir.mkdir(parents=True, exist_ok=True)
18
+
19
+ def validate_files(self, files: List) -> None:
20
+ """Validate the total size of the uploaded files."""
21
+ total_size = sum(os.path.getsize(f.name) for f in files)
22
+ if total_size > constants.MAX_TOTAL_SIZE:
23
+ raise ValueError(f"Total size exceeds {constants.MAX_TOTAL_SIZE//1024//1024}MB limit")
24
+
25
+ def process(self, files: List) -> List:
26
+ """Process files with caching for subsequent queries"""
27
+ self.validate_files(files)
28
+ all_chunks = []
29
+ seen_hashes = set()
30
+
31
+ for file in files:
32
+ try:
33
+ # Generate content-based hash for caching
34
+ with open(file.name, "rb") as f:
35
+ file_hash = self._generate_hash(f.read())
36
+
37
+ cache_path = self.cache_dir / f"{file_hash}.pkl"
38
+
39
+ if self._is_cache_valid(cache_path):
40
+ logger.info(f"Loading from cache: {file.name}")
41
+ chunks = self._load_from_cache(cache_path)
42
+ else:
43
+ logger.info(f"Processing and caching: {file.name}")
44
+ chunks = self._process_file(file)
45
+ self._save_to_cache(chunks, cache_path)
46
+
47
+ # Deduplicate chunks across files
48
+ for chunk in chunks:
49
+ chunk_hash = self._generate_hash(chunk.page_content.encode())
50
+ if chunk_hash not in seen_hashes:
51
+ all_chunks.append(chunk)
52
+ seen_hashes.add(chunk_hash)
53
+
54
+ except Exception as e:
55
+ logger.error(f"Failed to process {file.name}: {str(e)}")
56
+ continue
57
+
58
+ logger.info(f"Total unique chunks: {len(all_chunks)}")
59
+ return all_chunks
60
+
61
+ def _process_file(self, file) -> List:
62
+ """Original processing logic with Docling"""
63
+ if not file.name.endswith(('.pdf', '.docx', '.txt', '.md')):
64
+ logger.warning(f"Skipping unsupported file type: {file.name}")
65
+ return []
66
+
67
+ converter = DocumentConverter()
68
+ markdown = converter.convert(file.name).document.export_to_markdown()
69
+ splitter = MarkdownHeaderTextSplitter(self.headers)
70
+ return splitter.split_text(markdown)
71
+
72
+ def _generate_hash(self, content: bytes) -> str:
73
+ return hashlib.sha256(content).hexdigest()
74
+
75
+ def _save_to_cache(self, chunks: List, cache_path: Path):
76
+ with open(cache_path, "wb") as f:
77
+ pickle.dump({
78
+ "timestamp": datetime.now().timestamp(),
79
+ "chunks": chunks
80
+ }, f)
81
+
82
+ def _load_from_cache(self, cache_path: Path) -> List:
83
+ with open(cache_path, "rb") as f:
84
+ data = pickle.load(f)
85
+ return data["chunks"]
86
+
87
+ def _is_cache_valid(self, cache_path: Path) -> bool:
88
+ if not cache_path.exists():
89
+ return False
90
+
91
+ cache_age = datetime.now() - datetime.fromtimestamp(cache_path.stat().st_mtime)
92
+ return cache_age < timedelta(days=settings.CACHE_EXPIRE_DAYS)
requirements.txt ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core Python
2
+ python-dotenv==1.0.1
3
+ pydantic==2.10.6
4
+ pydantic-settings==2.7.1
5
+ typing-extensions==4.12.2
6
+
7
+ # Web Framework
8
+ fastapi==0.115.7
9
+ uvicorn[standard]==0.34.0
10
+ gradio==5.13.2
11
+
12
+ # LangChain Core
13
+ langchain==0.3.16
14
+ langchain-core==0.3.32
15
+ langchain-community==0.3.16
16
+ langchain-text-splitters==0.3.5
17
+ langgraph==0.2.68
18
+
19
+ # LLM Providers
20
+ langchain-google-genai==2.1.2
21
+ google-generativeai==0.8.4
22
+ langchain-openai==0.3.2
23
+ openai==1.60.2
24
+
25
+ # Embeddings & Vector Stores
26
+ chromadb==0.6.3
27
+ langchain-chroma==0.2.4
28
+ sentence-transformers==3.0.1
29
+
30
+ # Document Processing
31
+ docling==2.15.0
32
+ pypdf==5.2.0
33
+ python-docx==1.1.2
34
+ markdown==3.6
35
+ beautifulsoup4==4.12.3
36
+ lxml==5.3.0
37
+
38
+ # Text Processing & Retrieval
39
+ rank-bm25==0.2.2
40
+ nltk==3.9.1
41
+ scikit-learn==1.6.0
42
+ numpy==1.26.4
43
+
44
+ # Caching & Hashing
45
+ cachetools==5.5.1
46
+
47
+ # Logging
48
+ loguru==0.7.3
49
+
50
+ # Utilities
51
+ python-multipart==0.0.20
52
+ aiofiles==23.2.1
53
+ pillow==10.4.0
54
+ tqdm==4.67.1
55
+ tenacity==9.0.0
56
+ backoff==2.2.1
57
+ httpx==0.28.1
58
+ requests==2.32.3
59
+ orjson==3.10.15
60
+
61
+ # Development & Testing
62
+ pytest==8.3.4
63
+ pytest-asyncio==0.23.7
64
+ black==24.10.0
65
+ isort==5.13.2
66
+ mypy==1.13.0
67
+ ruff==0.9.3
retriever/__init.py__ ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .builder import RetrieverBuilder
2
+
3
+ __all__ = ["RetrieverBuilder"]
retriever/builder.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.vectorstores import Chroma
2
+ from langchain_community.retrievers import BM25Retriever
3
+ from langchain.retrievers import EnsembleRetriever
4
+ from config.settings import settings
5
+ from config.llm_config import llm_config
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ class RetrieverBuilder:
11
+ def __init__(self):
12
+ """Initialize the retriever builder with embeddings."""
13
+ logger.info("Initializing RetrieverBuilder...")
14
+
15
+ # Get embeddings from configuration
16
+ self.embeddings = llm_config.create_embedding()
17
+
18
+ logger.info("RetrieverBuilder initialized successfully.")
19
+
20
+ def build_hybrid_retriever(self, docs):
21
+ """Build a hybrid retriever using BM25 and vector-based retrieval."""
22
+ try:
23
+ logger.info(f"Building hybrid retriever with {len(docs)} documents")
24
+
25
+ # Create Chroma vector store
26
+ vector_store = Chroma.from_documents(
27
+ documents=docs,
28
+ embedding=self.embeddings,
29
+ persist_directory=settings.CHROMA_DB_PATH,
30
+ collection_name=settings.CHROMA_COLLECTION_NAME
31
+ )
32
+ logger.info("Vector store created successfully.")
33
+
34
+ # Create BM25 retriever
35
+ bm25 = BM25Retriever.from_documents(docs)
36
+ logger.info("BM25 retriever created successfully.")
37
+
38
+ # Create vector-based retriever
39
+ vector_retriever = vector_store.as_retriever(
40
+ search_kwargs={"k": settings.VECTOR_SEARCH_K}
41
+ )
42
+ logger.info("Vector retriever created successfully.")
43
+
44
+ # Combine retrievers into a hybrid retriever
45
+ hybrid_retriever = EnsembleRetriever(
46
+ retrievers=[bm25, vector_retriever],
47
+ weights=settings.HYBRID_RETRIEVER_WEIGHTS
48
+ )
49
+ logger.info("Hybrid retriever created successfully.")
50
+
51
+ return hybrid_retriever
52
+
53
+ except Exception as e:
54
+ logger.error(f"Failed to build hybrid retriever: {e}")
55
+ raise
utils/__init.py__ ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .logging import logger
2
+
3
+ __all__ = ["logger"]
utils/logging.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from loguru import logger
2
+
3
+ logger.add(
4
+ "app.log",
5
+ rotation="10 MB",
6
+ retention="30 days",
7
+ format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}"
8
+ )