|
|
import gradio as gr
|
|
|
import constants as cte
|
|
|
from rag_smolagent import SmolAgent, MessageRole
|
|
|
from googlesearch import search
|
|
|
from typing import List, Tuple, Optional, Any
|
|
|
|
|
|
class ChatState:
|
|
|
def __init__(self) -> None:
|
|
|
self.sources: str = "There are no sources yet."
|
|
|
self.extra_info: str = ""
|
|
|
self.age: str = ""
|
|
|
self.residence: str = ""
|
|
|
self.response_type: str = "Concise"
|
|
|
self.seen_messages: List[Any] = []
|
|
|
self.cot_steps: str = ""
|
|
|
|
|
|
state = ChatState()
|
|
|
|
|
|
|
|
|
agent = SmolAgent()
|
|
|
|
|
|
def get_first_result(query: str) -> Optional[str]:
|
|
|
"""
|
|
|
Searches Google using a processed version of the query string and returns the URL of the first result.
|
|
|
"""
|
|
|
parts = query.split(",")
|
|
|
if len(parts) > 3:
|
|
|
msg = parts[-3] + parts[-2] + " manual de la renta 2023"
|
|
|
else:
|
|
|
msg = parts[-2] + " manual de la renta 2023"
|
|
|
|
|
|
try:
|
|
|
|
|
|
results = search(
|
|
|
msg,
|
|
|
num_results=1,
|
|
|
lang="es",
|
|
|
)
|
|
|
|
|
|
return next(results) if results else None
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
return None
|
|
|
|
|
|
def process_sources(sources: str) -> str:
|
|
|
"""
|
|
|
Processes a string with multiple sources and generates Markdown links using the first search result for each source.
|
|
|
"""
|
|
|
lines: List[str] = sources.strip().split('\n')
|
|
|
results: List[str] = []
|
|
|
|
|
|
for line in lines:
|
|
|
clean_line = line.strip()
|
|
|
if not clean_line:
|
|
|
continue
|
|
|
url = get_first_result(clean_line)
|
|
|
url_str = url + "\n" if url else "Sin resultados"
|
|
|
results.append(f"[{clean_line}]({url_str})")
|
|
|
|
|
|
return '\n'.join(results)
|
|
|
|
|
|
def update_age(value: str) -> None:
|
|
|
"""Updates the age in the state and refreshes the extra info."""
|
|
|
state.age = value if value else ""
|
|
|
update_extra_info()
|
|
|
|
|
|
def update_residence(value: str) -> None:
|
|
|
"""Updates the residence in the state and refreshes the extra info."""
|
|
|
state.residence = value if value else ""
|
|
|
update_extra_info()
|
|
|
|
|
|
def update_extra_info() -> None:
|
|
|
"""Updates the extra info by combining age and residence."""
|
|
|
if state.age and state.residence:
|
|
|
state.extra_info = f"Tengo {state.age} años y resido en {state.residence}."
|
|
|
else:
|
|
|
state.extra_info = ""
|
|
|
|
|
|
def update_response_type(value: str) -> None:
|
|
|
"""Updates the response type in the state."""
|
|
|
state.response_type = value
|
|
|
|
|
|
def get_prompt(query: str, extra_info: str, response_type: str) -> str:
|
|
|
"""Injects the user's query, extra info, and response type into the prompt template."""
|
|
|
return cte.PROMPT_TEMPLATE.format(query=query, extra_info=extra_info, response_type=response_type)
|
|
|
|
|
|
def chatbot_response(message: str, history: List[Tuple[str, str]]) -> str:
|
|
|
"""Calls the agent, separates the main response, sources, and extracts CoT steps."""
|
|
|
response: str = agent(get_prompt(message, state.extra_info, state.response_type))
|
|
|
|
|
|
|
|
|
if "Fuentes:" in response:
|
|
|
state.sources = response.split("Fuentes:")[1].strip()
|
|
|
response = response.split("Fuentes:")[0].strip()
|
|
|
else:
|
|
|
state.sources = "No sources were used for the generation of this message"
|
|
|
|
|
|
|
|
|
if "Respuesta:" in response:
|
|
|
answer: str = response.split("Respuesta:")[1].strip()
|
|
|
else:
|
|
|
answer = response
|
|
|
|
|
|
|
|
|
cot_messages: List[Any] = agent.agent.write_memory_to_messages()
|
|
|
state.cot_steps = ""
|
|
|
cot_step_counter: int = 0
|
|
|
|
|
|
for message in cot_messages:
|
|
|
if (message.get('role') == MessageRole.ASSISTANT) & (message not in state.seen_messages):
|
|
|
for content in message.get('content', []):
|
|
|
if content.get('type') == 'text' and 'Thought:' in content.get('text', ''):
|
|
|
state.cot_steps += f"\n============================== Step {cot_step_counter} ==============================\n"
|
|
|
cot_step_counter += 1
|
|
|
state.cot_steps += (content['text'] + "\n")
|
|
|
state.seen_messages.append(message)
|
|
|
|
|
|
return answer.strip()
|
|
|
|
|
|
def respond(message, chat_history):
|
|
|
"""Handles user input, generates a response, and updates the chat history."""
|
|
|
chat_history.append((message, ""))
|
|
|
yield "", chat_history
|
|
|
|
|
|
chat_history[-1] = (message, "⏳ Procesando respuesta...")
|
|
|
yield "", chat_history
|
|
|
|
|
|
bot_message: str = chatbot_response(message, chat_history)
|
|
|
|
|
|
|
|
|
if state.sources != "No sources were used for the generation of this message":
|
|
|
try:
|
|
|
all_url = process_sources(state.sources)
|
|
|
except Exception:
|
|
|
all_url = state.sources
|
|
|
accordion_sources: str = (
|
|
|
"\n\n<details>\n"
|
|
|
" <summary style='color: #0090ff; cursor: pointer;'>Sources</summary>\n\n"
|
|
|
f" {all_url}\n\n"
|
|
|
"</details>"
|
|
|
)
|
|
|
bot_message += accordion_sources
|
|
|
|
|
|
if state.cot_steps:
|
|
|
accordion_cot: str = (
|
|
|
"\n\n<details>\n"
|
|
|
" <summary style='color: #0090ff; cursor: pointer;'>Chain of Thought</summary>\n\n"
|
|
|
f" {state.cot_steps}\n\n"
|
|
|
"</details>"
|
|
|
)
|
|
|
bot_message += accordion_cot
|
|
|
|
|
|
chat_history[-1] = (message, bot_message)
|
|
|
yield "", chat_history
|
|
|
|
|
|
def clear_history(chat_history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
|
|
|
"""Clears the chat history and the agent's internal memory."""
|
|
|
chat_history.clear()
|
|
|
global agent
|
|
|
agent = SmolAgent()
|
|
|
state.sources = "There are no sources yet."
|
|
|
state.seen_messages.clear()
|
|
|
return chat_history
|
|
|
|
|
|
|
|
|
custom_theme = gr.Theme.load("/teamspace/studios/this_studio/AgenticRAG/theme.json")
|
|
|
|
|
|
with gr.Blocks(
|
|
|
fill_height=True,
|
|
|
fill_width=True,
|
|
|
theme=custom_theme
|
|
|
) as demo:
|
|
|
|
|
|
|
|
|
with gr.Sidebar():
|
|
|
gr.Markdown("## User Information")
|
|
|
age_input = gr.Textbox(label="Age", placeholder="Unknown", value=state.age)
|
|
|
residence_input = gr.Textbox(label="Residence", placeholder="Unknown", value=state.residence)
|
|
|
|
|
|
|
|
|
age_input.change(update_age, inputs=age_input)
|
|
|
residence_input.change(update_residence, inputs=residence_input)
|
|
|
|
|
|
gr.Markdown("## Answer customization")
|
|
|
|
|
|
response_type_dropdown = gr.Dropdown(
|
|
|
label="Response Type",
|
|
|
choices=["Concise", "Detailed"],
|
|
|
value="Concise"
|
|
|
)
|
|
|
response_type_dropdown.change(update_response_type, inputs=response_type_dropdown)
|
|
|
|
|
|
|
|
|
clear_button = gr.Button("Clear History", variant="secondary")
|
|
|
|
|
|
web_link = gr.HTML("<a href='https://sede.agenciatributaria.gob.es/Sede/Ayuda/23Manual/100.html' target='_blank'>Abrir página web</a>")
|
|
|
|
|
|
gr.Markdown("# RAG ChatBot Agent")
|
|
|
|
|
|
with gr.Column(scale=1):
|
|
|
chatbot = gr.Chatbot(
|
|
|
label="Chat",
|
|
|
bubble_full_width=True,
|
|
|
scale=1,
|
|
|
avatar_images=(None, "https://logosandtypes.com/wp-content/uploads/2022/03/cognizant.svg"),
|
|
|
group_consecutive_messages=False
|
|
|
)
|
|
|
|
|
|
msg = gr.Textbox(
|
|
|
show_label=False,
|
|
|
placeholder="What do you want to know?",
|
|
|
submit_btn=True
|
|
|
)
|
|
|
|
|
|
|
|
|
msg.submit(respond, [msg, chatbot], [msg, chatbot])
|
|
|
clear_button.click(clear_history, inputs=chatbot, outputs=chatbot)
|
|
|
|
|
|
demo.launch(share=True) |