rag_template / ui /chat_tab.py
Guilherme Favaron
Major update: Add hybrid search, reranking, multiple LLMs, and UI improvements
1b447de
"""
Aba de Chat RAG Interativo
Chat com painel lateral mostrando contextos recuperados e processo
"""
import time
import uuid
import gradio as gr
from typing import List, Dict, Any
from src.database import DatabaseManager
from src.embeddings import EmbeddingManager
from src.generation import GenerationManager
from src.query_expansion import QueryExpander
def create_chat_tab(
db_manager: DatabaseManager,
embedding_manager: EmbeddingManager,
generation_manager: GenerationManager,
session_id: str
):
"""Cria aba de chat RAG interativo"""
with gr.Tab(" Chat RAG"):
gr.Markdown("""
## Chat com Retrieval-Augmented Generation
Faça perguntas e veja o processo RAG em ação:
- Contextos são recuperados da base de conhecimento
- O LLM usa esses contextos para gerar respostas precisas
- Acompanhe cada passo do processo em tempo real
""")
with gr.Row():
with gr.Column(scale=2):
chatbot = gr.Chatbot(
label="Conversa",
height=500
)
with gr.Row():
msg_input = gr.Textbox(
label="Sua mensagem",
placeholder="Digite sua pergunta...",
lines=2,
scale=4
)
send_btn = gr.Button(" Enviar", variant="primary", scale=1)
clear_btn = gr.Button("🗑 Limpar Conversa")
with gr.Column(scale=1):
gr.Markdown("### Painel de Processo")
with gr.Accordion(" Configurações", open=True):
top_k_chat = gr.Slider(
minimum=1,
maximum=10,
value=4,
step=1,
label="Top K (chunks a recuperar)"
)
temperature_chat = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.3,
step=0.1,
label="Temperature"
)
max_tokens_chat = gr.Slider(
minimum=50,
maximum=2048,
value=512,
step=50,
label="Max Tokens"
)
use_reranking_chat = gr.Checkbox(
label="Usar Reranking",
value=True,
info="Reordena resultados com cross-encoder para melhor precisão"
)
use_query_expansion = gr.Checkbox(
label="Usar Query Expansion",
value=False,
info="Gera múltiplas variações da query para melhor cobertura"
)
expansion_method = gr.Radio(
choices=["llm", "template", "paraphrase"],
value="llm",
label="Método de Expansão",
info="LLM: melhor qualidade | Template: mais rápido | Paraphrase: balanceado",
visible=False
)
num_variations = gr.Slider(
minimum=1,
maximum=5,
value=2,
step=1,
label="Número de Variações",
info="Queries adicionais a gerar",
visible=False
)
with gr.Accordion(" Contextos Recuperados", open=True):
contexts_display = gr.Dataframe(
headers=["Rank", "Score", "Fonte", "Preview"],
label="Chunks Relevantes",
wrap=True
)
with gr.Accordion(" Impacto do Reranking", open=False):
rerank_comparison = gr.Dataframe(
headers=["Novo Rank", "Rank Original", "Score Original", "Score Rerank", "Mudança"],
label="Comparação Antes/Depois",
wrap=True
)
with gr.Accordion(" Expansão de Query", open=False):
query_variations_display = gr.Dataframe(
headers=["#", "Query", "Resultados"],
label="Queries Geradas",
wrap=True
)
with gr.Accordion(" Prompt Construído", open=False):
prompt_display = gr.Textbox(
label="Prompt enviado ao LLM",
lines=10,
max_lines=20,
interactive=False
)
with gr.Accordion(" Métricas de Performance", open=False):
metrics_display = gr.JSON(label="Tempos de Processamento")
# Estado da conversa
conversation_state = gr.State([])
# Toggle visibility dos controles de expansão
def toggle_expansion_controls(enabled):
return gr.update(visible=enabled), gr.update(visible=enabled)
use_query_expansion.change(
fn=toggle_expansion_controls,
inputs=[use_query_expansion],
outputs=[expansion_method, num_variations]
)
def respond(message, history, top_k, temperature, max_tokens, use_reranking, use_expansion, method, n_vars):
if not message or not message.strip():
return history, [], "", {}, [], []
# Métricas
total_start = time.time()
metrics = {}
query_variations_data = []
# Passo 0: Query Expansion (se ativado)
queries_to_search = [message]
if use_expansion:
expansion_start = time.time()
expander = QueryExpander(generation_manager)
queries_to_search = expander.expand_query(message, num_variations=int(n_vars), method=method)
expansion_time = (time.time() - expansion_start) * 1000
metrics['expansion_time_ms'] = expansion_time
metrics['num_queries'] = len(queries_to_search)
# Passo 1: Retrieve
retrieve_start = time.time()
# Se usar expansão, busca com cada query e combina resultados
if use_expansion and len(queries_to_search) > 1:
all_contexts = []
seen_ids = set()
for i, query in enumerate(queries_to_search, 1):
query_embedding = embedding_manager.encode_single(query, normalize=True)
retrieve_k = int(top_k) * 2 if use_reranking else int(top_k)
query_contexts = db_manager.search_similar(query_embedding, k=retrieve_k, session_id=session_id)
# Adiciona à lista de variações para display
query_variations_data.append([i, query, len(query_contexts)])
# Combina resultados evitando duplicatas
for ctx in query_contexts:
if ctx['id'] not in seen_ids:
all_contexts.append(ctx)
seen_ids.add(ctx['id'])
# Ordena por score e pega top-K * 2
all_contexts.sort(key=lambda x: x.get('score', 0), reverse=True)
retrieve_k = int(top_k) * 2 if use_reranking else int(top_k)
contexts = all_contexts[:retrieve_k]
else:
# Busca normal com query única
query_embedding = embedding_manager.encode_single(message, normalize=True)
retrieve_k = int(top_k) * 2 if use_reranking else int(top_k)
contexts = db_manager.search_similar(query_embedding, k=retrieve_k, session_id=session_id)
retrieve_time = (time.time() - retrieve_start) * 1000
metrics['retrieval_time_ms'] = retrieve_time
# Guarda contextos originais para comparação
original_contexts = contexts.copy() if use_reranking else []
# Passo 1.5: Reranking (se ativado)
rerank_comparison_data = []
if use_reranking and contexts:
from src.reranking import Reranker
rerank_start = time.time()
reranker = Reranker()
contexts = reranker.rerank(message, contexts, top_k=int(top_k))
rerank_time = (time.time() - rerank_start) * 1000
metrics['reranking_time_ms'] = rerank_time
# Gera dados de comparação
for i, ctx in enumerate(contexts, 1):
# Encontra posição original
original_pos = next((j+1 for j, c in enumerate(original_contexts) if c['id'] == ctx['id']), -1)
position_change = original_pos - i if original_pos != -1 else 0
rerank_comparison_data.append([
i,
original_pos,
f"{ctx.get('original_score', 0.0):.4f}",
f"{ctx.get('rerank_score', 0.0):.4f}",
f"+{position_change}" if position_change > 0 else str(position_change)
])
# Prepara display de contextos
contexts_table = []
for i, ctx in enumerate(contexts, 1):
preview = ctx['content'][:60] + "..." if len(ctx['content']) > 60 else ctx['content']
score = ctx.get('rerank_score', ctx.get('score', 0.0))
contexts_table.append([
i,
f"{score:.4f}",
ctx['title'],
preview
])
# Passo 2: Build prompt
prompt_start = time.time()
prompt = generation_manager.build_rag_prompt(message, contexts)
prompt_time = (time.time() - prompt_start) * 1000
metrics['prompt_build_time_ms'] = prompt_time
# Passo 3: Generate
generate_start = time.time()
response = generation_manager.generate(
prompt,
temperature=float(temperature),
max_tokens=int(max_tokens)
)
generate_time = (time.time() - generate_start) * 1000
metrics['generation_time_ms'] = generate_time
# Adiciona fontes à resposta
response_with_sources = response + "\n" + generation_manager.format_sources(contexts)
# Tempo total
total_time = (time.time() - total_start) * 1000
metrics['total_time_ms'] = total_time
metrics['num_contexts'] = len(contexts)
metrics['top_k'] = top_k
metrics['temperature'] = temperature
metrics['max_tokens'] = max_tokens
# Salva no banco
chat_id = db_manager.get_chat_id(session_id)
if chat_id:
db_manager.save_message(chat_id, "user", message)
db_manager.save_message(chat_id, "assistant", response_with_sources)
db_manager.save_query_metric(
session_id,
message,
len(contexts),
retrieve_time,
generate_time,
total_time,
int(top_k)
)
# Atualiza histórico
new_history = history + [
{"role": "user", "content": message},
{"role": "assistant", "content": response_with_sources}
]
return new_history, contexts_table, prompt, metrics, rerank_comparison_data, query_variations_data
def clear_conversation():
return [], [], "", {}, [], []
# Conecta eventos
send_btn.click(
fn=respond,
inputs=[msg_input, chatbot, top_k_chat, temperature_chat, max_tokens_chat, use_reranking_chat, use_query_expansion, expansion_method, num_variations],
outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display]
).then(
lambda: "",
outputs=[msg_input]
)
msg_input.submit(
fn=respond,
inputs=[msg_input, chatbot, top_k_chat, temperature_chat, max_tokens_chat, use_reranking_chat, use_query_expansion, expansion_method, num_variations],
outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display]
).then(
lambda: "",
outputs=[msg_input]
)
clear_btn.click(
fn=clear_conversation,
outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display]
)
return {
"chatbot": chatbot,
"msg_input": msg_input,
"send_btn": send_btn
}