kith777 commited on
Commit
0fc97a4
Β·
1 Parent(s): 30ee88a

added gradio

Browse files
agent/graph.py CHANGED
@@ -6,31 +6,46 @@ from functools import partial
6
  from .state import AgentState
7
  from .nodes import *
8
 
9
- def create_agent_graph(llm, tools) -> StateGraph:
 
10
  """Create the RAG agent graph."""
11
- llm_with_tools = llm.with_tools(tools)
12
 
13
  graph = StateGraph(AgentState)
14
  checkpointer = MemorySaver()
15
 
16
- tool_node = ToolNode(tools)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- # Nodes
19
- graph.add_node("summarize", partial(analyze_chat_and_summarize, llm=llm)) # summarize last 6 messages
20
- graph.add_node("analyze_rewrite", partial(analyze_and_rewrite_query, llm=llm)) # analyze and rewrite query
21
- graph.add_node("agent", partial(agent_node, llm_with_tools=llm_with_tools)) # generate answer based on retrieved info
22
- graph.add_node("tools", tool_node)
23
 
24
- graph.add_edge(START, "summarize")
25
- graph.add_edge("summarize", "analyze_rewrite")
26
- graph.add_conditional_edges("analyze_rewrite", route_after_rewrite)
27
- graph.add_edge("human_input", "analyze_rewrite")
28
- graph.add_conditional_edges("agent", tools_condition)
29
- graph.add_edge("tools", "agent")
30
 
31
  agent_graph = graph.compile(
32
  checkpointer=checkpointer,
33
- interrupt_before=["human_input"]
34
  )
35
 
36
- return agent_graph
 
 
 
 
6
  from .state import AgentState
7
  from .nodes import *
8
 
9
+
10
+ def create_agent_graph(llm, vectordb, search_tools) -> StateGraph:
11
  """Create the RAG agent graph."""
 
12
 
13
  graph = StateGraph(AgentState)
14
  checkpointer = MemorySaver()
15
 
16
+ web_search_tool_node = ToolNode(search_tools)
17
+
18
+ # --- Nodes ---
19
+ graph.add_node("router_node", partial(router_node, llm=llm))
20
+ graph.add_node("vectordb_node", partial(vectordb_node, vectorstore=vectordb))
21
+ graph.add_node("web_search_node", web_search_tool_node)
22
+ graph.add_node("generate_node", partial(generate_node, llm=llm))
23
+
24
+ # --- Edges ---
25
+ graph.add_edge(START, "router_node")
26
+
27
+ graph.add_conditional_edges(
28
+ "router_node",
29
+ routing_logic,
30
+ {
31
+ # Output from routing_logic -> Target Node Name
32
+ "vectordb_node": "vectordb_node",
33
+ "web_search_node": "web_search_node",
34
+ "generate_node": "generate_node",
35
+ # If your logic has an 'else' that returns END, you don't list it here.
36
+ }
37
+ )
38
 
39
+ graph.add_edge("vectordb_node", "generate_node")
40
+ graph.add_edge("web_search_node", "generate_node")
 
 
 
41
 
42
+ graph.add_edge("generate_node", END)
 
 
 
 
 
43
 
44
  agent_graph = graph.compile(
45
  checkpointer=checkpointer,
 
46
  )
47
 
