Spaces:
Running
Running
Bhaskar Ram
fix: sentence-aware chunking, score threshold, DOCX tables, streaming error handling, LLM_MODEL env var
2623b17 | """ | |
| chain.py | |
| Calls the LLM via HF Inference API with a strict RAG prompt. | |
| Only answers from the retrieved context — never from general knowledge. | |
| Upgrades vs original: | |
| • answer_stream() — yields token-by-token for real-time Gradio streaming | |
| • tenacity retry (3 attempts, exponential back-off) on transient API errors | |
| • Hard input length guard (query ≤ 2000 chars, history capped at 6 messages) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from typing import Generator | |
| from huggingface_hub import InferenceClient | |
| from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type | |
| SYSTEM_PROMPT = """You are an enterprise document assistant. Your ONLY job is to answer questions using the provided document context below. | |
| STRICT RULES: | |
| 1. Answer ONLY using information explicitly found in the provided context. | |
| 2. Do NOT use any outside knowledge or assumptions. | |
| 3. If the answer is not found in the context, respond EXACTLY with: "I don't have that information in the uploaded documents." | |
| 4. Always cite the source document name(s) in your answer using [Source: <filename>]. | |
| 5. Be concise and professional. | |
| Context from uploaded documents: | |
| --- | |
| {context} | |
| --- | |
| """ | |
| LLM_MODEL = os.environ.get("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct") | |
| MAX_NEW_TOKENS = 1024 | |
| TEMPERATURE = 0.1 # Low temperature for factual, grounded responses | |
| MAX_QUERY_CHARS = 2000 | |
| MAX_HISTORY_TURNS = 6 # Keep last N messages (each turn = 1 user + 1 assistant) | |
| def build_context(chunks: list[dict]) -> str: | |
| """Format retrieved chunks into a readable context block.""" | |
| parts = [] | |
| for i, chunk in enumerate(chunks, 1): | |
| parts.append(f"[{i}] (Source: {chunk['source']})\n{chunk['text']}") | |
| return "\n\n".join(parts) | |
| def _build_messages(query: str, context_chunks: list[dict], chat_history: list[dict] | None) -> list[dict]: | |
| """Assemble the full message list for the LLM call.""" | |
| context = build_context(context_chunks) | |
| system_msg = SYSTEM_PROMPT.format(context=context) | |
| messages: list[dict] = [{"role": "system", "content": system_msg}] | |
| if chat_history: | |
| # Cap history to avoid overflow | |
| for msg in chat_history[-MAX_HISTORY_TURNS:]: | |
| if msg.get("role") in ("user", "assistant") and msg.get("content"): | |
| messages.append({"role": msg["role"], "content": msg["content"]}) | |
| # Guard: truncate excessively long queries | |
| query = query[:MAX_QUERY_CHARS] | |
| messages.append({"role": "user", "content": query}) | |
| return messages | |
| def _open_stream(client: InferenceClient, messages: list[dict]): | |
| """ | |
| Open a streaming connection to the LLM. | |
| The @retry decorator governs ONLY this connection phase (handshake / auth / | |
| transient 5xx). Mid-stream token errors are handled separately in answer_stream(). | |
| """ | |
| return client.chat_completion( | |
| model=LLM_MODEL, | |
| messages=messages, | |
| max_tokens=MAX_NEW_TOKENS, | |
| temperature=TEMPERATURE, | |
| stream=True, | |
| ) | |
| def answer_stream( | |
| query: str, | |
| context_chunks: list[dict], | |
| hf_token: str, | |
| chat_history: list[dict] | None = None, | |
| ) -> Generator[str, None, None]: | |
| """ | |
| Stream the LLM answer token-by-token. | |
| Yields the progressively-growing reply string so Gradio can update in real time. | |
| Error handling: | |
| • Connection failures → retried up to 3× before yielding an error message. | |
| • Mid-stream failures → partial response is preserved; error notice appended. | |
| """ | |
| if not context_chunks: | |
| yield "I don't have that information in the uploaded documents." | |
| return | |
| messages = _build_messages(query, context_chunks, chat_history) | |
| client = InferenceClient(token=hf_token) | |
| # Phase 1: open stream (retried automatically by _open_stream) | |
| try: | |
| stream = _open_stream(client, messages) | |
| except Exception as e: | |
| yield f"❌ Could not reach the LLM after 3 attempts: {e}" | |
| return | |
| # Phase 2: consume the stream token-by-token | |
| accumulated = "" | |
| try: | |
| for chunk in stream: | |
| delta = chunk.choices[0].delta.content | |
| if delta: | |
| accumulated += delta | |
| yield accumulated | |
| except Exception as e: | |
| # Surface whatever was streamed so far alongside the error. | |
| error_notice = f"\n\n⚠️ *Streaming interrupted: {e}*" | |
| yield (accumulated + error_notice) if accumulated else f"❌ Streaming error: {e}" | |