Clone_Lm / backend /rag.py
skumar54's picture
NotebookLM clone: Gradio app, backend, Gemini artifacts
9c9ce67
"""
RAG: retrieve context, build prompt, call LLM, format response with citations.
Tracks retrieval_time and generation_time for UI.
"""
import time
from typing import Any, Dict, List, Optional, Tuple
from backend.config import HF_LLM_MODEL, HF_TOKEN, TOP_K
from backend.retriever import retrieve
from backend.utils import logger
def _build_context(chunks: List[Dict[str, Any]]) -> str:
lines = []
for i, c in enumerate(chunks, 1):
doc = c.get("document", "")
meta = c.get("metadata", {})
name = meta.get("source_name", "Source")
page = meta.get("page_or_slide", "")
ref = f"[{i}] {name}"
if page:
ref += f" (p.{page})"
lines.append(f"{ref}:\n{doc}")
return "\n\n---\n\n".join(lines)
def _citations_block(chunks: List[Dict[str, Any]]) -> str:
lines = ["Citations:"]
for i, c in enumerate(chunks, 1):
meta = c.get("metadata", {})
name = meta.get("source_name", "Source")
page = meta.get("page_or_slide", "")
if page:
lines.append(f"[{i}] {name} (p.{page})")
else:
lines.append(f"[{i}] {name}")
return "\n".join(lines)
def call_llm(prompt: str, system: Optional[str] = None) -> Tuple[str, float]:
"""Public wrapper for LLM call. Returns (response_text, generation_time)."""
return _call_llm(prompt, system)
def _call_llm(prompt: str, system: Optional[str] = None) -> Tuple[str, float]:
"""Call HF Inference API or local fallback. Returns (response_text, generation_time)."""
t0 = time.perf_counter()
if HF_TOKEN:
try:
from huggingface_hub import InferenceClient
client = InferenceClient(token=HF_TOKEN)
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
out = client.chat.completions.create(
model=HF_LLM_MODEL,
messages=messages,
max_tokens=1024,
)
text = (out.choices[0].message.content or "").strip()
return text, time.perf_counter() - t0
except Exception as e:
logger.warning("HF chat API failed: %s", e)
# Local fallback: minimal
try:
from transformers import pipeline
pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=256)
out = pipe(prompt[:2000])
text = (out[0].get("generated_text") or "").strip()
return text or "(No response from local model)", time.perf_counter() - t0
except Exception as e:
logger.warning("Local LLM fallback failed: %s", e)
return (
"I couldn't generate a response. Set HF_TOKEN for Hugging Face Inference API, or install transformers + a small model for local use.",
time.perf_counter() - t0,
)
def answer(
username: str,
notebook_id: str,
query: str,
strategy: str = "similarity",
top_k: int = TOP_K,
) -> Tuple[str, List[Dict[str, Any]], float, float]:
"""
RAG answer. Returns (answer_text, citations_list, retrieval_time, generation_time).
citations_list: list of {document, metadata, id} for UI.
"""
chunks, retrieval_time = retrieve(username, notebook_id, query, top_k=top_k, strategy=strategy)
if not chunks:
return (
"I don't have any sources in this notebook yet. Add PDFs, slides, or URLs and try again.",
[],
retrieval_time,
0.0,
)
context = _build_context(chunks)
system = (
"You are a helpful assistant. Answer based only on the provided context. "
"When you use information from the context, cite it with the corresponding number in brackets, e.g. [1]. "
"At the end of your response, list Citations: with each [N] source name (page/slide if available)."
)
prompt = f"""Context:\n{context}\n\nQuestion: {query}\n\nAnswer (with citations):"""
answer_text, generation_time = _call_llm(prompt, system=system)
# Ensure citations block if model didn't add it
if "Citations:" not in answer_text and chunks:
answer_text = answer_text.rstrip() + "\n\n" + _citations_block(chunks)
return answer_text, chunks, retrieval_time, generation_time