48
+ return agent_graph
49
+
50
+ if __name__ == "__main__":
51
+ pass
agent/more_nodes.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
2
+ from typing import Literal
3
+
4
+ from .state import AgentState, QueryAnalysis
5
+ from .prompts import *
6
+
7
+ def analyze_chat_and_summarize(state: AgentState, llm):
8
+ """
9
+ Analyzes chat history and summarizes key points for context.
10
+ """
11
+ if len(state["messages"]) < 4: # Need some history to summarize
12
+ return {"conversation_summary": ""}
13
+
14
+ # Extract relevant messages (excluding current query and system messages)
15
+ relevant_msgs = [
16
+ msg for msg in state["messages"][:-1] # Exclude current query
17
+ if isinstance(msg, (HumanMessage, AIMessage))
18
+ and not getattr(msg, "tool_calls", None)
19
+ ]
20
+
21
+ if not relevant_msgs:
22
+ return {"conversation_summary": ""}
23
+
24
+ summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
25
+ Discard irrelevant information, such as misunderstandings or off-topic queries/responses.
26
+ If there are no key topics, return an empty string.
27
+
28
+ """
29
+ for msg in relevant_msgs[-6:]: # Last 6 messages for context
30
+ role = "User" if isinstance(msg, HumanMessage) else "Assistant"
31
+ summary_prompt += f"{role}: {msg.content}\n"
32
+
33
+ summary_prompt += "\nBrief Summary:"
34
+ summary_response = llm.with_config(temperature=0.3).invoke([SystemMessage(content=summary_prompt)])
35
+ return {"conversation_summary": summary_response.content}
36
+
37
+ def analyze_and_rewrite_query(state: AgentState, llm):
38
+ """
39
+ Analyzes user query and rewrites it for clarity, optionally using conversation context.
40
+ """
41
+ last_message = state["messages"][-1]
42
+ conversation_summary = state.get("conversation_summary", "")
43
+
44
+ context_section = (
45
+ f"**Conversation Context:**\n{conversation_summary}"
46
+ if conversation_summary.strip()
47
+ else "**Conversation Context:**\n[First query in conversation]"
48
+ )
49
+
50
+ # Create analysis prompt
51
+ query_analysis_prompt = get_query_analysis_prompt(last_message.content, conversation_summary)
52
+
53
+ llm_with_structure = llm.with_config(temperature=0.3).with_structured_output(QueryAnalysis)
54
+ response = llm_with_structure.invoke([SystemMessage(content=query_analysis_prompt)])
55
+
56
+ if response.is_clear:
57
+ # Remove all non-system messages
58
+ delete_all = [
59
+ RemoveMessage(id=m.id)
60
+ for m in state["messages"]
61
+ if not isinstance(m, SystemMessage)
62
+ ]
63
+
64
+ # Format rewritten query
65
+ rewritten = (
66
+ "\n".join([f"{i+1}. {q}" for i, q in enumerate(response.questions)])
67
+ if len(response.questions) > 1
68
+ else response.questions[0]
69
+ )
70
+ return {
71
+ "questionIsClear": True,
72
+ "messages": delete_all + [HumanMessage(content=rewritten)]
73
+ }
74
+ else:
75
+ clarification = response.clarification_needed or "I need more information to understand your question."
76
+ return {
77
+ "questionIsClear": False,
78
+ "messages": [AIMessage(content=clarification)]
79
+ }
80
+
81
+ def human_input_node(state: AgentState):
82
+ """Placeholder node for human-in-the-loop interruption"""
83
+ return {}
84
+
85
+ def route_after_rewrite(state: AgentState) -> Literal["agent", "human_input"]:
86
+ """Route to agent if question is clear, otherwise wait for human input"""
87
+ return "agent" if state.get("questionIsClear", False) else "human_input"
88
+
89
+ def agent_node(state: AgentState, llm_with_tools):
90
+ """Main agent node that processes queries using tools"""
91
+ system_prompt = get_system_prompt()
92
+ messages = [system_prompt] + state["messages"]
93
+ response = llm_with_tools.invoke(messages)
94
+ return {"messages": [response]}
95
+
96
+ if __name__ == "__main__":
97
+ pass
agent/nodes.py CHANGED
@@ -1,97 +1,61 @@
1
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
2
  from typing import Literal
 
3
 
4
  from .state import AgentState, QueryAnalysis
5
  from .prompts import *
 
6
 
7
- def analyze_chat_and_summarize(state: AgentState, llm):
8
- """
9
- Analyzes chat history and summarizes key points for context.
10
- """
11
- if len(state["messages"]) < 4: # Need some history to summarize
12
- return {"conversation_summary": ""}
13
-
14
- # Extract relevant messages (excluding current query and system messages)
15
- relevant_msgs = [
16
- msg for msg in state["messages"][:-1] # Exclude current query
17
- if isinstance(msg, (HumanMessage, AIMessage))
18
- and not getattr(msg, "tool_calls", None)
19
- ]
20
-
21
- if not relevant_msgs:
22
- return {"conversation_summary": ""}
23
-
24
- summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
25
- Discard irrelevant information, such as misunderstandings or off-topic queries/responses.
26
- If there are no key topics, return an empty string.
27
 
 
28
  """
29
- for msg in relevant_msgs[-6:]: # Last 6 messages for context
30
- role = "User" if isinstance(msg, HumanMessage) else "Assistant"
31
- summary_prompt += f"{role}: {msg.content}\n"
32
-
33
- summary_prompt += "\nBrief Summary:"
34
- summary_response = llm.with_config(temperature=0.3).invoke([SystemMessage(content=summary_prompt)])
35
- return {"conversation_summary": summary_response.content}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- def analyze_and_rewrite_query(state: AgentState, llm):
38
  """
39
- Analyzes user query and rewrites it for clarity, optionally using conversation context.
40
  """
