File size: 13,272 Bytes
f5eb34f
 
 
 
 
 
 
 
 
 
 
1b447de
f5eb34f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b447de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5eb34f
 
 
 
 
 
 
1b447de
 
 
 
 
 
 
 
 
 
 
 
 
 
f5eb34f
 
 
 
 
 
 
 
 
 
 
 
 
 
1b447de
 
 
 
 
 
 
 
 
 
 
f5eb34f
1b447de
f5eb34f
 
 
 
1b447de
 
 
 
 
 
 
 
 
 
 
f5eb34f
 
 
1b447de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5eb34f
 
 
1b447de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f5eb34f
 
 
 
1b447de
f5eb34f
 
1b447de
f5eb34f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b447de
f5eb34f
 
1b447de
f5eb34f
 
 
 
1b447de
 
f5eb34f
 
 
 
 
 
 
1b447de
 
f5eb34f
 
 
 
 
 
 
1b447de
f5eb34f
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
"""
Aba de Chat RAG Interativo
Chat com painel lateral mostrando contextos recuperados e processo
"""
import time
import uuid
import gradio as gr
from typing import List, Dict, Any
from src.database import DatabaseManager
from src.embeddings import EmbeddingManager
from src.generation import GenerationManager
from src.query_expansion import QueryExpander


