| """ |
| RAG: retrieve context, build prompt, call LLM, format response with citations. |
| Tracks retrieval_time and generation_time for UI. |
| """ |
| import time |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| from backend.config import HF_LLM_MODEL, HF_TOKEN, TOP_K |
| from backend.retriever import retrieve |
| from backend.utils import logger |
|
|
|
|
| def _build_context(chunks: List[Dict[str, Any]]) -> str: |
| lines = [] |
| for i, c in enumerate(chunks, 1): |
| doc = c.get("document", "") |
| meta = c.get("metadata", {}) |
| name = meta.get("source_name", "Source") |
| page = meta.get("page_or_slide", "") |
| ref = f"[{i}] {name}" |
| if page: |
| ref += f" (p.{page})" |
| lines.append(f"{ref}:\n{doc}") |
| return "\n\n---\n\n".join(lines) |
|
|
|
|
| def _citations_block(chunks: List[Dict[str, Any]]) -> str: |
| lines = ["Citations:"] |
| for i, c in enumerate(chunks, 1): |
| meta = c.get("metadata", {}) |
| name = meta.get("source_name", "Source") |
| page = meta.get("page_or_slide", "") |
| if page: |
| lines.append(f"[{i}] {name} (p.{page})") |
| else: |
| lines.append(f"[{i}] {name}") |
| return "\n".join(lines) |
|
|
|
|
| def call_llm(prompt: str, system: Optional[str] = None) -> Tuple[str, float]: |
| """Public wrapper for LLM call. Returns (response_text, generation_time).""" |
| return _call_llm(prompt, system) |
|
|
|
|
| def _call_llm(prompt: str, system: Optional[str] = None) -> Tuple[str, float]: |
| """Call HF Inference API or local fallback. Returns (response_text, generation_time).""" |
| t0 = time.perf_counter() |
| if HF_TOKEN: |
| try: |
| from huggingface_hub import InferenceClient |
| client = InferenceClient(token=HF_TOKEN) |
| messages = [] |
| if system: |
| messages.append({"role": "system", "content": system}) |
| messages.append({"role": "user", "content": prompt}) |
| out = client.chat.completions.create( |
| model=HF_LLM_MODEL, |
| messages=messages, |
| max_tokens=1024, |
| ) |
| text = (out.choices[0].message.content or "").strip() |
| return text, time.perf_counter() - t0 |
| except Exception as e: |
| logger.warning("HF chat API failed: %s", e) |
| |
| try: |
| from transformers import pipeline |
| pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=256) |
| out = pipe(prompt[:2000]) |
| text = (out[0].get("generated_text") or "").strip() |
| return text or "(No response from local model)", time.perf_counter() - t0 |
| except Exception as e: |
| logger.warning("Local LLM fallback failed: %s", e) |
| return ( |
| "I couldn't generate a response. Set HF_TOKEN for Hugging Face Inference API, or install transformers + a small model for local use.", |
| time.perf_counter() - t0, |
| ) |
|
|
|
|
| def answer( |
| username: str, |
| notebook_id: str, |
| query: str, |
| strategy: str = "similarity", |
| top_k: int = TOP_K, |
| ) -> Tuple[str, List[Dict[str, Any]], float, float]: |
| """ |
| RAG answer. Returns (answer_text, citations_list, retrieval_time, generation_time). |
| citations_list: list of {document, metadata, id} for UI. |
| """ |
| chunks, retrieval_time = retrieve(username, notebook_id, query, top_k=top_k, strategy=strategy) |
| if not chunks: |
| return ( |
| "I don't have any sources in this notebook yet. Add PDFs, slides, or URLs and try again.", |
| [], |
| retrieval_time, |
| 0.0, |
| ) |
| context = _build_context(chunks) |
| system = ( |
| "You are a helpful assistant. Answer based only on the provided context. " |
| "When you use information from the context, cite it with the corresponding number in brackets, e.g. [1]. " |
| "At the end of your response, list Citations: with each [N] source name (page/slide if available)." |
| ) |
| prompt = f"""Context:\n{context}\n\nQuestion: {query}\n\nAnswer (with citations):""" |
| answer_text, generation_time = _call_llm(prompt, system=system) |
| |
| if "Citations:" not in answer_text and chunks: |
| answer_text = answer_text.rstrip() + "\n\n" + _citations_block(chunks) |
| return answer_text, chunks, retrieval_time, generation_time |
|
|