File size: 8,197 Bytes
ea8a378 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import gradio as gr
import constants as cte
from rag_smolagent import SmolAgent, MessageRole
from googlesearch import search
from typing import List, Tuple, Optional, Any
class ChatState:
def __init__(self) -> None:
self.sources: str = "There are no sources yet."
self.extra_info: str = ""
self.age: str = ""
self.residence: str = ""
self.response_type: str = "Concise" # Default response type
self.seen_messages: List[Any] = []
self.cot_steps: str = ""
state = ChatState()
# Initialize the SmolAgent with the specified data path
agent = SmolAgent()
def get_first_result(query: str) -> Optional[str]:
"""
Searches Google using a processed version of the query string and returns the URL of the first result.
"""
parts = query.split(",")
if len(parts) > 3:
msg = parts[-3] + parts[-2] + " manual de la renta 2023" # search only in 2023's Manual
else:
msg = parts[-2] + " manual de la renta 2023"
try:
# Perform the search (updated configuration)
results = search(
msg,
num_results=1, # Desired number of results
lang="es", # Search language (optional)
)
return next(results) if results else None
except Exception as e:
print(f"Error: {e}")
return None
def process_sources(sources: str) -> str:
"""
Processes a string with multiple sources and generates Markdown links using the first search result for each source.
"""
lines: List[str] = sources.strip().split('\n')
results: List[str] = []
for line in lines:
clean_line = line.strip()
if not clean_line:
continue
url = get_first_result(clean_line)
url_str = url + "\n" if url else "Sin resultados"
results.append(f"[{clean_line}]({url_str})")
return '\n'.join(results)
def update_age(value: str) -> None:
"""Updates the age in the state and refreshes the extra info."""
state.age = value if value else ""
update_extra_info()
def update_residence(value: str) -> None:
"""Updates the residence in the state and refreshes the extra info."""
state.residence = value if value else ""
update_extra_info()
def update_extra_info() -> None:
"""Updates the extra info by combining age and residence."""
if state.age and state.residence:
state.extra_info = f"Tengo {state.age} años y resido en {state.residence}."
else:
state.extra_info = ""
def update_response_type(value: str) -> None:
"""Updates the response type in the state."""
state.response_type = value
def get_prompt(query: str, extra_info: str, response_type: str) -> str:
"""Injects the user's query, extra info, and response type into the prompt template."""
return cte.PROMPT_TEMPLATE.format(query=query, extra_info=extra_info, response_type=response_type)
def chatbot_response(message: str, history: List[Tuple[str, str]]) -> str:
"""Calls the agent, separates the main response, sources, and extracts CoT steps."""
response: str = agent(get_prompt(message, state.extra_info, state.response_type))
# Extract sources if present
if "Fuentes:" in response:
state.sources = response.split("Fuentes:")[1].strip()
response = response.split("Fuentes:")[0].strip()
else:
state.sources = "No sources were used for the generation of this message"
# Extract the main response
if "Respuesta:" in response:
answer: str = response.split("Respuesta:")[1].strip()
else:
answer = response
# Extract Chain of Thought (CoT) steps
cot_messages: List[Any] = agent.agent.write_memory_to_messages()
state.cot_steps = ""
cot_step_counter: int = 0
for message in cot_messages:
if (message.get('role') == MessageRole.ASSISTANT) & (message not in state.seen_messages):
for content in message.get('content', []):
if content.get('type') == 'text' and 'Thought:' in content.get('text', ''):
state.cot_steps += f"\n============================== Step {cot_step_counter} ==============================\n"
cot_step_counter += 1
state.cot_steps += (content['text'] + "\n")
state.seen_messages.append(message)
return answer.strip()
def respond(message, chat_history):
"""Handles user input, generates a response, and updates the chat history."""
chat_history.append((message, ""))
yield "", chat_history
chat_history[-1] = (message, "⏳ Procesando respuesta...")
yield "", chat_history
bot_message: str = chatbot_response(message, chat_history)
# Append sources and Chain of Thought as accordions if available
if state.sources != "No sources were used for the generation of this message":
try:
all_url = process_sources(state.sources)
except Exception:
all_url = state.sources
accordion_sources: str = (
"\n\n<details>\n"
" <summary style='color: #0090ff; cursor: pointer;'>Sources</summary>\n\n"
f" {all_url}\n\n"
"</details>"
)
bot_message += accordion_sources
if state.cot_steps:
accordion_cot: str = (
"\n\n<details>\n"
" <summary style='color: #0090ff; cursor: pointer;'>Chain of Thought</summary>\n\n"
f" {state.cot_steps}\n\n"
"</details>"
)
bot_message += accordion_cot
chat_history[-1] = (message, bot_message)
yield "", chat_history
def clear_history(chat_history: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
"""Clears the chat history and the agent's internal memory."""
chat_history.clear()
global agent
agent = SmolAgent()
state.sources = "There are no sources yet."
state.seen_messages.clear()
return chat_history
# Gradio UI setup
custom_theme = gr.Theme.load("/teamspace/studios/this_studio/AgenticRAG/theme.json")
with gr.Blocks(
fill_height=True,
fill_width=True,
theme=custom_theme
) as demo:
## Demo Sidebar
with gr.Sidebar():
gr.Markdown("## User Information")
age_input = gr.Textbox(label="Age", placeholder="Unknown", value=state.age)
residence_input = gr.Textbox(label="Residence", placeholder="Unknown", value=state.residence)
# Update values when fields change
age_input.change(update_age, inputs=age_input)
residence_input.change(update_residence, inputs=residence_input)
gr.Markdown("## Answer customization")
# Dropdown for response type
response_type_dropdown = gr.Dropdown(
label="Response Type",
choices=["Concise", "Detailed"],
value="Concise"
)
response_type_dropdown.change(update_response_type, inputs=response_type_dropdown)
# Clear History button
clear_button = gr.Button("Clear History", variant="secondary")
web_link = gr.HTML("<a href='https://sede.agenciatributaria.gob.es/Sede/Ayuda/23Manual/100.html' target='_blank'>Abrir página web</a>")
gr.Markdown("# RAG ChatBot Agent")
with gr.Column(scale=1):
chatbot = gr.Chatbot(
label="Chat",
bubble_full_width=True,
scale=1,
avatar_images=(None, "https://logosandtypes.com/wp-content/uploads/2022/03/cognizant.svg"),
group_consecutive_messages=False
)
msg = gr.Textbox(
show_label=False,
placeholder="What do you want to know?",
submit_btn=True
)
# Submit and clear history functionality
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear_button.click(clear_history, inputs=chatbot, outputs=chatbot)
demo.launch(share=True) |