""" RAG: retrieve context, build prompt, call LLM, format response with citations. Tracks retrieval_time and generation_time for UI. """ import time from typing import Any, Dict, List, Optional, Tuple from backend.config import HF_LLM_MODEL, HF_TOKEN, TOP_K from backend.retriever import retrieve from backend.utils import logger def _build_context(chunks: List[Dict[str, Any]]) -> str: lines = [] for i, c in enumerate(chunks, 1): doc = c.get("document", "") meta = c.get("metadata", {}) name = meta.get("source_name", "Source") page = meta.get("page_or_slide", "") ref = f"[{i}] {name}" if page: ref += f" (p.{page})" lines.append(f"{ref}:\n{doc}") return "\n\n---\n\n".join(lines) def _citations_block(chunks: List[Dict[str, Any]]) -> str: lines = ["Citations:"] for i, c in enumerate(chunks, 1): meta = c.get("metadata", {}) name = meta.get("source_name", "Source") page = meta.get("page_or_slide", "") if page: lines.append(f"[{i}] {name} (p.{page})") else: lines.append(f"[{i}] {name}") return "\n".join(lines) def call_llm(prompt: str, system: Optional[str] = None) -> Tuple[str, float]: """Public wrapper for LLM call. Returns (response_text, generation_time).""" return _call_llm(prompt, system) def _call_llm(prompt: str, system: Optional[str] = None) -> Tuple[str, float]: """Call HF Inference API or local fallback. Returns (response_text, generation_time).""" t0 = time.perf_counter() if HF_TOKEN: try: from huggingface_hub import InferenceClient client = InferenceClient(token=HF_TOKEN) messages = [] if system: messages.append({"role": "system", "content": system}) messages.append({"role": "user", "content": prompt}) out = client.chat.completions.create( model=HF_LLM_MODEL, messages=messages, max_tokens=1024, ) text = (out.choices[0].message.content or "").strip() return text, time.perf_counter() - t0 except Exception as e: logger.warning("HF chat API failed: %s", e) # Local fallback: minimal try: from transformers import pipeline pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=256) out = pipe(prompt[:2000]) text = (out[0].get("generated_text") or "").strip() return text or "(No response from local model)", time.perf_counter() - t0 except Exception as e: logger.warning("Local LLM fallback failed: %s", e) return ( "I couldn't generate a response. Set HF_TOKEN for Hugging Face Inference API, or install transformers + a small model for local use.", time.perf_counter() - t0, ) def answer( username: str, notebook_id: str, query: str, strategy: str = "similarity", top_k: int = TOP_K, ) -> Tuple[str, List[Dict[str, Any]], float, float]: """ RAG answer. Returns (answer_text, citations_list, retrieval_time, generation_time). citations_list: list of {document, metadata, id} for UI. """ chunks, retrieval_time = retrieve(username, notebook_id, query, top_k=top_k, strategy=strategy) if not chunks: return ( "I don't have any sources in this notebook yet. Add PDFs, slides, or URLs and try again.", [], retrieval_time, 0.0, ) context = _build_context(chunks) system = ( "You are a helpful assistant. Answer based only on the provided context. " "When you use information from the context, cite it with the corresponding number in brackets, e.g. [1]. " "At the end of your response, list Citations: with each [N] source name (page/slide if available)." ) prompt = f"""Context:\n{context}\n\nQuestion: {query}\n\nAnswer (with citations):""" answer_text, generation_time = _call_llm(prompt, system=system) # Ensure citations block if model didn't add it if "Citations:" not in answer_text and chunks: answer_text = answer_text.rstrip() + "\n\n" + _citations_block(chunks) return answer_text, chunks, retrieval_time, generation_time