Cheh Kit Hong
commited on
Commit
Β·
aa018e3
1
Parent(s):
0fc97a4
fixing gradio
Browse files- README.md +2 -4
- agent/graph.py +14 -5
- agent/more_nodes.py +0 -97
- agent/nodes.py +32 -5
- agent/prompts.py +19 -75
- agent/state.py +1 -1
- config.py +4 -1
- core/rag_agent.py +8 -3
- requirements.txt +2 -1
- test_scripts.py +349 -0
- ui/gradio_components.py +131 -57
README.md
CHANGED
|
@@ -2,14 +2,12 @@ rag_agent/
|
|
| 2 |
βββ app.py # Main Gradio application entry point
|
| 3 |
βββ config.py # Configuration hub (models, chunk sizes, providers)
|
| 4 |
βββ util.py # PDF to markdown conversion
|
| 5 |
-
βββ document_chunker.py # Chunking strategy
|
| 6 |
βββ core/ # Core RAG components orchestration
|
| 7 |
β βββ chat_interface.py
|
| 8 |
β βββ document_manager.py
|
| 9 |
β βββ rag_system.py
|
| 10 |
-
βββ knowledge_base/
|
| 11 |
-
|
| 12 |
-
β βββ vector_db_manager.py
|
| 13 |
βββ agent_logic/ # LangGraph agent workflow
|
| 14 |
β βββ edges.py # Conditional routing logic
|
| 15 |
β βββ graph.py # Graph construction and compilation
|
|
|
|
| 2 |
βββ app.py # Main Gradio application entry point
|
| 3 |
βββ config.py # Configuration hub (models, chunk sizes, providers)
|
| 4 |
βββ util.py # PDF to markdown conversion
|
|
|
|
| 5 |
βββ core/ # Core RAG components orchestration
|
| 6 |
β βββ chat_interface.py
|
| 7 |
β βββ document_manager.py
|
| 8 |
β βββ rag_system.py
|
| 9 |
+
βββ knowledge_base/ # for create chromadb
|
| 10 |
+
βββ chroma_data/ # chroma vectorstore data
|
|
|
|
| 11 |
βββ agent_logic/ # LangGraph agent workflow
|
| 12 |
β βββ edges.py # Conditional routing logic
|
| 13 |
β βββ graph.py # Graph construction and compilation
|
agent/graph.py
CHANGED
|
@@ -13,12 +13,14 @@ def create_agent_graph(llm, vectordb, search_tools) -> StateGraph:
|
|
| 13 |
graph = StateGraph(AgentState)
|
| 14 |
checkpointer = MemorySaver()
|
| 15 |
|
|
|
|
| 16 |
web_search_tool_node = ToolNode(search_tools)
|
| 17 |
|
| 18 |
# --- Nodes ---
|
| 19 |
graph.add_node("router_node", partial(router_node, llm=llm))
|
| 20 |
graph.add_node("vectordb_node", partial(vectordb_node, vectorstore=vectordb))
|
| 21 |
-
graph.add_node("
|
|
|
|
| 22 |
graph.add_node("generate_node", partial(generate_node, llm=llm))
|
| 23 |
|
| 24 |
# --- Edges ---
|
|
@@ -28,16 +30,23 @@ def create_agent_graph(llm, vectordb, search_tools) -> StateGraph:
|
|
| 28 |
"router_node",
|
| 29 |
routing_logic,
|
| 30 |
{
|
| 31 |
-
# Output from routing_logic -> Target Node Name
|
| 32 |
"vectordb_node": "vectordb_node",
|
| 33 |
-
"
|
| 34 |
"generate_node": "generate_node",
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
}
|
| 37 |
)
|
| 38 |
|
| 39 |
graph.add_edge("vectordb_node", "generate_node")
|
| 40 |
-
graph.add_edge("
|
| 41 |
|
| 42 |
graph.add_edge("generate_node", END)
|
| 43 |
|
|
|
|
| 13 |
graph = StateGraph(AgentState)
|
| 14 |
checkpointer = MemorySaver()
|
| 15 |
|
| 16 |
+
llm_with_tools = llm.bind_tools(search_tools)
|
| 17 |
web_search_tool_node = ToolNode(search_tools)
|
| 18 |
|
| 19 |
# --- Nodes ---
|
| 20 |
graph.add_node("router_node", partial(router_node, llm=llm))
|
| 21 |
graph.add_node("vectordb_node", partial(vectordb_node, vectorstore=vectordb))
|
| 22 |
+
graph.add_node("web_search_agent_node", partial(web_search_agent_node, llm=llm_with_tools))
|
| 23 |
+
graph.add_node("web_search_tool_node", web_search_tool_node)
|
| 24 |
graph.add_node("generate_node", partial(generate_node, llm=llm))
|
| 25 |
|
| 26 |
# --- Edges ---
|
|
|
|
| 30 |
"router_node",
|
| 31 |
routing_logic,
|
| 32 |
{
|
|
|
|
| 33 |
"vectordb_node": "vectordb_node",
|
| 34 |
+
"web_search_agent_node": "web_search_agent_node",
|
| 35 |
"generate_node": "generate_node",
|
| 36 |
+
}
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
graph.add_conditional_edges(
|
| 40 |
+
"web_search_agent_node",
|
| 41 |
+
tools_condition,
|
| 42 |
+
{
|
| 43 |
+
"tools": "web_search_tool_node", # Changed key from node name to "tools"
|
| 44 |
+
"__end__": "generate_node", # Changed key from "generate_node" to "__end__"
|
| 45 |
}
|
| 46 |
)
|
| 47 |
|
| 48 |
graph.add_edge("vectordb_node", "generate_node")
|
| 49 |
+
graph.add_edge("web_search_tool_node", "generate_node")
|
| 50 |
|
| 51 |
graph.add_edge("generate_node", END)
|
| 52 |
|
agent/more_nodes.py
DELETED
|
@@ -1,97 +0,0 @@
|
|
| 1 |
-
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
|
| 2 |
-
from typing import Literal
|
| 3 |
-
|
| 4 |
-
from .state import AgentState, QueryAnalysis
|
| 5 |
-
from .prompts import *
|
| 6 |
-
|
| 7 |
-
def analyze_chat_and_summarize(state: AgentState, llm):
|
| 8 |
-
"""
|
| 9 |
-
Analyzes chat history and summarizes key points for context.
|
| 10 |
-
"""
|
| 11 |
-
if len(state["messages"]) < 4: # Need some history to summarize
|
| 12 |
-
return {"conversation_summary": ""}
|
| 13 |
-
|
| 14 |
-
# Extract relevant messages (excluding current query and system messages)
|
| 15 |
-
relevant_msgs = [
|
| 16 |
-
msg for msg in state["messages"][:-1] # Exclude current query
|
| 17 |
-
if isinstance(msg, (HumanMessage, AIMessage))
|
| 18 |
-
and not getattr(msg, "tool_calls", None)
|
| 19 |
-
]
|
| 20 |
-
|
| 21 |
-
if not relevant_msgs:
|
| 22 |
-
return {"conversation_summary": ""}
|
| 23 |
-
|
| 24 |
-
summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
|
| 25 |
-
Discard irrelevant information, such as misunderstandings or off-topic queries/responses.
|
| 26 |
-
If there are no key topics, return an empty string.
|
| 27 |
-
|
| 28 |
-
"""
|
| 29 |
-
for msg in relevant_msgs[-6:]: # Last 6 messages for context
|
| 30 |
-
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
|
| 31 |
-
summary_prompt += f"{role}: {msg.content}\n"
|
| 32 |
-
|
| 33 |
-
summary_prompt += "\nBrief Summary:"
|
| 34 |
-
summary_response = llm.with_config(temperature=0.3).invoke([SystemMessage(content=summary_prompt)])
|
| 35 |
-
return {"conversation_summary": summary_response.content}
|
| 36 |
-
|
| 37 |
-
def analyze_and_rewrite_query(state: AgentState, llm):
|
| 38 |
-
"""
|
| 39 |
-
Analyzes user query and rewrites it for clarity, optionally using conversation context.
|
| 40 |
-
"""
|
| 41 |
-
last_message = state["messages"][-1]
|
| 42 |
-
conversation_summary = state.get("conversation_summary", "")
|
| 43 |
-
|
| 44 |
-
context_section = (
|
| 45 |
-
f"**Conversation Context:**\n{conversation_summary}"
|
| 46 |
-
if conversation_summary.strip()
|
| 47 |
-
else "**Conversation Context:**\n[First query in conversation]"
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
# Create analysis prompt
|
| 51 |
-
query_analysis_prompt = get_query_analysis_prompt(last_message.content, conversation_summary)
|
| 52 |
-
|
| 53 |
-
llm_with_structure = llm.with_config(temperature=0.3).with_structured_output(QueryAnalysis)
|
| 54 |
-
response = llm_with_structure.invoke([SystemMessage(content=query_analysis_prompt)])
|
| 55 |
-
|
| 56 |
-
if response.is_clear:
|
| 57 |
-
# Remove all non-system messages
|
| 58 |
-
delete_all = [
|
| 59 |
-
RemoveMessage(id=m.id)
|
| 60 |
-
for m in state["messages"]
|
| 61 |
-
if not isinstance(m, SystemMessage)
|
| 62 |
-
]
|
| 63 |
-
|
| 64 |
-
# Format rewritten query
|
| 65 |
-
rewritten = (
|
| 66 |
-
"\n".join([f"{i+1}. {q}" for i, q in enumerate(response.questions)])
|
| 67 |
-
if len(response.questions) > 1
|
| 68 |
-
else response.questions[0]
|
| 69 |
-
)
|
| 70 |
-
return {
|
| 71 |
-
"questionIsClear": True,
|
| 72 |
-
"messages": delete_all + [HumanMessage(content=rewritten)]
|
| 73 |
-
}
|
| 74 |
-
else:
|
| 75 |
-
clarification = response.clarification_needed or "I need more information to understand your question."
|
| 76 |
-
return {
|
| 77 |
-
"questionIsClear": False,
|
| 78 |
-
"messages": [AIMessage(content=clarification)]
|
| 79 |
-
}
|
| 80 |
-
|
| 81 |
-
def human_input_node(state: AgentState):
|
| 82 |
-
"""Placeholder node for human-in-the-loop interruption"""
|
| 83 |
-
return {}
|
| 84 |
-
|
| 85 |
-
def route_after_rewrite(state: AgentState) -> Literal["agent", "human_input"]:
|
| 86 |
-
"""Route to agent if question is clear, otherwise wait for human input"""
|
| 87 |
-
return "agent" if state.get("questionIsClear", False) else "human_input"
|
| 88 |
-
|
| 89 |
-
def agent_node(state: AgentState, llm_with_tools):
|
| 90 |
-
"""Main agent node that processes queries using tools"""
|
| 91 |
-
system_prompt = get_system_prompt()
|
| 92 |
-
messages = [system_prompt] + state["messages"]
|
| 93 |
-
response = llm_with_tools.invoke(messages)
|
| 94 |
-
return {"messages": [response]}
|
| 95 |
-
|
| 96 |
-
if __name__ == "__main__":
|
| 97 |
-
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
agent/nodes.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
|
| 2 |
from typing import Literal
|
| 3 |
from langgraph.graph import START, END
|
| 4 |
|
|
@@ -13,17 +13,17 @@ def router_node(state: AgentState, llm):
|
|
| 13 |
"""
|
| 14 |
query = state["messages"][-1].content
|
| 15 |
rag_method_prompt = determine_rag_method_prompt()
|
| 16 |
-
rag_method_result = llm.invoke([rag_method_prompt, HumanMessage(content=query)])
|
| 17 |
rag_method = rag_method_result.content.strip().upper()
|
| 18 |
state["rag_method"] = rag_method
|
| 19 |
return state
|
| 20 |
|
| 21 |
-
def routing_logic(
|
| 22 |
rag_method = state["rag_method"]
|
| 23 |
if rag_method == "RAG":
|
| 24 |
return "vectordb_node"
|
| 25 |
elif rag_method == "WEBSEARCH":
|
| 26 |
-
return "
|
| 27 |
elif rag_method == "GENERAL":
|
| 28 |
return "generate_node" # fallback to generate_node if the question do not requires RAG or websearch
|
| 29 |
else:
|
|
@@ -31,7 +31,7 @@ def routing_logic(self, state: AgentState) -> str:
|
|
| 31 |
print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.")
|
| 32 |
return END
|
| 33 |
|
| 34 |
-
def vectordb_node(state: AgentState,
|
| 35 |
"""
|
| 36 |
Use vectordb to answer the query.
|
| 37 |
"""
|
|
@@ -43,12 +43,39 @@ def vectordb_node(state: AgentState, llm, vectorstore):
|
|
| 43 |
state["context"] = context
|
| 44 |
return state
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
def generate_node(state: AgentState, llm):
|
| 47 |
messages = state["messages"][-10:] # Limit to last 10 messages to handle token limit
|
| 48 |
context = state.get("context", [])
|
| 49 |
|
| 50 |
system_content = get_system_prompt()
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
if context:
|
| 53 |
system_content += f"\n\nRelevant Context:\n{context}"
|
| 54 |
|
|
|
|
| 1 |
+
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, ToolMessage, RemoveMessage
|
| 2 |
from typing import Literal
|
| 3 |
from langgraph.graph import START, END
|
| 4 |
|
|
|
|
| 13 |
"""
|
| 14 |
query = state["messages"][-1].content
|
| 15 |
rag_method_prompt = determine_rag_method_prompt()
|
| 16 |
+
rag_method_result = llm.invoke([SystemMessage(content=rag_method_prompt), HumanMessage(content=query)])
|
| 17 |
rag_method = rag_method_result.content.strip().upper()
|
| 18 |
state["rag_method"] = rag_method
|
| 19 |
return state
|
| 20 |
|
| 21 |
+
def routing_logic(state: AgentState) -> str:
|
| 22 |
rag_method = state["rag_method"]
|
| 23 |
if rag_method == "RAG":
|
| 24 |
return "vectordb_node"
|
| 25 |
elif rag_method == "WEBSEARCH":
|
| 26 |
+
return "web_search_agent_node"
|
| 27 |
elif rag_method == "GENERAL":
|
| 28 |
return "generate_node" # fallback to generate_node if the question do not requires RAG or websearch
|
| 29 |
else:
|
|
|
|
| 31 |
print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.")
|
| 32 |
return END
|
| 33 |
|
| 34 |
+
def vectordb_node(state: AgentState, vectorstore):
|
| 35 |
"""
|
| 36 |
Use vectordb to answer the query.
|
| 37 |
"""
|
|
|
|
| 43 |
state["context"] = context
|
| 44 |
return state
|
| 45 |
|
| 46 |
+
def web_search_agent_node(state: AgentState, llm):
|
| 47 |
+
"""
|
| 48 |
+
LLM agent that decides which web search tools to call.
|
| 49 |
+
This generates an AIMessage with tool_calls.
|
| 50 |
+
"""
|
| 51 |
+
messages = state["messages"]
|
| 52 |
+
|
| 53 |
+
# Add instruction to use tools
|
| 54 |
+
system_msg = SystemMessage(content="""You are a web search assistant.
|
| 55 |
+
Use the available search tools (web_search_tavily, wikipedia_search) to find information about the user's query.
|
| 56 |
+
Call the appropriate tool with the query.""")
|
| 57 |
+
|
| 58 |
+
messages_with_system = [system_msg] + messages
|
| 59 |
+
|
| 60 |
+
# LLM with tools bound will generate AIMessage with tool_calls
|
| 61 |
+
response = llm.invoke(messages_with_system)
|
| 62 |
+
|
| 63 |
+
return {"messages": [response]}
|
| 64 |
+
|
| 65 |
def generate_node(state: AgentState, llm):
|
| 66 |
messages = state["messages"][-10:] # Limit to last 10 messages to handle token limit
|
| 67 |
context = state.get("context", [])
|
| 68 |
|
| 69 |
system_content = get_system_prompt()
|
| 70 |
|
| 71 |
+
# Extract web search results from ToolMessages if available
|
| 72 |
+
if not context:
|
| 73 |
+
for msg in reversed(messages):
|
| 74 |
+
if isinstance(msg, ToolMessage):
|
| 75 |
+
# Web search results come as ToolMessage content
|
| 76 |
+
if msg.content:
|
| 77 |
+
context += f"\n\n{msg.content}"
|
| 78 |
+
|
| 79 |
if context:
|
| 80 |
system_content += f"\n\nRelevant Context:\n{context}"
|
| 81 |
|
agent/prompts.py
CHANGED
|
@@ -2,87 +2,31 @@ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
|
|
| 2 |
|
| 3 |
def get_system_prompt() -> SystemMessage:
|
| 4 |
"""Generate the system prompt for the RAG agent."""
|
| 5 |
-
return
|
| 6 |
-
You are
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
4. **Use metadata** such as `source` and `parent_id` to help clarify or support the answer when applicable.
|
| 13 |
-
5. **Answer using ONLY the retrieved information**:
|
| 14 |
-
- Combine relevant chunks and use metadata (e.g., citation sources) as needed to clarify or support the response.
|
| 15 |
-
6. **If no relevant information is found**, rewrite the query into an **answer-focused declarative statement** and search again **only once** using `search_chroma`.
|
| 16 |
-
7. **Return the final answer** derived from the most relevant results.
|
| 17 |
-
""")
|
| 18 |
|
| 19 |
def determine_rag_method_prompt() -> str:
|
| 20 |
-
return
|
| 21 |
-
You are
|
| 22 |
as one and only one of the following options:
|
| 23 |
|
| 24 |
-
1. **RAG**: The query
|
| 25 |
-
2. **WEBSEARCH**: The query
|
| 26 |
-
3. **GENERAL**: The query
|
| 27 |
-
|
| 28 |
-
Respond STRICTLY with only one of these words: RAG, WEBSEARCH, or GENERAL. Do not include any punctuation, explanation, or extra text.
|
| 29 |
-
"""
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
def get_conversation_summary_prompt(messages):
|
| 33 |
-
"""Generate a prompt for conversation summarization."""
|
| 34 |
-
summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
|
| 35 |
-
Discard irrelevant information, such as misunderstandings or off-topic queries/responses.
|
| 36 |
-
If there are no key topics, return an empty string.
|
| 37 |
-
|
| 38 |
-
"""
|
| 39 |
-
|
| 40 |
-
for msg in messages[-6:]:
|
| 41 |
-
role = "User" if isinstance(msg, HumanMessage) else "Assistant"
|
| 42 |
-
summary_prompt += f"{role}: {msg.content}\n"
|
| 43 |
-
|
| 44 |
-
summary_prompt += "\n**Brief Summary:**"
|
| 45 |
-
return summary_prompt
|
| 46 |
-
|
| 47 |
-
def get_query_analysis_prompt(query: str, conversation_summary: str = "") -> str:
|
| 48 |
-
"""Generate a prompt for query analysis and rewriting."""
|
| 49 |
-
context_section = (
|
| 50 |
-
f"**Conversation Context:**\n{conversation_summary}"
|
| 51 |
-
if conversation_summary.strip()
|
| 52 |
-
else "**Conversation Context:**\n[First query in conversation]"
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
return f"""
|
| 56 |
-
**Rewrite the user's query** to be clear, self-contained, and optimized for information retrieval.
|
| 57 |
-
|
| 58 |
-
**User Query:**
|
| 59 |
-
"{query}"
|
| 60 |
-
|
| 61 |
-
{context_section}
|
| 62 |
-
|
| 63 |
-
**Instructions:**
|
| 64 |
-
|
| 65 |
-
1. **Resolve references for follow-ups:**
|
| 66 |
-
- If the query uses pronouns or refers to previous topics, use the context to make it self-contained.
|
| 67 |
-
|
| 68 |
-
2. **Ensure clarity for new queries:**
|
| 69 |
-
- Make the query specific, concise, and unambiguous.
|
| 70 |
-
|
| 71 |
-
3. **Correct errors and interpret intent:**
|
| 72 |
-
- If the query is grammatically incorrect, contains typos, or has abbreviations, correct it and infer the intended meaning.
|
| 73 |
-
|
| 74 |
-
4. **Split only when necessary:**
|
| 75 |
-
- If multiple distinct questions exist, split into **up to 3 focused sub-queries** to avoid over-segmentation.
|
| 76 |
-
- Each sub-query must still be meaningful on its own.
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
- This includes nonsense, gibberish, insults, or statements without an apparent question.
|
| 85 |
"""
|
| 86 |
-
|
| 87 |
if __name__ == "__main__":
|
| 88 |
pass
|
|
|
|
| 2 |
|
| 3 |
def get_system_prompt() -> SystemMessage:
|
| 4 |
"""Generate the system prompt for the RAG agent."""
|
| 5 |
+
return """
|
| 6 |
+
You are a helpful assistant tasked with answering questions using a set of tools.
|
| 7 |
+
Follow the ReAct framework: iteratively reason through the problem step-by-step, use tools when necessary, and refine your approach based on tool outputs.
|
| 8 |
+
You will be provided with relevant context from the knowledge base if required. Use this context to inform your response, but feel free to supplement with your own knowledge when appropriate. Context will be provided in the state under 'context' key.
|
| 9 |
+
You will also have access to web search tools like Tavily, Wikipedia or Arxiv.
|
| 10 |
+
DO NOT make any assumptions.
|
| 11 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
def determine_rag_method_prompt() -> str:
|
| 14 |
+
return """
|
| 15 |
+
You are a query classification model. Given the user's query, you must classify the method to use
|
| 16 |
as one and only one of the following options:
|
| 17 |
|
| 18 |
+
1. **RAG**: The query asks about specific documents, papers, or systems like DeepAnalyze, AgentMem, SAM3, SAM 3, SAM3D, DeepSeek-OCR, or any technical architecture/implementation details from research papers.
|
| 19 |
+
2. **WEBSEARCH**: The query asks for current events, latest news, real-time information after January 2024, or general factual knowledge not in specialized documents.
|
| 20 |
+
3. **GENERAL**: The query is a simple calculation, definition, reasoning task, or common knowledge question that doesn't need external data.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
+
**Examples:**
|
| 23 |
+
- "What is DeepAnalyze?" β RAG
|
| 24 |
+
- "Explain SAM 3 architecture" β RAG
|
| 25 |
+
- "Latest AI news in 2025" β WEBSEARCH
|
| 26 |
+
- "What is 15 times 7?" β GENERAL
|
| 27 |
|
| 28 |
+
Respond STRICTLY with only one word: RAG, WEBSEARCH, or GENERAL. No punctuation or extra text.
|
|
|
|
| 29 |
"""
|
| 30 |
+
|
| 31 |
if __name__ == "__main__":
|
| 32 |
pass
|
agent/state.py
CHANGED
|
@@ -18,7 +18,7 @@ class AgentState(TypedDict):
|
|
| 18 |
conversation_summary: str = ""
|
| 19 |
|
| 20 |
|
| 21 |
-
|
| 22 |
class QueryAnalysis(BaseModel):
|
| 23 |
"""Structured output for query analysis"""
|
| 24 |
is_clear: bool = Field(description="Indicates if the user's question is clear and answerable")
|
|
|
|
| 18 |
conversation_summary: str = ""
|
| 19 |
|
| 20 |
|
| 21 |
+
# Implement later if needed, omit first
|
| 22 |
class QueryAnalysis(BaseModel):
|
| 23 |
"""Structured output for query analysis"""
|
| 24 |
is_clear: bool = Field(description="Indicates if the user's question is clear and answerable")
|
config.py
CHANGED
|
@@ -4,7 +4,10 @@ configs = {
|
|
| 4 |
"DATA_PATH": "./docs/markdowns",
|
| 5 |
"PERSIST_PATH": "./chroma_data",
|
| 6 |
"EMBEDDING_MODEL_NAME": "sentence-transformers/all-mpnet-base-v2",
|
| 7 |
-
"COLLECTION_NAME": "langchain_mpnet_collection"
|
|
|
|
|
|
|
|
|
|
| 8 |
}
|
| 9 |
|
| 10 |
if __name__ == "__main__":
|
|
|
|
| 4 |
"DATA_PATH": "./docs/markdowns",
|
| 5 |
"PERSIST_PATH": "./chroma_data",
|
| 6 |
"EMBEDDING_MODEL_NAME": "sentence-transformers/all-mpnet-base-v2",
|
| 7 |
+
"COLLECTION_NAME": "langchain_mpnet_collection",
|
| 8 |
+
"LLM_MODEL_NAME": "gemini-2.0-flash",
|
| 9 |
+
"TEMPERATURE": 0.2,
|
| 10 |
+
"MAX_TOKENS": 2048,
|
| 11 |
}
|
| 12 |
|
| 13 |
if __name__ == "__main__":
|
core/rag_agent.py
CHANGED
|
@@ -1,16 +1,21 @@
|
|
| 1 |
import uuid
|
| 2 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 3 |
-
import
|
| 4 |
from agent.tools import *
|
| 5 |
from agent.graph import create_agent_graph
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
class RAGAgent:
|
| 8 |
def __init__(self):
|
| 9 |
self.thread_id = str(uuid.uuid4())
|
| 10 |
|
| 11 |
self.llm = ChatGoogleGenerativeAI(
|
| 12 |
-
model=
|
| 13 |
-
temperature=
|
|
|
|
| 14 |
)
|
| 15 |
|
| 16 |
vectordb = intialize_chroma_vectorstore()
|
|
|
|
| 1 |
import uuid
|
| 2 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 3 |
+
from config import configs
|
| 4 |
from agent.tools import *
|
| 5 |
from agent.graph import create_agent_graph
|
| 6 |
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
load_dotenv()
|
| 10 |
+
|
| 11 |
class RAGAgent:
|
| 12 |
def __init__(self):
|
| 13 |
self.thread_id = str(uuid.uuid4())
|
| 14 |
|
| 15 |
self.llm = ChatGoogleGenerativeAI(
|
| 16 |
+
model=configs["LLM_MODEL_NAME"],
|
| 17 |
+
temperature=configs["TEMPERATURE"],
|
| 18 |
+
max_tokens=configs["MAX_TOKENS"]
|
| 19 |
)
|
| 20 |
|
| 21 |
vectordb = intialize_chroma_vectorstore()
|
requirements.txt
CHANGED
|
@@ -13,4 +13,5 @@ langchain-community
|
|
| 13 |
langchain_text_splitters
|
| 14 |
pymupdf-layout
|
| 15 |
sentence_transformers
|
| 16 |
-
gradio
|
|
|
|
|
|
| 13 |
langchain_text_splitters
|
| 14 |
pymupdf-layout
|
| 15 |
sentence_transformers
|
| 16 |
+
gradio
|
| 17 |
+
python-dotenv
|
test_scripts.py
ADDED
|
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Test script for RAG Agent logic.
|
| 3 |
+
Tests the agent workflow, nodes, state management, and retrieval.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# Add project root to path
|
| 10 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 11 |
+
|
| 12 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 13 |
+
from agent.state import AgentState
|
| 14 |
+
from core.rag_agent import RAGAgent
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def print_separator(title: str):
|
| 18 |
+
"""Print a visual separator."""
|
| 19 |
+
print("\n" + "="*70)
|
| 20 |
+
print(f" {title}")
|
| 21 |
+
print("="*70 + "\n")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_agent_initialization():
|
| 25 |
+
"""Test RAGAgent can be initialized properly."""
|
| 26 |
+
print_separator("TEST 1: Agent Initialization")
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
agent = RAGAgent()
|
| 30 |
+
print("β RAGAgent initialized successfully")
|
| 31 |
+
print(f" - Thread ID: {agent.thread_id}")
|
| 32 |
+
print(f" - LLM Model: {agent.llm.model_name if hasattr(agent.llm, 'model_name') else 'initialized'}")
|
| 33 |
+
print(f" - Graph: {type(agent.agent_graph).__name__}")
|
| 34 |
+
return agent
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"β Failed to initialize RAGAgent: {e}")
|
| 37 |
+
import traceback
|
| 38 |
+
traceback.print_exc()
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_simple_query(agent: RAGAgent):
|
| 43 |
+
"""Test a simple query execution."""
|
| 44 |
+
print_separator("TEST 2: Simple Query")
|
| 45 |
+
|
| 46 |
+
if agent is None:
|
| 47 |
+
print("β Skipping - agent not initialized")
|
| 48 |
+
return False
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
query = "What is DeepAnalyze?"
|
| 52 |
+
print(f"Query: '{query}'")
|
| 53 |
+
|
| 54 |
+
initial_state = {
|
| 55 |
+
"messages": [HumanMessage(content=query)],
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
result = agent.agent_graph.invoke(
|
| 59 |
+
initial_state,
|
| 60 |
+
config=agent.get_config()
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
messages = result.get("messages", [])
|
| 64 |
+
ai_messages = [m for m in messages if isinstance(m, AIMessage)]
|
| 65 |
+
|
| 66 |
+
if ai_messages:
|
| 67 |
+
print(f"β Query executed successfully")
|
| 68 |
+
print(f" Total messages: {len(messages)}")
|
| 69 |
+
print(f" Response length: {len(ai_messages[-1].content)} chars")
|
| 70 |
+
print(f"\n Response preview:")
|
| 71 |
+
print(f" {ai_messages[-1].content[:300]}...")
|
| 72 |
+
return True
|
| 73 |
+
else:
|
| 74 |
+
print(f"β No AI response generated")
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"β Query execution failed: {e}")
|
| 79 |
+
import traceback
|
| 80 |
+
traceback.print_exc()
|
| 81 |
+
return False
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def test_rag_query(agent: RAGAgent):
|
| 85 |
+
"""Test a query that should use RAG (local documents)."""
|
| 86 |
+
print_separator("TEST 3: RAG Query")
|
| 87 |
+
|
| 88 |
+
if agent is None:
|
| 89 |
+
print("β Skipping - agent not initialized")
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
query = "Explain the architecture of SAM 3"
|
| 94 |
+
print(f"Query: '{query}' (should use local documents)")
|
| 95 |
+
|
| 96 |
+
initial_state = {
|
| 97 |
+
"messages": [HumanMessage(content=query)],
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
result = agent.agent_graph.invoke(
|
| 101 |
+
initial_state,
|
| 102 |
+
config=agent.get_config()
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
messages = result.get("messages", [])
|
| 106 |
+
rag_method = result.get("rag_method", "UNKNOWN")
|
| 107 |
+
ai_messages = [m for m in messages if isinstance(m, AIMessage)]
|
| 108 |
+
|
| 109 |
+
print(f" Routing decision: {rag_method}")
|
| 110 |
+
|
| 111 |
+
if ai_messages:
|
| 112 |
+
print(f"β RAG query executed")
|
| 113 |
+
print(f" Response preview:")
|
| 114 |
+
print(f" {ai_messages[-1].content[:300]}...")
|
| 115 |
+
return True
|
| 116 |
+
else:
|
| 117 |
+
print(f"β No response generated")
|
| 118 |
+
return False
|
| 119 |
+
|
| 120 |
+
except Exception as e:
|
| 121 |
+
print(f"β RAG query failed: {e}")
|
| 122 |
+
import traceback
|
| 123 |
+
traceback.print_exc()
|
| 124 |
+
return False
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def test_web_search_query(agent: RAGAgent):
|
| 128 |
+
"""Test a query that should use web search."""
|
| 129 |
+
print_separator("TEST 4: Web Search Query")
|
| 130 |
+
|
| 131 |
+
if agent is None:
|
| 132 |
+
print("β Skipping - agent not initialized")
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
query = "What's the latest news about AI in 2025?"
|
| 137 |
+
print(f"Query: '{query}' (should use web search)")
|
| 138 |
+
|
| 139 |
+
initial_state = {
|
| 140 |
+
"messages": [HumanMessage(content=query)],
|
| 141 |
+
}
|
| 142 |
+
|
| 143 |
+
result = agent.agent_graph.invoke(
|
| 144 |
+
initial_state,
|
| 145 |
+
config=agent.get_config()
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
messages = result.get("messages", [])
|
| 149 |
+
rag_method = result.get("rag_method", "UNKNOWN")
|
| 150 |
+
ai_messages = [m for m in messages if isinstance(m, AIMessage)]
|
| 151 |
+
|
| 152 |
+
print(f" Routing decision: {rag_method}")
|
| 153 |
+
|
| 154 |
+
if ai_messages:
|
| 155 |
+
print(f"β Web search query executed")
|
| 156 |
+
print(f" Response preview:")
|
| 157 |
+
print(f" {ai_messages[-1].content[:300]}...")
|
| 158 |
+
return True
|
| 159 |
+
else:
|
| 160 |
+
print(f"β No response generated")
|
| 161 |
+
return False
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"β Web search query failed: {e}")
|
| 165 |
+
import traceback
|
| 166 |
+
traceback.print_exc()
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def test_general_query(agent: RAGAgent):
|
| 171 |
+
"""Test a general query that doesn't need RAG or web search."""
|
| 172 |
+
print_separator("TEST 5: General Query")
|
| 173 |
+
|
| 174 |
+
if agent is None:
|
| 175 |
+
print("β Skipping - agent not initialized")
|
| 176 |
+
return False
|
| 177 |
+
|
| 178 |
+
try:
|
| 179 |
+
query = "What is 15 multiplied by 7?"
|
| 180 |
+
print(f"Query: '{query}' (should use general LLM)")
|
| 181 |
+
|
| 182 |
+
initial_state = {
|
| 183 |
+
"messages": [HumanMessage(content=query)],
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
result = agent.agent_graph.invoke(
|
| 187 |
+
initial_state,
|
| 188 |
+
config=agent.get_config()
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
messages = result.get("messages", [])
|
| 192 |
+
rag_method = result.get("rag_method", "UNKNOWN")
|
| 193 |
+
ai_messages = [m for m in messages if isinstance(m, AIMessage)]
|
| 194 |
+
|
| 195 |
+
print(f" Routing decision: {rag_method}")
|
| 196 |
+
|
| 197 |
+
if ai_messages:
|
| 198 |
+
print(f"β General query executed")
|
| 199 |
+
print(f" Response: {ai_messages[-1].content}")
|
| 200 |
+
return True
|
| 201 |
+
else:
|
| 202 |
+
print(f"β No response generated")
|
| 203 |
+
return False
|
| 204 |
+
|
| 205 |
+
except Exception as e:
|
| 206 |
+
print(f"β General query failed: {e}")
|
| 207 |
+
import traceback
|
| 208 |
+
traceback.print_exc()
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def test_conversation_memory(agent: RAGAgent):
|
| 213 |
+
"""Test multi-turn conversation with memory."""
|
| 214 |
+
print_separator("TEST 6: Conversation Memory")
|
| 215 |
+
|
| 216 |
+
if agent is None:
|
| 217 |
+
print("β Skipping - agent not initialized")
|
| 218 |
+
return False
|
| 219 |
+
|
| 220 |
+
try:
|
| 221 |
+
# Reset thread for clean test
|
| 222 |
+
agent.reset_thread()
|
| 223 |
+
|
| 224 |
+
# First turn
|
| 225 |
+
print("Turn 1: 'What is DeepAnalyze?'")
|
| 226 |
+
state1 = {
|
| 227 |
+
"messages": [HumanMessage(content="What is DeepAnalyze?")],
|
| 228 |
+
}
|
| 229 |
+
result1 = agent.agent_graph.invoke(state1, config=agent.get_config())
|
| 230 |
+
|
| 231 |
+
ai_msg_1 = [m for m in result1["messages"] if isinstance(m, AIMessage)]
|
| 232 |
+
if not ai_msg_1:
|
| 233 |
+
print("β No response in turn 1")
|
| 234 |
+
return False
|
| 235 |
+
|
| 236 |
+
print(f"β Turn 1 response: {ai_msg_1[-1].content[:100]}...")
|
| 237 |
+
|
| 238 |
+
# Second turn - follow-up question
|
| 239 |
+
print("\nTurn 2: 'What are its main features?' (requires context)")
|
| 240 |
+
state2 = {
|
| 241 |
+
"messages": [HumanMessage(content="What are its main features?")],
|
| 242 |
+
}
|
| 243 |
+
result2 = agent.agent_graph.invoke(state2, config=agent.get_config())
|
| 244 |
+
|
| 245 |
+
ai_msg_2 = [m for m in result2["messages"] if isinstance(m, AIMessage)]
|
| 246 |
+
if not ai_msg_2:
|
| 247 |
+
print("β No response in turn 2")
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
print(f"β Turn 2 response: {ai_msg_2[-1].content[:100]}...")
|
| 251 |
+
|
| 252 |
+
# Check if response makes sense in context
|
| 253 |
+
response = ai_msg_2[-1].content.lower()
|
| 254 |
+
if "deepanalyze" in response or "feature" in response or "agent" in response:
|
| 255 |
+
print("β Conversation memory working - response uses context")
|
| 256 |
+
return True
|
| 257 |
+
else:
|
| 258 |
+
print("β Response may not be using conversation context properly")
|
| 259 |
+
return True # Still pass, as it generated a response
|
| 260 |
+
|
| 261 |
+
except Exception as e:
|
| 262 |
+
print(f"β Conversation memory test failed: {e}")
|
| 263 |
+
import traceback
|
| 264 |
+
traceback.print_exc()
|
| 265 |
+
return False
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def test_thread_reset(agent: RAGAgent):
|
| 269 |
+
"""Test thread reset functionality."""
|
| 270 |
+
print_separator("TEST 7: Thread Reset")
|
| 271 |
+
|
| 272 |
+
if agent is None:
|
| 273 |
+
print("β Skipping - agent not initialized")
|
| 274 |
+
return False
|
| 275 |
+
|
| 276 |
+
try:
|
| 277 |
+
old_thread_id = agent.thread_id
|
| 278 |
+
print(f"Old thread ID: {old_thread_id}")
|
| 279 |
+
|
| 280 |
+
agent.reset_thread()
|
| 281 |
+
|
| 282 |
+
new_thread_id = agent.thread_id
|
| 283 |
+
print(f"New thread ID: {new_thread_id}")
|
| 284 |
+
|
| 285 |
+
if old_thread_id != new_thread_id:
|
| 286 |
+
print("β Thread reset successfully")
|
| 287 |
+
return True
|
| 288 |
+
else:
|
| 289 |
+
print("β Thread ID unchanged after reset")
|
| 290 |
+
return False
|
| 291 |
+
|
| 292 |
+
except Exception as e:
|
| 293 |
+
print(f"β Thread reset failed: {e}")
|
| 294 |
+
import traceback
|
| 295 |
+
traceback.print_exc()
|
| 296 |
+
return False
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def run_all_tests():
|
| 300 |
+
"""Run all tests and provide summary."""
|
| 301 |
+
print("\n" + "β"*70)
|
| 302 |
+
print(" RAG AGENT TEST SUITE")
|
| 303 |
+
print("β"*70)
|
| 304 |
+
|
| 305 |
+
# Initialize agent once
|
| 306 |
+
agent = test_agent_initialization()
|
| 307 |
+
|
| 308 |
+
if agent is None:
|
| 309 |
+
print("\nβ Cannot proceed - agent initialization failed")
|
| 310 |
+
return False
|
| 311 |
+
|
| 312 |
+
tests = [
|
| 313 |
+
("Simple Query", lambda: test_simple_query(agent)),
|
| 314 |
+
("RAG Query", lambda: test_rag_query(agent)),
|
| 315 |
+
("Web Search Query", lambda: test_web_search_query(agent)),
|
| 316 |
+
("General Query", lambda: test_general_query(agent)),
|
| 317 |
+
("Conversation Memory", lambda: test_conversation_memory(agent)),
|
| 318 |
+
("Thread Reset", lambda: test_thread_reset(agent)),
|
| 319 |
+
]
|
| 320 |
+
|
| 321 |
+
results = {}
|
| 322 |
+
for name, test_func in tests:
|
| 323 |
+
try:
|
| 324 |
+
results[name] = test_func()
|
| 325 |
+
except Exception as e:
|
| 326 |
+
print(f"\nβ Test '{name}' crashed: {e}")
|
| 327 |
+
import traceback
|
| 328 |
+
traceback.print_exc()
|
| 329 |
+
results[name] = False
|
| 330 |
+
|
| 331 |
+
# Print summary
|
| 332 |
+
print_separator("TEST SUMMARY")
|
| 333 |
+
passed = sum(results.values())
|
| 334 |
+
total = len(results)
|
| 335 |
+
|
| 336 |
+
for name, passed_test in results.items():
|
| 337 |
+
status = "β PASS" if passed_test else "β FAIL"
|
| 338 |
+
print(f"{status}: {name}")
|
| 339 |
+
|
| 340 |
+
print(f"\n{'='*70}")
|
| 341 |
+
print(f" TOTAL: {passed}/{total} tests passed ({passed/total*100:.1f}%)")
|
| 342 |
+
print(f"{'='*70}\n")
|
| 343 |
+
|
| 344 |
+
return passed == total
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
if __name__ == "__main__":
|
| 348 |
+
success = run_all_tests()
|
| 349 |
+
sys.exit(0 if success else 1)
|
ui/gradio_components.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
from core.rag_agent import RAGAgent
|
| 3 |
|
| 4 |
# Initialize components
|
|
@@ -19,81 +20,143 @@ def chat_with_agent(message, history):
|
|
| 19 |
try:
|
| 20 |
agent = initialize_agent()
|
| 21 |
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
for
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
):
|
| 29 |
-
if "messages" in event and len(event["messages"]) > 0:
|
| 30 |
-
last_message = event["messages"][-1]
|
| 31 |
-
if hasattr(last_message, "content"):
|
| 32 |
-
response_text = last_message.content
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
except Exception as e:
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def reset_conversation():
|
| 43 |
"""Reset the conversation thread"""
|
| 44 |
global rag_agent
|
| 45 |
if rag_agent:
|
| 46 |
rag_agent.reset_thread()
|
| 47 |
-
return
|
| 48 |
|
| 49 |
def create_gradio_ui():
|
| 50 |
"""Create the complete Gradio interface"""
|
| 51 |
|
| 52 |
-
with gr.Blocks(title="RAG Agent with Agentic Memory"
|
| 53 |
gr.Markdown("""
|
| 54 |
# π€ RAG Agent with Agentic Memory
|
| 55 |
|
| 56 |
Chat with an intelligent agent that uses:
|
| 57 |
-
- π **Local Knowledge Base** (ChromaDB)
|
| 58 |
-
- π **Web Search** (Tavily)
|
| 59 |
-
- π **Wikipedia**
|
| 60 |
-
- π **ArXiv**
|
| 61 |
""")
|
| 62 |
|
| 63 |
-
gr.Markdown("### Chat with Your Documents")
|
| 64 |
-
gr.Markdown("Ask questions about your documents or any topic. The agent will search multiple sources.")
|
| 65 |
-
|
| 66 |
-
chatbot = gr.Chatbot(
|
| 67 |
-
label="Conversation",
|
| 68 |
-
height=500,
|
| 69 |
-
show_label=True,
|
| 70 |
-
avatar_images=(None, "π€")
|
| 71 |
-
)
|
| 72 |
-
|
| 73 |
with gr.Row():
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
#
|
| 86 |
-
|
| 87 |
-
fn=
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
submit_btn=submit_btn,
|
| 91 |
-
retry_btn=None,
|
| 92 |
-
undo_btn=None,
|
| 93 |
-
clear_btn=None
|
| 94 |
)
|
| 95 |
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
fn=reset_conversation,
|
| 98 |
outputs=[chatbot]
|
| 99 |
)
|
|
@@ -101,16 +164,27 @@ def create_gradio_ui():
|
|
| 101 |
gr.Markdown("""
|
| 102 |
---
|
| 103 |
### π§ How it works:
|
| 104 |
-
1. **
|
| 105 |
2. The agent will:
|
| 106 |
-
- Analyze your query
|
| 107 |
-
- Search relevant sources (
|
| 108 |
-
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
""")
|
| 111 |
|
| 112 |
return demo
|
| 113 |
|
| 114 |
if __name__ == "__main__":
|
| 115 |
demo = create_gradio_ui()
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from langchain_core.messages import HumanMessage, AIMessage
|
| 3 |
from core.rag_agent import RAGAgent
|
| 4 |
|
| 5 |
# Initialize components
|
|
|
|
| 20 |
try:
|
| 21 |
agent = initialize_agent()
|
| 22 |
|
| 23 |
+
# Convert Gradio history format to LangChain messages
|
| 24 |
+
messages = []
|
| 25 |
+
for user_msg, assistant_msg in history:
|
| 26 |
+
messages.append(HumanMessage(content=user_msg))
|
| 27 |
+
if assistant_msg:
|
| 28 |
+
messages.append(AIMessage(content=assistant_msg))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
# Add current user message
|
| 31 |
+
messages.append(HumanMessage(content=message))
|
| 32 |
|
| 33 |
+
# Create initial state
|
| 34 |
+
initial_state = {
|
| 35 |
+
"messages": messages,
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
# Invoke the agent graph
|
| 39 |
+
result = agent.agent_graph.invoke(
|
| 40 |
+
initial_state,
|
| 41 |
+
config=agent.get_config()
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
# Extract AI response
|
| 45 |
+
result_messages = result.get("messages", [])
|
| 46 |
+
ai_messages = [m for m in result_messages if isinstance(m, AIMessage)]
|
| 47 |
+
|
| 48 |
+
if ai_messages:
|
| 49 |
+
# Get the last AI message
|
| 50 |
+
response = ai_messages[-1].content
|
| 51 |
+
|
| 52 |
+
# Add routing info as metadata (optional)
|
| 53 |
+
rag_method = result.get("rag_method", "UNKNOWN")
|
| 54 |
+
response_with_metadata = f"{response}\n\n*[Source: {rag_method}]*"
|
| 55 |
+
|
| 56 |
+
# Return history in Gradio's format [[user, bot], [user, bot], ...]
|
| 57 |
+
new_history = history + [[message, response_with_metadata]]
|
| 58 |
+
return new_history
|
| 59 |
+
else:
|
| 60 |
+
new_history = history + [[message, "β οΈ No response generated. Please try again."]]
|
| 61 |
+
return new_history
|
| 62 |
+
|
| 63 |
except Exception as e:
|
| 64 |
+
error_msg = f"β Error: {str(e)}"
|
| 65 |
+
print(f"Chat error: {e}")
|
| 66 |
+
import traceback
|
| 67 |
+
traceback.print_exc()
|
| 68 |
+
|
| 69 |
+
new_history = history + [[message, error_msg]]
|
| 70 |
+
return new_history
|
| 71 |
|
| 72 |
def reset_conversation():
|
| 73 |
"""Reset the conversation thread"""
|
| 74 |
global rag_agent
|
| 75 |
if rag_agent:
|
| 76 |
rag_agent.reset_thread()
|
| 77 |
+
return [] # Clear chat history
|
| 78 |
|
| 79 |
def create_gradio_ui():
|
| 80 |
"""Create the complete Gradio interface"""
|
| 81 |
|
| 82 |
+
with gr.Blocks(title="RAG Agent with Agentic Memory") as demo:
|
| 83 |
gr.Markdown("""
|
| 84 |
# π€ RAG Agent with Agentic Memory
|
| 85 |
|
| 86 |
Chat with an intelligent agent that uses:
|
| 87 |
+
- π **Local Knowledge Base** (ChromaDB) - Research papers on DeepAnalyze, AgentMem, SAM3, etc.
|
| 88 |
+
- π **Web Search** (Tavily) - Real-time information and current events
|
| 89 |
+
- π **Wikipedia** - General knowledge
|
| 90 |
+
- π **ArXiv** - Academic papers
|
| 91 |
""")
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
with gr.Row():
|
| 94 |
+
with gr.Column(scale=4):
|
| 95 |
+
gr.Markdown("### π¬ Chat Interface")
|
| 96 |
+
|
| 97 |
+
chatbot = gr.Chatbot(
|
| 98 |
+
label="Conversation",
|
| 99 |
+
height=500,
|
| 100 |
+
show_label=False,
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
with gr.Row():
|
| 104 |
+
msg = gr.Textbox(
|
| 105 |
+
label="Your Message",
|
| 106 |
+
placeholder="Ask me anything about your documents or general knowledge...",
|
| 107 |
+
scale=5,
|
| 108 |
+
show_label=False
|
| 109 |
+
)
|
| 110 |
+
submit_btn = gr.Button("Send π€", variant="primary", scale=1)
|
| 111 |
+
|
| 112 |
+
with gr.Row():
|
| 113 |
+
clear_btn = gr.Button("π Reset Conversation", variant="secondary")
|
| 114 |
+
|
| 115 |
+
with gr.Column(scale=1):
|
| 116 |
+
gr.Markdown("### π Agent Status")
|
| 117 |
+
status_box = gr.Markdown("*Ready*")
|
| 118 |
+
|
| 119 |
+
gr.Markdown("### π‘ Example Queries")
|
| 120 |
+
gr.Markdown("""
|
| 121 |
+
**Local Documents (RAG):**
|
| 122 |
+
- What is DeepAnalyze?
|
| 123 |
+
- Explain SAM 3 architecture
|
| 124 |
+
- What is AgentMem?
|
| 125 |
+
|
| 126 |
+
**Web Search:**
|
| 127 |
+
- Latest AI news in 2025
|
| 128 |
+
- Current events in technology
|
| 129 |
+
|
| 130 |
+
**General:**
|
| 131 |
+
- What is 15 Γ 7?
|
| 132 |
+
- Explain machine learning
|
| 133 |
+
""")
|
| 134 |
|
| 135 |
+
# Event handlers
|
| 136 |
+
def submit_message(message, history):
|
| 137 |
+
"""Handle message submission with status update"""
|
| 138 |
+
if not message.strip():
|
| 139 |
+
return history, ""
|
| 140 |
+
|
| 141 |
+
# Get response
|
| 142 |
+
new_history = chat_with_agent(message, history)
|
| 143 |
+
|
| 144 |
+
return new_history, ""
|
| 145 |
|
| 146 |
+
# Wire up events
|
| 147 |
+
msg.submit(
|
| 148 |
+
fn=submit_message,
|
| 149 |
+
inputs=[msg, chatbot],
|
| 150 |
+
outputs=[chatbot, msg]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
)
|
| 152 |
|
| 153 |
+
submit_btn.click(
|
| 154 |
+
fn=submit_message,
|
| 155 |
+
inputs=[msg, chatbot],
|
| 156 |
+
outputs=[chatbot, msg]
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
clear_btn.click(
|
| 160 |
fn=reset_conversation,
|
| 161 |
outputs=[chatbot]
|
| 162 |
)
|
|
|
|
| 164 |
gr.Markdown("""
|
| 165 |
---
|
| 166 |
### π§ How it works:
|
| 167 |
+
1. **Type your question** in the text box
|
| 168 |
2. The agent will:
|
| 169 |
+
- π§ Analyze your query to determine the best source
|
| 170 |
+
- π Search relevant sources (Local docs, Web, Wikipedia)
|
| 171 |
+
- π Generate a comprehensive answer
|
| 172 |
+
- πΎ Remember conversation context for follow-up questions
|
| 173 |
+
3. Use **Reset Conversation** to start a new thread
|
| 174 |
+
|
| 175 |
+
---
|
| 176 |
+
*Powered by LangGraph + LangChain + ChromaDB + Anthropic Claude*
|
| 177 |
""")
|
| 178 |
|
| 179 |
return demo
|
| 180 |
|
| 181 |
if __name__ == "__main__":
|
| 182 |
demo = create_gradio_ui()
|
| 183 |
+
print("π Starting Gradio interface...")
|
| 184 |
+
print("π Running on: http://127.0.0.1:7860")
|
| 185 |
+
demo.launch(
|
| 186 |
+
share=False,
|
| 187 |
+
server_name="127.0.0.1",
|
| 188 |
+
server_port=7860,
|
| 189 |
+
show_error=True
|
| 190 |
+
)
|