Spaces:
Running
Running
Asish Karthikeya Gogineni commited on
Commit ·
986715f
1
Parent(s): 4af2457
fix: Optimize RAG retrieval and enable streaming
Browse files- Increased retrieval context: Top-30 docs (was 5)
- Removed context truncation: Providing full file content to LLM
- Boosted README.md priority in retrieval
- Enabled streaming chat responses for lower latency
- code_chatbot/graph_rag.py +5 -1
- code_chatbot/rag.py +94 -63
code_chatbot/graph_rag.py
CHANGED
|
@@ -53,7 +53,11 @@ class GraphEnhancedRetriever(BaseRetriever):
|
|
| 53 |
if any(file_path.endswith(ext) for ext in config_extensions):
|
| 54 |
return 50
|
| 55 |
|
| 56 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
text_extensions = [".txt", ".md", ".rst"]
|
| 58 |
if any(file_path.endswith(ext) for ext in text_extensions):
|
| 59 |
return 30
|
|
|
|
| 53 |
if any(file_path.endswith(ext) for ext in config_extensions):
|
| 54 |
return 50
|
| 55 |
|
| 56 |
+
# Low priority: Text/doc files (often too generic)
|
| 57 |
+
# EXCEPTION: README files are critical for context
|
| 58 |
+
if "readme" in file_path.lower():
|
| 59 |
+
return 90
|
| 60 |
+
|
| 61 |
text_extensions = [".txt", ".md", ".rst"]
|
| 62 |
if any(file_path.endswith(ext) for ext in text_extensions):
|
| 63 |
return 30
|
code_chatbot/rag.py
CHANGED
|
@@ -363,75 +363,106 @@ class ChatEngine:
|
|
| 363 |
return f"Error: {str(e)}", []
|
| 364 |
|
| 365 |
def _linear_chat(self, question: str) -> Tuple[str, List[dict]]:
|
| 366 |
-
|
| 367 |
-
"""
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
try:
|
| 374 |
-
# Contextualize query based on history
|
| 375 |
-
contextualized_query = self._contextualize_query(question, self.chat_history)
|
| 376 |
-
|
| 377 |
-
# Retrieve relevant documents
|
| 378 |
-
docs = self.retriever.invoke(contextualized_query)
|
| 379 |
-
logger.info(f"Retrieved {len(docs)} documents")
|
| 380 |
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
for doc in docs[:5] # Limit to top 5 docs
|
| 388 |
-
])
|
| 389 |
-
|
| 390 |
-
# Extract sources
|
| 391 |
-
sources = []
|
| 392 |
-
for doc in docs[:5]:
|
| 393 |
-
file_path = doc.metadata.get("file_path") or doc.metadata.get("source", "unknown")
|
| 394 |
-
sources.append({
|
| 395 |
-
"file_path": file_path,
|
| 396 |
-
"url": doc.metadata.get("url", f"file://{file_path}"),
|
| 397 |
-
})
|
| 398 |
-
|
| 399 |
-
# Build prompt with history - use provider-specific prompt
|
| 400 |
-
from code_chatbot.prompts import get_prompt_for_provider
|
| 401 |
-
base_prompt = get_prompt_for_provider("linear_rag", self.provider)
|
| 402 |
-
qa_system_prompt = base_prompt.format(
|
| 403 |
-
repo_name=self.repo_name,
|
| 404 |
-
context=context_text
|
| 405 |
-
)
|
| 406 |
-
|
| 407 |
-
# Build messages with history
|
| 408 |
-
messages = [SystemMessage(content=qa_system_prompt)]
|
| 409 |
-
|
| 410 |
-
# Add chat history
|
| 411 |
-
for msg in self.chat_history[-10:]: # Last 10 messages for context
|
| 412 |
-
messages.append(msg)
|
| 413 |
-
|
| 414 |
-
# Add current question
|
| 415 |
-
messages.append(HumanMessage(content=question))
|
| 416 |
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
|
| 425 |
-
#
|
| 426 |
-
|
| 427 |
-
self.chat_history = self.chat_history[-20:]
|
| 428 |
|
| 429 |
-
|
| 430 |
|
| 431 |
-
except Exception as e:
|
| 432 |
-
logger.error(f"Error during chat: {e}", exc_info=True)
|
| 433 |
-
return f"Error: {str(e)}", []
|
| 434 |
-
|
| 435 |
def clear_memory(self):
|
| 436 |
"""Clear the conversation history."""
|
| 437 |
self.chat_history.clear()
|
|
|
|
| 363 |
return f"Error: {str(e)}", []
|
| 364 |
|
| 365 |
def _linear_chat(self, question: str) -> Tuple[str, List[dict]]:
|
| 366 |
+
def _prepare_chat_context(self, question: str):
|
| 367 |
+
"""Prepare messages and sources for chat/stream."""
|
| 368 |
+
# 1. Retrieve relevant documents
|
| 369 |
+
query_for_retrieval = question
|
| 370 |
+
if len(question) < 5 and len(self.chat_history) > 0:
|
| 371 |
+
# Enhance short queries with history
|
| 372 |
+
query_for_retrieval = f"{self.chat_history[-1].content} {question}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
+
# Increase retrieval limit to 30 docs since Gemini has large context
|
| 375 |
+
docs = self.retriever.get_relevant_documents(query_for_retrieval)
|
| 376 |
+
|
| 377 |
+
if not docs:
|
| 378 |
+
# Return empty context if no docs found
|
| 379 |
+
return None, [], ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
+
# Build context from documents - Use FULL content, not truncated
|
| 382 |
+
# Gemini 1.5/2.0 can handle 1M+ tokens, so we should provide as much context as possible.
|
| 383 |
+
context_parts = []
|
| 384 |
+
for doc in docs[:30]: # Use top 30 documents
|
| 385 |
+
file_path = doc.metadata.get('file_path', 'unknown')
|
| 386 |
+
content = doc.page_content
|
| 387 |
+
context_parts.append(f"File: {file_path}\nWait, content:\n{content}\n---")
|
| 388 |
|
| 389 |
+
context_text = "\n\n".join(context_parts)
|
| 390 |
+
|
| 391 |
+
# Extract sources
|
| 392 |
+
sources = []
|
| 393 |
+
for doc in docs[:30]:
|
| 394 |
+
file_path = doc.metadata.get("file_path") or doc.metadata.get("source", "unknown")
|
| 395 |
+
sources.append({
|
| 396 |
+
"file_path": file_path,
|
| 397 |
+
"url": doc.metadata.get("url", f"file://{file_path}"),
|
| 398 |
+
})
|
| 399 |
+
|
| 400 |
+
# Build prompt with history - use provider-specific prompt
|
| 401 |
+
from code_chatbot.prompts import get_prompt_for_provider
|
| 402 |
+
base_prompt = get_prompt_for_provider("linear_rag", self.provider)
|
| 403 |
+
qa_system_prompt = base_prompt.format(
|
| 404 |
+
repo_name=self.repo_name,
|
| 405 |
+
context=context_text
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Build messages with history
|
| 409 |
+
messages = [SystemMessage(content=qa_system_prompt)]
|
| 410 |
+
|
| 411 |
+
# Add chat history
|
| 412 |
+
for msg in self.chat_history[-10:]: # Last 10 messages for context
|
| 413 |
+
messages.append(msg)
|
| 414 |
+
|
| 415 |
+
# Add current question
|
| 416 |
+
messages.append(HumanMessage(content=question))
|
| 417 |
+
|
| 418 |
+
return messages, sources, context_text
|
| 419 |
+
|
| 420 |
+
def chat(self, question: str) -> tuple[str, list]:
|
| 421 |
+
"""Blocking chat method."""
|
| 422 |
+
messages, sources, _ = self._prepare_chat_context(question)
|
| 423 |
+
|
| 424 |
+
if not messages:
|
| 425 |
+
return "I don't have any information about this codebase. Please make sure the codebase has been indexed properly.", []
|
| 426 |
+
|
| 427 |
+
# Get response from LLM
|
| 428 |
+
response_msg = self.llm.invoke(messages)
|
| 429 |
+
answer = response_msg.content
|
| 430 |
+
|
| 431 |
+
# Update chat history
|
| 432 |
+
self.chat_history.append(HumanMessage(content=question))
|
| 433 |
+
self.chat_history.append(AIMessage(content=answer))
|
| 434 |
+
|
| 435 |
+
# Keep history manageable (last 20 messages)
|
| 436 |
+
if len(self.chat_history) > 20:
|
| 437 |
+
self.chat_history = self.chat_history[-20:]
|
| 438 |
+
|
| 439 |
+
return answer, sources
|
| 440 |
+
|
| 441 |
+
def stream_chat(self, question: str):
|
| 442 |
+
"""Streaming chat method returning (generator, sources)."""
|
| 443 |
+
messages, sources, _ = self._prepare_chat_context(question)
|
| 444 |
+
|
| 445 |
+
if not messages:
|
| 446 |
+
def empty_gen(): yield "I don't have any information about this codebase."
|
| 447 |
+
return empty_gen(), []
|
| 448 |
+
|
| 449 |
+
# Update history with USER message immediately
|
| 450 |
+
self.chat_history.append(HumanMessage(content=question))
|
| 451 |
+
if len(self.chat_history) > 20: self.chat_history = self.chat_history[-20:]
|
| 452 |
+
|
| 453 |
+
# Generator wrapper to capture full response for history
|
| 454 |
+
def response_generator():
|
| 455 |
+
full_response = ""
|
| 456 |
+
for chunk in self.llm.stream(messages):
|
| 457 |
+
content = chunk.content
|
| 458 |
+
full_response += content
|
| 459 |
+
yield content
|
| 460 |
|
| 461 |
+
# Update history with AI message after generation
|
| 462 |
+
self.chat_history.append(AIMessage(content=full_response))
|
|
|
|
| 463 |
|
| 464 |
+
return response_generator(), sources
|
| 465 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
def clear_memory(self):
|
| 467 |
"""Clear the conversation history."""
|
| 468 |
self.chat_history.clear()
|