41
- last_message = state["messages"][-1]
42
- conversation_summary = state.get("conversation_summary", "")
43
-
44
- context_section = (
45
- f"**Conversation Context:**\n{conversation_summary}"
46
- if conversation_summary.strip()
47
- else "**Conversation Context:**\n[First query in conversation]"
48
  )
49
-
50
- # Create analysis prompt
51
- query_analysis_prompt = get_query_analysis_prompt(last_message.content, conversation_summary)
52
-
53
- llm_with_structure = llm.with_config(temperature=0.3).with_structured_output(QueryAnalysis)
54
- response = llm_with_structure.invoke([SystemMessage(content=query_analysis_prompt)])
55
-
56
- if response.is_clear:
57
- # Remove all non-system messages
58
- delete_all = [
59
- RemoveMessage(id=m.id)
60
- for m in state["messages"]
61
- if not isinstance(m, SystemMessage)
62
- ]
63
-
64
- # Format rewritten query
65
- rewritten = (
66
- "\n".join([f"{i+1}. {q}" for i, q in enumerate(response.questions)])
67
- if len(response.questions) > 1
68
- else response.questions[0]
69
- )
70
- return {
71
- "questionIsClear": True,
72
- "messages": delete_all + [HumanMessage(content=rewritten)]
73
- }
74
- else:
75
- clarification = response.clarification_needed or "I need more information to understand your question."
76
- return {
77
- "questionIsClear": False,
78
- "messages": [AIMessage(content=clarification)]
79
- }
80
-
81
- def human_input_node(state: AgentState):
82
- """Placeholder node for human-in-the-loop interruption"""
83
- return {}
84
-
85
- def route_after_rewrite(state: AgentState) -> Literal["agent", "human_input"]:
86
- """Route to agent if question is clear, otherwise wait for human input"""
87
- return "agent" if state.get("questionIsClear", False) else "human_input"
88
-
89
- def agent_node(state: AgentState, llm_with_tools):
90
- """Main agent node that processes queries using tools"""
91
- system_prompt = get_system_prompt()
92
- messages = [system_prompt] + state["messages"]
93
- response = llm_with_tools.invoke(messages)
94
- return {"messages": [response]}
95
-
96
  if __name__ == "__main__":
97
  pass
 
1
  from langchain_core.messages import HumanMessage, AIMessage, SystemMessage, RemoveMessage
2
  from typing import Literal
3
+ from langgraph.graph import START, END
4
 
5
  from .state import AgentState, QueryAnalysis
6
  from .prompts import *
7
+ from .tools import intialize_chroma_vectorstore
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ def router_node(state: AgentState, llm):
11
  """
12
+ Takes the query (and history). Decides the next step: vectordb, tools, or refuse.
13
+ """
14
+ query = state["messages"][-1].content
15
+ rag_method_prompt = determine_rag_method_prompt()
16
+ rag_method_result = llm.invoke([rag_method_prompt, HumanMessage(content=query)])
17
+ rag_method = rag_method_result.content.strip().upper()
18
+ state["rag_method"] = rag_method
19
+ return state
20
+
21
+ def routing_logic(self, state: AgentState) -> str:
22
+ rag_method = state["rag_method"]
23
+ if rag_method == "RAG":
24
+ return "vectordb_node"
25
+ elif rag_method == "WEBSEARCH":
26
+ return "web_search_node"
27
+ elif rag_method == "GENERAL":
28
+ return "generate_node" # fallback to generate_node if the question do not requires RAG or websearch
29
+ else:
30
+ # If the LLM violates the prompt and outputs an unknown word,
31
+ print(f"ERROR: Router returned unclassified intent: {rag_method}. Terminating flow.")
32
+ return END
33
 
34
+ def vectordb_node(state: AgentState, llm, vectorstore):
35
  """
36
+ Use vectordb to answer the query.
37
  """
38
+ context_docs = vectorstore.similarity_search(
39
+ query=state["messages"][-1].content,
40
+ k=5
 
 
 
 
41
  )
42
+ context = "\n\n".join([doc.page_content for doc in context_docs])
43
+ state["context"] = context
44
+ return state
45
+
46
+ def generate_node(state: AgentState, llm):
47
+ messages = state["messages"][-10:] # Limit to last 10 messages to handle token limit
48
+ context = state.get("context", [])
49
+
50
+ system_content = get_system_prompt()
51
+
52
+ if context:
53
+ system_content += f"\n\nRelevant Context:\n{context}"
54
+
55
+ messages_with_system = [SystemMessage(content=system_content)] + messages
56
+ response = llm.invoke(messages_with_system)
57
+
58
+ return {'messages': [response]}
59
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  if __name__ == "__main__":
61
  pass
