Yuvraj Sarathe commited on
Commit
2927bc6
Β·
unverified Β·
1 Parent(s): d27b67c

Add comments to clarify RAG pipeline stages

Browse files
Files changed (1) hide show
  1. backend/app/rag/agent.py +19 -0
backend/app/rag/agent.py CHANGED
@@ -74,11 +74,14 @@ def generate_answer(
74
  Full RAG pipeline: retrieve β†’ build context β†’ generate answer.
75
  Returns dict with 'answer' and 'sources'.
76
  """
 
77
  client = get_llm_client()
78
 
79
  # ── Handle greetings ─────────────────────────────
 
80
  if is_greeting(question):
81
  try:
 
82
  messages = _chat_messages(
83
  "You are Document AI Analyst, a friendly AI assistant for document analysis.",
84
  question,
@@ -96,6 +99,7 @@ def generate_answer(
96
  return {"answer": answer, "sources": []}
97
 
98
  # ── Retrieve relevant chunks ─────────────────────
 
99
  chunks = retrieve(
100
  query=question,
101
  user_id=user_id,
@@ -103,11 +107,13 @@ def generate_answer(
103
  )
104
 
105
  # ── Build prompt ─────────────────────────────────
 
106
  context = build_context(chunks)
107
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
108
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
109
 
110
  # ── Generate answer ──────────────────────────────
 
111
  try:
112
  response = client.chat_completion(
113
  messages=messages,
@@ -124,6 +130,7 @@ def generate_answer(
124
  answer = f"I encountered an error generating a response. Please try again. Error: {str(e)}"
125
 
126
  # ── Format sources ───────────────────────────────
 
127
  sources = [
128
  {
129
  "text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""),
@@ -147,17 +154,22 @@ def generate_answer_stream(
147
  Streaming RAG pipeline β€” yields SSE-formatted chunks.
148
  First yields sources, then streams answer tokens.
149
  """
 
150
  client = get_llm_client()
151
 
152
  # ── Handle greetings ─────────────────────────────
 
153
  if is_greeting(question):
 
154
  yield f"data: {json.dumps({'type': 'sources', 'data': []})}\n\n"
155
 
156
  try:
 
157
  messages = _chat_messages(
158
  "You are Document AI Analyst, a friendly AI assistant for document analysis.",
159
  question,
160
  )
 
161
  stream = client.chat_completion(
162
  messages=messages,
163
  model=settings.LLM_MODEL,
@@ -173,10 +185,12 @@ def generate_answer_stream(
173
  except Exception as e:
174
  yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
175
 
 
176
  yield f"data: {json.dumps({'type': 'done'})}\n\n"
177
  return
178
 
179
  # ── Retrieve relevant chunks ─────────────────────
 
180
  chunks = retrieve(
181
  query=question,
182
  user_id=user_id,
@@ -184,6 +198,7 @@ def generate_answer_stream(
184
  )
185
 
186
  # ── Yield sources first ──────────────────────────
 
187
  sources = [
188
  {
189
  "text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""),
@@ -197,11 +212,13 @@ def generate_answer_stream(
197
  yield f"data: {json.dumps({'type': 'sources', 'data': sources})}\n\n"
198
 
199
  # ── Build prompt ─────────────────────────────────
 
200
  context = build_context(chunks)
201
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
202
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
203
 
204
  # ── Stream answer tokens ─────────────────────────
 
205
  try:
206
  stream = client.chat_completion(
207
  messages=messages,
@@ -216,8 +233,10 @@ def generate_answer_stream(
216
  if delta:
217
  yield f"data: {json.dumps({'type': 'token', 'data': delta})}\n\n"
218
 
 
219
  except Exception as e:
220
  logger.error(f"LLM streaming error: {e}")
221
  yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
222
 
 
223
  yield f"data: {json.dumps({'type': 'done'})}\n\n"
 
74
  Full RAG pipeline: retrieve β†’ build context β†’ generate answer.
75
  Returns dict with 'answer' and 'sources'.
76
  """
77
+ # Get HuggingFace InferenceClient singleton (created once, reused)
78
  client = get_llm_client()
79
 
80
  # ── Handle greetings ─────────────────────────────
81
+ # Short-circuit: if user just says "hello", skip RAG entirely
82
  if is_greeting(question):
83
  try:
84
+ # Send greeting to LLM with a friendly system prompt (no document context)
85
  messages = _chat_messages(
86
  "You are Document AI Analyst, a friendly AI assistant for document analysis.",
87
  question,
 
99
  return {"answer": answer, "sources": []}
100
 
101
  # ── Retrieve relevant chunks ─────────────────────
102
+ # STAGE 1+2: Semantic search (ChromaDB) + cross-encoder reranking β†’ top 5 chunks
103
  chunks = retrieve(
104
  query=question,
105
  user_id=user_id,
 
107
  )
108
 
109
  # ── Build prompt ─────────────────────────────────
110
+ # Format retrieved chunks into a readable context block, then inject into the RAG prompt template
111
  context = build_context(chunks)
112
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
113
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
114
 
115
  # ── Generate answer ──────────────────────────────
116
+ # STAGE 3: Send prompt to HuggingFace Inference API and get the generated answer
117
  try:
118
  response = client.chat_completion(
119
  messages=messages,
 
130
  answer = f"I encountered an error generating a response. Please try again. Error: {str(e)}"
131
 
132
  # ── Format sources ───────────────────────────────
133
+ # Truncate chunk text to 300 chars and attach metadata (filename, page, score, confidence) for frontend citation display
134
  sources = [
135
  {
136
  "text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""),
 
154
  Streaming RAG pipeline β€” yields SSE-formatted chunks.
155
  First yields sources, then streams answer tokens.
156
  """
157
+ # Get HuggingFace InferenceClient singleton (created once, reused)
158
  client = get_llm_client()
159
 
160
  # ── Handle greetings ─────────────────────────────
161
+ # Short-circuit: if user just says "hello", skip RAG entirely
162
  if is_greeting(question):
163
+ # Yield empty sources array first so frontend resets its citation display
164
  yield f"data: {json.dumps({'type': 'sources', 'data': []})}\n\n"
165
 
166
  try:
167
+ # Send greeting to LLM with a friendly system prompt (no document context)
168
  messages = _chat_messages(
169
  "You are Document AI Analyst, a friendly AI assistant for document analysis.",
170
  question,
171
  )
172
+ # Stream greeting response token-by-token via SSE
173
  stream = client.chat_completion(
174
  messages=messages,
175
  model=settings.LLM_MODEL,
 
185
  except Exception as e:
186
  yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
187
 
188
+ # Signal end of stream, then exit early (no RAG)
189
  yield f"data: {json.dumps({'type': 'done'})}\n\n"
190
  return
191
 
192
  # ── Retrieve relevant chunks ─────────────────────
193
+ # STAGE 1+2: Semantic search (ChromaDB) + cross-encoder reranking β†’ top 5 chunks
194
  chunks = retrieve(
195
  query=question,
196
  user_id=user_id,
 
198
  )
199
 
200
  # ── Yield sources first ──────────────────────────
201
+ # Yield all sources first β€” frontend needs them to render citation cards before the answer starts appearing
202
  sources = [
203
  {
204
  "text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""),
 
212
  yield f"data: {json.dumps({'type': 'sources', 'data': sources})}\n\n"
213
 
214
  # ── Build prompt ─────────────────────────────────
215
+ # Format retrieved chunks into a readable context block, then inject into the RAG prompt template
216
  context = build_context(chunks)
217
  user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
218
  messages = _chat_messages(SYSTEM_PROMPT, user_content)
219
 
220
  # ── Stream answer tokens ─────────────────────────
221
+ # STAGE 3: Stream tokens from HuggingFace Inference API β†’ forward each as an SSE 'token' event
222
  try:
223
  stream = client.chat_completion(
224
  messages=messages,
 
233
  if delta:
234
  yield f"data: {json.dumps({'type': 'token', 'data': delta})}\n\n"
235
 
236
+ # If LLM fails mid-stream, yield an error event so frontend can display the message
237
  except Exception as e:
238
  logger.error(f"LLM streaming error: {e}")
239
  yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
240
 
241
+ # Signal end of stream to frontend (stops the streaming indicator)
242
  yield f"data: {json.dumps({'type': 'done'})}\n\n"