added gradio
Browse files- agent/graph.py +31 -16
- agent/more_nodes.py +97 -0
- agent/nodes.py +47 -83
- agent/prompts.py +13 -0
- agent/state.py +11 -1
- agent/tools.py +14 -16
- config.py +5 -1
- core/chat_interface.py +196 -0
- core/rag_agent.py +10 -7
- main.py +9 -4
- ui/gradio_components.py +106 -81
agent/graph.py
CHANGED
|
@@ -6,31 +6,46 @@ from functools import partial
|
|
| 6 |
from .state import AgentState
|
| 7 |
from .nodes import *
|
| 8 |
|
| 9 |
-
|
|
|
|
| 10 |
"""Create the RAG agent graph."""
|
| 11 |
-
llm_with_tools = llm.with_tools(tools)
|
| 12 |
|
| 13 |
graph = StateGraph(AgentState)
|
| 14 |
checkpointer = MemorySaver()
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
|
| 19 |
-
graph.
|
| 20 |
-
graph.add_node("analyze_rewrite", partial(analyze_and_rewrite_query, llm=llm)) # analyze and rewrite query
|
| 21 |
-
graph.add_node("agent", partial(agent_node, llm_with_tools=llm_with_tools)) # generate answer based on retrieved info
|
| 22 |
-
graph.add_node("tools", tool_node)
|
| 23 |
|
| 24 |
-
graph.add_edge(
|
| 25 |
-
graph.add_edge("summarize", "analyze_rewrite")
|
| 26 |
-
graph.add_conditional_edges("analyze_rewrite", route_after_rewrite)
|
| 27 |
-
graph.add_edge("human_input", "analyze_rewrite")
|
| 28 |
-
graph.add_conditional_edges("agent", tools_condition)
|
| 29 |
-
graph.add_edge("tools", "agent")
|
| 30 |
|
| 31 |
agent_graph = graph.compile(
|
| 32 |
checkpointer=checkpointer,
|
| 33 |
-
interrupt_before=["human_input"]
|
| 34 |
)
|
| 35 |
|
| 36 |
-
return agent_graph
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from .state import AgentState
|
| 7 |
from .nodes import *
|
| 8 |
|
| 9 |
+
|
| 10 |
+
def create_agent_graph(llm, vectordb, search_tools) -> StateGraph:
|
| 11 |
"""Create the RAG agent graph."""
|
|
|
|
| 12 |
|
| 13 |
graph = StateGraph(AgentState)
|
| 14 |
checkpointer = MemorySaver()
|
| 15 |
|
| 16 |
+
web_search_tool_node = ToolNode(search_tools)
|
| 17 |
+
|
| 18 |
+
# --- Nodes ---
|
| 19 |
+
graph.add_node("router_node", partial(router_node, llm=llm))
|
| 20 |
+
graph.add_node("vectordb_node", partial(vectordb_node, vectorstore=vectordb))
|
| 21 |
+
graph.add_node("web_search_node", web_search_tool_node)
|
| 22 |
+
graph.add_node("generate_node", partial(generate_node, llm=llm))
|
| 23 |
+
|
| 24 |
+
# --- Edges ---
|
| 25 |
+
graph.add_edge(START, "router_node")
|
| 26 |
+
|
| 27 |
+
graph.add_conditional_edges(
|
| 28 |
+
"router_node",
|
| 29 |
+
routing_logic,
|
| 30 |
+
{
|
| 31 |
+
# Output from routing_logic -> Target Node Name
|
| 32 |
+
"vectordb_node": "vectordb_node",
|
| 33 |
+
"web_search_node": "web_search_node",
|
| 34 |
+
"generate_node": "generate_node",
|
| 35 |
+
# If your logic has an 'else' that returns END, you don't list it here.
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
|
| 39 |
+
graph.add_edge("vectordb_node", "generate_node")
|
| 40 |
+
graph.add_edge("web_search_node", "generate_node")
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
+
graph.add_edge("generate_node", END)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
agent_graph = graph.compile(
|
| 45 |
checkpointer=checkpointer,
|
|
|
|
| 46 |
)
|
| 47 |
|
| 48 |
+
return agent_graph
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
pass
|
agent/more_nodes.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
|
| 2 |
+
from typing import Literal
|
| 3 |
+
|
| 4 |
+
from .state import AgentState, QueryAnalysis
|
| 5 |
+
from .prompts import *
|
| 6 |
+
|
| 7 |
+
def analyze_chat_and_summarize(state: AgentState, llm):
|
| 8 |
+
"""
|
| 9 |
+
Analyzes chat history and summarizes key points for context.
|
| 10 |
+
"""
|
| 11 |
+
if len(state["messages"]) < 4: # Need some history to summarize
|
| 12 |
+
return {"conversation_summary": ""}
|
| 13 |
+
|
| 14 |
+
# Extract relevant messages (excluding current query and system messages)
|
| 15 |
+
relevant_msgs = [
|
| 16 |
+
msg for msg in state["messages"][:-1] # Exclude current query
|
| 17 |
+
if isinstance(msg, (HumanMessage, AIMessage))
|
| 18 |
+
and not getattr(msg, "tool_calls", None)
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
if not relevant_msgs:
|
| 22 |
+
return {"conversation_summary": ""}
|
| 23 |
+
|
| 24 |
+
summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
|
| 25 |
+
Discard irrelevant information, such as misunderstandings or off-topic queries/responses.
|
| 26 |
+
If there are no key topics, return an empty string.
|
| 27 |
+
|
| 28 |
+
"""
|
| 29 |
+
for msg in relevant_msgs[-6:]: # Last 6 messages for context
|
| 30 |
+
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
|
| 31 |
+
summary_prompt += f"{role}: {msg.content}\n"
|
| 32 |
+
|
| 33 |
+
summary_prompt += "\nBrief Summary:"
|
| 34 |
+
summary_response = llm.with_config(temperature=0.3).invoke([SystemMessage(content=summary_prompt)])
|
| 35 |
+
return {"conversation_summary": summary_response.content}
|
| 36 |
+
|
| 37 |
+
def analyze_and_rewrite_query(state: AgentState, llm):
|
| 38 |
+
"""
|
| 39 |
+
Analyzes user query and rewrites it for clarity, optionally using conversation context.
|
| 40 |
+
"""
|
| 41 |
+
last_message = state["messages"][-1]
|
| 42 |
+
conversation_summary = state.get("conversation_summary", "")
|
| 43 |
+
|
| 44 |
+
context_section = (
|
| 45 |
+
f"**Conversation Context:**\n{conversation_summary}"
|
| 46 |
+
if conversation_summary.strip()
|
| 47 |
+
else "**Conversation Context:**\n[First query in conversation]"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Create analysis prompt
|
| 51 |
+
query_analysis_prompt = get_query_analysis_prompt(last_message.content, conversation_summary)
|
| 52 |
+
|
| 53 |
+
llm_with_structure = llm.with_config(temperature=0.3).with_structured_output(QueryAnalysis)
|
| 54 |
+
response = llm_with_structure.invoke([SystemMessage(content=query_analysis_prompt)])
|
| 55 |
+
|
| 56 |
+
if response.is_clear:
|
| 57 |
+
# Remove all non-system messages
|
| 58 |
+
delete_all = [
|
| 59 |
+
RemoveMessage(id=m.id)
|
| 60 |
+
for m in state["messages"]
|
| 61 |
+
if not isinstance(m, SystemMessage)
|
| 62 |
+
]
|
| 63 |
+
|
| 64 |
+
# Format rewritten query
|
| 65 |
+
rewritten = (
|
| 66 |
+
"\n".join([f"{i+1}. {q}" for i, q in enumerate(response.questions)])
|
| 67 |
+
if len(response.questions) > 1
|
| 68 |
+
else response.questions[0]
|
| 69 |
+
)
|
| 70 |
+
return {
|
| 71 |
+
"questionIsClear": True,
|
| 72 |
+
"messages": delete_all + [HumanMessage(content=rewritten)]
|
| 73 |
+
}
|
| 74 |
+
else:
|
| 75 |
+
clarification = response.clarification_needed or "I need more information to understand your question."
|
| 76 |
+
return {
|
| 77 |
+
"questionIsClear": False,
|
| 78 |
+
"messages": [AIMessage(content=clarification)]
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
def human_input_node(state: AgentState):
|
| 82 |
+
"""Placeholder node for human-in-the-loop interruption"""
|
| 83 |
+
return {}
|
| 84 |
+
|
| 85 |
+
def route_after_rewrite(state: AgentState) -> Literal["agent", "human_input"]:
|
| 86 |
+
"""Route to agent if question is clear, otherwise wait for human input"""
|
| 87 |
+
return "agent" if state.get("questionIsClear", False) else "human_input"
|
| 88 |
+
|
| 89 |
+
def agent_node(state: AgentState, llm_with_tools):
|
| 90 |
+
"""Main agent node that processes queries using tools"""
|
| 91 |
+
system_prompt = get_system_prompt()
|
| 92 |
+
messages = [system_prompt] + state["messages"]
|
| 93 |
+
response = llm_with_tools.invoke(messages)
|
| 94 |
+
return {"messages": [response]}
|
| 95 |
+
|
| 96 |
+
if __name__ == "__main__":
|
| 97 |
+
pass
|
agent/nodes.py
CHANGED
|
@@ -1,97 +1,61 @@
|
|
| 1 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
|
| 2 |
from typing import Literal
|
|
|
|
| 3 |
|
| 4 |
from .state import AgentState, QueryAnalysis
|
| 5 |
from .prompts import *
|
|
|
|
| 6 |
|
| 7 |
-
def analyze_chat_and_summarize(state: AgentState, llm):
|
| 8 |
-
"""
|
| 9 |
-
Analyzes chat history and summarizes key points for context.
|
| 10 |
-
"""
|
| 11 |
-
if len(state["messages"]) < 4: # Need some history to summarize
|
| 12 |
-
return {"conversation_summary": ""}
|
| 13 |
-
|
| 14 |
-
# Extract relevant messages (excluding current query and system messages)
|
| 15 |
-
relevant_msgs = [
|
| 16 |
-
msg for msg in state["messages"][:-1] # Exclude current query
|
| 17 |
-
if isinstance(msg, (HumanMessage, AIMessage))
|
| 18 |
-
and not getattr(msg, "tool_calls", None)
|
| 19 |
-
]
|
| 20 |
-
|
| 21 |
-
if not relevant_msgs:
|
| 22 |
-
return {"conversation_summary": ""}
|
| 23 |
-
|
| 24 |
-
summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
|
| 25 |
-
Discard irrelevant information, such as misunderstandings or off-topic queries/responses.
|
| 26 |
-
If there are no key topics, return an empty string.
|
| 27 |
|
|
|
|
| 28 |
"""
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
def
|
| 38 |
"""
|
| 39 |
-
|
| 40 |
"""
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
context_section = (
|
| 45 |
-
f"**Conversation Context:**\n{conversation_summary}"
|
| 46 |
-
if conversation_summary.strip()
|
| 47 |
-
else "**Conversation Context:**\n[First query in conversation]"
|
| 48 |
)
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
if len(response.questions) > 1
|
| 68 |
-
else response.questions[0]
|
| 69 |
-
)
|
| 70 |
-
return {
|
| 71 |
-
"questionIsClear": True,
|
| 72 |
-
"messages": delete_all + [HumanMessage(content=rewritten)]
|
| 73 |
-
}
|
| 74 |
-
else:
|
| 75 |
-
clarification = response.clarification_needed or "I need more information to understand your question."
|
| 76 |
-
return {
|
| 77 |
-
"questionIsClear": False,
|
| 78 |
-
"messages": [AIMessage(content=clarification)]
|
| 79 |
-
}
|
| 80 |
-
|
| 81 |
-
def human_input_node(state: AgentState):
|
| 82 |
-
"""Placeholder node for human-in-the-loop interruption"""
|
| 83 |
-
return {}
|
| 84 |
-
|
| 85 |
-
def route_after_rewrite(state: AgentState) -> Literal["agent", "human_input"]:
|
| 86 |
-
"""Route to agent if question is clear, otherwise wait for human input"""
|
| 87 |
-
return "agent" if state.get("questionIsClear", False) else "human_input"
|
| 88 |
-
|
| 89 |
-
def agent_node(state: AgentState, llm_with_tools):
|
| 90 |
-
"""Main agent node that processes queries using tools"""
|
| 91 |
-
system_prompt = get_system_prompt()
|
| 92 |
-
messages = [system_prompt] + state["messages"]
|
| 93 |
-
response = llm_with_tools.invoke(messages)
|
| 94 |
-
return {"messages": [response]}
|
| 95 |
-
|
| 96 |
if __name__ == "__main__":
|
| 97 |
pass
|
|
|
|
| 1 |
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
|
| 2 |
from typing import Literal
|
| 3 |
+
from langgraph.graph import START, END
|
| 4 |
|
| 5 |
from .state import AgentState, QueryAnalysis
|
| 6 |
from .prompts import *
|
| 7 |
+
from .tools import intialize_chroma_vectorstore
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
+
def router_node(state: AgentState, llm):
|
| 11 |
"""
|
| 12 |
+
Takes the query (and history). Decides the next step: vectordb, tools, or refuse.
|
| 13 |
+
"""
|
| 14 |
+
query = state["messages"][-1].content
|
| 15 |
+
rag_method_prompt = determine_rag_method_prompt()
|
| 16 |
+
rag_method_result = llm.invoke([rag_method_prompt, HumanMessage(content=query)])
|
| 17 |
+
rag_method = rag_method_result.content.strip().upper()
|
| 18 |
+
state["rag_method"] = rag_method
|
| 19 |
+
return state
|
| 20 |
+
|
| 21 |
+
def routing_logic(self, state: AgentState) -> str:
|
| 22 |
+
rag_method = state["rag_method"]
|
| 23 |
+
if rag_method == "RAG":
|
| 24 |
+
return "vectordb_node"
|
| 25 |
+
elif rag_method == "WEBSEARCH":
|
| 26 |
+
return "web_search_node"
|
| 27 |
+
elif rag_method == "GENERAL":
|
| 28 |
+
return "generate_node" # fallback to generate_node if the question do not requires RAG or websearch
|
| 29 |
+
else:
|
| 30 |
+
# If the LLM violates the prompt and outputs an unknown word,
|
| 31 |
+
print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.")
|
| 32 |
+
return END
|
| 33 |
|
| 34 |
+
def vectordb_node(state: AgentState, llm, vectorstore):
|
| 35 |
"""
|
| 36 |
+
Use vectordb to answer the query.
|
| 37 |
"""
|
| 38 |
+
context_docs = vectorstore.similarity_search(
|
| 39 |
+
query=state["messages"][-1].content,
|
| 40 |
+
k=5
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
)
|
| 42 |
+
context = "\n\n".join([doc.page_content for doc in context_docs])
|
| 43 |
+
state["context"] = context
|
| 44 |
+
return state
|
| 45 |
+
|
| 46 |
+
def generate_node(state: AgentState, llm):
|
| 47 |
+
messages = state["messages"][-10:] # Limit to last 10 messages to handle token limit
|
| 48 |
+
context = state.get("context", [])
|
| 49 |
+
|
| 50 |
+
system_content = get_system_prompt()
|
| 51 |
+
|
| 52 |
+
if context:
|
| 53 |
+
system_content += f"\n\nRelevant Context:\n{context}"
|
| 54 |
+
|
| 55 |
+
messages_with_system = [SystemMessage(content=system_content)] + messages
|
| 56 |
+
response = llm.invoke(messages_with_system)
|
| 57 |
+
|
| 58 |
+
return {'messages': [response]}
|
| 59 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
if __name__ == "__main__":
|
| 61 |
pass
|
agent/prompts.py
CHANGED
|
@@ -16,6 +16,19 @@ You are an intelligent assistant that MUST use the available tools to answer que
|
|
| 16 |
7. **Return the final answer** derived from the most relevant results.
|
| 17 |
""")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
def get_conversation_summary_prompt(messages):
|
| 20 |
"""Generate a prompt for conversation summarization."""
|
| 21 |
summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
|
|
|
|
| 16 |
7. **Return the final answer** derived from the most relevant results.
|
| 17 |
""")
|
| 18 |
|
| 19 |
+
def determine_rag_method_prompt() -> str:
|
| 20 |
+
return SystemMessage(content="""
|
| 21 |
+
You are an rag method classification model. Given the user's query, you must classify the method to use
|
| 22 |
+
as one and only one of the following options:
|
| 23 |
+
|
| 24 |
+
1. **RAG**: The query likely relates to the internal, domain-specific documents you have access to.
|
| 25 |
+
2. **WEBSEARCH**: The query requires real-time facts, general knowledge, or external information not in your documents.
|
| 26 |
+
3. **GENERAL**: The query can be answered based on your existing knowledge without external resources.
|
| 27 |
+
|
| 28 |
+
Respond STRICTLY with only one of these words: RAG, WEBSEARCH, or GENERAL. Do not include any punctuation, explanation, or extra text.
|
| 29 |
+
"""
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
def get_conversation_summary_prompt(messages):
|
| 33 |
"""Generate a prompt for conversation summarization."""
|
| 34 |
summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
|
agent/state.py
CHANGED
|
@@ -2,13 +2,23 @@ from typing import TypedDict, Annotated, Sequence, Optional, List
|
|
| 2 |
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
|
| 3 |
from langgraph.graph.message import add_messages
|
| 4 |
from pydantic import BaseModel, Field
|
|
|
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
class AgentState(TypedDict):
|
| 8 |
messages: Annotated[Sequence[AnyMessage], add_messages]
|
|
|
|
|
|
|
|
|
|
| 9 |
questionIsClear: bool
|
| 10 |
conversation_summary: str = ""
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
class QueryAnalysis(BaseModel):
|
| 13 |
"""Structured output for query analysis"""
|
| 14 |
is_clear: bool = Field(description="Indicates if the user's question is clear and answerable")
|
|
|
|
| 2 |
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
|
| 3 |
from langgraph.graph.message import add_messages
|
| 4 |
from pydantic import BaseModel, Field
|
| 5 |
+
from enum import Enum
|
| 6 |
|
| 7 |
+
class RAG_method(str, Enum):
|
| 8 |
+
RAG = "RAG"
|
| 9 |
+
WEBSEARCH = "WEBSEARCH"
|
| 10 |
+
GENERAL = "GENERAL"
|
| 11 |
|
| 12 |
class AgentState(TypedDict):
|
| 13 |
messages: Annotated[Sequence[AnyMessage], add_messages]
|
| 14 |
+
rag_method: RAG_method
|
| 15 |
+
context: Optional[str]
|
| 16 |
+
|
| 17 |
questionIsClear: bool
|
| 18 |
conversation_summary: str = ""
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
class QueryAnalysis(BaseModel):
|
| 23 |
"""Structured output for query analysis"""
|
| 24 |
is_clear: bool = Field(description="Indicates if the user's question is clear and answerable")
|
agent/tools.py
CHANGED
|
@@ -23,30 +23,28 @@ def intialize_chroma_vectorstore():
|
|
| 23 |
)
|
| 24 |
return vectorstore
|
| 25 |
|
| 26 |
-
|
| 27 |
@tool
|
| 28 |
-
def
|
| 29 |
-
"""Search for
|
| 30 |
-
|
| 31 |
Args:
|
| 32 |
-
query:
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
"""
|
| 35 |
try:
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
return [
|
| 39 |
{
|
| 40 |
-
"
|
| 41 |
-
"
|
| 42 |
-
"
|
| 43 |
}
|
| 44 |
-
for doc in
|
| 45 |
]
|
| 46 |
-
|
| 47 |
except Exception as e:
|
| 48 |
-
|
| 49 |
-
return []
|
| 50 |
|
| 51 |
@tool
|
| 52 |
def wikipedia_search(query: str) -> dict:
|
|
|
|
| 23 |
)
|
| 24 |
return vectorstore
|
| 25 |
|
|
|
|
| 26 |
@tool
|
| 27 |
+
def web_search_tavily(query: str) -> dict:
|
| 28 |
+
"""Search Tavily for a query and return up to 3 results.
|
|
|
|
| 29 |
Args:
|
| 30 |
+
query: The search query.
|
| 31 |
+
Returns:
|
| 32 |
+
dict with key 'web_results', containing a list of search results with
|
| 33 |
+
'source', 'page', and 'content'.
|
| 34 |
"""
|
| 35 |
try:
|
| 36 |
+
search_docs = TavilySearchResults(max_results=3).invoke(input=query)
|
| 37 |
+
results = [
|
|
|
|
| 38 |
{
|
| 39 |
+
"title": doc.get("title", ""),
|
| 40 |
+
"url": doc.get("url", ""),
|
| 41 |
+
"content": doc.get("content", ""),
|
| 42 |
}
|
| 43 |
+
for doc in search_docs
|
| 44 |
]
|
| 45 |
+
return {"web_results": results}
|
| 46 |
except Exception as e:
|
| 47 |
+
return {"web_results": f"Error retrieving results: {str(e)}"}
|
|
|
|
| 48 |
|
| 49 |
@tool
|
| 50 |
def wikipedia_search(query: str) -> dict:
|
config.py
CHANGED
|
@@ -2,7 +2,11 @@ import os
|
|
| 2 |
|
| 3 |
configs = {
|
| 4 |
"DATA_PATH": "./docs/markdowns",
|
| 5 |
-
"PERSIST_PATH": "./
|
| 6 |
"EMBEDDING_MODEL_NAME": "sentence-transformers/all-mpnet-base-v2",
|
| 7 |
"COLLECTION_NAME": "langchain_mpnet_collection"
|
| 8 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
configs = {
|
| 4 |
"DATA_PATH": "./docs/markdowns",
|
| 5 |
+
"PERSIST_PATH": "./chroma_data",
|
| 6 |
"EMBEDDING_MODEL_NAME": "sentence-transformers/all-mpnet-base-v2",
|
| 7 |
"COLLECTION_NAME": "langchain_mpnet_collection"
|
| 8 |
}
|
| 9 |
+
|
| 10 |
+
if __name__ == "__main__":
|
| 11 |
+
for key, value in configs.items():
|
| 12 |
+
os.environ[key] = value
|
core/chat_interface.py
CHANGED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from core.rag_agent import RAGAgent
|
| 3 |
+
from core.document_manager import DocumentManager
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# Initialize components
|
| 7 |
+
doc_manager = DocumentManager()
|
| 8 |
+
rag_agent = None
|
| 9 |
+
|
| 10 |
+
def initialize_agent():
|
| 11 |
+
"""Initialize RAG agent lazily"""
|
| 12 |
+
global rag_agent
|
| 13 |
+
if rag_agent is None:
|
| 14 |
+
rag_agent = RAGAgent()
|
| 15 |
+
return rag_agent
|
| 16 |
+
|
| 17 |
+
def upload_files(files):
|
| 18 |
+
"""Handle file uploads"""
|
| 19 |
+
if not files:
|
| 20 |
+
return "No files selected", get_file_list()
|
| 21 |
+
|
| 22 |
+
results = []
|
| 23 |
+
for file in files:
|
| 24 |
+
try:
|
| 25 |
+
result = doc_manager.add_document(file.name)
|
| 26 |
+
results.append(result)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
results.append(f"Error processing {os.path.basename(file.name)}: {str(e)}")
|
| 29 |
+
|
| 30 |
+
return "\n".join(results), get_file_list()
|
| 31 |
+
|
| 32 |
+
def get_file_list():
|
| 33 |
+
"""Get list of documents in the knowledge base"""
|
| 34 |
+
try:
|
| 35 |
+
files = doc_manager.list_documents()
|
| 36 |
+
if not files:
|
| 37 |
+
return "No documents in knowledge base"
|
| 38 |
+
return "\n".join([f"β’ {f}" for f in files])
|
| 39 |
+
except Exception as e:
|
| 40 |
+
return f"Error listing files: {str(e)}"
|
| 41 |
+
|
| 42 |
+
def clear_database():
|
| 43 |
+
"""Clear all documents from the knowledge base"""
|
| 44 |
+
try:
|
| 45 |
+
result = doc_manager.clear_all()
|
| 46 |
+
return result, get_file_list()
|
| 47 |
+
except Exception as e:
|
| 48 |
+
return f"Error clearing database: {str(e)}", get_file_list()
|
| 49 |
+
|
| 50 |
+
def chat_with_agent(message, history):
|
| 51 |
+
"""Handle chat interactions with the RAG agent"""
|
| 52 |
+
if not message.strip():
|
| 53 |
+
return history
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
agent = initialize_agent()
|
| 57 |
+
|
| 58 |
+
# Stream the agent's response
|
| 59 |
+
response_text = ""
|
| 60 |
+
for event in agent.agent_graph.stream(
|
| 61 |
+
{"messages": [("user", message)]},
|
| 62 |
+
agent.get_config(),
|
| 63 |
+
stream_mode="values"
|
| 64 |
+
):
|
| 65 |
+
if "messages" in event and len(event["messages"]) > 0:
|
| 66 |
+
last_message = event["messages"][-1]
|
| 67 |
+
if hasattr(last_message, "content"):
|
| 68 |
+
response_text = last_message.content
|
| 69 |
+
|
| 70 |
+
if not response_text:
|
| 71 |
+
response_text = "I apologize, but I couldn't generate a response. Please try again."
|
| 72 |
+
|
| 73 |
+
return response_text
|
| 74 |
+
|
| 75 |
+
except Exception as e:
|
| 76 |
+
return f"Error: {str(e)}"
|
| 77 |
+
|
| 78 |
+
def reset_conversation():
|
| 79 |
+
"""Reset the conversation thread"""
|
| 80 |
+
global rag_agent
|
| 81 |
+
if rag_agent:
|
| 82 |
+
rag_agent.reset_thread()
|
| 83 |
+
return None # Clear chat history
|
| 84 |
+
|
| 85 |
+
def create_gradio_ui():
|
| 86 |
+
"""Create the complete Gradio interface"""
|
| 87 |
+
|
| 88 |
+
with gr.Blocks(title="RAG Agent with Agentic Memory", theme=gr.themes.Soft()) as demo:
|
| 89 |
+
gr.Markdown("""
|
| 90 |
+
# π€ RAG Agent with Agentic Memory
|
| 91 |
+
|
| 92 |
+
Upload documents and chat with an intelligent agent that uses:
|
| 93 |
+
- π **Local Knowledge Base** (ChromaDB)
|
| 94 |
+
- π **Web Search** (Tavily)
|
| 95 |
+
- π **Wikipedia**
|
| 96 |
+
- π **ArXiv** (Academic Papers)
|
| 97 |
+
""")
|
| 98 |
+
|
| 99 |
+
with gr.Tabs():
|
| 100 |
+
# Documents Tab
|
| 101 |
+
with gr.Tab("π Documents"):
|
| 102 |
+
gr.Markdown("### Upload and Manage Documents")
|
| 103 |
+
gr.Markdown("Upload PDF or Markdown files to add them to the knowledge base.")
|
| 104 |
+
|
| 105 |
+
with gr.Row():
|
| 106 |
+
with gr.Column(scale=2):
|
| 107 |
+
file_upload = gr.File(
|
| 108 |
+
label="Upload Documents",
|
| 109 |
+
file_count="multiple",
|
| 110 |
+
file_types=[".pdf", ".md"]
|
| 111 |
+
)
|
| 112 |
+
upload_btn = gr.Button("π€ Add to Knowledge Base", variant="primary")
|
| 113 |
+
upload_status = gr.Textbox(label="Upload Status", lines=3)
|
| 114 |
+
|
| 115 |
+
with gr.Column(scale=1):
|
| 116 |
+
file_list = gr.Textbox(
|
| 117 |
+
label="Documents in Knowledge Base",
|
| 118 |
+
lines=10,
|
| 119 |
+
value=get_file_list()
|
| 120 |
+
)
|
| 121 |
+
refresh_btn = gr.Button("π Refresh List")
|
| 122 |
+
clear_btn = gr.Button("ποΈ Clear All Documents", variant="stop")
|
| 123 |
+
|
| 124 |
+
# Connect document management buttons
|
| 125 |
+
upload_btn.click(
|
| 126 |
+
fn=upload_files,
|
| 127 |
+
inputs=[file_upload],
|
| 128 |
+
outputs=[upload_status, file_list]
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
refresh_btn.click(
|
| 132 |
+
fn=get_file_list,
|
| 133 |
+
outputs=[file_list]
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
clear_btn.click(
|
| 137 |
+
fn=clear_database,
|
| 138 |
+
outputs=[upload_status, file_list]
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Chat Tab
|
| 142 |
+
with gr.Tab("π¬ Chat"):
|
| 143 |
+
gr.Markdown("### Chat with Your Documents")
|
| 144 |
+
gr.Markdown("Ask questions about your documents or any topic. The agent will search multiple sources.")
|
| 145 |
+
|
| 146 |
+
chatbot = gr.Chatbot(
|
| 147 |
+
label="Conversation",
|
| 148 |
+
height=500,
|
| 149 |
+
show_label=True,
|
| 150 |
+
avatar_images=(None, "π€")
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
with gr.Row():
|
| 154 |
+
msg = gr.Textbox(
|
| 155 |
+
label="Your Message",
|
| 156 |
+
placeholder="Ask me anything about your documents or general knowledge...",
|
| 157 |
+
scale=4
|
| 158 |
+
)
|
| 159 |
+
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
| 160 |
+
|
| 161 |
+
with gr.Row():
|
| 162 |
+
clear_chat_btn = gr.Button("π Reset Conversation")
|
| 163 |
+
gr.Markdown("*Note: Resetting clears the conversation history*")
|
| 164 |
+
|
| 165 |
+
# Chat interface
|
| 166 |
+
chat_interface = gr.ChatInterface(
|
| 167 |
+
fn=chat_with_agent,
|
| 168 |
+
chatbot=chatbot,
|
| 169 |
+
textbox=msg,
|
| 170 |
+
submit_btn=submit_btn,
|
| 171 |
+
retry_btn=None,
|
| 172 |
+
undo_btn=None,
|
| 173 |
+
clear_btn=None
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
clear_chat_btn.click(
|
| 177 |
+
fn=reset_conversation,
|
| 178 |
+
outputs=[chatbot]
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
gr.Markdown("""
|
| 182 |
+
---
|
| 183 |
+
### π§ How it works:
|
| 184 |
+
1. **Upload documents** in the Documents tab
|
| 185 |
+
2. **Ask questions** in the Chat tab
|
| 186 |
+
3. The agent will:
|
| 187 |
+
- Analyze your query
|
| 188 |
+
- Search relevant sources
|
| 189 |
+
- Provide comprehensive answers with citations
|
| 190 |
+
""")
|
| 191 |
+
|
| 192 |
+
return demo
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
demo = create_gradio_ui()
|
| 196 |
+
demo.launch(share=False, server_name="127.0.0.1", server_port=7860)
|
core/rag_agent.py
CHANGED
|
@@ -5,15 +5,18 @@ from agent.tools import *
|
|
| 5 |
from agent.graph import create_agent_graph
|
| 6 |
|
| 7 |
class RAGAgent:
|
| 8 |
-
|
| 9 |
-
def __init__(self, collection_name=config.CHILD_COLLECTION):
|
| 10 |
-
self.collection_name = collection_name
|
| 11 |
-
self.retriever = intialize_chroma_vectorstore()
|
| 12 |
self.thread_id = str(uuid.uuid4())
|
| 13 |
|
| 14 |
-
self.llm = ChatGoogleGenerativeAI(
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
def get_config(self):
|
| 19 |
return {"configurable": {"thread_id": self.thread_id}}
|
|
|
|
| 5 |
from agent.graph import create_agent_graph
|
| 6 |
|
| 7 |
class RAGAgent:
|
| 8 |
+
def __init__(self):
|
|
|
|
|
|
|
|
|
|
| 9 |
self.thread_id = str(uuid.uuid4())
|
| 10 |
|
| 11 |
+
self.llm = ChatGoogleGenerativeAI(
|
| 12 |
+
model=config.LLM_MODEL,
|
| 13 |
+
temperature=config.LLM_TEMPERATURE
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
vectordb = intialize_chroma_vectorstore()
|
| 17 |
+
|
| 18 |
+
search_tools = [web_search_tavily, arxiv_search, wikipedia_search]
|
| 19 |
+
self.agent_graph = create_agent_graph(self.llm, vectordb, search_tools)
|
| 20 |
|
| 21 |
def get_config(self):
|
| 22 |
return {"configurable": {"thread_id": self.thread_id}}
|
main.py
CHANGED
|
@@ -1,6 +1,11 @@
|
|
| 1 |
-
|
| 2 |
-
print("Hello from rag-agent!")
|
| 3 |
-
|
| 4 |
|
| 5 |
if __name__ == "__main__":
|
| 6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ui.gradio_components import create_gradio_ui
|
|
|
|
|
|
|
| 2 |
|
| 3 |
if __name__ == "__main__":
|
| 4 |
+
print("π Launching RAG Agent UI...")
|
| 5 |
+
demo = create_gradio_ui()
|
| 6 |
+
demo.launch(
|
| 7 |
+
share=False,
|
| 8 |
+
server_name="127.0.0.1",
|
| 9 |
+
server_port=7860,
|
| 10 |
+
show_error=True
|
| 11 |
+
)
|
ui/gradio_components.py
CHANGED
|
@@ -1,91 +1,116 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from core.
|
| 3 |
-
from core.document_manager import DocumentManager
|
| 4 |
-
from core.rag_system import RAGSystem
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
return None, format_file_list()
|
| 22 |
-
|
| 23 |
-
added, skipped = doc_manager.add_documents(
|
| 24 |
-
files,
|
| 25 |
-
progress_callback=lambda p, desc: progress(p, desc=desc)
|
| 26 |
-
)
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
with gr.Blocks(title="Agentic
|
|
|
|
|
|
|
| 43 |
|
| 44 |
-
with
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
show_label=False
|
| 54 |
-
)
|
| 55 |
-
|
| 56 |
-
add_btn = gr.Button("Add Documents", variant="primary", size="md")
|
| 57 |
-
|
| 58 |
-
gr.Markdown("## Current Documents in the Knowledge Base")
|
| 59 |
-
file_list = gr.Textbox(
|
| 60 |
-
value=format_file_list(),
|
| 61 |
-
interactive=False,
|
| 62 |
-
lines = 7,
|
| 63 |
-
max_lines=10,
|
| 64 |
-
elem_id="file-list-box",
|
| 65 |
-
show_label=False
|
| 66 |
-
)
|
| 67 |
-
|
| 68 |
-
with gr.Row():
|
| 69 |
-
refresh_btn = gr.Button("Refresh", size="md")
|
| 70 |
-
clear_btn = gr.Button("Clear All", variant="stop", size="md")
|
| 71 |
-
|
| 72 |
-
add_btn.click(
|
| 73 |
-
upload_handler,
|
| 74 |
-
[files_input],
|
| 75 |
-
[files_input, file_list],
|
| 76 |
-
show_progress="corner"
|
| 77 |
-
)
|
| 78 |
-
refresh_btn.click(format_file_list, None, file_list)
|
| 79 |
-
clear_btn.click(clear_handler, None, file_list)
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
)
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
-
return demo
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from core.rag_agent import RAGAgent
|
|
|
|
|
|
|
| 3 |
|
| 4 |
+
# Initialize components
|
| 5 |
+
rag_agent = None
|
| 6 |
+
|
| 7 |
+
def initialize_agent():
|
| 8 |
+
"""Initialize RAG agent lazily"""
|
| 9 |
+
global rag_agent
|
| 10 |
+
if rag_agent is None:
|
| 11 |
+
rag_agent = RAGAgent()
|
| 12 |
+
return rag_agent
|
| 13 |
+
|
| 14 |
+
def chat_with_agent(message, history):
|
| 15 |
+
"""Handle chat interactions with the RAG agent"""
|
| 16 |
+
if not message.strip():
|
| 17 |
+
return history
|
| 18 |
|
| 19 |
+
try:
|
| 20 |
+
agent = initialize_agent()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
# Stream the agent's response
|
| 23 |
+
response_text = ""
|
| 24 |
+
for event in agent.agent_graph.stream(
|
| 25 |
+
{"messages": [("user", message)]},
|
| 26 |
+
agent.get_config(),
|
| 27 |
+
stream_mode="values"
|
| 28 |
+
):
|
| 29 |
+
if "messages" in event and len(event["messages"]) > 0:
|
| 30 |
+
last_message = event["messages"][-1]
|
| 31 |
+
if hasattr(last_message, "content"):
|
| 32 |
+
response_text = last_message.content
|
| 33 |
+
|
| 34 |
+
if not response_text:
|
| 35 |
+
response_text = "I apologize, but I couldn't generate a response. Please try again."
|
| 36 |
+
|
| 37 |
+
return response_text
|
| 38 |
|
| 39 |
+
except Exception as e:
|
| 40 |
+
return f"Error: {str(e)}"
|
| 41 |
+
|
| 42 |
+
def reset_conversation():
|
| 43 |
+
"""Reset the conversation thread"""
|
| 44 |
+
global rag_agent
|
| 45 |
+
if rag_agent:
|
| 46 |
+
rag_agent.reset_thread()
|
| 47 |
+
return None # Clear chat history
|
| 48 |
+
|
| 49 |
+
def create_gradio_ui():
|
| 50 |
+
"""Create the complete Gradio interface"""
|
| 51 |
|
| 52 |
+
with gr.Blocks(title="RAG Agent with Agentic Memory", theme=gr.themes.Soft()) as demo:
|
| 53 |
+
gr.Markdown("""
|
| 54 |
+
# π€ RAG Agent with Agentic Memory
|
| 55 |
|
| 56 |
+
Chat with an intelligent agent that uses:
|
| 57 |
+
- π **Local Knowledge Base** (ChromaDB)
|
| 58 |
+
- π **Web Search** (Tavily)
|
| 59 |
+
- π **Wikipedia**
|
| 60 |
+
- π **ArXiv** (Academic Papers)
|
| 61 |
+
""")
|
| 62 |
+
|
| 63 |
+
gr.Markdown("### Chat with Your Documents")
|
| 64 |
+
gr.Markdown("Ask questions about your documents or any topic. The agent will search multiple sources.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
+
chatbot = gr.Chatbot(
|
| 67 |
+
label="Conversation",
|
| 68 |
+
height=500,
|
| 69 |
+
show_label=True,
|
| 70 |
+
avatar_images=(None, "π€")
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
with gr.Row():
|
| 74 |
+
msg = gr.Textbox(
|
| 75 |
+
label="Your Message",
|
| 76 |
+
placeholder="Ask me anything about your documents or general knowledge...",
|
| 77 |
+
scale=4
|
| 78 |
)
|
| 79 |
+
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
| 80 |
+
|
| 81 |
+
with gr.Row():
|
| 82 |
+
clear_chat_btn = gr.Button("π Reset Conversation")
|
| 83 |
+
gr.Markdown("*Note: Resetting clears the conversation history*")
|
| 84 |
+
|
| 85 |
+
# Chat interface
|
| 86 |
+
chat_interface = gr.ChatInterface(
|
| 87 |
+
fn=chat_with_agent,
|
| 88 |
+
chatbot=chatbot,
|
| 89 |
+
textbox=msg,
|
| 90 |
+
submit_btn=submit_btn,
|
| 91 |
+
retry_btn=None,
|
| 92 |
+
undo_btn=None,
|
| 93 |
+
clear_btn=None
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
clear_chat_btn.click(
|
| 97 |
+
fn=reset_conversation,
|
| 98 |
+
outputs=[chatbot]
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
gr.Markdown("""
|
| 102 |
+
---
|
| 103 |
+
### π§ How it works:
|
| 104 |
+
1. **Ask questions** in the chat
|
| 105 |
+
2. The agent will:
|
| 106 |
+
- Analyze your query
|
| 107 |
+
- Search relevant sources (ChromaDB, Web, Wikipedia, ArXiv)
|
| 108 |
+
- Provide comprehensive answers with citations
|
| 109 |
+
3. Use **Reset Conversation** to start fresh
|
| 110 |
+
""")
|
| 111 |
|
| 112 |
+
return demo
|
| 113 |
+
|
| 114 |
+
if __name__ == "__main__":
|
| 115 |
+
demo = create_gradio_ui()
|
| 116 |
+
demo.launch(share=False, server_name="127.0.0.1", server_port=7860)
|