raviix46 commited on
Commit
c2756e4
·
verified ·
1 Parent(s): 5664597

Create api.py

Browse files
Files changed (1) hide show
  1. api.py +176 -0
api.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Query, HTTPException
2
+ from pydantic import BaseModel
3
+ from typing import List, Optional, Any
4
+
5
+ from rag_sessions import (
6
+ start_session,
7
+ reset_session,
8
+ get_session,
9
+ update_entity_memory,
10
+ )
11
+ from rag_retrieval import (
12
+ rewrite_query,
13
+ retrieve_chunks,
14
+ build_answer,
15
+ log_trace,
16
+ extract_entities_for_turn,
17
+ )
18
+
19
+ app = FastAPI(title="Email Thread RAG API")
20
+
21
+
22
+ # ---------- Pydantic models ----------
23
+
24
+ class StartSessionRequest(BaseModel):
25
+ thread_id: str
26
+
27
+
28
+ class StartSessionResponse(BaseModel):
29
+ session_id: str
30
+ thread_id: str
31
+
32
+
33
+ class AskRequest(BaseModel):
34
+ session_id: str
35
+ text: str
36
+ # body flag (optional); also support query flag ?search_outside_thread=true
37
+ search_outside_thread: Optional[bool] = False
38
+
39
+
40
+ class Citation(BaseModel):
41
+ message_id: str
42
+ page_no: Optional[int] = None
43
+ chunk_id: str
44
+
45
+
46
+ class RetrievedChunk(BaseModel):
47
+ chunk_id: str
48
+ thread_id: str
49
+ message_id: str
50
+ page_no: Optional[int] = None
51
+ source: str
52
+ score_bm25: float
53
+ score_sem: float
54
+ score_combined: float
55
+
56
+
57
+ class AskResponse(BaseModel):
58
+ answer: str
59
+ citations: List[Citation]
60
+ rewrite: str
61
+ retrieved: List[RetrievedChunk]
62
+ trace_id: str
63
+
64
+
65
+ class SwitchThreadRequest(BaseModel):
66
+ thread_id: str
67
+
68
+
69
+ class ResetSessionRequest(BaseModel):
70
+ session_id: str
71
+
72
+
73
+ # ---------- Endpoints ----------
74
+
75
+ @app.post("/start_session", response_model=StartSessionResponse)
76
+ def api_start_session(payload: StartSessionRequest):
77
+ """
78
+ Start a new session bound to a given thread_id.
79
+ """
80
+ session_id = start_session(payload.thread_id)
81
+ return StartSessionResponse(session_id=session_id, thread_id=payload.thread_id)
82
+
83
+
84
+ @app.post("/ask", response_model=AskResponse)
85
+ def api_ask(
86
+ payload: AskRequest,
87
+ search_outside_thread: bool = Query(
88
+ False,
89
+ description="Set to true to allow fallback search outside the active thread.",
90
+ ),
91
+ ):
92
+ """
93
+ Ask a question within an existing session.
94
+
95
+ - Uses thread-scoped retrieval by default.
96
+ - Supports global search fallback via ?search_outside_thread=true
97
+ or payload.search_outside_thread = true.
98
+ """
99
+ session = get_session(payload.session_id)
100
+ if session is None:
101
+ raise HTTPException(status_code=404, detail="Session not found")
102
+
103
+ # combine body + query flag (OR)
104
+ search_flag = bool(payload.search_outside_thread or search_outside_thread)
105
+
106
+ # rewrite using thread + entity memory
107
+ rewrite = rewrite_query(payload.text, session)
108
+
109
+ # retrieve chunks
110
+ retrieved = retrieve_chunks(rewrite, session, search_flag)
111
+
112
+ # entity memory update
113
+ new_entities = extract_entities_for_turn(payload.text, retrieved)
114
+ if new_entities:
115
+ update_entity_memory(payload.session_id, new_entities)
116
+
117
+ # build answer
118
+ answer, citations = build_answer(payload.text, rewrite, retrieved)
119
+
120
+ # log and get trace_id
121
+ trace_id = log_trace(payload.session_id, payload.text, rewrite, retrieved, answer, citations)
122
+
123
+ # format retrieved chunks for response
124
+ retrieved_out = [
125
+ RetrievedChunk(
126
+ chunk_id=r["chunk_id"],
127
+ thread_id=r["thread_id"],
128
+ message_id=r["message_id"],
129
+ page_no=r.get("page_no"),
130
+ source=r.get("source", "email"),
131
+ score_bm25=r["score_bm25"],
132
+ score_sem=r["score_sem"],
133
+ score_combined=r["score_combined"],
134
+ )
135
+ for r in retrieved
136
+ ]
137
+
138
+ citations_out = [
139
+ Citation(
140
+ message_id=c["message_id"],
141
+ page_no=c.get("page_no"),
142
+ chunk_id=c["chunk_id"],
143
+ )
144
+ for c in citations
145
+ ]
146
+
147
+ return AskResponse(
148
+ answer=answer,
149
+ citations=citations_out,
150
+ rewrite=rewrite,
151
+ retrieved=retrieved_out,
152
+ trace_id=trace_id,
153
+ )
154
+
155
+
156
+ @app.post("/switch_thread", response_model=StartSessionResponse)
157
+ def api_switch_thread(payload: SwitchThreadRequest):
158
+ """
159
+ Simplest interpretation: switching thread = start a new session on that thread.
160
+
161
+ (Keeps the API contract: { "thread_id": "..." } → session info)
162
+ """
163
+ session_id = start_session(payload.thread_id)
164
+ return StartSessionResponse(session_id=session_id, thread_id=payload.thread_id)
165
+
166
+
167
+ @app.post("/reset_session")
168
+ def api_reset_session(payload: ResetSessionRequest):
169
+ """
170
+ Reset an existing session's memory (same behavior as UI reset).
171
+ """
172
+ if get_session(payload.session_id) is None:
173
+ raise HTTPException(status_code=404, detail="Session not found")
174
+
175
+ reset_session(payload.session_id)
176
+ return {"status": "ok", "session_id": payload.session_id}