File size: 4,240 Bytes
e7d525e
 
aac26cf
 
ddbe657
 
 
 
 
aac26cf
ddbe657
 
 
 
 
 
aac26cf
e7d525e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddbe657
e7d525e
ddbe657
 
e7d525e
ddbe657
 
 
 
 
 
 
e7d525e
 
ddbe657
e7d525e
 
 
 
ddbe657
e7d525e
 
ddbe657
e7d525e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddbe657
 
 
 
 
 
 
 
 
 
e7d525e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddbe657
e7d525e
 
ddbe657
 
e7d525e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddbe657
 
 
 
 
 
e7d525e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr

from email_rag.rag_data import THREAD_OPTIONS
from email_rag.rag_sessions import (
    start_session,
    reset_session,
    get_session,
    update_entity_memory,
)
from email_rag.rag_retrieval import (
    rewrite_query,
    retrieve_chunks,
    build_answer,
    log_trace,
    extract_entities_for_turn,
)
from email_rag.rag_timeline import build_timeline


def init_session_ui(thread_id: str):
    if not thread_id:
        return None, "Please select a thread to start."
    sid = start_session(thread_id)
    return sid, f"Started session for thread: {thread_id}"


def chat_ui(user_text: str, session_id: str, search_outside_thread: bool):
    if not session_id:
        return "Please start a session by selecting a thread.", "", ""

    session = get_session(session_id)
    if session is None:
        return "Session not found. Please start again.", "", ""

    # 1) Rewrite query using thread + entity memory
    rewrite = rewrite_query(user_text, session)

    # 2) Retrieve chunks
    retrieved = retrieve_chunks(rewrite, session, search_outside_thread)

    # 3) Extract entities from this turn + retrieved evidence, update memory
    new_entities = extract_entities_for_turn(user_text, retrieved)
    if new_entities:
        update_entity_memory(session_id, new_entities)

    # 4) Build grounded answer
    answer, citations = build_answer(user_text, rewrite, retrieved)

    # 5) Update simple turn memory
    session["recent_turns"].append({"user": user_text, "answer": answer})
    if len(session["recent_turns"]) > 5:
        session["recent_turns"] = session["recent_turns"][-5:]

    # 6) Log trace for this turn
    log_trace(session_id, user_text, rewrite, retrieved, answer, citations)

    # 7) Debug: show retrieved chunk ids + scores
    debug_retrieved = "\n".join(
        [
            f"{r['chunk_id']} (msg={r['message_id']}, "
            f"bm25={r['score_bm25']:.3f}, sem={r['score_sem']:.3f}, "
            f"combined={r['score_combined']:.3f})"
            for r in retrieved
        ]
    )

    return answer, rewrite, debug_retrieved


def reset_session_ui(session_id: str):
    if session_id:
        reset_session(session_id)
    return "", "Session reset."


def timeline_ui(session_id: str):
    if not session_id:
        return "Please start a session by selecting a thread."
    session = get_session(session_id)
    if session is None:
        return "Session not found. Please start again."
    tid = session["thread_id"]
    return build_timeline(tid)


with gr.Blocks() as demo:
    gr.Markdown("# 📧 Email Thread RAG Assistant\nAsk questions about a selected Enron email thread.")

    with gr.Row():
        thread_dd = gr.Dropdown(
            choices=THREAD_OPTIONS,
            label="Select Thread ID",
            value=THREAD_OPTIONS[0] if THREAD_OPTIONS else None,
            interactive=True,
        )
        start_btn = gr.Button("Start Session")
        session_state = gr.State(value=None)
        status_box = gr.Markdown("")

    start_btn.click(
        fn=init_session_ui,
        inputs=[thread_dd],
        outputs=[session_state, status_box],
    )

    with gr.Row():
        user_box = gr.Textbox(label="Your question", lines=2)

    with gr.Row():
        search_toggle = gr.Checkbox(label="Search outside selected thread", value=False)
        ask_btn = gr.Button("Ask")
        reset_btn = gr.Button("Reset Session")
        timeline_btn = gr.Button("Show Timeline")

    answer_box = gr.Markdown(label="Answer")
    timeline_box = gr.Markdown(label="Thread timeline")

    with gr.Accordion("Debug info", open=False):
        rewrite_box = gr.Textbox(label="Rewritten query", interactive=False)
        retrieved_box = gr.Textbox(label="Retrieved chunks", interactive=False)

    ask_btn.click(
        fn=chat_ui,
        inputs=[user_box, session_state, search_toggle],
        outputs=[answer_box, rewrite_box, retrieved_box],
    )

    reset_btn.click(
        fn=reset_session_ui,
        inputs=[session_state],
        outputs=[session_state, status_box],
    )

    timeline_btn.click(
        fn=timeline_ui,
        inputs=[session_state],
        outputs=[timeline_box],
    )

if __name__ == "__main__":
    demo.launch()