added base logic
Browse files- README.md +1 -1
- agent/graph.py +36 -0
- agent/nodes.py +97 -0
- agent/prompts.py +75 -0
- agent/state.py +19 -0
- agent/tools.py +99 -0
- core/chat_interface.py +0 -0
- core/rag_agent.py +29 -0
- knowledge_base/chroma.py +0 -3
- knowledge_base/embeddings.py +0 -22
- requirements.txt +2 -1
- testing_main.py +11 -0
- ui/gradio_components.py +91 -0
README.md
CHANGED
|
@@ -9,7 +9,7 @@ rag_agent/
|
|
| 9 |
β βββ rag_system.py
|
| 10 |
βββ knowledge_base/ # Storage management
|
| 11 |
β βββ chroma.py # Parent chunks storage (JSON)
|
| 12 |
-
β βββ vector_db_manager.py
|
| 13 |
βββ agent_logic/ # LangGraph agent workflow
|
| 14 |
β βββ edges.py # Conditional routing logic
|
| 15 |
β βββ graph.py # Graph construction and compilation
|
|
|
|
| 9 |
β βββ rag_system.py
|
| 10 |
βββ knowledge_base/ # Storage management
|
| 11 |
β βββ chroma.py # Parent chunks storage (JSON)
|
| 12 |
+
β βββ vector_db_manager.py
|
| 13 |
βββ agent_logic/ # LangGraph agent workflow
|
| 14 |
β βββ edges.py # Conditional routing logic
|
| 15 |
β βββ graph.py # Graph construction and compilation
|
agent/graph.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph.graph import START, StateGraph, END
|
| 2 |
+
from langgraph.checkpoint.memory import MemorySaver
|
| 3 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
from .state import AgentState
|
| 7 |
+
from .nodes import *
|
| 8 |
+
|
| 9 |
+
def create_agent_graph(llm, tools) -> StateGraph:
|
| 10 |
+
"""Create the RAG agent graph."""
|
| 11 |
+
llm_with_tools = llm.with_tools(tools)
|
| 12 |
+
|
| 13 |
+
graph = StateGraph(AgentState)
|
| 14 |
+
checkpointer = MemorySaver()
|
| 15 |
+
|
| 16 |
+
tool_node = ToolNode(tools)
|
| 17 |
+
|
| 18 |
+
# Nodes
|
| 19 |
+
graph.add_node("summarize", partial(analyze_chat_and_summarize, llm=llm)) # summarize last 6 messages
|
| 20 |
+
graph.add_node("analyze_rewrite", partial(analyze_and_rewrite_query, llm=llm)) # analyze and rewrite query
|
| 21 |
+
graph.add_node("agent", partial(agent_node, llm_with_tools=llm_with_tools)) # generate answer based on retrieved info
|
| 22 |
+
graph.add_node("tools", tool_node)
|
| 23 |
+
|
| 24 |
+
graph.add_edge(START, "summarize")
|
| 25 |
+
graph.add_edge("summarize", "analyze_rewrite")
|
| 26 |
+
graph.add_conditional_edges("analyze_rewrite", route_after_rewrite)
|
| 27 |
+
graph.add_edge("human_input", "analyze_rewrite")
|
| 28 |
+
graph.add_conditional_edges("agent", tools_condition)
|
| 29 |
+
graph.add_edge("tools", "agent")
|
| 30 |
+
|
| 31 |
+
agent_graph = graph.compile(
|
| 32 |
+
checkpointer=checkpointer,
|
| 33 |
+
interrupt_before=["human_input"]
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
return agent_graph
|
agent/nodes.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
from .state import AgentState, QueryAnalysis
|
| 5 |
+
from .prompts import *
|
| 6 |
+
|
| 7 |
+
def analyze_chat_and_summarize(state: AgentState, llm):
|
| 8 |
+
"""
|
| 9 |
+
Analyzes chat history and summarizes key points for context.
|
| 10 |
+
"""
|
| 11 |
+
if len(state["messages"]) < 4: # Need some history to summarize
|
| 12 |
+
return {"conversation_summary": ""}
|
| 13 |
+
|
| 14 |
+
# Extract relevant messages (excluding current query and system messages)
|
| 15 |
+
relevant_msgs = [
|
| 16 |
+
msg for msg in state["messages"][:-1] # Exclude current query
|
| 17 |
+
if isinstance(msg, (HumanMessage, AIMessage))
|
| 18 |
+
and not getattr(msg, "tool_calls", None)
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
if not relevant_msgs:
|
| 22 |
+
return {"conversation_summary": ""}
|
| 23 |
+
|
| 24 |
+
summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
|
| 25 |
+
Discard irrelevant information, such as misunderstandings or off-topic queries/responses.
|
| 26 |
+
If there are no key topics, return an empty string.
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
for msg in relevant_msgs[-6:]: # Last 6 messages for context
|
| 30 |
+
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
|
| 31 |
+
summary_prompt += f"{role}: {msg.content}\n"
|
| 32 |
+
|
| 33 |
+
summary_prompt += "\nBrief Summary:"
|
| 34 |
+
summary_response = llm.with_config(temperature=0.3).invoke([SystemMessage(content=summary_prompt)])
|
| 35 |
+
return {"conversation_summary": summary_response.content}
|
| 36 |
+
|
| 37 |
+
def analyze_and_rewrite_query(state: AgentState, llm):
|
| 38 |
+
"""
|
| 39 |
+
Analyzes user query and rewrites it for clarity, optionally using conversation context.
|
| 40 |
+
"""
|
| 41 |
+
last_message = state["messages"][-1]
|
| 42 |
+
conversation_summary = state.get("conversation_summary", "")
|
| 43 |
+
|
| 44 |
+
context_section = (
|
| 45 |
+
f"**Conversation Context:**\n{conversation_summary}"
|
| 46 |
+
if conversation_summary.strip()
|
| 47 |
+
else "**Conversation Context:**\n[First query in conversation]"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Create analysis prompt
|
| 51 |
+
query_analysis_prompt = get_query_analysis_prompt(last_message.content, conversation_summary)
|
| 52 |
+
|
| 53 |
+
llm_with_structure = llm.with_config(temperature=0.3).with_structured_output(QueryAnalysis)
|
| 54 |
+
response = llm_with_structure.invoke([SystemMessage(content=query_analysis_prompt)])
|
| 55 |
+
|
| 56 |
+
if response.is_clear:
|
| 57 |
+
# Remove all non-system messages
|
| 58 |
+
delete_all = [
|
| 59 |
+
RemoveMessage(id=m.id)
|
| 60 |
+
for m in state["messages"]
|
| 61 |
+
if not isinstance(m, SystemMessage)
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
# Format rewritten query
|
| 65 |
+
rewritten = (
|
| 66 |
+
"\n".join([f"{i+1}. {q}" for i, q in enumerate(response.questions)])
|
| 67 |
+
if len(response.questions) > 1
|
| 68 |
+
else response.questions[0]
|
| 69 |
+
)
|
| 70 |
+
return {
|
| 71 |
+
"questionIsClear": True,
|
| 72 |
+
"messages": delete_all + [HumanMessage(content=rewritten)]
|
| 73 |
+
}
|
| 74 |
+
else:
|
| 75 |
+
clarification = response.clarification_needed or "I need more information to understand your question."
|
| 76 |
+
return {
|
| 77 |
+
"questionIsClear": False,
|
| 78 |
+
"messages": [AIMessage(content=clarification)]
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
def human_input_node(state: AgentState):
|
| 82 |
+
"""Placeholder node for human-in-the-loop interruption"""
|
| 83 |
+
return {}
|
| 84 |
+
|
| 85 |
+
def route_after_rewrite(state: AgentState) -> Literal["agent", "human_input"]:
|
| 86 |
+
"""Route to agent if question is clear, otherwise wait for human input"""
|
| 87 |
+
return "agent" if state.get("questionIsClear", False) else "human_input"
|
| 88 |
+
|
| 89 |
+
def agent_node(state: AgentState, llm_with_tools):
|
| 90 |
+
"""Main agent node that processes queries using tools"""
|
| 91 |
+
system_prompt = get_system_prompt()
|
| 92 |
+
messages = [system_prompt] + state["messages"]
|
| 93 |
+
response = llm_with_tools.invoke(messages)
|
| 94 |
+
return {"messages": [response]}
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
pass
|
agent/prompts.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
| 2 |
+
|
| 3 |
+
def get_system_prompt() -> SystemMessage:
|
| 4 |
+
"""Generate the system prompt for the RAG agent."""
|
| 5 |
+
return SystemMessage(content="""
|
| 6 |
+
You are an intelligent assistant that MUST use the available tools to answer questions.
|
| 7 |
+
|
| 8 |
+
**MANDATORY WORKFLOW β Follow these steps for EVERY question:**
|
| 9 |
+
1. **Call `search_chroma`** with the user's query (K = 3β7) to find the most relevant chunks in the Chroma vector store.
|
| 10 |
+
2. **Review the retrieved chunks** and identify the relevant ones. The chunks will contain content and metadata (such as `parent_id` and `source`).
|
| 11 |
+
3. **If additional context is needed**, retrieve more details from the source tools (e.g., Wikipedia or Arxiv) to provide the full answer.
|
| 12 |
+
4. **Use metadata** such as `source` and `parent_id` to help clarify or support the answer when applicable.
|
| 13 |
+
5. **Answer using ONLY the retrieved information**:
|
| 14 |
+
- Combine relevant chunks and use metadata (e.g., citation sources) as needed to clarify or support the response.
|
| 15 |
+
6. **If no relevant information is found**, rewrite the query into an **answer-focused declarative statement** and search again **only once** using `search_chroma`.
|
| 16 |
+
7. **Return the final answer** derived from the most relevant results.
|
| 17 |
+
""")
|
| 18 |
+
|
| 19 |
+
def get_conversation_summary_prompt(messages):
|
| 20 |
+
"""Generate a prompt for conversation summarization."""
|
| 21 |
+
summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
|
| 22 |
+
Discard irrelevant information, such as misunderstandings or off-topic queries/responses.
|
| 23 |
+
If there are no key topics, return an empty string.
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
for msg in messages[-6:]:
|
| 28 |
+
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
|
| 29 |
+
summary_prompt += f"{role}: {msg.content}\n"
|
| 30 |
+
|
| 31 |
+
summary_prompt += "\n**Brief Summary:**"
|
| 32 |
+
return summary_prompt
|
| 33 |
+
|
| 34 |
+
def get_query_analysis_prompt(query: str, conversation_summary: str = "") -> str:
|
| 35 |
+
"""Generate a prompt for query analysis and rewriting."""
|
| 36 |
+
context_section = (
|
| 37 |
+
f"**Conversation Context:**\n{conversation_summary}"
|
| 38 |
+
if conversation_summary.strip()
|
| 39 |
+
else "**Conversation Context:**\n[First query in conversation]"
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
return f"""
|
| 43 |
+
**Rewrite the user's query** to be clear, self-contained, and optimized for information retrieval.
|
| 44 |
+
|
| 45 |
+
**User Query:**
|
| 46 |
+
"{query}"
|
| 47 |
+
|
| 48 |
+
{context_section}
|
| 49 |
+
|
| 50 |
+
**Instructions:**
|
| 51 |
+
|
| 52 |
+
1. **Resolve references for follow-ups:**
|
| 53 |
+
- If the query uses pronouns or refers to previous topics, use the context to make it self-contained.
|
| 54 |
+
|
| 55 |
+
2. **Ensure clarity for new queries:**
|
| 56 |
+
- Make the query specific, concise, and unambiguous.
|
| 57 |
+
|
| 58 |
+
3. **Correct errors and interpret intent:**
|
| 59 |
+
- If the query is grammatically incorrect, contains typos, or has abbreviations, correct it and infer the intended meaning.
|
| 60 |
+
|
| 61 |
+
4. **Split only when necessary:**
|
| 62 |
+
- If multiple distinct questions exist, split into **up to 3 focused sub-queries** to avoid over-segmentation.
|
| 63 |
+
- Each sub-query must still be meaningful on its own.
|
| 64 |
+
|
| 65 |
+
5. **Optimize for search:**
|
| 66 |
+
- Use **keywords, proper nouns, numbers, dates, and technical terms**.
|
| 67 |
+
- Remove conversational filler, vague words, and redundancies.
|
| 68 |
+
- Make the query concise and focused for information retrieval.
|
| 69 |
+
|
| 70 |
+
6. **Mark as unclear if intent is missing:**
|
| 71 |
+
- This includes nonsense, gibberish, insults, or statements without an apparent question.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
if __name__ == "__main__":
|
| 75 |
+
pass
|
agent/state.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import TypedDict, Annotated, Sequence, Optional, List
|
| 2 |
+
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
|
| 3 |
+
from langgraph.graph.message import add_messages
|
| 4 |
+
from pydantic import BaseModel, Field
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class AgentState(TypedDict):
|
| 8 |
+
messages: Annotated[Sequence[AnyMessage], add_messages]
|
| 9 |
+
questionIsClear: bool
|
| 10 |
+
conversation_summary: str = ""
|
| 11 |
+
|
| 12 |
+
class QueryAnalysis(BaseModel):
|
| 13 |
+
"""Structured output for query analysis"""
|
| 14 |
+
is_clear: bool = Field(description="Indicates if the user's question is clear and answerable")
|
| 15 |
+
questions: List[str] = Field(description="List of rewritten, self-contained questions")
|
| 16 |
+
clarification_needed: str = Field(description="Explanation if the question is unclear")
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
pass
|
agent/tools.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from typing import List
|
| 3 |
+
from langchain_core.tools import tool
|
| 4 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 5 |
+
from langchain_chroma import Chroma
|
| 6 |
+
from langchain_community.tools.tavily_search import TavilySearchResults
|
| 7 |
+
from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from config import configs
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def intialize_chroma_vectorstore():
|
| 14 |
+
"""Initialize and return the Chroma vector store."""
|
| 15 |
+
dense_embeddings = HuggingFaceEmbeddings(
|
| 16 |
+
model_name=configs["EMBEDDING_MODEL_NAME"]
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
vectorstore = Chroma(
|
| 20 |
+
persist_directory=configs["PERSIST_PATH"],
|
| 21 |
+
embedding_function=dense_embeddings,
|
| 22 |
+
collection_name=configs["COLLECTION_NAME"]
|
| 23 |
+
)
|
| 24 |
+
return vectorstore
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@tool
|
| 28 |
+
def search_chroma(vectorstore: Chroma, query: str, k: int = 5) -> List[dict]:
|
| 29 |
+
"""Search for the top K most relevant chunks from Chroma vector store.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
query: Search query string
|
| 33 |
+
k: Number of results to return
|
| 34 |
+
"""
|
| 35 |
+
try:
|
| 36 |
+
results = vectorstore.similarity_search(query, k=k, score_threshold=0.7)
|
| 37 |
+
|
| 38 |
+
return [
|
| 39 |
+
{
|
| 40 |
+
"content": doc.page_content,
|
| 41 |
+
"parent_id": doc.metadata.get("parent_id", ""),
|
| 42 |
+
"source": doc.metadata.get("source", "")
|
| 43 |
+
}
|
| 44 |
+
for doc in results
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
except Exception as e:
|
| 48 |
+
print(f"Error searching chunks: {e}")
|
| 49 |
+
return []
|
| 50 |
+
|
| 51 |
+
@tool
|
| 52 |
+
def wikipedia_search(query: str) -> dict:
|
| 53 |
+
"""Search Wikipedia for a query and return up to 3 results.
|
| 54 |
+
Args:
|
| 55 |
+
query: The search query.
|
| 56 |
+
Returns:
|
| 57 |
+
dict with key 'wiki_results', containing a list of search results with
|
| 58 |
+
'title', 'url', and 'snippet'.
|
| 59 |
+
"""
|
| 60 |
+
try:
|
| 61 |
+
search_docs = WikipediaLoader(query=query, load_max_docs=3).load()
|
| 62 |
+
results = [
|
| 63 |
+
{
|
| 64 |
+
"title": doc.metadata.get("title", ""),
|
| 65 |
+
"url": doc.metadata.get("url", ""),
|
| 66 |
+
"snippet": doc.page_content,
|
| 67 |
+
}
|
| 68 |
+
for doc in search_docs
|
| 69 |
+
]
|
| 70 |
+
return {"wiki_results": results}
|
| 71 |
+
except Exception as e:
|
| 72 |
+
return {"wiki_results": f"Error retrieving results: {str(e)}"}
|
| 73 |
+
|
| 74 |
+
@tool
|
| 75 |
+
def arxiv_search(query: str) -> dict:
|
| 76 |
+
"""Search Arxiv for a query and return up to 3 results.
|
| 77 |
+
Args:
|
| 78 |
+
query: The search query.
|
| 79 |
+
Returns:
|
| 80 |
+
dict with key 'arxiv_results', containing a list of search results with
|
| 81 |
+
'title', 'url', and 'snippet'.
|
| 82 |
+
"""
|
| 83 |
+
try:
|
| 84 |
+
search_docs = ArxivLoader(query=query, load_max_docs=3).load()
|
| 85 |
+
results = [
|
| 86 |
+
{
|
| 87 |
+
"title": doc.metadata.get("title", ""),
|
| 88 |
+
"url": doc.metadata.get("url", ""),
|
| 89 |
+
"snippet": doc.page_content,
|
| 90 |
+
}
|
| 91 |
+
for doc in search_docs
|
| 92 |
+
]
|
| 93 |
+
return {"arxiv_results": results}
|
| 94 |
+
except Exception as e:
|
| 95 |
+
return {"arxiv_results": f"Error retrieving results: {str(e)}"}
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
pass
|
core/chat_interface.py
ADDED
|
File without changes
|
core/rag_agent.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uuid
|
| 2 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 3 |
+
import config
|
| 4 |
+
from agent.tools import *
|
| 5 |
+
from agent.graph import create_agent_graph
|
| 6 |
+
|
| 7 |
+
class RAGAgent:
|
| 8 |
+
|
| 9 |
+
def __init__(self, collection_name=config.CHILD_COLLECTION):
|
| 10 |
+
self.collection_name = collection_name
|
| 11 |
+
self.retriever = intialize_chroma_vectorstore()
|
| 12 |
+
self.thread_id = str(uuid.uuid4())
|
| 13 |
+
|
| 14 |
+
self.llm = ChatGoogleGenerativeAI(model=config.LLM_MODEL, temperature=config.LLM_TEMPERATURE)
|
| 15 |
+
tools = []
|
| 16 |
+
self.agent_graph = create_agent_graph(self.llm, tools)
|
| 17 |
+
|
| 18 |
+
def get_config(self):
|
| 19 |
+
return {"configurable": {"thread_id": self.thread_id}}
|
| 20 |
+
|
| 21 |
+
def reset_thread(self):
|
| 22 |
+
try:
|
| 23 |
+
self.agent_graph.checkpointer.delete_thread(self.thread_id)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
print(f"Warning: Could not delete thread {self.thread_id}: {e}")
|
| 26 |
+
self.thread_id = str(uuid.uuid4())
|
| 27 |
+
|
| 28 |
+
if __name__ == "__main__":
|
| 29 |
+
pass
|
knowledge_base/chroma.py
CHANGED
|
@@ -50,8 +50,5 @@ if __name__ == "__main__":
|
|
| 50 |
persist_directory=configs["PERSIST_PATH"]
|
| 51 |
)
|
| 52 |
|
| 53 |
-
# Explicitly persist the data for immediate use
|
| 54 |
-
vectorstore.persist()
|
| 55 |
-
|
| 56 |
print("β
Success: Chroma vector store created and data persisted.")
|
| 57 |
print(f"The vector database is now ready for query using the collection: '{configs['COLLECTION_NAME']}'")
|
|
|
|
| 50 |
persist_directory=configs["PERSIST_PATH"]
|
| 51 |
)
|
| 52 |
|
|
|
|
|
|
|
|
|
|
| 53 |
print("β
Success: Chroma vector store created and data persisted.")
|
| 54 |
print(f"The vector database is now ready for query using the collection: '{configs['COLLECTION_NAME']}'")
|
knowledge_base/embeddings.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
from langchain_huggingface import HuggingFaceEmbeddings
|
| 2 |
-
from langchain_chroma import Chroma
|
| 3 |
-
|
| 4 |
-
# 1. Define the custom embedding object
|
| 5 |
-
dense_embeddings = HuggingFaceEmbeddings(
|
| 6 |
-
model_name="sentence-transformers/all-mpnet-base-v2"
|
| 7 |
-
)
|
| 8 |
-
|
| 9 |
-
# 2. Initialize the LangChain Chroma vector store, passing the embeddings
|
| 10 |
-
vectorstore = Chroma.from_documents(
|
| 11 |
-
documents=["./docs/markdowns"], # Placeholder for actual documents
|
| 12 |
-
embedding=dense_embeddings,
|
| 13 |
-
collection_name="langchain_mpnet_collection",
|
| 14 |
-
persist_directory="./knowledge_base/chroma_data"
|
| 15 |
-
)
|
| 16 |
-
|
| 17 |
-
# 3. Save the database (essential for persistence)
|
| 18 |
-
vectorstore.persist()
|
| 19 |
-
print("LangChain Chroma vector store created with custom embeddings and persisted.")
|
| 20 |
-
|
| 21 |
-
if __name__ == "__main__":
|
| 22 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -12,4 +12,5 @@ pymupdf4llm
|
|
| 12 |
langchain-community
|
| 13 |
langchain_text_splitters
|
| 14 |
pymupdf-layout
|
| 15 |
-
sentence_transformers
|
|
|
|
|
|
| 12 |
langchain-community
|
| 13 |
langchain_text_splitters
|
| 14 |
pymupdf-layout
|
| 15 |
+
sentence_transformers
|
| 16 |
+
gradio
|
testing_main.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import configs
|
| 2 |
+
from knowledge_base.test_retrieval import PERSIST_PATH, EMBEDDING_MODEL_NAME, COLLECTION_NAME
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
print("Testing configuration values...")
|
| 6 |
+
for key, value in configs.items():
|
| 7 |
+
print(f"{key}: {value}")
|
| 8 |
+
print("β
Configuration test completed successfully.")
|
| 9 |
+
print(f"PERSIST_PATH: {PERSIST_PATH}")
|
| 10 |
+
print(f"EMBEDDING_MODEL_NAME: {EMBEDDING_MODEL_NAME}")
|
| 11 |
+
print(f"COLLECTION_NAME: {COLLECTION_NAME}")
|
ui/gradio_components.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from core.chat_interface import ChatInterface
|
| 3 |
+
from core.document_manager import DocumentManager
|
| 4 |
+
from core.rag_system import RAGSystem
|
| 5 |
+
|
| 6 |
+
def create_gradio_ui():
|
| 7 |
+
rag_system = RAGSystem()
|
| 8 |
+
rag_system.initialize()
|
| 9 |
+
|
| 10 |
+
doc_manager = DocumentManager(rag_system)
|
| 11 |
+
chat_interface = ChatInterface(rag_system)
|
| 12 |
+
|
| 13 |
+
def format_file_list():
|
| 14 |
+
files = doc_manager.get_markdown_files()
|
| 15 |
+
if not files:
|
| 16 |
+
return "π No documents available in the knowledge base"
|
| 17 |
+
return "\n".join([f"{f}" for f in files])
|
| 18 |
+
|
| 19 |
+
def upload_handler(files, progress=gr.Progress()):
|
| 20 |
+
if not files:
|
| 21 |
+
return None, format_file_list()
|
| 22 |
+
|
| 23 |
+
added, skipped = doc_manager.add_documents(
|
| 24 |
+
files,
|
| 25 |
+
progress_callback=lambda p, desc: progress(p, desc=desc)
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
gr.Info(f"β
Added: {added} | Skipped: {skipped}")
|
| 29 |
+
return None, format_file_list()
|
| 30 |
+
|
| 31 |
+
def clear_handler():
|
| 32 |
+
doc_manager.clear_all()
|
| 33 |
+
gr.Info(f"ποΈ Removed all documents")
|
| 34 |
+
return format_file_list()
|
| 35 |
+
|
| 36 |
+
def chat_handler(msg, hist):
|
| 37 |
+
return chat_interface.chat(msg, hist)
|
| 38 |
+
|
| 39 |
+
def clear_chat_handler():
|
| 40 |
+
chat_interface.clear_session()
|
| 41 |
+
|
| 42 |
+
with gr.Blocks(title="Agentic RAG") as demo:
|
| 43 |
+
|
| 44 |
+
with gr.Tab("Documents", elem_id="doc-management-tab"):
|
| 45 |
+
gr.Markdown("## Add New Documents")
|
| 46 |
+
gr.Markdown("Upload PDF or Markdown files. Duplicates will be automatically skipped.")
|
| 47 |
+
|
| 48 |
+
files_input = gr.File(
|
| 49 |
+
label="Drop PDF or Markdown files here",
|
| 50 |
+
file_count="multiple",
|
| 51 |
+
type="filepath",
|
| 52 |
+
height=200,
|
| 53 |
+
show_label=False
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
add_btn = gr.Button("Add Documents", variant="primary", size="md")
|
| 57 |
+
|
| 58 |
+
gr.Markdown("## Current Documents in the Knowledge Base")
|
| 59 |
+
file_list = gr.Textbox(
|
| 60 |
+
value=format_file_list(),
|
| 61 |
+
interactive=False,
|
| 62 |
+
lines = 7,
|
| 63 |
+
max_lines=10,
|
| 64 |
+
elem_id="file-list-box",
|
| 65 |
+
show_label=False
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
with gr.Row():
|
| 69 |
+
refresh_btn = gr.Button("Refresh", size="md")
|
| 70 |
+
clear_btn = gr.Button("Clear All", variant="stop", size="md")
|
| 71 |
+
|
| 72 |
+
add_btn.click(
|
| 73 |
+
upload_handler,
|
| 74 |
+
[files_input],
|
| 75 |
+
[files_input, file_list],
|
| 76 |
+
show_progress="corner"
|
| 77 |
+
)
|
| 78 |
+
refresh_btn.click(format_file_list, None, file_list)
|
| 79 |
+
clear_btn.click(clear_handler, None, file_list)
|
| 80 |
+
|
| 81 |
+
with gr.Tab("Chat"):
|
| 82 |
+
chatbot = gr.Chatbot(
|
| 83 |
+
height=600,
|
| 84 |
+
placeholder="Ask me anything about your documents!",
|
| 85 |
+
show_label=False
|
| 86 |
+
)
|
| 87 |
+
chatbot.clear(clear_chat_handler)
|
| 88 |
+
|
| 89 |
+
gr.ChatInterface(fn=chat_handler, chatbot=chatbot)
|
| 90 |
+
|
| 91 |
+
return demo
|