NotebookLM-Clone / backend /rag_service.py
rahulrb99
Add citations, retrieval improvements, chat isolation
a524c65
"""RAG chat service - retrieve chunks, call LLM, persist messages."""
import re
from backend.chat_service import save_message, load_chat
from backend.llm_client import DEFAULT_MODEL, get_llm_client
from backend.retrieval_service import retrieve_chunks
MAX_HISTORY_MESSAGES = 20
# Together AI - you have recent usage. Or :groq for Groq.
TOP_K = 5
def _validate_citations(text: str, num_chunks: int) -> str:
"""Strip or fix citation numbers [N] where N > num_chunks."""
if num_chunks <= 0:
return text
def replace_citation(match):
n = int(match.group(1))
if 1 <= n <= num_chunks:
return match.group(0)
return ""
return re.sub(r"\[(\d+)\]", replace_citation, text)
def rag_chat(notebook_id: str, query: str, chat_history: list, user_id: str | None = None) -> tuple[str, list, list[dict]]:
"""
RAG chat: retrieve chunks, build prompt, call LLM, persist, return answer and updated history.
chat_history: list of [user_msg, assistant_msg] pairs (Gradio Chatbot format).
user_id: for ownership validation; messages are only saved if notebook belongs to user.
Returns: (assistant_reply, updated_history, chunks).
chunks: list of dicts with id, content, metadata, similarity for citation display.
"""
save_message(notebook_id, user_id, "user", query)
chunks = retrieve_chunks(notebook_id, query, top_k=TOP_K)
context_parts = []
for i, c in enumerate(chunks, 1):
context_parts.append(f"[{i}] {c['content']}")
context = "\n\n".join(context_parts) if context_parts else "(No relevant sources found.)"
system_content = (
"You are a helpful assistant. Answer ONLY from the provided context. "
"Cite sources using [1], [2], etc. corresponding to the numbered passages. "
"If the answer is not in the context, say so clearly.\n\n"
f"Context:\n{context}"
)
# Truncate history to last MAX_HISTORY_MESSAGES (pairs -> 2*N messages)
max_pairs = MAX_HISTORY_MESSAGES // 2
truncated = chat_history[-max_pairs:] if len(chat_history) > max_pairs else chat_history
messages = [{"role": "system", "content": system_content}]
for user_msg, asst_msg in truncated:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if asst_msg:
messages.append({"role": "assistant", "content": asst_msg})
messages.append({"role": "user", "content": query})
try:
client = get_llm_client()
response = client.chat.completions.create(
model=DEFAULT_MODEL,
messages=messages,
max_tokens=512,
)
raw_answer = response.choices[0].message.content or ""
answer = _validate_citations(raw_answer, len(chunks))
except Exception as e:
answer = f"Error calling model: {e}"
save_message(notebook_id, user_id, "assistant", answer)
updated_history = chat_history + [[query, answer]]
return answer, updated_history, chunks