""" Aba de Chat RAG Interativo Chat com painel lateral mostrando contextos recuperados e processo """ import time import uuid import gradio as gr from typing import List, Dict, Any from src.database import DatabaseManager from src.embeddings import EmbeddingManager from src.generation import GenerationManager from src.query_expansion import QueryExpander def create_chat_tab( db_manager: DatabaseManager, embedding_manager: EmbeddingManager, generation_manager: GenerationManager, session_id: str ): """Cria aba de chat RAG interativo""" with gr.Tab(" Chat RAG"): gr.Markdown(""" ## Chat com Retrieval-Augmented Generation Faça perguntas e veja o processo RAG em ação: - Contextos são recuperados da base de conhecimento - O LLM usa esses contextos para gerar respostas precisas - Acompanhe cada passo do processo em tempo real """) with gr.Row(): with gr.Column(scale=2): chatbot = gr.Chatbot( label="Conversa", height=500 ) with gr.Row(): msg_input = gr.Textbox( label="Sua mensagem", placeholder="Digite sua pergunta...", lines=2, scale=4 ) send_btn = gr.Button(" Enviar", variant="primary", scale=1) clear_btn = gr.Button("🗑 Limpar Conversa") with gr.Column(scale=1): gr.Markdown("### Painel de Processo") with gr.Accordion(" Configurações", open=True): top_k_chat = gr.Slider( minimum=1, maximum=10, value=4, step=1, label="Top K (chunks a recuperar)" ) temperature_chat = gr.Slider( minimum=0.0, maximum=2.0, value=0.3, step=0.1, label="Temperature" ) max_tokens_chat = gr.Slider( minimum=50, maximum=2048, value=512, step=50, label="Max Tokens" ) use_reranking_chat = gr.Checkbox( label="Usar Reranking", value=True, info="Reordena resultados com cross-encoder para melhor precisão" ) use_query_expansion = gr.Checkbox( label="Usar Query Expansion", value=False, info="Gera múltiplas variações da query para melhor cobertura" ) expansion_method = gr.Radio( choices=["llm", "template", "paraphrase"], value="llm", label="Método de Expansão", info="LLM: melhor qualidade | Template: mais rápido | Paraphrase: balanceado", visible=False ) num_variations = gr.Slider( minimum=1, maximum=5, value=2, step=1, label="Número de Variações", info="Queries adicionais a gerar", visible=False ) with gr.Accordion(" Contextos Recuperados", open=True): contexts_display = gr.Dataframe( headers=["Rank", "Score", "Fonte", "Preview"], label="Chunks Relevantes", wrap=True ) with gr.Accordion(" Impacto do Reranking", open=False): rerank_comparison = gr.Dataframe( headers=["Novo Rank", "Rank Original", "Score Original", "Score Rerank", "Mudança"], label="Comparação Antes/Depois", wrap=True ) with gr.Accordion(" Expansão de Query", open=False): query_variations_display = gr.Dataframe( headers=["#", "Query", "Resultados"], label="Queries Geradas", wrap=True ) with gr.Accordion(" Prompt Construído", open=False): prompt_display = gr.Textbox( label="Prompt enviado ao LLM", lines=10, max_lines=20, interactive=False ) with gr.Accordion(" Métricas de Performance", open=False): metrics_display = gr.JSON(label="Tempos de Processamento") # Estado da conversa conversation_state = gr.State([]) # Toggle visibility dos controles de expansão def toggle_expansion_controls(enabled): return gr.update(visible=enabled), gr.update(visible=enabled) use_query_expansion.change( fn=toggle_expansion_controls, inputs=[use_query_expansion], outputs=[expansion_method, num_variations] ) def respond(message, history, top_k, temperature, max_tokens, use_reranking, use_expansion, method, n_vars): if not message or not message.strip(): return history, [], "", {}, [], [] # Métricas total_start = time.time() metrics = {} query_variations_data = [] # Passo 0: Query Expansion (se ativado) queries_to_search = [message] if use_expansion: expansion_start = time.time() expander = QueryExpander(generation_manager) queries_to_search = expander.expand_query(message, num_variations=int(n_vars), method=method) expansion_time = (time.time() - expansion_start) * 1000 metrics['expansion_time_ms'] = expansion_time metrics['num_queries'] = len(queries_to_search) # Passo 1: Retrieve retrieve_start = time.time() # Se usar expansão, busca com cada query e combina resultados if use_expansion and len(queries_to_search) > 1: all_contexts = [] seen_ids = set() for i, query in enumerate(queries_to_search, 1): query_embedding = embedding_manager.encode_single(query, normalize=True) retrieve_k = int(top_k) * 2 if use_reranking else int(top_k) query_contexts = db_manager.search_similar(query_embedding, k=retrieve_k, session_id=session_id) # Adiciona à lista de variações para display query_variations_data.append([i, query, len(query_contexts)]) # Combina resultados evitando duplicatas for ctx in query_contexts: if ctx['id'] not in seen_ids: all_contexts.append(ctx) seen_ids.add(ctx['id']) # Ordena por score e pega top-K * 2 all_contexts.sort(key=lambda x: x.get('score', 0), reverse=True) retrieve_k = int(top_k) * 2 if use_reranking else int(top_k) contexts = all_contexts[:retrieve_k] else: # Busca normal com query única query_embedding = embedding_manager.encode_single(message, normalize=True) retrieve_k = int(top_k) * 2 if use_reranking else int(top_k) contexts = db_manager.search_similar(query_embedding, k=retrieve_k, session_id=session_id) retrieve_time = (time.time() - retrieve_start) * 1000 metrics['retrieval_time_ms'] = retrieve_time # Guarda contextos originais para comparação original_contexts = contexts.copy() if use_reranking else [] # Passo 1.5: Reranking (se ativado) rerank_comparison_data = [] if use_reranking and contexts: from src.reranking import Reranker rerank_start = time.time() reranker = Reranker() contexts = reranker.rerank(message, contexts, top_k=int(top_k)) rerank_time = (time.time() - rerank_start) * 1000 metrics['reranking_time_ms'] = rerank_time # Gera dados de comparação for i, ctx in enumerate(contexts, 1): # Encontra posição original original_pos = next((j+1 for j, c in enumerate(original_contexts) if c['id'] == ctx['id']), -1) position_change = original_pos - i if original_pos != -1 else 0 rerank_comparison_data.append([ i, original_pos, f"{ctx.get('original_score', 0.0):.4f}", f"{ctx.get('rerank_score', 0.0):.4f}", f"+{position_change}" if position_change > 0 else str(position_change) ]) # Prepara display de contextos contexts_table = [] for i, ctx in enumerate(contexts, 1): preview = ctx['content'][:60] + "..." if len(ctx['content']) > 60 else ctx['content'] score = ctx.get('rerank_score', ctx.get('score', 0.0)) contexts_table.append([ i, f"{score:.4f}", ctx['title'], preview ]) # Passo 2: Build prompt prompt_start = time.time() prompt = generation_manager.build_rag_prompt(message, contexts) prompt_time = (time.time() - prompt_start) * 1000 metrics['prompt_build_time_ms'] = prompt_time # Passo 3: Generate generate_start = time.time() response = generation_manager.generate( prompt, temperature=float(temperature), max_tokens=int(max_tokens) ) generate_time = (time.time() - generate_start) * 1000 metrics['generation_time_ms'] = generate_time # Adiciona fontes à resposta response_with_sources = response + "\n" + generation_manager.format_sources(contexts) # Tempo total total_time = (time.time() - total_start) * 1000 metrics['total_time_ms'] = total_time metrics['num_contexts'] = len(contexts) metrics['top_k'] = top_k metrics['temperature'] = temperature metrics['max_tokens'] = max_tokens # Salva no banco chat_id = db_manager.get_chat_id(session_id) if chat_id: db_manager.save_message(chat_id, "user", message) db_manager.save_message(chat_id, "assistant", response_with_sources) db_manager.save_query_metric( session_id, message, len(contexts), retrieve_time, generate_time, total_time, int(top_k) ) # Atualiza histórico new_history = history + [ {"role": "user", "content": message}, {"role": "assistant", "content": response_with_sources} ] return new_history, contexts_table, prompt, metrics, rerank_comparison_data, query_variations_data def clear_conversation(): return [], [], "", {}, [], [] # Conecta eventos send_btn.click( fn=respond, inputs=[msg_input, chatbot, top_k_chat, temperature_chat, max_tokens_chat, use_reranking_chat, use_query_expansion, expansion_method, num_variations], outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display] ).then( lambda: "", outputs=[msg_input] ) msg_input.submit( fn=respond, inputs=[msg_input, chatbot, top_k_chat, temperature_chat, max_tokens_chat, use_reranking_chat, use_query_expansion, expansion_method, num_variations], outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display] ).then( lambda: "", outputs=[msg_input] ) clear_btn.click( fn=clear_conversation, outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display] ) return { "chatbot": chatbot, "msg_input": msg_input, "send_btn": send_btn }