""" chain.py Calls the LLM via HF Inference API with a strict RAG prompt. Only answers from the retrieved context — never from general knowledge. Upgrades vs original: • answer_stream() — yields token-by-token for real-time Gradio streaming • tenacity retry (3 attempts, exponential back-off) on transient API errors • Hard input length guard (query ≤ 2000 chars, history capped at 6 messages) """ from __future__ import annotations import os from typing import Generator from huggingface_hub import InferenceClient from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type SYSTEM_PROMPT = """You are an enterprise document assistant. Your ONLY job is to answer questions using the provided document context below. STRICT RULES: 1. Answer ONLY using information explicitly found in the provided context. 2. Do NOT use any outside knowledge or assumptions. 3. If the answer is not found in the context, respond EXACTLY with: "I don't have that information in the uploaded documents." 4. Always cite the source document name(s) in your answer using [Source: ]. 5. Be concise and professional. Context from uploaded documents: --- {context} --- """ LLM_MODEL = os.environ.get("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct") MAX_NEW_TOKENS = 1024 TEMPERATURE = 0.1 # Low temperature for factual, grounded responses MAX_QUERY_CHARS = 2000 MAX_HISTORY_TURNS = 6 # Keep last N messages (each turn = 1 user + 1 assistant) def build_context(chunks: list[dict]) -> str: """Format retrieved chunks into a readable context block.""" parts = [] for i, chunk in enumerate(chunks, 1): parts.append(f"[{i}] (Source: {chunk['source']})\n{chunk['text']}") return "\n\n".join(parts) def _build_messages(query: str, context_chunks: list[dict], chat_history: list[dict] | None) -> list[dict]: """Assemble the full message list for the LLM call.""" context = build_context(context_chunks) system_msg = SYSTEM_PROMPT.format(context=context) messages: list[dict] = [{"role": "system", "content": system_msg}] if chat_history: # Cap history to avoid overflow for msg in chat_history[-MAX_HISTORY_TURNS:]: if msg.get("role") in ("user", "assistant") and msg.get("content"): messages.append({"role": msg["role"], "content": msg["content"]}) # Guard: truncate excessively long queries query = query[:MAX_QUERY_CHARS] messages.append({"role": "user", "content": query}) return messages @retry( stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=2, max=10), retry=retry_if_exception_type(Exception), reraise=True, ) def _open_stream(client: InferenceClient, messages: list[dict]): """ Open a streaming connection to the LLM. The @retry decorator governs ONLY this connection phase (handshake / auth / transient 5xx). Mid-stream token errors are handled separately in answer_stream(). """ return client.chat_completion( model=LLM_MODEL, messages=messages, max_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE, stream=True, ) def answer_stream( query: str, context_chunks: list[dict], hf_token: str, chat_history: list[dict] | None = None, ) -> Generator[str, None, None]: """ Stream the LLM answer token-by-token. Yields the progressively-growing reply string so Gradio can update in real time. Error handling: • Connection failures → retried up to 3× before yielding an error message. • Mid-stream failures → partial response is preserved; error notice appended. """ if not context_chunks: yield "I don't have that information in the uploaded documents." return messages = _build_messages(query, context_chunks, chat_history) client = InferenceClient(token=hf_token) # Phase 1: open stream (retried automatically by _open_stream) try: stream = _open_stream(client, messages) except Exception as e: yield f"❌ Could not reach the LLM after 3 attempts: {e}" return # Phase 2: consume the stream token-by-token accumulated = "" try: for chunk in stream: delta = chunk.choices[0].delta.content if delta: accumulated += delta yield accumulated except Exception as e: # Surface whatever was streamed so far alongside the error. error_notice = f"\n\n⚠️ *Streaming interrupted: {e}*" yield (accumulated + error_notice) if accumulated else f"❌ Streaming error: {e}"