agent/prompts.py CHANGED
@@ -16,6 +16,19 @@ You are an intelligent assistant that MUST use the available tools to answer que
16
  7. **Return the final answer** derived from the most relevant results.
17
  """)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def get_conversation_summary_prompt(messages):
20
  """Generate a prompt for conversation summarization."""
21
  summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
 
16
  7. **Return the final answer** derived from the most relevant results.
17
  """)
18
 
19
+ def determine_rag_method_prompt() -> str:
20
+ return SystemMessage(content="""
21
+ You are an rag method classification model. Given the user's query, you must classify the method to use
22
+ as one and only one of the following options:
23
+
24
+ 1. **RAG**: The query likely relates to the internal, domain-specific documents you have access to.
25
+ 2. **WEBSEARCH**: The query requires real-time facts, general knowledge, or external information not in your documents.
26
+ 3. **GENERAL**: The query can be answered based on your existing knowledge without external resources.
27
+
28
+ Respond STRICTLY with only one of these words: RAG, WEBSEARCH, or GENERAL. Do not include any punctuation, explanation, or extra text.
29
+ """
30
+ )
31
+
32
  def get_conversation_summary_prompt(messages):
33
  """Generate a prompt for conversation summarization."""
34
  summary_prompt = """**Summarize the key topics and context from this conversation concisely (1-2 sentences max).**
agent/state.py CHANGED
@@ -2,13 +2,23 @@ from typing import TypedDict, Annotated, Sequence, Optional, List
2
  from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
3
  from langgraph.graph.message import add_messages
4
  from pydantic import BaseModel, Field
 
5
 
 
 
 
 
6
 
7
  class AgentState(TypedDict):
8
  messages: Annotated[Sequence[AnyMessage], add_messages]
 
 
 
9
  questionIsClear: bool
10
  conversation_summary: str = ""
11
-
 
 
12
  class QueryAnalysis(BaseModel):
13
  """Structured output for query analysis"""
14
  is_clear: bool = Field(description="Indicates if the user's question is clear and answerable")
 
2
  from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage, ToolMessage
3
  from langgraph.graph.message import add_messages
4
  from pydantic import BaseModel, Field
5
+ from enum import Enum
6
 
7
+ class RAG_method(str, Enum):
8
+ RAG = "RAG"
9
+ WEBSEARCH = "WEBSEARCH"
10
+ GENERAL = "GENERAL"
11
 
12
  class AgentState(TypedDict):
13
  messages: Annotated[Sequence[AnyMessage], add_messages]
14
+ rag_method: RAG_method
15
+ context: Optional[str]
16
+
17
  questionIsClear: bool
18
  conversation_summary: str = ""
19
+
20
+
21
+
22
  class QueryAnalysis(BaseModel):
23
  """Structured output for query analysis"""
24
  is_clear: bool = Field(description="Indicates if the user's question is clear and answerable")
agent/tools.py CHANGED
@@ -23,30 +23,28 @@ def intialize_chroma_vectorstore():
23
  )
24
  return vectorstore
25
 
26
-
27
  @tool
28
- def search_chroma(vectorstore: Chroma, query: str, k: int = 5) -> List[dict]:
29
- """Search for the top K most relevant chunks from Chroma vector store.
30
-
31
  Args:
32
- query: Search query string
33
- k: Number of results to return
 
 
34
  """
35
  try:
36
- results = vectorstore.similarity_search(query, k=k, score_threshold=0.7)
37
-
38
- return [
39
  {
40
- "content": doc.page_content,
41
- "parent_id": doc.metadata.get("parent_id", ""),
42
- "source": doc.metadata.get("source", "")
43
  }
44
- for doc in results
45
  ]
46
-
47
  except Exception as e:
48
- print(f"Error searching chunks: {e}")
49
- return []
50
 
51
  @tool
52
  def wikipedia_search(query: str) -> dict:
 
23
  )
24
  return vectorstore
25
 
 
26
  @tool
27
+ def web_search_tavily(query: str) -> dict:
28
+ """Search Tavily for a query and return up to 3 results.
 
29
  Args:
