Custom-LLM-Chat / rag /chain.py
Bhaskar Ram
fix: sentence-aware chunking, score threshold, DOCX tables, streaming error handling, LLM_MODEL env var
2623b17
"""
chain.py
Calls the LLM via HF Inference API with a strict RAG prompt.
Only answers from the retrieved context — never from general knowledge.
Upgrades vs original:
• answer_stream() — yields token-by-token for real-time Gradio streaming
• tenacity retry (3 attempts, exponential back-off) on transient API errors
• Hard input length guard (query ≤ 2000 chars, history capped at 6 messages)
"""
from __future__ import annotations
import os
from typing import Generator
from huggingface_hub import InferenceClient
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
SYSTEM_PROMPT = """You are an enterprise document assistant. Your ONLY job is to answer questions using the provided document context below.
STRICT RULES:
1. Answer ONLY using information explicitly found in the provided context.
2. Do NOT use any outside knowledge or assumptions.
3. If the answer is not found in the context, respond EXACTLY with: "I don't have that information in the uploaded documents."
4. Always cite the source document name(s) in your answer using [Source: <filename>].
5. Be concise and professional.
Context from uploaded documents:
---
{context}
---
"""
LLM_MODEL = os.environ.get("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.1 # Low temperature for factual, grounded responses
MAX_QUERY_CHARS = 2000
MAX_HISTORY_TURNS = 6 # Keep last N messages (each turn = 1 user + 1 assistant)
def build_context(chunks: list[dict]) -> str:
"""Format retrieved chunks into a readable context block."""
parts = []
for i, chunk in enumerate(chunks, 1):
parts.append(f"[{i}] (Source: {chunk['source']})\n{chunk['text']}")
return "\n\n".join(parts)
def _build_messages(query: str, context_chunks: list[dict], chat_history: list[dict] | None) -> list[dict]:
"""Assemble the full message list for the LLM call."""
context = build_context(context_chunks)
system_msg = SYSTEM_PROMPT.format(context=context)
messages: list[dict] = [{"role": "system", "content": system_msg}]
if chat_history:
# Cap history to avoid overflow
for msg in chat_history[-MAX_HISTORY_TURNS:]:
if msg.get("role") in ("user", "assistant") and msg.get("content"):
messages.append({"role": msg["role"], "content": msg["content"]})
# Guard: truncate excessively long queries
query = query[:MAX_QUERY_CHARS]
messages.append({"role": "user", "content": query})
return messages
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type(Exception),
reraise=True,
)
def _open_stream(client: InferenceClient, messages: list[dict]):
"""
Open a streaming connection to the LLM.
The @retry decorator governs ONLY this connection phase (handshake / auth /
transient 5xx). Mid-stream token errors are handled separately in answer_stream().
"""
return client.chat_completion(
model=LLM_MODEL,
messages=messages,
max_tokens=MAX_NEW_TOKENS,
temperature=TEMPERATURE,
stream=True,
)
def answer_stream(
query: str,
context_chunks: list[dict],
hf_token: str,
chat_history: list[dict] | None = None,
) -> Generator[str, None, None]:
"""
Stream the LLM answer token-by-token.
Yields the progressively-growing reply string so Gradio can update in real time.
Error handling:
• Connection failures → retried up to 3× before yielding an error message.
• Mid-stream failures → partial response is preserved; error notice appended.
"""
if not context_chunks:
yield "I don't have that information in the uploaded documents."
return
messages = _build_messages(query, context_chunks, chat_history)
client = InferenceClient(token=hf_token)
# Phase 1: open stream (retried automatically by _open_stream)
try:
stream = _open_stream(client, messages)
except Exception as e:
yield f"❌ Could not reach the LLM after 3 attempts: {e}"
return
# Phase 2: consume the stream token-by-token
accumulated = ""
try:
for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
accumulated += delta
yield accumulated
except Exception as e:
# Surface whatever was streamed so far alongside the error.
error_notice = f"\n\n⚠️ *Streaming interrupted: {e}*"
yield (accumulated + error_notice) if accumulated else f"❌ Streaming error: {e}"