def create_chat_tab(
    db_manager: DatabaseManager,
    embedding_manager: EmbeddingManager,
    generation_manager: GenerationManager,
    session_id: str
):
    """Cria aba de chat RAG interativo"""

    with gr.Tab(" Chat RAG"):
        gr.Markdown("""
        ## Chat com Retrieval-Augmented Generation

        Faça perguntas e veja o processo RAG em ação:
        - Contextos são recuperados da base de conhecimento
        - O LLM usa esses contextos para gerar respostas precisas
        - Acompanhe cada passo do processo em tempo real
        """)

        with gr.Row():
            with gr.Column(scale=2):
                chatbot = gr.Chatbot(
                    label="Conversa",
                    height=500
                )

                with gr.Row():
                    msg_input = gr.Textbox(
                        label="Sua mensagem",
                        placeholder="Digite sua pergunta...",
                        lines=2,
                        scale=4
                    )
                    send_btn = gr.Button(" Enviar", variant="primary", scale=1)

                clear_btn = gr.Button("🗑 Limpar Conversa")

            with gr.Column(scale=1):
                gr.Markdown("###  Painel de Processo")

                with gr.Accordion("  Configurações", open=True):
                    top_k_chat = gr.Slider(
                        minimum=1,
                        maximum=10,
                        value=4,
                        step=1,
                        label="Top K (chunks a recuperar)"
                    )

                    temperature_chat = gr.Slider(
                        minimum=0.0,
                        maximum=2.0,
                        value=0.3,
                        step=0.1,
                        label="Temperature"
                    )

                    max_tokens_chat = gr.Slider(
                        minimum=50,
                        maximum=2048,
                        value=512,
                        step=50,
                        label="Max Tokens"
                    )

                    use_reranking_chat = gr.Checkbox(
                        label="Usar Reranking",
                        value=True,
                        info="Reordena resultados com cross-encoder para melhor precisão"
                    )

                    use_query_expansion = gr.Checkbox(
                        label="Usar Query Expansion",
                        value=False,
                        info="Gera múltiplas variações da query para melhor cobertura"
                    )

                    expansion_method = gr.Radio(
                        choices=["llm", "template", "paraphrase"],
                        value="llm",
                        label="Método de Expansão",
                        info="LLM: melhor qualidade | Template: mais rápido | Paraphrase: balanceado",
                        visible=False
                    )

                    num_variations = gr.Slider(
                        minimum=1,
                        maximum=5,
                        value=2,
                        step=1,
                        label="Número de Variações",
                        info="Queries adicionais a gerar",
                        visible=False
                    )

                with gr.Accordion(" Contextos Recuperados", open=True):
                    contexts_display = gr.Dataframe(
                        headers=["Rank", "Score", "Fonte", "Preview"],
                        label="Chunks Relevantes",
                        wrap=True
                    )

                with gr.Accordion(" Impacto do Reranking", open=False):
                    rerank_comparison = gr.Dataframe(
                        headers=["Novo Rank", "Rank Original", "Score Original", "Score Rerank", "Mudança"],
                        label="Comparação Antes/Depois",
                        wrap=True
                    )

                with gr.Accordion(" Expansão de Query", open=False):
                    query_variations_display = gr.Dataframe(
                        headers=["#", "Query", "Resultados"],
                        label="Queries Geradas",
                        wrap=True
                    )

                with gr.Accordion(" Prompt Construído", open=False):
                    prompt_display = gr.Textbox(
                        label="Prompt enviado ao LLM",
                        lines=10,
                        max_lines=20,
                        interactive=False
                    )

                with gr.Accordion("  Métricas de Performance", open=False):
                    metrics_display = gr.JSON(label="Tempos de Processamento")

        # Estado da conversa
        conversation_state = gr.State([])

        # Toggle visibility dos controles de expansão
        def toggle_expansion_controls(enabled):
            return gr.update(visible=enabled), gr.update(visible=enabled)

        use_query_expansion.change(
            fn=toggle_expansion_controls,
            inputs=[use_query_expansion],
            outputs=[expansion_method, num_variations]
        )

        def respond(message, history, top_k, temperature, max_tokens, use_reranking, use_expansion, method, n_vars):
            if not message or not message.strip():
                return history, [], "", {}, [], []

            # Métricas
            total_start = time.time()
            metrics = {}
            query_variations_data = []

            # Passo 0: Query Expansion (se ativado)
            queries_to_search = [message]
            if use_expansion:
                expansion_start = time.time()
                expander = QueryExpander(generation_manager)
                queries_to_search = expander.expand_query(message, num_variations=int(n_vars), method=method)
                expansion_time = (time.time() - expansion_start) * 1000
                metrics['expansion_time_ms'] = expansion_time
                metrics['num_queries'] = len(queries_to_search)

            # Passo 1: Retrieve
            retrieve_start = time.time()

            # Se usar expansão, busca com cada query e combina resultados
            if use_expansion and len(queries_to_search) > 1:
                all_contexts = []
                seen_ids = set()

                for i, query in enumerate(queries_to_search, 1):
                    query_embedding = embedding_manager.encode_single(query, normalize=True)
                    retrieve_k = int(top_k) * 2 if use_reranking else int(top_k)
                    query_contexts = db_manager.search_similar(query_embedding, k=retrieve_k, session_id=session_id)

                    # Adiciona à lista de variações para display
                    query_variations_data.append([i, query, len(query_contexts)])

                    # Combina resultados evitando duplicatas
                    for ctx in query_contexts:
                        if ctx['id'] not in seen_ids:
                            all_contexts.append(ctx)
                            seen_ids.add(ctx['id'])

                # Ordena por score e pega top-K * 2
                all_contexts.sort(key=lambda x: x.get('score', 0), reverse=True)
                retrieve_k = int(top_k) * 2 if use_reranking else int(top_k)
                contexts = all_contexts[:retrieve_k]
            else:
                # Busca normal com query única
                query_embedding = embedding_manager.encode_single(message, normalize=True)
                retrieve_k = int(top_k) * 2 if use_reranking else int(top_k)
                contexts = db_manager.search_similar(query_embedding, k=retrieve_k, session_id=session_id)

            retrieve_time = (time.time() - retrieve_start) * 1000
            metrics['retrieval_time_ms'] = retrieve_time

            # Guarda contextos originais para comparação
            original_contexts = contexts.copy() if use_reranking else []

            # Passo 1.5: Reranking (se ativado)
            rerank_comparison_data = []
            if use_reranking and contexts:
                from src.reranking import Reranker
                rerank_start = time.time()
                reranker = Reranker()
                contexts = reranker.rerank(message, contexts, top_k=int(top_k))
                rerank_time = (time.time() - rerank_start) * 1000
                metrics['reranking_time_ms'] = rerank_time

                # Gera dados de comparação
                for i, ctx in enumerate(contexts, 1):
                    # Encontra posição original
                    original_pos = next((j+1 for j, c in enumerate(original_contexts) if c['id'] == ctx['id']), -1)
                    position_change = original_pos - i if original_pos != -1 else 0
                    rerank_comparison_data.append([
                        i,
                        original_pos,
                        f"{ctx.get('original_score', 0.0):.4f}",
                        f"{ctx.get('rerank_score', 0.0):.4f}",
                        f"+{position_change}" if position_change > 0 else str(position_change)
                    ])

            # Prepara display de contextos
            contexts_table = []
            for i, ctx in enumerate(contexts, 1):
                preview = ctx['content'][:60] + "..." if len(ctx['content']) > 60 else ctx['content']
                score = ctx.get('rerank_score', ctx.get('score', 0.0))
                contexts_table.append([
                    i,
                    f"{score:.4f}",
                    ctx['title'],
                    preview
                ])

            # Passo 2: Build prompt
            prompt_start = time.time()
            prompt = generation_manager.build_rag_prompt(message, contexts)
            prompt_time = (time.time() - prompt_start) * 1000
            metrics['prompt_build_time_ms'] = prompt_time

            # Passo 3: Generate
            generate_start = time.time()
            response = generation_manager.generate(
                prompt,
                temperature=float(temperature),
                max_tokens=int(max_tokens)
            )
            generate_time = (time.time() - generate_start) * 1000
            metrics['generation_time_ms'] = generate_time

            # Adiciona fontes à resposta
            response_with_sources = response + "\n" + generation_manager.format_sources(contexts)

            # Tempo total
            total_time = (time.time() - total_start) * 1000
            metrics['total_time_ms'] = total_time
            metrics['num_contexts'] = len(contexts)
            metrics['top_k'] = top_k
            metrics['temperature'] = temperature
            metrics['max_tokens'] = max_tokens

            # Salva no banco
            chat_id = db_manager.get_chat_id(session_id)
            if chat_id:
                db_manager.save_message(chat_id, "user", message)
                db_manager.save_message(chat_id, "assistant", response_with_sources)
                db_manager.save_query_metric(
                    session_id,
                    message,
                    len(contexts),
                    retrieve_time,
                    generate_time,
                    total_time,
                    int(top_k)
                )

            # Atualiza histórico
            new_history = history + [
                {"role": "user", "content": message},
                {"role": "assistant", "content": response_with_sources}
            ]

            return new_history, contexts_table, prompt, metrics, rerank_comparison_data, query_variations_data

        def clear_conversation():
            return [], [], "", {}, [], []

        # Conecta eventos
        send_btn.click(
            fn=respond,
            inputs=[msg_input, chatbot, top_k_chat, temperature_chat, max_tokens_chat, use_reranking_chat, use_query_expansion, expansion_method, num_variations],
            outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display]
        ).then(
            lambda: "",
            outputs=[msg_input]
        )

        msg_input.submit(
            fn=respond,
            inputs=[msg_input, chatbot, top_k_chat, temperature_chat, max_tokens_chat, use_reranking_chat, use_query_expansion, expansion_method, num_variations],
            outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display]
        ).then(
            lambda: "",
            outputs=[msg_input]
        )

        clear_btn.click(
            fn=clear_conversation,
            outputs=[chatbot, contexts_display, prompt_display, metrics_display, rerank_comparison, query_variations_display]
        )

    return {
        "chatbot": chatbot,
        "msg_input": msg_input,
        "send_btn": send_btn
    }