Spaces:
Sleeping
Sleeping
Guilherme Favaron
Major update: Add hybrid search, reranking, multiple LLMs, and UI improvements
1b447de
| """ | |
| Aba de Chat RAG Interativo | |
| Chat com painel lateral mostrando contextos recuperados e processo | |
| """ | |
| import time | |
| import uuid | |
| import gradio as gr | |
| from typing import List, Dict, Any | |
| from src.database import DatabaseManager | |
| from src.embeddings import EmbeddingManager | |
| from src.generation import GenerationManager | |
| from src.query_expansion import QueryExpander | |
| def create_chat_tab( | |
| db_manager: DatabaseManager, | |
| embedding_manager: EmbeddingManager, | |
| generation_manager: GenerationManager, | |
| session_id: str | |
| ): | |
| """Cria aba de chat RAG interativo""" | |
| with gr.Tab(" Chat RAG"): | |
| gr.Markdown(""" | |
| ## Chat com Retrieval-Augmented Generation | |
| Faça perguntas e veja o processo RAG em ação: | |
| - Contextos são recuperados da base de conhecimento | |
| - O LLM usa esses contextos para gerar respostas precisas | |
| - Acompanhe cada passo do processo em tempo real | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| chatbot = gr.Chatbot( | |
| label="Conversa", | |
| height=500 | |
| ) | |
| with gr.Row(): | |
| msg_input = gr.Textbox( | |
| label="Sua mensagem", | |
| placeholder="Digite sua pergunta...", | |
| lines=2, | |
| scale=4 | |
| ) | |
| send_btn = gr.Button(" Enviar", variant="primary", scale=1) | |
| clear_btn = gr.Button("🗑 Limpar Conversa") | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Painel de Processo") | |
| with gr.Accordion(" Configurações", open=True): | |
| top_k_chat = gr.Slider( | |
| minimum=1, | |
| maximum=10, | |
| value=4, | |
| step=1, | |
| label="Top K (chunks a recuperar)" | |
| ) | |
| temperature_chat = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.3, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| max_tokens_chat = gr.Slider( | |
| minimum=50, | |
| maximum=2048, | |
| value=512, | |
| step=50, | |
| label="Max Tokens" | |
| ) | |
| use_reranking_chat = gr.Checkbox( | |
| label="Usar Reranking", | |
| value=True, | |
| info="Reordena resultados com cross-encoder para melhor precisão" | |
| ) | |
| use_query_expansion = gr.Checkbox( | |
| label="Usar Query Expansion", | |
| value=False, | |
| info="Gera múltiplas variações da query para melhor cobertura" | |
| ) | |
| expansion_method = gr.Radio( | |
| choices=["llm", "template", "paraphrase"], | |
| value="llm", | |
| label="Método de Expansão", | |
| info="LLM: melhor qualidade | Template: mais rápido | Paraphrase: balanceado", | |
| visible=False | |
| ) | |
| num_variations = gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| value=2, | |
| step=1, | |
| label="Número de Variações", | |
| info="Queries adicionais a gerar", | |
| visible=False | |
| ) | |
| with gr.Accordion(" Contextos Recuperados", open=True): | |
| contexts_display = gr.Dataframe( | |
| headers=["Rank", "Score", "Fonte", "Preview"], | |
| label="Chunks Relevantes", | |
| wrap=True | |
| ) | |
| with gr.Accordion(" Impacto do Reranking", open=False): | |
| rerank_comparison = gr.Dataframe( | |
| headers=["Novo Rank", "Rank Original", "Score Original", "Score Rerank", "Mudança"], | |
| label="Comparação Antes/Depois", | |
| wrap=True | |
| ) | |
| with gr.Accordion(" Expansão de Query", open=False): | |
| query_variations_display = gr.Dataframe( | |
| headers=["#", "Query", "Resultados"], | |
| label="Queries Geradas", | |
| wrap=True | |
| ) | |
| with gr.Accordion(" Prompt Construído", open=False): | |
| prompt_display = gr.Textbox( | |
| label="Prompt enviado ao LLM", | |
| lines=10, | |
| max_lines=20, | |
| interactive=False | |
| ) | |
| with gr.Accordion(" Métricas de Performance", open=False): | |
| metrics_display = gr.JSON(label="Tempos de Processamento") | |
| # Estado da conversa | |
| conversation_state = gr.State([]) | |
| # Toggle visibility dos controles de expansão | |
| def toggle_expansion_controls(enabled): | |
| return gr.update(visible=enabled), gr.update(visible=enabled) | |
| use_query_expansion.change( | |
| fn=toggle_expansion_controls, | |
| inputs=[use_query_expansion], | |
| outputs=[expansion_method, num_variations] | |
| ) | |
| def respond(message, history, top_k, temperature, max_tokens, use_reranking, use_expansion, method, n_vars): | |
| if not message or not message.strip(): | |
| return history, [], "", {}, [], [] | |
| # Métricas | |
| total_start = time.time() | |
| metrics = {} | |
| query_variations_data = [] | |
| # Passo 0: Query Expansion (se ativado) | |
| queries_to_search = [message] | |
| if use_expansion: | |
| expansion_start = time.time() | |
| expander = QueryExpander(generation_manager) | |
| queries_to_search = expander.expand_query(message, num_variations=int(n_vars), method=method) | |
| expansion_time = (time.time() - expansion_start) * 1000 | |
| metrics['expansion_time_ms'] = expansion_time | |
| metrics['num_queries'] = len(queries_to_search) | |
| # Passo 1: Retrieve | |
| retrieve_start = time.time() | |
| # Se usar expansão, busca com cada query e combina resultados | |
| if use_expansion and len(queries_to_search) > 1: | |
| all_contexts = [] | |
| seen_ids = set() | |
| for i, query in enumerate(queries_to_search, 1): | |
| query_embedding = embedding_manager.encode_single(query, normalize=True) | |
| retrieve_k = int(top_k) * 2 if use_reranking else int(top_k) | |
| query_contexts = db_manager.search_similar(query_embedding, k=retrieve_k, session_id=session_id) | |
| # Adiciona à lista de variações para display | |
| query_variations_data.append([i, query, len(query_contexts)]) | |
| # Combina resultados evitando duplicatas | |
| for ctx in query_contexts: | |
| if ctx['id'] not in seen_ids: | |
| all_contexts.append(ctx) | |
| seen_ids.add(ctx['id']) | |
| # Ordena por score e pega top-K * 2 | |
| all_contexts.sort(key=lambda x: x.get('score', 0), reverse=True) | |
| retrieve_k = int(top_k) * 2 if use_reranking else int(top_k) | |
| contexts = all_contexts[:retrieve_k] | |
| else: | |
| # Busca normal com query única | |
| query_embedding = embedding_manager.encode_single(message, normalize=True) | |
| retrieve_k = int(top_k) * 2 if use_reranking else int(top_k) | |
| contexts = db_manager.search_similar(query_embedding, k=retrieve_k, session_id=session_id) | |
| retrieve_time = (time.time() - retrieve_start) * 1000 | |
| metrics['retrieval_time_ms'] = retrieve_time | |
| # Guarda contextos originais para comparação | |
| original_contexts = contexts.copy() if use_reranking else [] | |
| # Passo 1.5: Reranking (se ativado) | |
| rerank_comparison_data = [] | |
| if use_reranking and contexts: | |
| from src.reranking import Reranker | |
| rerank_start = time.time() | |
| reranker = Reranker() | |
| contexts = reranker.rerank(message, contexts, top_k=int(top_k)) | |
| rerank_time = (time.time() - rerank_start) * 1000 | |
| metrics['reranking_time_ms'] = rerank_time | |
| # Gera dados de comparação | |
| for i, ctx in enumerate(contexts, 1): | |
| # Encontra posição original | |
| original_pos = next((j+1 for j, c in enumerate(original_contexts) if c['id'] == ctx['id']), -1) | |
| position_change = original_pos - i if original_pos != -1 else 0 | |
| rerank_comparison_data.append([ | |
| i, | |
| original_pos, | |
| f"{ctx.get('original_score', 0.0):.4f}", | |
| f"{ctx.get('rerank_score', 0.0):.4f}", | |
| f"+{position_change}" if position_change > 0 else str(position_change) | |
| ]) | |
| # Prepara display de contextos | |
| contexts_table = [] | |
| for i, ctx in enumerate(contexts, 1): | |
| preview = ctx['content'][:60] + "..." if len(ctx['content']) > 60 else ctx['content'] | |
| score = ctx.get('rerank_score', ctx.get('score', 0.0)) | |
| contexts_table.append([ | |
| i, | |
| f"{score:.4f}", | |
| ctx['title'], | |
| preview | |
| ]) | |
| # Passo 2: Build prompt | |
| prompt_start = time.time() | |
| prompt = generation_manager.build_rag_prompt(message, contexts) | |
| prompt_time = (time.time() - prompt_start) * 1000 | |
| metrics['prompt_build_time_ms'] = prompt_time | |
| # Passo 3: Generate | |
| generate_start = time.time() | |
| response = generation_manager.generate( | |
| prompt, | |
| temperature=float(temperature), | |
| max_tokens=int(max_tokens) | |
| ) | |
| generate_time = (time.time() - generate_start) * 1000 | |
| metrics['generation_time_ms'] = generate_time | |
| # Adiciona fontes à resposta | |
| response_with_sources = response + "\n" + generation_manager.format_sources(contexts) | |
| # Tempo total | |
| total_time = (time.time() - total_start) * 1000 | |
| metrics['total_time_ms'] = total_time | |
| metrics['num_contexts'] = len(contexts) | |
| metrics['top_k'] = top_k | |
| metrics['temperature'] = temperature | |
| metrics['max_tokens'] = max_tokens | |
| # Salva no banco | |
| chat_id = db_manager.get_chat_id(session_id) | |
| if chat_id: | |
| db_manager.save_message(chat_id, "user", message) | |
| db_manager.save_message(chat_id, "assistant", response_with_sources) | |
| db_manager.save_query_metric( | |
| session_id, | |
| message, | |
| len(contexts), | |
| retrieve_time, | |
| generate_time, | |
| total_time, | |
| int(top_k) | |
| ) | |
| # Atualiza histórico | |
| new_history = history + [ | |
| {"role": "user", "content": message}, | |
| {"role": "assistant", "content": response_with_sources} | |
| ] | |
| return new_history, contexts_table, prompt, metrics, rerank_comparison_data, query_variations_data | |
| def clear_conversation(): | |
| return [], [], "", {}, [], [] | |
| # Conecta eventos | |
| send_btn.click( | |
| fn=respond, | |
| inputs=[msg_input, chatbot, top_k_chat, temperature_chat, max_tokens_chat, use_reranking_chat, use_query_expansion, expansion_method, num_variations], | |
| outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display] | |
| ).then( | |
| lambda: "", | |
| outputs=[msg_input] | |
| ) | |
| msg_input.submit( | |
| fn=respond, | |
| inputs=[msg_input, chatbot, top_k_chat, temperature_chat, max_tokens_chat, use_reranking_chat, use_query_expansion, expansion_method, num_variations], | |
| outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display] | |
| ).then( | |
| lambda: "", | |
| outputs=[msg_input] | |
| ) | |
| clear_btn.click( | |
| fn=clear_conversation, | |
| outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display] | |
| ) | |
| return { | |
| "chatbot": chatbot, | |
| "msg_input": msg_input, | |
| "send_btn": send_btn | |
| } | |