30
+ query: The search query.
31
+ Returns:
32
+ dict with key 'web_results', containing a list of search results with
33
+ 'source', 'page', and 'content'.
34
  """
35
  try:
36
+ search_docs = TavilySearchResults(max_results=3).invoke(input=query)
37
+ results = [
 
38
  {
39
+ "title": doc.get("title", ""),
40
+ "url": doc.get("url", ""),
41
+ "content": doc.get("content", ""),
42
  }
43
+ for doc in search_docs
44
  ]
45
+ return {"web_results": results}
46
  except Exception as e:
47
+ return {"web_results": f"Error retrieving results: {str(e)}"}
 
48
 
49
  @tool
50
  def wikipedia_search(query: str) -> dict:
config.py CHANGED
@@ -2,7 +2,11 @@ import os
2
 
3
  configs = {
4
  "DATA_PATH": "./docs/markdowns",
5
- "PERSIST_PATH": "./knowledge_base/chroma_data",
6
  "EMBEDDING_MODEL_NAME": "sentence-transformers/all-mpnet-base-v2",
7
  "COLLECTION_NAME": "langchain_mpnet_collection"
8
  }
 
 
 
 
 
2
 
3
  configs = {
4
  "DATA_PATH": "./docs/markdowns",
5
+ "PERSIST_PATH": "./chroma_data",
6
  "EMBEDDING_MODEL_NAME": "sentence-transformers/all-mpnet-base-v2",
7
  "COLLECTION_NAME": "langchain_mpnet_collection"
8
  }
9
+
10
+ if __name__ == "__main__":
11
+ for key, value in configs.items():
12
+ os.environ[key] = value
core/chat_interface.py CHANGED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from core.rag_agent import RAGAgent
3
+ from core.document_manager import DocumentManager
4
+ import os
5
+
6
+ # Initialize components
7
+ doc_manager = DocumentManager()
8
+ rag_agent = None
9
+
10
+ def initialize_agent():
11
+ """Initialize RAG agent lazily"""
12
+ global rag_agent
13
+ if rag_agent is None:
14
+ rag_agent = RAGAgent()
15
+ return rag_agent
16
+
17
+ def upload_files(files):
18
+ """Handle file uploads"""
19
+ if not files:
20
+ return "No files selected", get_file_list()
21
+
22
+ results = []
23
+ for file in files:
24
+ try:
25
+ result = doc_manager.add_document(file.name)
26
+ results.append(result)
27
+ except Exception as e:
28
+ results.append(f"Error processing {os.path.basename(file.name)}: {str(e)}")
29
+
30
+ return "\n".join(results), get_file_list()
31
+
32
+ def get_file_list():
33
+ """Get list of documents in the knowledge base"""
34
+ try:
35
+ files = doc_manager.list_documents()
36
+ if not files:
37
+ return "No documents in knowledge base"
38
+ return "\n".join([f"β€’ {f}" for f in files])
39
+ except Exception as e:
40
+ return f"Error listing files: {str(e)}"
41
+
42
+ def clear_database():
43
+ """Clear all documents from the knowledge base"""
44
+ try:
45
+ result = doc_manager.clear_all()
46
+ return result, get_file_list()
47
+ except Exception as e:
48
+ return f"Error clearing database: {str(e)}", get_file_list()
49
+
50
+ def chat_with_agent(message, history):
51
+ """Handle chat interactions with the RAG agent"""
52
+ if not message.strip():
53
+ return history
54
+
55
+ try:
56
+ agent = initialize_agent()
57
+
58
+ # Stream the agent's response
59
+ response_text = ""
60
+ for event in agent.agent_graph.stream(
61
+ {"messages": [("user", message)]},
62
+ agent.get_config(),
63
+ stream_mode="values"
64
+ ):
65
+ if "messages" in event and len(event["messages"]) > 0:
66
+ last_message = event["messages"][-1]
67
+ if hasattr(last_message, "content"):
68
+ response_text = last_message.content
69
+
70
+ if not response_text:
71
+ response_text = "I apologize, but I couldn't generate a response. Please try again."
72
+
73
+ return response_text
74
+
75
+ except Exception as e:
76
+ return f"Error: {str(e)}"
77
+
78
+ def reset_conversation():
79
+ """Reset the conversation thread"""
80
+ global rag_agent
81
+ if rag_agent:
82
+ rag_agent.reset_thread()
83
+ return None # Clear chat history
84
+
85
+ def create_gradio_ui():
86
+ """Create the complete Gradio interface"""
87
+
88
+ with gr.Blocks(title="RAG Agent with Agentic Memory", theme=gr.themes.Soft()) as demo:
89
+ gr.Markdown("""
90
+ # πŸ€– RAG Agent with Agentic Memory
91
+
92
+ Upload documents and chat with an intelligent agent that uses:
93
+ - πŸ“š **Local Knowledge Base** (ChromaDB)
94
+ - πŸ” **Web Search** (Tavily)
95
+ - πŸ“– **Wikipedia**
96
+ - πŸŽ“ **ArXiv** (Academic Papers)
97
+ """)
98
+
99
+ with gr.Tabs():
100
+ # Documents Tab
101
+ with gr.Tab("πŸ“„ Documents"):
102
+ gr.Markdown("### Upload and Manage Documents")
103
+ gr.Markdown("Upload PDF or Markdown files to add them to the knowledge base.")
104
+
105
+ with gr.Row():
106
+ with gr.Column(scale=2):
107
+ file_upload = gr.File(
108
+ label="Upload Documents",
109
+ file_count="multiple",
110
+ file_types=[".pdf", ".md"]
111
+ )
112
+ upload_btn = gr.Button("πŸ“€ Add to Knowledge Base", variant="primary")
113
+ upload_status = gr.Textbox(label="Upload Status", lines=3)
114
+
115
+ with gr.Column(scale=1):
116
+ file_list = gr.Textbox(
117
+ label="Documents in Knowledge Base",
118
+ lines=10,
119
+ value=get_file_list()
120
+ )
121
+ refresh_btn = gr.Button("πŸ”„ Refresh List")
122
+ clear_btn = gr.Button("πŸ—‘οΈ Clear All Documents", variant="stop")
123
+
124
+ # Connect document management buttons
125
+ upload_btn.click(
126
+ fn=upload_files,
127
+ inputs=[file_upload],
128
+ outputs=[upload_status, file_list]
129
+ )
130
+
131
+ refresh_btn.click(
132
+ fn=get_file_list,
133
+ outputs=[file_list]
134
+ )
135
+
136
+ clear_btn.click(
137
+ fn=clear_database,
138
+ outputs=[upload_status, file_list]
139
+ )
140
+
141
+ # Chat Tab
142
+ with gr.Tab("πŸ’¬ Chat"):
143
+ gr.Markdown("### Chat with Your Documents")
144
+ gr.Markdown("Ask questions about your documents or any topic. The agent will search multiple sources.")
145
+
146
+ chatbot = gr.Chatbot(
147
+ label="Conversation",
148
+ height=500,
149
+ show_label=True,
150
+ avatar_images=(None, "πŸ€–")
151
+ )
152
+
153
+ with gr.Row():
154
+ msg = gr.Textbox(
155
+ label="Your Message",
156
+ placeholder="Ask me anything about your documents or general knowledge...",
157
+ scale=4
158
+ )
159
+ submit_btn = gr.Button("Send", variant="primary", scale=1)
160
+
161
+ with gr.Row():
162
+ clear_chat_btn = gr.Button("πŸ”„ Reset Conversation")
163
+ gr.Markdown("*Note: Resetting clears the conversation history*")
164
+
165
+ # Chat interface
166
+ chat_interface = gr.ChatInterface(
167
+ fn=chat_with_agent,
168
+ chatbot=chatbot,
169
+ textbox=msg,
170
+ submit_btn=submit_btn,
171
+ retry_btn=None,
172
+ undo_btn=None,
173
+ clear_btn=None
174
+ )
175
+
176
+ clear_chat_btn.click(
177
+ fn=reset_conversation,
178
+ outputs=[chatbot]
179
+ )
180
+
181
+ gr.Markdown("""
182
+ ---
183
+ ### πŸ”§ How it works:
184
+ 1. **Upload documents** in the Documents tab
185
+ 2. **Ask questions** in the Chat tab
186
+ 3. The agent will:
187
+ - Analyze your query
188
+ - Search relevant sources
189
+ - Provide comprehensive answers with citations
190
+ """)
191
+
192
+ return demo
193
+
194
+ if __name__ == "__main__":
195
+ demo = create_gradio_ui()
196
+ demo.launch(share=False, server_name="127.0.0.1", server_port=7860)
core/rag_agent.py CHANGED
@@ -5,15 +5,18 @@ from agent.tools import *
5
  from agent.graph import create_agent_graph
6
 
7
  class RAGAgent:
8
-
9
- def __init__(self, collection_name=config.CHILD_COLLECTION):
10
- self.collection_name = collection_name
11
- self.retriever = intialize_chroma_vectorstore()
12
  self.thread_id = str(uuid.uuid4())
13
 
14
- self.llm = ChatGoogleGenerativeAI(model=config.LLM_MODEL, temperature=config.LLM_TEMPERATURE)
15
- tools = []
16
- self.agent_graph = create_agent_graph(self.llm, tools)
 
 
 
 
 
 
17
 
18
  def get_config(self):
19
  return {"configurable": {"thread_id": self.thread_id}}
 
5
  from agent.graph import create_agent_graph
6
 
7
  class RAGAgent:
8
+ def __init__(self):
 
 
 
9
  self.thread_id = str(uuid.uuid4())
10
 
11
+ self.llm = ChatGoogleGenerativeAI(
12
+ model=config.LLM_MODEL,
13
+ temperature=config.LLM_TEMPERATURE
14
+ )
15
+
16
+ vectordb = intialize_chroma_vectorstore()
17
+
18
+ search_tools = [web_search_tavily, arxiv_search, wikipedia_search]
19
+ self.agent_graph = create_agent_graph(self.llm, vectordb, search_tools)
20
 
21
  def get_config(self):
22
  return {"configurable": {"thread_id": self.thread_id}}
main.py CHANGED
@@ -1,6 +1,11 @@
1
- def main():
2
- print("Hello from rag-agent!")
3
-
4
 
5
  if __name__ == "__main__":
6
- main()
 
 
 
 
 
 
 
 
1
+ from ui.gradio_components import create_gradio_ui
 
 
2
 
3
  if __name__ == "__main__":
4
+ print("πŸš€ Launching RAG Agent UI...")
5
+ demo = create_gradio_ui()
6
+ demo.launch(
7
+ share=False,
8
+ server_name="127.0.0.1",
9
+ server_port=7860,
10
+ show_error=True
11
+ )
ui/gradio_components.py CHANGED
@@ -1,91 +1,116 @@
1
  import gradio as gr
2
- from core.chat_interface import ChatInterface
3
- from core.document_manager import DocumentManager
4
- from core.rag_system import RAGSystem
5
 
6
- def create_gradio_ui():
7
- rag_system = RAGSystem()
8
- rag_system.initialize()
9
-
10
- doc_manager = DocumentManager(rag_system)
11
- chat_interface = ChatInterface(rag_system)
12
-
13
- def format_file_list():
14
- files = doc_manager.get_markdown_files()
15
- if not files:
16
- return "πŸ“­ No documents available in the knowledge base"
17
- return "\n".join([f"{f}" for f in files])
 
 
18
 
19
- def upload_handler(files, progress=gr.Progress()):
20
- if not files:
21
- return None, format_file_list()
22
-
23
- added, skipped = doc_manager.add_documents(
24
- files,
25
- progress_callback=lambda p, desc: progress(p, desc=desc)
26
- )
27
 
28
- gr.Info(f"βœ… Added: {added} | Skipped: {skipped}")
29
- return None, format_file_list()
30
-
31
- def clear_handler():
32
- doc_manager.clear_all()
33
- gr.Info(f"πŸ—‘οΈ Removed all documents")
34
- return format_file_list()
35
-
36
- def chat_handler(msg, hist):
37
- return chat_interface.chat(msg, hist)
 
 
 
 
 
 
38
 
39
- def clear_chat_handler():
40
- chat_interface.clear_session()
 
 
 
 
 
 
 
 
 
 
41
 
42
- with gr.Blocks(title="Agentic RAG") as demo:
 
 
43
 
44
- with gr.Tab("Documents", elem_id="doc-management-tab"):
45
- gr.Markdown("## Add New Documents")
46
- gr.Markdown("Upload PDF or Markdown files. Duplicates will be automatically skipped.")
47
-
48
- files_input = gr.File(
49
- label="Drop PDF or Markdown files here",
50
- file_count="multiple",
51
- type="filepath",
52
- height=200,
53
- show_label=False
54
- )
55
-
56
- add_btn = gr.Button("Add Documents", variant="primary", size="md")
57
-
58
- gr.Markdown("## Current Documents in the Knowledge Base")
59
- file_list = gr.Textbox(
60
- value=format_file_list(),
61
- interactive=False,
62
- lines = 7,
63
- max_lines=10,
64
- elem_id="file-list-box",
65
- show_label=False
66
- )
67
-
68
- with gr.Row():
69
- refresh_btn = gr.Button("Refresh", size="md")
70
- clear_btn = gr.Button("Clear All", variant="stop", size="md")
71
-
72
- add_btn.click(
73
- upload_handler,
74
- [files_input],
75
- [files_input, file_list],
76
- show_progress="corner"
77
- )
78
- refresh_btn.click(format_file_list, None, file_list)
79
- clear_btn.click(clear_handler, None, file_list)
80
 
81
- with gr.Tab("Chat"):
82
- chatbot = gr.Chatbot(
83
- height=600,
84
- placeholder="Ask me anything about your documents!",
85
- show_label=False
 
 
 
 
 
 
 
86
  )
87
- chatbot.clear(clear_chat_handler)
88
-
89
- gr.ChatInterface(fn=chat_handler, chatbot=chatbot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
- return demo
 
 
 
 
 
1
  import gradio as gr
2
+ from core.rag_agent import RAGAgent
 
 
3
 
4
+ # Initialize components
5
+ rag_agent = None
6
+
7
+ def initialize_agent():
8
+ """Initialize RAG agent lazily"""
9
+ global rag_agent
10
+ if rag_agent is None:
11
+ rag_agent = RAGAgent()
12
+ return rag_agent
13
+
14
+ def chat_with_agent(message, history):
15
+ """Handle chat interactions with the RAG agent"""
16
+ if not message.strip():
17
+ return history
18
 
19
+ try:
20
+ agent = initialize_agent()
 
 
 
 
 
 
21
 
22
+ # Stream the agent's response
23
+ response_text = ""
24
+ for event in agent.agent_graph.stream(
25
+ {"messages": [("user", message)]},
26
+ agent.get_config(),
27
+ stream_mode="values"
28
+ ):
29
+ if "messages" in event and len(event["messages"]) > 0:
30
+ last_message = event["messages"][-1]
31
+ if hasattr(last_message, "content"):
32
+ response_text = last_message.content
33
+
34
+ if not response_text:
35
+ response_text = "I apologize, but I couldn't generate a response. Please try again."
36
+
37
+ return response_text
38
 
39
+ except Exception as e:
40
+ return f"Error: {str(e)}"
41
+
42
+ def reset_conversation():
43
+ """Reset the conversation thread"""
44
+ global rag_agent
45
+ if rag_agent:
46
+ rag_agent.reset_thread()
47
+ return None # Clear chat history
48
+
49
+ def create_gradio_ui():
50
+ """Create the complete Gradio interface"""
51
 
52
+ with gr.Blocks(title="RAG Agent with Agentic Memory", theme=gr.themes.Soft()) as demo:
53
+ gr.Markdown("""
54
+ # πŸ€– RAG Agent with Agentic Memory
55
 
