Spaces:
Running
Running
Yuvraj Sarathe commited on
Add comments to clarify RAG pipeline stages
Browse files- backend/app/rag/agent.py +19 -0
backend/app/rag/agent.py
CHANGED
|
@@ -74,11 +74,14 @@ def generate_answer(
|
|
| 74 |
Full RAG pipeline: retrieve β build context β generate answer.
|
| 75 |
Returns dict with 'answer' and 'sources'.
|
| 76 |
"""
|
|
|
|
| 77 |
client = get_llm_client()
|
| 78 |
|
| 79 |
# ββ Handle greetings βββββββββββββββββββββββββββββ
|
|
|
|
| 80 |
if is_greeting(question):
|
| 81 |
try:
|
|
|
|
| 82 |
messages = _chat_messages(
|
| 83 |
"You are Document AI Analyst, a friendly AI assistant for document analysis.",
|
| 84 |
question,
|
|
@@ -96,6 +99,7 @@ def generate_answer(
|
|
| 96 |
return {"answer": answer, "sources": []}
|
| 97 |
|
| 98 |
# ββ Retrieve relevant chunks βββββββββββββββββββββ
|
|
|
|
| 99 |
chunks = retrieve(
|
| 100 |
query=question,
|
| 101 |
user_id=user_id,
|
|
@@ -103,11 +107,13 @@ def generate_answer(
|
|
| 103 |
)
|
| 104 |
|
| 105 |
# ββ Build prompt βββββββββββββββββββββββββββββββββ
|
|
|
|
| 106 |
context = build_context(chunks)
|
| 107 |
user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
|
| 108 |
messages = _chat_messages(SYSTEM_PROMPT, user_content)
|
| 109 |
|
| 110 |
# ββ Generate answer ββββββββββββββββββββββββββββββ
|
|
|
|
| 111 |
try:
|
| 112 |
response = client.chat_completion(
|
| 113 |
messages=messages,
|
|
@@ -124,6 +130,7 @@ def generate_answer(
|
|
| 124 |
answer = f"I encountered an error generating a response. Please try again. Error: {str(e)}"
|
| 125 |
|
| 126 |
# ββ Format sources βββββββββββββββββββββββββββββββ
|
|
|
|
| 127 |
sources = [
|
| 128 |
{
|
| 129 |
"text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""),
|
|
@@ -147,17 +154,22 @@ def generate_answer_stream(
|
|
| 147 |
Streaming RAG pipeline β yields SSE-formatted chunks.
|
| 148 |
First yields sources, then streams answer tokens.
|
| 149 |
"""
|
|
|
|
| 150 |
client = get_llm_client()
|
| 151 |
|
| 152 |
# ββ Handle greetings βββββββββββββββββββββββββββββ
|
|
|
|
| 153 |
if is_greeting(question):
|
|
|
|
| 154 |
yield f"data: {json.dumps({'type': 'sources', 'data': []})}\n\n"
|
| 155 |
|
| 156 |
try:
|
|
|
|
| 157 |
messages = _chat_messages(
|
| 158 |
"You are Document AI Analyst, a friendly AI assistant for document analysis.",
|
| 159 |
question,
|
| 160 |
)
|
|
|
|
| 161 |
stream = client.chat_completion(
|
| 162 |
messages=messages,
|
| 163 |
model=settings.LLM_MODEL,
|
|
@@ -173,10 +185,12 @@ def generate_answer_stream(
|
|
| 173 |
except Exception as e:
|
| 174 |
yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
|
| 175 |
|
|
|
|
| 176 |
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
| 177 |
return
|
| 178 |
|
| 179 |
# ββ Retrieve relevant chunks βββββββββββββββββββββ
|
|
|
|
| 180 |
chunks = retrieve(
|
| 181 |
query=question,
|
| 182 |
user_id=user_id,
|
|
@@ -184,6 +198,7 @@ def generate_answer_stream(
|
|
| 184 |
)
|
| 185 |
|
| 186 |
# ββ Yield sources first ββββββββββββββββββββββββββ
|
|
|
|
| 187 |
sources = [
|
| 188 |
{
|
| 189 |
"text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""),
|
|
@@ -197,11 +212,13 @@ def generate_answer_stream(
|
|
| 197 |
yield f"data: {json.dumps({'type': 'sources', 'data': sources})}\n\n"
|
| 198 |
|
| 199 |
# ββ Build prompt βββββββββββββββββββββββββββββββββ
|
|
|
|
| 200 |
context = build_context(chunks)
|
| 201 |
user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
|
| 202 |
messages = _chat_messages(SYSTEM_PROMPT, user_content)
|
| 203 |
|
| 204 |
# ββ Stream answer tokens βββββββββββββββββββββββββ
|
|
|
|
| 205 |
try:
|
| 206 |
stream = client.chat_completion(
|
| 207 |
messages=messages,
|
|
@@ -216,8 +233,10 @@ def generate_answer_stream(
|
|
| 216 |
if delta:
|
| 217 |
yield f"data: {json.dumps({'type': 'token', 'data': delta})}\n\n"
|
| 218 |
|
|
|
|
| 219 |
except Exception as e:
|
| 220 |
logger.error(f"LLM streaming error: {e}")
|
| 221 |
yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
|
| 222 |
|
|
|
|
| 223 |
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
|
|
|
| 74 |
Full RAG pipeline: retrieve β build context β generate answer.
|
| 75 |
Returns dict with 'answer' and 'sources'.
|
| 76 |
"""
|
| 77 |
+
# Get HuggingFace InferenceClient singleton (created once, reused)
|
| 78 |
client = get_llm_client()
|
| 79 |
|
| 80 |
# ββ Handle greetings βββββββββββββββββββββββββββββ
|
| 81 |
+
# Short-circuit: if user just says "hello", skip RAG entirely
|
| 82 |
if is_greeting(question):
|
| 83 |
try:
|
| 84 |
+
# Send greeting to LLM with a friendly system prompt (no document context)
|
| 85 |
messages = _chat_messages(
|
| 86 |
"You are Document AI Analyst, a friendly AI assistant for document analysis.",
|
| 87 |
question,
|
|
|
|
| 99 |
return {"answer": answer, "sources": []}
|
| 100 |
|
| 101 |
# ββ Retrieve relevant chunks βββββββββββββββββββββ
|
| 102 |
+
# STAGE 1+2: Semantic search (ChromaDB) + cross-encoder reranking β top 5 chunks
|
| 103 |
chunks = retrieve(
|
| 104 |
query=question,
|
| 105 |
user_id=user_id,
|
|
|
|
| 107 |
)
|
| 108 |
|
| 109 |
# ββ Build prompt βββββββββββββββββββββββββββββββββ
|
| 110 |
+
# Format retrieved chunks into a readable context block, then inject into the RAG prompt template
|
| 111 |
context = build_context(chunks)
|
| 112 |
user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
|
| 113 |
messages = _chat_messages(SYSTEM_PROMPT, user_content)
|
| 114 |
|
| 115 |
# ββ Generate answer ββββββββββββββββββββββββββββββ
|
| 116 |
+
# STAGE 3: Send prompt to HuggingFace Inference API and get the generated answer
|
| 117 |
try:
|
| 118 |
response = client.chat_completion(
|
| 119 |
messages=messages,
|
|
|
|
| 130 |
answer = f"I encountered an error generating a response. Please try again. Error: {str(e)}"
|
| 131 |
|
| 132 |
# ββ Format sources βββββββββββββββββββββββββββββββ
|
| 133 |
+
# Truncate chunk text to 300 chars and attach metadata (filename, page, score, confidence) for frontend citation display
|
| 134 |
sources = [
|
| 135 |
{
|
| 136 |
"text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""),
|
|
|
|
| 154 |
Streaming RAG pipeline β yields SSE-formatted chunks.
|
| 155 |
First yields sources, then streams answer tokens.
|
| 156 |
"""
|
| 157 |
+
# Get HuggingFace InferenceClient singleton (created once, reused)
|
| 158 |
client = get_llm_client()
|
| 159 |
|
| 160 |
# ββ Handle greetings βββββββββββββββββββββββββββββ
|
| 161 |
+
# Short-circuit: if user just says "hello", skip RAG entirely
|
| 162 |
if is_greeting(question):
|
| 163 |
+
# Yield empty sources array first so frontend resets its citation display
|
| 164 |
yield f"data: {json.dumps({'type': 'sources', 'data': []})}\n\n"
|
| 165 |
|
| 166 |
try:
|
| 167 |
+
# Send greeting to LLM with a friendly system prompt (no document context)
|
| 168 |
messages = _chat_messages(
|
| 169 |
"You are Document AI Analyst, a friendly AI assistant for document analysis.",
|
| 170 |
question,
|
| 171 |
)
|
| 172 |
+
# Stream greeting response token-by-token via SSE
|
| 173 |
stream = client.chat_completion(
|
| 174 |
messages=messages,
|
| 175 |
model=settings.LLM_MODEL,
|
|
|
|
| 185 |
except Exception as e:
|
| 186 |
yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
|
| 187 |
|
| 188 |
+
# Signal end of stream, then exit early (no RAG)
|
| 189 |
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
| 190 |
return
|
| 191 |
|
| 192 |
# ββ Retrieve relevant chunks βββββββββββββββββββββ
|
| 193 |
+
# STAGE 1+2: Semantic search (ChromaDB) + cross-encoder reranking β top 5 chunks
|
| 194 |
chunks = retrieve(
|
| 195 |
query=question,
|
| 196 |
user_id=user_id,
|
|
|
|
| 198 |
)
|
| 199 |
|
| 200 |
# ββ Yield sources first ββββββββββββββββββββββββββ
|
| 201 |
+
# Yield all sources first β frontend needs them to render citation cards before the answer starts appearing
|
| 202 |
sources = [
|
| 203 |
{
|
| 204 |
"text": chunk["text"][:300] + ("..." if len(chunk["text"]) > 300 else ""),
|
|
|
|
| 212 |
yield f"data: {json.dumps({'type': 'sources', 'data': sources})}\n\n"
|
| 213 |
|
| 214 |
# ββ Build prompt βββββββββββββββββββββββββββββββββ
|
| 215 |
+
# Format retrieved chunks into a readable context block, then inject into the RAG prompt template
|
| 216 |
context = build_context(chunks)
|
| 217 |
user_content = RAG_PROMPT_TEMPLATE.format(context=context, question=question)
|
| 218 |
messages = _chat_messages(SYSTEM_PROMPT, user_content)
|
| 219 |
|
| 220 |
# ββ Stream answer tokens βββββββββββββββββββββββββ
|
| 221 |
+
# STAGE 3: Stream tokens from HuggingFace Inference API β forward each as an SSE 'token' event
|
| 222 |
try:
|
| 223 |
stream = client.chat_completion(
|
| 224 |
messages=messages,
|
|
|
|
| 233 |
if delta:
|
| 234 |
yield f"data: {json.dumps({'type': 'token', 'data': delta})}\n\n"
|
| 235 |
|
| 236 |
+
# If LLM fails mid-stream, yield an error event so frontend can display the message
|
| 237 |
except Exception as e:
|
| 238 |
logger.error(f"LLM streaming error: {e}")
|
| 239 |
yield f"data: {json.dumps({'type': 'error', 'data': str(e)})}\n\n"
|
| 240 |
|
| 241 |
+
# Signal end of stream to frontend (stops the streaming indicator)
|
| 242 |
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|