File size: 4,936 Bytes
aba4ae4
 
c2756e4
 
 
 
1419aa3
c2756e4
 
 
 
 
1419aa3
c2756e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aba4ae4
c2756e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aba4ae4
 
 
c2756e4
 
 
 
 
 
 
 
 
 
 
 
 
 
aba4ae4
 
c2756e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aba4ae4
c2756e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# api.py
import time
from fastapi import FastAPI, Query, HTTPException
from pydantic import BaseModel
from typing import List, Optional, Any

from email_rag.rag_sessions import (
    start_session,
    reset_session,
    get_session,
    update_entity_memory,
)
from email_rag.rag_retrieval import (
    rewrite_query,
    retrieve_chunks,
    build_answer,
    log_trace,
    extract_entities_for_turn,
)

app = FastAPI(title="Email Thread RAG API")


# ---------- Pydantic models ----------

class StartSessionRequest(BaseModel):
    thread_id: str


class StartSessionResponse(BaseModel):
    session_id: str
    thread_id: str


class AskRequest(BaseModel):
    session_id: str
    text: str
    # body flag (optional); also support query flag ?search_outside_thread=true
    search_outside_thread: Optional[bool] = False


class Citation(BaseModel):
    message_id: str
    page_no: Optional[int] = None
    chunk_id: str


class RetrievedChunk(BaseModel):
    chunk_id: str
    thread_id: str
    message_id: str
    page_no: Optional[int] = None
    source: str
    score_bm25: float
    score_sem: float
    score_combined: float


class AskResponse(BaseModel):
    answer: str
    citations: List[Citation]
    rewrite: str
    retrieved: List[RetrievedChunk]
    trace_id: str
    latency_sec: float   # ⬅️ latency included in response


class SwitchThreadRequest(BaseModel):
    thread_id: str


class ResetSessionRequest(BaseModel):
    session_id: str


# ---------- Endpoints ----------

@app.post("/start_session", response_model=StartSessionResponse)
def api_start_session(payload: StartSessionRequest):
    """
    Start a new session bound to a given thread_id.
    """
    session_id = start_session(payload.thread_id)
    return StartSessionResponse(session_id=session_id, thread_id=payload.thread_id)


@app.post("/ask", response_model=AskResponse)
def api_ask(
    payload: AskRequest,
    search_outside_thread: bool = Query(
        False,
        description="Set to true to allow fallback search outside the active thread.",
    ),
):
    """
    Ask a question within an existing session.

    - Uses thread-scoped retrieval by default.
    - Supports global search fallback via ?search_outside_thread=true
      or payload.search_outside_thread = true.
    """
    session = get_session(payload.session_id)
    if session is None:
        raise HTTPException(status_code=404, detail="Session not found")

    # combine body + query flag (OR)
    search_flag = bool(payload.search_outside_thread or search_outside_thread)

    # ---- measure latency for core RAG pipeline ----
    t0 = time.perf_counter()

    # rewrite using thread + entity memory
    rewrite = rewrite_query(payload.text, session)

    # retrieve chunks
    retrieved = retrieve_chunks(rewrite, session, search_flag)

    # entity memory update
    new_entities = extract_entities_for_turn(payload.text, retrieved)
    if new_entities:
        update_entity_memory(payload.session_id, new_entities)

    # build answer
    answer, citations = build_answer(payload.text, rewrite, retrieved)

    elapsed = time.perf_counter() - t0  # seconds

    # log and get trace_id
    trace_id = log_trace(payload.session_id, payload.text, rewrite, retrieved, answer, citations)

    # format retrieved chunks for response
    retrieved_out = [
        RetrievedChunk(
            chunk_id=r["chunk_id"],
            thread_id=r["thread_id"],
            message_id=r["message_id"],
            page_no=r.get("page_no"),
            source=r.get("source", "email"),
            score_bm25=r["score_bm25"],
            score_sem=r["score_sem"],
            score_combined=r["score_combined"],
        )
        for r in retrieved
    ]

    citations_out = [
        Citation(
            message_id=c["message_id"],
            page_no=c.get("page_no"),
            chunk_id=c["chunk_id"],
        )
        for c in citations
    ]

    return AskResponse(
        answer=answer,
        citations=citations_out,
        rewrite=rewrite,
        retrieved=retrieved_out,
        trace_id=trace_id,
        latency_sec=elapsed,
    )


@app.post("/switch_thread", response_model=StartSessionResponse)
def api_switch_thread(payload: SwitchThreadRequest):
    """
    Simplest interpretation: switching thread = start a new session on that thread.

    (Keeps the API contract: { "thread_id": "..." } → session info)
    """
    session_id = start_session(payload.thread_id)
    return StartSessionResponse(session_id=session_id, thread_id=payload.thread_id)


@app.post("/reset_session")
def api_reset_session(payload: ResetSessionRequest):
    """
    Reset an existing session's memory (same behavior as UI reset).
    """
    if get_session(payload.session_id) is None:
        raise HTTPException(status_code=404, detail="Session not found")

    reset_session(payload.session_id)
    return {"status": "ok", "session_id": payload.session_id}