56
+ Chat with an intelligent agent that uses:
57
+ - πŸ“š **Local Knowledge Base** (ChromaDB)
58
+ - πŸ” **Web Search** (Tavily)
59
+ - πŸ“– **Wikipedia**
60
+ - πŸŽ“ **ArXiv** (Academic Papers)
61
+ """)
62
+
63
+ gr.Markdown("### Chat with Your Documents")
64
+ gr.Markdown("Ask questions about your documents or any topic. The agent will search multiple sources.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ chatbot = gr.Chatbot(
67
+ label="Conversation",
68
+ height=500,
69
+ show_label=True,
70
+ avatar_images=(None, "πŸ€–")
71
+ )
72
+
73
+ with gr.Row():
74
+ msg = gr.Textbox(
75
+ label="Your Message",
76
+ placeholder="Ask me anything about your documents or general knowledge...",
77
+ scale=4
78
  )
79
+ submit_btn = gr.Button("Send", variant="primary", scale=1)
80
+
81
+ with gr.Row():
82
+ clear_chat_btn = gr.Button("πŸ”„ Reset Conversation")
83
+ gr.Markdown("*Note: Resetting clears the conversation history*")
84
+
85
+ # Chat interface
86
+ chat_interface = gr.ChatInterface(
87
+ fn=chat_with_agent,
88
+ chatbot=chatbot,
89
+ textbox=msg,
90
+ submit_btn=submit_btn,
91
+ retry_btn=None,
92
+ undo_btn=None,
93
+ clear_btn=None
94
+ )
95
+
96
+ clear_chat_btn.click(
97
+ fn=reset_conversation,
98
+ outputs=[chatbot]
99
+ )
100
+
101
+ gr.Markdown("""
102
+ ---
103
+ ### πŸ”§ How it works:
104
+ 1. **Ask questions** in the chat
105
+ 2. The agent will:
106
+ - Analyze your query
107
+ - Search relevant sources (ChromaDB, Web, Wikipedia, ArXiv)
108
+ - Provide comprehensive answers with citations
109
+ 3. Use **Reset Conversation** to start fresh
110
+ """)
111
 
112
+ return demo
113
+
114
+ if __name__ == "__main__":
115
+ demo = create_gradio_ui()
116
+ demo.launch(share=False, server_name="127.0.0.1", server_port=7860)