Nam Fam commited on
Commit Β·
1d49f4a
1
Parent(s): f6d2db2
update files
Browse files- agents/agent_graph.py +8 -4
- agents/chat/chat_agent.py +27 -16
- agents/documentsearch/documentsearch_agent.py +59 -15
- agents/orchestrator/orchestrator_agent.py +35 -39
- agents/websearch/websearch_agent.py +1 -1
- app.py +66 -10
- main.py +1 -1
agents/agent_graph.py
CHANGED
|
@@ -23,7 +23,7 @@ def chat_node(state):
|
|
| 23 |
|
| 24 |
def documentsearch_node(state):
|
| 25 |
print("[AgentGraph] Invoking DocumentSearchAgent")
|
| 26 |
-
agent = DocumentSearchAgent()
|
| 27 |
user_input = state['user_input']
|
| 28 |
chat_history = state.get('chat_history', [])
|
| 29 |
response, context, trace = agent.run(user_input, chat_history)
|
|
@@ -80,6 +80,9 @@ def finalize_node(state):
|
|
| 80 |
"""
|
| 81 |
Finalize response using OrchestratorAgent.system_prompt.
|
| 82 |
"""
|
|
|
|
|
|
|
|
|
|
| 83 |
return orchestrator.finalize_response(state)
|
| 84 |
|
| 85 |
def input_guard_node(state):
|
|
@@ -89,6 +92,7 @@ def input_guard_node(state):
|
|
| 89 |
return orchestrator.validate_input(state)
|
| 90 |
|
| 91 |
def orchestrator_node(state):
|
|
|
|
| 92 |
user_input = state['user_input']
|
| 93 |
chat_history = state.get('chat_history', [])
|
| 94 |
agent_name = orchestrator._decide_agent(user_input, chat_history=chat_history)
|
|
@@ -148,13 +152,13 @@ def build_agent_graph():
|
|
| 148 |
graph.add_edge('chat', 'finalize')
|
| 149 |
graph.add_edge('websearch', 'finalize')
|
| 150 |
graph.add_edge('documentsearch', 'evaluate')
|
| 151 |
-
# Evaluate documentsearch and fallback based on search toggles
|
| 152 |
graph.add_conditional_edges(
|
| 153 |
'evaluate',
|
| 154 |
lambda state: (
|
| 155 |
-
'finalize' if state.get('
|
|
|
|
| 156 |
else 'websearch' if state.get('enable_websearch', True)
|
| 157 |
-
# else 'documentsearch' if state.get('enable_docsearch', True)
|
| 158 |
else 'chat'
|
| 159 |
),
|
| 160 |
{
|
|
|
|
| 23 |
|
| 24 |
def documentsearch_node(state):
|
| 25 |
print("[AgentGraph] Invoking DocumentSearchAgent")
|
| 26 |
+
agent = DocumentSearchAgent(urls=state.get('doc_urls'), pdf_files=state.get('pdf_files'))
|
| 27 |
user_input = state['user_input']
|
| 28 |
chat_history = state.get('chat_history', [])
|
| 29 |
response, context, trace = agent.run(user_input, chat_history)
|
|
|
|
| 80 |
"""
|
| 81 |
Finalize response using OrchestratorAgent.system_prompt.
|
| 82 |
"""
|
| 83 |
+
# If document search indicated no sources, skip refinement
|
| 84 |
+
if any(entry.get('step') == 'no_sources' for entry in state.get('trace', [])):
|
| 85 |
+
return state
|
| 86 |
return orchestrator.finalize_response(state)
|
| 87 |
|
| 88 |
def input_guard_node(state):
|
|
|
|
| 92 |
return orchestrator.validate_input(state)
|
| 93 |
|
| 94 |
def orchestrator_node(state):
|
| 95 |
+
print("[AgentGraph] Invoking OrchestratorAgent")
|
| 96 |
user_input = state['user_input']
|
| 97 |
chat_history = state.get('chat_history', [])
|
| 98 |
agent_name = orchestrator._decide_agent(user_input, chat_history=chat_history)
|
|
|
|
| 152 |
graph.add_edge('chat', 'finalize')
|
| 153 |
graph.add_edge('websearch', 'finalize')
|
| 154 |
graph.add_edge('documentsearch', 'evaluate')
|
| 155 |
+
# Evaluate documentsearch and fallback based on search toggles, but if no_sources, finalize immediately
|
| 156 |
graph.add_conditional_edges(
|
| 157 |
'evaluate',
|
| 158 |
lambda state: (
|
| 159 |
+
'finalize' if any(step.get('step') == 'no_sources' for step in state.get('trace', []))
|
| 160 |
+
else 'finalize' if state.get('answered', False)
|
| 161 |
else 'websearch' if state.get('enable_websearch', True)
|
|
|
|
| 162 |
else 'chat'
|
| 163 |
),
|
| 164 |
{
|
agents/chat/chat_agent.py
CHANGED
|
@@ -16,7 +16,7 @@ class ChatAgent(AgentBase):
|
|
| 16 |
prompt = PromptTemplate(
|
| 17 |
input_variables=["input"],
|
| 18 |
template=(
|
| 19 |
-
"You are
|
| 20 |
"\nQuestion: {input}\n{agent_scratchpad}"
|
| 21 |
)
|
| 22 |
)
|
|
@@ -106,10 +106,15 @@ class ChatAgent(AgentBase):
|
|
| 106 |
user_input = state['user_input']
|
| 107 |
history_str = state.get('history_str', '')
|
| 108 |
prompt = (
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
)
|
|
|
|
| 113 |
state['prompt'] = prompt
|
| 114 |
state.setdefault('trace', []).append({'step': 'build_prompt', 'prompt': prompt})
|
| 115 |
return state
|
|
@@ -119,12 +124,13 @@ class ChatAgent(AgentBase):
|
|
| 119 |
messages = [HumanMessage(content=state['prompt'])]
|
| 120 |
# res = self.llm.chat_model.invoke(messages)
|
| 121 |
res = self.agent_executor.invoke({"input": state['prompt']})
|
|
|
|
| 122 |
output = res.get("output")
|
| 123 |
actions = res.get("actions", [])
|
| 124 |
state['response'] = output
|
| 125 |
state['handled'] = True
|
| 126 |
state['actions'] = actions
|
| 127 |
-
state.setdefault('trace', []).append({'step': 'llm_call', 'response': output, 'actions': actions
|
| 128 |
return state
|
| 129 |
|
| 130 |
|
|
@@ -171,14 +177,19 @@ class ChatAgent(AgentBase):
|
|
| 171 |
|
| 172 |
def run(self, user_input: str, chat_history=None):
|
| 173 |
# Try the tool-enabled agent executor first
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
prompt = PromptTemplate(
|
| 17 |
input_variables=["input"],
|
| 18 |
template=(
|
| 19 |
+
"You are Kia, the Know-It-All assistant. Use the provided tools when appropriate.\n"
|
| 20 |
"\nQuestion: {input}\n{agent_scratchpad}"
|
| 21 |
)
|
| 22 |
)
|
|
|
|
| 106 |
user_input = state['user_input']
|
| 107 |
history_str = state.get('history_str', '')
|
| 108 |
prompt = (
|
| 109 |
+
"You are Kia, the Know-It-All assistant.\n"
|
| 110 |
+
"Your task is to answer the user's question.\n"
|
| 111 |
+
"Here is the conversation so far:\n"
|
| 112 |
+
f"{history_str}\n"
|
| 113 |
+
"Here is the user's question:\n"
|
| 114 |
+
f"{user_input}\n"
|
| 115 |
+
"Your response:\n"
|
| 116 |
)
|
| 117 |
+
|
| 118 |
state['prompt'] = prompt
|
| 119 |
state.setdefault('trace', []).append({'step': 'build_prompt', 'prompt': prompt})
|
| 120 |
return state
|
|
|
|
| 124 |
messages = [HumanMessage(content=state['prompt'])]
|
| 125 |
# res = self.llm.chat_model.invoke(messages)
|
| 126 |
res = self.agent_executor.invoke({"input": state['prompt']})
|
| 127 |
+
# print('res', res)
|
| 128 |
output = res.get("output")
|
| 129 |
actions = res.get("actions", [])
|
| 130 |
state['response'] = output
|
| 131 |
state['handled'] = True
|
| 132 |
state['actions'] = actions
|
| 133 |
+
state.setdefault('trace', []).append({'step': 'llm_call', 'response': output, 'actions': actions})
|
| 134 |
return state
|
| 135 |
|
| 136 |
|
|
|
|
| 177 |
|
| 178 |
def run(self, user_input: str, chat_history=None):
|
| 179 |
# Try the tool-enabled agent executor first
|
| 180 |
+
# Fallback to the original graph-based flow
|
| 181 |
+
state = {'user_input': user_input, 'chat_history': chat_history or [], 'trace': []}
|
| 182 |
+
result = self.graph.invoke(state)
|
| 183 |
+
return result['response'], result['trace']
|
| 184 |
+
|
| 185 |
+
# try:
|
| 186 |
+
# # Use invoke to capture function calls and args
|
| 187 |
+
# result = self.agent_executor.invoke({"input": user_input})
|
| 188 |
+
# output = result.get("output")
|
| 189 |
+
# actions = result.get("actions", [])
|
| 190 |
+
# return output, actions
|
| 191 |
+
# except Exception:
|
| 192 |
+
# # Fallback to the original graph-based flow
|
| 193 |
+
# state = {'user_input': user_input, 'chat_history': chat_history or [], 'trace': []}
|
| 194 |
+
# result = self.graph.invoke(state)
|
| 195 |
+
# return result['response'], result['trace']
|
agents/documentsearch/documentsearch_agent.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from langchain_community.document_loaders import WebBaseLoader
|
| 2 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 3 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 4 |
from langchain_community.vectorstores import FAISS
|
|
@@ -14,26 +14,58 @@ def get_embeddings():
|
|
| 14 |
return embeddings
|
| 15 |
|
| 16 |
class DocumentSearchAgent(AgentBase):
|
| 17 |
-
def __init__(self, urls=None):
|
| 18 |
-
#
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
| 24 |
-
docs = splitter.split_documents(
|
| 25 |
embeddings = get_embeddings()
|
| 26 |
self.vectorstore = FAISS.from_documents(docs, embeddings)
|
| 27 |
self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
|
| 28 |
self.graph = self.build_graph()
|
|
|
|
| 29 |
|
| 30 |
def build_graph(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
def rephrase(state):
|
| 32 |
-
|
| 33 |
-
state.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
return state
|
| 35 |
def retrieve(state):
|
| 36 |
-
query = state['
|
| 37 |
results = self.retriever.invoke(query)
|
| 38 |
# print("len(results)", len(results))
|
| 39 |
# Store raw results instead of joining them
|
|
@@ -45,6 +77,7 @@ class DocumentSearchAgent(AgentBase):
|
|
| 45 |
context = state.get('context', [])
|
| 46 |
# print("len(context)", len(context))
|
| 47 |
user_input = state['user_input']
|
|
|
|
| 48 |
llm = LLM()
|
| 49 |
|
| 50 |
# Format context as markdown table, one row per document
|
|
@@ -59,25 +92,36 @@ class DocumentSearchAgent(AgentBase):
|
|
| 59 |
context_table = "| # | Content |\n|---|---------|\n" + "\n".join(table_rows)
|
| 60 |
|
| 61 |
prompt = (
|
| 62 |
-
|
| 63 |
f"Excerpts (in markdown table format):\n{context_table}\n\n"
|
| 64 |
-
f"User: {
|
| 65 |
)
|
| 66 |
response = llm.generate(prompt)
|
| 67 |
state['response'] = response
|
| 68 |
state.setdefault('trace', []).append({'step': 'synthesize', 'prompt': prompt, 'response': response})
|
| 69 |
return state
|
| 70 |
graph = StateGraph(dict)
|
|
|
|
| 71 |
graph.add_node('rephrase', rephrase)
|
| 72 |
graph.add_node('retrieve', retrieve)
|
| 73 |
graph.add_node('synthesize', synthesize)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
graph.add_edge('rephrase', 'retrieve')
|
| 75 |
graph.add_edge('retrieve', 'synthesize')
|
| 76 |
graph.add_edge('synthesize', END)
|
| 77 |
-
graph.set_entry_point('
|
| 78 |
return graph.compile()
|
| 79 |
|
| 80 |
def run(self, query: str, chat_history=None):
|
| 81 |
state = {'user_input': query, 'chat_history': chat_history or [], 'trace': []}
|
| 82 |
result = self.graph.invoke(state)
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langchain_community.document_loaders import WebBaseLoader, PyPDFLoader
|
| 2 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 3 |
from langchain_google_genai import GoogleGenerativeAIEmbeddings
|
| 4 |
from langchain_community.vectorstores import FAISS
|
|
|
|
| 14 |
return embeddings
|
| 15 |
|
| 16 |
class DocumentSearchAgent(AgentBase):
|
| 17 |
+
def __init__(self, urls=None, pdf_files=None):
|
| 18 |
+
# Prepare raw documents list (only from provided URLs or PDFs)
|
| 19 |
+
docs_raw = []
|
| 20 |
+
# Load from URLs
|
| 21 |
+
if urls:
|
| 22 |
+
loader = WebBaseLoader(urls)
|
| 23 |
+
docs_raw.extend(loader.load())
|
| 24 |
+
# Load from uploaded PDF files
|
| 25 |
+
if pdf_files:
|
| 26 |
+
import tempfile, os
|
| 27 |
+
for uploaded_file in pdf_files:
|
| 28 |
+
suffix = os.path.splitext(uploaded_file.name)[1] or ".pdf"
|
| 29 |
+
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
|
| 30 |
+
tmp.write(uploaded_file.read())
|
| 31 |
+
tmp.close()
|
| 32 |
+
loader_pdf = PyPDFLoader(tmp.name)
|
| 33 |
+
docs_raw.extend(loader_pdf.load())
|
| 34 |
+
# If no sources provided, skip indexing and prepare guard-only graph
|
| 35 |
+
if not docs_raw:
|
| 36 |
+
self.vectorstore = None
|
| 37 |
+
self.retriever = None
|
| 38 |
+
self.graph = self.build_graph()
|
| 39 |
+
return
|
| 40 |
+
# Split documents into chunks and build index
|
| 41 |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
|
| 42 |
+
docs = splitter.split_documents(docs_raw)
|
| 43 |
embeddings = get_embeddings()
|
| 44 |
self.vectorstore = FAISS.from_documents(docs, embeddings)
|
| 45 |
self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
|
| 46 |
self.graph = self.build_graph()
|
| 47 |
+
self.llm = LLM()
|
| 48 |
|
| 49 |
def build_graph(self):
|
| 50 |
+
# Guard node: check for document sources
|
| 51 |
+
def check_sources(state):
|
| 52 |
+
if self.retriever is None:
|
| 53 |
+
msg = "π TΓnh nΔng tΓ¬m kiαΊΏm tΓ i liα»u ΔΓ£ bαΊt, nhΖ°ng bαΊ‘n chΖ°a tαΊ£i lΓͺn tΓ i liα»u hoαΊ·c nhαΊp URL. Vui lΓ²ng bα» sung Δα» tiαΊΏp tα»₯c."
|
| 54 |
+
state['response'] = msg
|
| 55 |
+
state['handled'] = True
|
| 56 |
+
state.setdefault('trace', []).append({'step': 'no_sources', 'response': msg})
|
| 57 |
+
return state
|
| 58 |
def rephrase(state):
|
| 59 |
+
history_str = ""
|
| 60 |
+
chat_history = state.get('chat_history', [])
|
| 61 |
+
for turn in chat_history:
|
| 62 |
+
history_str += f"User: {turn['user']}\nBot: {turn['bot']}\n"
|
| 63 |
+
rephrased_query = self.llm.generate(f"User: {state['user_input']}\nBot: Rephrase the user's question so it is clear and complete for a document search. Only output the rephrased question.\n{history_str}").strip()
|
| 64 |
+
state['rephrased_query'] = rephrased_query
|
| 65 |
+
state.setdefault('trace', []).append({'step': 'rephrase', 'query': state['user_input'], 'rephrased_query': rephrased_query})
|
| 66 |
return state
|
| 67 |
def retrieve(state):
|
| 68 |
+
query = state['rephrased_query']
|
| 69 |
results = self.retriever.invoke(query)
|
| 70 |
# print("len(results)", len(results))
|
| 71 |
# Store raw results instead of joining them
|
|
|
|
| 77 |
context = state.get('context', [])
|
| 78 |
# print("len(context)", len(context))
|
| 79 |
user_input = state['user_input']
|
| 80 |
+
rephrased_query = state['rephrased_query']
|
| 81 |
llm = LLM()
|
| 82 |
|
| 83 |
# Format context as markdown table, one row per document
|
|
|
|
| 92 |
context_table = "| # | Content |\n|---|---------|\n" + "\n".join(table_rows)
|
| 93 |
|
| 94 |
prompt = (
|
| 95 |
+
"You are Kia, the Know-It-All assistant. Based on the following provided document excerpts, answer the user's question as accurately and concisely as possible.\n\n"
|
| 96 |
f"Excerpts (in markdown table format):\n{context_table}\n\n"
|
| 97 |
+
f"User: {rephrased_query}\nBot:"
|
| 98 |
)
|
| 99 |
response = llm.generate(prompt)
|
| 100 |
state['response'] = response
|
| 101 |
state.setdefault('trace', []).append({'step': 'synthesize', 'prompt': prompt, 'response': response})
|
| 102 |
return state
|
| 103 |
graph = StateGraph(dict)
|
| 104 |
+
graph.add_node('check_sources', check_sources)
|
| 105 |
graph.add_node('rephrase', rephrase)
|
| 106 |
graph.add_node('retrieve', retrieve)
|
| 107 |
graph.add_node('synthesize', synthesize)
|
| 108 |
+
# If no sources, exit immediately
|
| 109 |
+
graph.add_conditional_edges(
|
| 110 |
+
'check_sources',
|
| 111 |
+
lambda state: state.get('handled', False),
|
| 112 |
+
{True: END, False: 'rephrase'}
|
| 113 |
+
)
|
| 114 |
graph.add_edge('rephrase', 'retrieve')
|
| 115 |
graph.add_edge('retrieve', 'synthesize')
|
| 116 |
graph.add_edge('synthesize', END)
|
| 117 |
+
graph.set_entry_point('check_sources')
|
| 118 |
return graph.compile()
|
| 119 |
|
| 120 |
def run(self, query: str, chat_history=None):
|
| 121 |
state = {'user_input': query, 'chat_history': chat_history or [], 'trace': []}
|
| 122 |
result = self.graph.invoke(state)
|
| 123 |
+
# Safely extract response, context, and trace
|
| 124 |
+
response = result.get('response', '')
|
| 125 |
+
context = result.get('context', [])
|
| 126 |
+
trace = result.get('trace', [])
|
| 127 |
+
return response, context, trace
|
agents/orchestrator/orchestrator_agent.py
CHANGED
|
@@ -11,11 +11,9 @@ class OrchestratorAgent:
|
|
| 11 |
self.llm = LLM()
|
| 12 |
# System prompt for all agents
|
| 13 |
self.system_prompt = (
|
| 14 |
-
"You are
|
| 15 |
-
"Provide accurate, concise, and clear responses in a professional style. "
|
| 16 |
"Respond in the user's language. "
|
| 17 |
-
"Do not provide personal or sensitive data, refuse harmful or inappropriate requests
|
| 18 |
-
"avoid off-topic responses (politely decline queries outside Thang Long University domain). "
|
| 19 |
)
|
| 20 |
|
| 21 |
def _decide_agent(self, user_input: str, chat_history=None) -> str:
|
|
@@ -80,15 +78,13 @@ class OrchestratorAgent:
|
|
| 80 |
context_table = "| Role | Message |\n|------|---------|\n" + "\n".join(table_rows)
|
| 81 |
|
| 82 |
prompt = (
|
| 83 |
-
"As an AI assistant, analyze the user's query and conversation context to determine if it requires searching through
|
| 84 |
"Context:\n"
|
| 85 |
f"Recent conversation:\n{context_table}\n\n"
|
| 86 |
f"Current query: {user_input}\n\n"
|
| 87 |
"Consider these factors:\n"
|
| 88 |
-
"1.
|
| 89 |
-
"2.
|
| 90 |
-
"3. Is the information likely to be found in university documentation?\n"
|
| 91 |
-
"4. Would document search be more reliable than web search for this query?\n\n"
|
| 92 |
"Output ONLY 'documentsearch' if document search is most appropriate, or 'other' if not."
|
| 93 |
)
|
| 94 |
|
|
@@ -117,25 +113,13 @@ class OrchestratorAgent:
|
|
| 117 |
if table_rows:
|
| 118 |
history_table = "| Role | Message |\n|------|---------|\n" + "\n".join(table_rows)
|
| 119 |
|
| 120 |
-
#
|
| 121 |
-
# prompt = (
|
| 122 |
-
# "You are an orchestrator for a multi-agent assistant. "
|
| 123 |
-
# "Decide which agent should handle the user's latest message: "
|
| 124 |
-
# "output ONLY 'chat' (for general conversation, Q&A, reasoning, etc.) "
|
| 125 |
-
# "or 'websearch' (if the user is asking for real-time, factual, or web-based information.)\n\n"
|
| 126 |
-
# f"Conversation so far:\n{history_table}\n\n"
|
| 127 |
-
# f"Current query: {user_input}\n\n"
|
| 128 |
-
# "Which agent should handle this? (chat/websearch):"
|
| 129 |
-
# )
|
| 130 |
-
# decision = self.llm.generate(prompt).strip().lower()
|
| 131 |
-
# return "websearch" if "websearch" in decision else "chat"
|
| 132 |
-
# New custom prompt for Thang Long University:
|
| 133 |
prompt = (
|
| 134 |
-
"You are an orchestrator
|
| 135 |
"Route the user's query to the most suitable agent: "
|
| 136 |
-
"'chat' for
|
| 137 |
"'websearch' for real-time or external data, "
|
| 138 |
-
"'documentsearch' for
|
| 139 |
f"Conversation history:\n{history_table}\n\n"
|
| 140 |
f"User query: {user_input}\n\n"
|
| 141 |
"Respond with ONLY the agent name (chat/websearch/documentsearch)."
|
|
@@ -162,12 +146,12 @@ class OrchestratorAgent:
|
|
| 162 |
"websearch_results": []
|
| 163 |
}
|
| 164 |
|
| 165 |
-
def finalize_response(self,
|
| 166 |
"""
|
| 167 |
Refine the routed agent's output using the system prompt.
|
| 168 |
"""
|
| 169 |
-
raw =
|
| 170 |
-
question =
|
| 171 |
# Construct detailed refinement prompt
|
| 172 |
prompt = (
|
| 173 |
f"{self.system_prompt}\n\n"
|
|
@@ -176,18 +160,30 @@ class OrchestratorAgent:
|
|
| 176 |
f"{raw}\n\n"
|
| 177 |
"Please refine this answer: polite, professional, and in the user's language; focus solely on answering the question without extraneous remarks. "
|
| 178 |
"Return only the final answer text."
|
|
|
|
|
|
|
| 179 |
)
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
#
|
| 183 |
-
|
| 184 |
-
result_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
'step': 'finalize',
|
| 186 |
'prompt': prompt,
|
| 187 |
-
'response':
|
| 188 |
'agent': 'orchestrator'
|
| 189 |
})
|
| 190 |
-
return
|
| 191 |
|
| 192 |
def validate_input(self, state: dict) -> dict:
|
| 193 |
"""
|
|
@@ -199,7 +195,7 @@ class OrchestratorAgent:
|
|
| 199 |
# Empty input
|
| 200 |
if not user_input:
|
| 201 |
state['guard_failed'] = True
|
| 202 |
-
state['response'] = "π Sorry, I didn't catch that. Please ask a question
|
| 203 |
state['agent'] = 'orchestrator'
|
| 204 |
# Record validation failure
|
| 205 |
state.setdefault('trace', []).append({'step': 'validate', 'reason': 'empty input', 'agent': 'orchestrator'})
|
|
@@ -212,15 +208,15 @@ class OrchestratorAgent:
|
|
| 212 |
history_str += f"User: {turn['user']}\nBot: {turn['bot']}\n"
|
| 213 |
classification_prompt = (
|
| 214 |
f"Given the following conversation history:\n{history_str}"
|
| 215 |
-
"Classify the following query as ON_TOPIC or OFF_TOPIC
|
| 216 |
f"Query: \"{user_input}\"\n"
|
| 217 |
-
"General greetings are not considered off-topic."
|
| 218 |
"Respond with exactly ON_TOPIC or OFF_TOPIC."
|
| 219 |
)
|
| 220 |
classification = self.llm.generate(classification_prompt).strip().upper()
|
| 221 |
if classification != "ON_TOPIC":
|
| 222 |
state['guard_failed'] = False
|
| 223 |
-
state['response'] = "π Sorry, I can only answer
|
| 224 |
state['agent'] = 'orchestrator'
|
| 225 |
# Record off-topic validation
|
| 226 |
state.setdefault('trace', []).append({'step': 'validate', 'reason': 'off topic_lm', 'agent': 'orchestrator'})
|
|
|
|
| 11 |
self.llm = LLM()
|
| 12 |
# System prompt for all agents
|
| 13 |
self.system_prompt = (
|
| 14 |
+
"You are Kia, the Know-It-All assistant. Provide accurate, concise, and clear responses in a professional style. "
|
|
|
|
| 15 |
"Respond in the user's language. "
|
| 16 |
+
"Do not provide personal or sensitive data, and refuse harmful or inappropriate requests."
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
def _decide_agent(self, user_input: str, chat_history=None) -> str:
|
|
|
|
| 78 |
context_table = "| Role | Message |\n|------|---------|\n" + "\n".join(table_rows)
|
| 79 |
|
| 80 |
prompt = (
|
| 81 |
+
"As an AI assistant, analyze the user's query and conversation context to determine if it requires searching through the provided documents (uploaded PDFs or URLs).\n\n"
|
| 82 |
"Context:\n"
|
| 83 |
f"Recent conversation:\n{context_table}\n\n"
|
| 84 |
f"Current query: {user_input}\n\n"
|
| 85 |
"Consider these factors:\n"
|
| 86 |
+
"1. Does the query require locating information within the provided documents?\n"
|
| 87 |
+
"2. Is document search more reliable than web search for this query?\n\n"
|
|
|
|
|
|
|
| 88 |
"Output ONLY 'documentsearch' if document search is most appropriate, or 'other' if not."
|
| 89 |
)
|
| 90 |
|
|
|
|
| 113 |
if table_rows:
|
| 114 |
history_table = "| Role | Message |\n|------|---------|\n" + "\n".join(table_rows)
|
| 115 |
|
| 116 |
+
# New custom prompt for general assistant:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
prompt = (
|
| 118 |
+
"You are an orchestrator for a general AI assistant. "
|
| 119 |
"Route the user's query to the most suitable agent: "
|
| 120 |
+
"'chat' for general conversation or reasoning, "
|
| 121 |
"'websearch' for real-time or external data, "
|
| 122 |
+
"'documentsearch' for searching within provided documents (PDFs/URLs). "
|
| 123 |
f"Conversation history:\n{history_table}\n\n"
|
| 124 |
f"User query: {user_input}\n\n"
|
| 125 |
"Respond with ONLY the agent name (chat/websearch/documentsearch)."
|
|
|
|
| 146 |
"websearch_results": []
|
| 147 |
}
|
| 148 |
|
| 149 |
+
def finalize_response(self, state: dict) -> dict:
|
| 150 |
"""
|
| 151 |
Refine the routed agent's output using the system prompt.
|
| 152 |
"""
|
| 153 |
+
raw = state.get('response', '')
|
| 154 |
+
question = state.get('user_input', '')
|
| 155 |
# Construct detailed refinement prompt
|
| 156 |
prompt = (
|
| 157 |
f"{self.system_prompt}\n\n"
|
|
|
|
| 160 |
f"{raw}\n\n"
|
| 161 |
"Please refine this answer: polite, professional, and in the user's language; focus solely on answering the question without extraneous remarks. "
|
| 162 |
"Return only the final answer text."
|
| 163 |
+
"Chα» trong trΖ°α»ng hợp draft answer chα»©a nhiα»u mΓ’u thuαΊ«n, khΓ΄ng trαΊ£ lα»i ΔαΊ§y Δα»§, hoαΊ·c khΓ΄ng chΓnh xΓ‘c, hΓ£y tinh chα»nh. "
|
| 164 |
+
"NαΊΏu khΓ΄ng, hΓ£y giα»― nguyΓͺn."
|
| 165 |
)
|
| 166 |
+
|
| 167 |
+
# refined = self.llm.generate(prompt).strip()
|
| 168 |
+
# result_state['response'] = refined
|
| 169 |
+
# # Label finalized output under orchestrator
|
| 170 |
+
# result_state['agent'] = 'orchestrator'
|
| 171 |
+
# result_state.setdefault('trace', []).append({
|
| 172 |
+
# 'step': 'finalize',
|
| 173 |
+
# 'prompt': prompt,
|
| 174 |
+
# 'response': refined,
|
| 175 |
+
# 'agent': 'orchestrator'
|
| 176 |
+
# })
|
| 177 |
+
|
| 178 |
+
state['response'] = raw
|
| 179 |
+
state['agent'] = 'orchestrator'
|
| 180 |
+
state.setdefault('trace', []).append({
|
| 181 |
'step': 'finalize',
|
| 182 |
'prompt': prompt,
|
| 183 |
+
'response': raw,
|
| 184 |
'agent': 'orchestrator'
|
| 185 |
})
|
| 186 |
+
return state
|
| 187 |
|
| 188 |
def validate_input(self, state: dict) -> dict:
|
| 189 |
"""
|
|
|
|
| 195 |
# Empty input
|
| 196 |
if not user_input:
|
| 197 |
state['guard_failed'] = True
|
| 198 |
+
state['response'] = "π Sorry, I didn't catch that. Please ask a question."
|
| 199 |
state['agent'] = 'orchestrator'
|
| 200 |
# Record validation failure
|
| 201 |
state.setdefault('trace', []).append({'step': 'validate', 'reason': 'empty input', 'agent': 'orchestrator'})
|
|
|
|
| 208 |
history_str += f"User: {turn['user']}\nBot: {turn['bot']}\n"
|
| 209 |
classification_prompt = (
|
| 210 |
f"Given the following conversation history:\n{history_str}"
|
| 211 |
+
"Classify the following query as ON_TOPIC or OFF_TOPIC with respect to a general conversational assistant.\n"
|
| 212 |
f"Query: \"{user_input}\"\n"
|
| 213 |
+
"General greetings are not considered off-topic.\n"
|
| 214 |
"Respond with exactly ON_TOPIC or OFF_TOPIC."
|
| 215 |
)
|
| 216 |
classification = self.llm.generate(classification_prompt).strip().upper()
|
| 217 |
if classification != "ON_TOPIC":
|
| 218 |
state['guard_failed'] = False
|
| 219 |
+
state['response'] = "π Sorry, I can only answer on-topic questions."
|
| 220 |
state['agent'] = 'orchestrator'
|
| 221 |
# Record off-topic validation
|
| 222 |
state.setdefault('trace', []).append({'step': 'validate', 'reason': 'off topic_lm', 'agent': 'orchestrator'})
|
agents/websearch/websearch_agent.py
CHANGED
|
@@ -17,7 +17,7 @@ class WebSearchAgent(AgentBase):
|
|
| 17 |
history_str += f"User: {turn['user']}\nBot: {turn['bot']}\n"
|
| 18 |
state['history_str'] = history_str
|
| 19 |
rephrase_prompt = (
|
| 20 |
-
f"You are a helpful assistant
|
| 21 |
f"Given the following conversation history:\n"
|
| 22 |
f"{history_str}"
|
| 23 |
f"User: {state['user_input']}\n"
|
|
|
|
| 17 |
history_str += f"User: {turn['user']}\nBot: {turn['bot']}\n"
|
| 18 |
state['history_str'] = history_str
|
| 19 |
rephrase_prompt = (
|
| 20 |
+
f"You are a helpful assistant.\n"
|
| 21 |
f"Given the following conversation history:\n"
|
| 22 |
f"{history_str}"
|
| 23 |
f"User: {state['user_input']}\n"
|
app.py
CHANGED
|
@@ -13,13 +13,13 @@ import time
|
|
| 13 |
def generate_response_message(response):
|
| 14 |
full_response = ""
|
| 15 |
response_words = response.split()
|
| 16 |
-
with st.chat_message("Kia", avatar="π€"):
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
return full_response
|
| 24 |
|
| 25 |
# Function to generate initial message
|
|
@@ -59,8 +59,13 @@ with st.sidebar:
|
|
| 59 |
if 'enable_websearch' not in st.session_state:
|
| 60 |
st.session_state['enable_websearch'] = True
|
| 61 |
with st.expander("Agent Settings", expanded=True):
|
| 62 |
-
|
| 63 |
st.checkbox("Enable Document Search", value=st.session_state['enable_docsearch'], key='enable_docsearch')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
st.checkbox("Enable Web Search", value=st.session_state['enable_websearch'], key='enable_websearch')
|
| 65 |
|
| 66 |
# Chat Controls
|
|
@@ -73,7 +78,7 @@ with st.sidebar:
|
|
| 73 |
# Settings Section
|
| 74 |
with st.expander("π οΈ Display Settings", expanded=True):
|
| 75 |
DEV_MODE = st.checkbox("Enable Dev Mode", value=False)
|
| 76 |
-
STREAMING_ENABLED = st.checkbox("Enable response streaming", value=True)
|
| 77 |
# # Theme selector
|
| 78 |
# if 'theme' not in st.session_state:
|
| 79 |
# st.session_state.theme = "Light"
|
|
@@ -177,11 +182,56 @@ for i, turn in enumerate(st.session_state['chat_history']):
|
|
| 177 |
continue
|
| 178 |
# Assistant response rendering (streaming optional)
|
| 179 |
if i == len(st.session_state['chat_history']) - 1 and STREAMING_ENABLED and st.session_state['should_stream']:
|
| 180 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
st.session_state['should_stream'] = False
|
| 182 |
else:
|
| 183 |
with st.chat_message("Kia", avatar="π€"):
|
| 184 |
st.markdown(turn['bot'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
# DEV_MODE trace expander (all completed turns)
|
| 186 |
if DEV_MODE and turn.get('trace'):
|
| 187 |
trace = turn.get('trace', [])
|
|
@@ -234,11 +284,17 @@ if st.session_state['waiting_for_response']:
|
|
| 234 |
user_input = st.session_state['chat_history'][-1]['user']
|
| 235 |
agent_graph = build_agent_graph()
|
| 236 |
chat_history = st.session_state['chat_history'][:-1] # Exclude the waiting message
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
state = {
|
| 238 |
"user_input": user_input,
|
| 239 |
"chat_history": chat_history,
|
| 240 |
"enable_docsearch": st.session_state['enable_docsearch'],
|
| 241 |
"enable_websearch": st.session_state['enable_websearch'],
|
|
|
|
|
|
|
| 242 |
}
|
| 243 |
|
| 244 |
# --- Invoke agent graph (may take time) with error handling ---
|
|
|
|
| 13 |
def generate_response_message(response):
|
| 14 |
full_response = ""
|
| 15 |
response_words = response.split()
|
| 16 |
+
# with st.chat_message("Kia", avatar="π€"):
|
| 17 |
+
message_placeholder = st.empty()
|
| 18 |
+
for word in response_words:
|
| 19 |
+
full_response += word + " "
|
| 20 |
+
message_placeholder.markdown(full_response + "β")
|
| 21 |
+
time.sleep(0.05)
|
| 22 |
+
message_placeholder.markdown(full_response)
|
| 23 |
return full_response
|
| 24 |
|
| 25 |
# Function to generate initial message
|
|
|
|
| 59 |
if 'enable_websearch' not in st.session_state:
|
| 60 |
st.session_state['enable_websearch'] = True
|
| 61 |
with st.expander("Agent Settings", expanded=True):
|
|
|
|
| 62 |
st.checkbox("Enable Document Search", value=st.session_state['enable_docsearch'], key='enable_docsearch')
|
| 63 |
+
if st.session_state['enable_docsearch']:
|
| 64 |
+
st.write("Upload PDF(s) or enter URL(s) to index for document search")
|
| 65 |
+
st.file_uploader("Upload PDF document(s)", type=["pdf"], accept_multiple_files=True, key="pdf_files")
|
| 66 |
+
st.text_input("Enter URL(s), comma-separated", key="doc_urls")
|
| 67 |
+
if st.button("Rebuild Index"):
|
| 68 |
+
st.experimental_rerun()
|
| 69 |
st.checkbox("Enable Web Search", value=st.session_state['enable_websearch'], key='enable_websearch')
|
| 70 |
|
| 71 |
# Chat Controls
|
|
|
|
| 78 |
# Settings Section
|
| 79 |
with st.expander("π οΈ Display Settings", expanded=True):
|
| 80 |
DEV_MODE = st.checkbox("Enable Dev Mode", value=False)
|
| 81 |
+
STREAMING_ENABLED = st.checkbox("Enable response streaming (just the effect π)", value=True)
|
| 82 |
# # Theme selector
|
| 83 |
# if 'theme' not in st.session_state:
|
| 84 |
# st.session_state.theme = "Light"
|
|
|
|
| 182 |
continue
|
| 183 |
# Assistant response rendering (streaming optional)
|
| 184 |
if i == len(st.session_state['chat_history']) - 1 and STREAMING_ENABLED and st.session_state['should_stream']:
|
| 185 |
+
with st.chat_message("Kia", avatar="π€"):
|
| 186 |
+
generate_response_message(turn['bot'])
|
| 187 |
+
|
| 188 |
+
# Display sources for websearch in streaming mode
|
| 189 |
+
if turn.get('trace') is not None:
|
| 190 |
+
for step in turn['trace']:
|
| 191 |
+
if step.get('agent') == 'websearch' and step.get('step') == 'search':
|
| 192 |
+
raw_results = step.get('raw_results')
|
| 193 |
+
if raw_results:
|
| 194 |
+
st.markdown("**Sources:**")
|
| 195 |
+
if isinstance(raw_results, dict) and raw_results.get('organic'):
|
| 196 |
+
for idx, item in enumerate(raw_results['organic'], 1):
|
| 197 |
+
link = item.get('link')
|
| 198 |
+
title = item.get('title', '')
|
| 199 |
+
st.markdown(f"{idx}. [{title}]({link})")
|
| 200 |
+
elif isinstance(raw_results, list):
|
| 201 |
+
for idx, item in enumerate(raw_results, 1):
|
| 202 |
+
link = item.get('link') if isinstance(item, dict) else item
|
| 203 |
+
title = item.get('title', '') if isinstance(item, dict) else ''
|
| 204 |
+
st.markdown(f"{idx}. [{title}]({link})")
|
| 205 |
+
break
|
| 206 |
st.session_state['should_stream'] = False
|
| 207 |
else:
|
| 208 |
with st.chat_message("Kia", avatar="π€"):
|
| 209 |
st.markdown(turn['bot'])
|
| 210 |
+
# st.markdown(turn['trace'])
|
| 211 |
+
|
| 212 |
+
if turn['trace'] is not None:
|
| 213 |
+
for step in turn['trace']:
|
| 214 |
+
if step.get('agent') == 'websearch' and step.get('step') == 'search':
|
| 215 |
+
# Extract raw_results from trace
|
| 216 |
+
raw_results = step.get('raw_results')
|
| 217 |
+
# Display citation sources for websearch agent
|
| 218 |
+
if raw_results:
|
| 219 |
+
st.markdown("**Sources:**")
|
| 220 |
+
# Prefer organic list from raw_results
|
| 221 |
+
if isinstance(raw_results, dict) and raw_results.get('organic'):
|
| 222 |
+
for idx, item in enumerate(raw_results['organic'], 1):
|
| 223 |
+
link = item.get('link')
|
| 224 |
+
title = item.get('title', '')
|
| 225 |
+
st.markdown(f"{idx}. [{title}]({link})")
|
| 226 |
+
# Fallback if raw_results itself is list
|
| 227 |
+
elif isinstance(raw_results, list):
|
| 228 |
+
for idx, item in enumerate(raw_results, 1):
|
| 229 |
+
link = item.get('link') if isinstance(item, dict) else item
|
| 230 |
+
title = item.get('title', '') if isinstance(item, dict) else ''
|
| 231 |
+
st.markdown(f"{idx}. [{title}]({link})")
|
| 232 |
+
break
|
| 233 |
+
|
| 234 |
+
|
| 235 |
# DEV_MODE trace expander (all completed turns)
|
| 236 |
if DEV_MODE and turn.get('trace'):
|
| 237 |
trace = turn.get('trace', [])
|
|
|
|
| 284 |
user_input = st.session_state['chat_history'][-1]['user']
|
| 285 |
agent_graph = build_agent_graph()
|
| 286 |
chat_history = st.session_state['chat_history'][:-1] # Exclude the waiting message
|
| 287 |
+
doc_urls = None
|
| 288 |
+
if st.session_state.get('doc_urls'):
|
| 289 |
+
doc_urls = [u.strip() for u in st.session_state['doc_urls'].split(",") if u.strip()]
|
| 290 |
+
pdf_files = st.session_state.get('pdf_files', [])
|
| 291 |
state = {
|
| 292 |
"user_input": user_input,
|
| 293 |
"chat_history": chat_history,
|
| 294 |
"enable_docsearch": st.session_state['enable_docsearch'],
|
| 295 |
"enable_websearch": st.session_state['enable_websearch'],
|
| 296 |
+
"doc_urls": doc_urls,
|
| 297 |
+
"pdf_files": pdf_files,
|
| 298 |
}
|
| 299 |
|
| 300 |
# --- Invoke agent graph (may take time) with error handling ---
|
main.py
CHANGED
|
@@ -5,7 +5,7 @@ def main():
|
|
| 5 |
# Visualize the agent graph every time it is built
|
| 6 |
visualize_agent_graph(agent_graph, as_image=True, save_to_file=True)
|
| 7 |
chat_history = []
|
| 8 |
-
print("
|
| 9 |
while True:
|
| 10 |
user_input = input("You: ")
|
| 11 |
if user_input.lower() in ["exit", "quit"]:
|
|
|
|
| 5 |
# Visualize the agent graph every time it is built
|
| 6 |
visualize_agent_graph(agent_graph, as_image=True, save_to_file=True)
|
| 7 |
chat_history = []
|
| 8 |
+
print("Kia - Your Know-It-All Assistant (type 'exit' to quit)")
|
| 9 |
while True:
|
| 10 |
user_input = input("You: ")
|
| 11 |
if user_input.lower() in ["exit", "quit"]:
|