File size: 4,644 Bytes
55953aa
 
 
 
a465955
 
 
 
 
55953aa
 
 
2623b17
a465955
 
55953aa
a465955
55953aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2623b17
55953aa
 
a465955
 
55953aa
 
 
 
 
 
 
 
 
 
a465955
 
55953aa
 
 
a465955
55953aa
a465955
 
e120e1f
 
a465955
 
 
55953aa
a465955
55953aa
a465955
 
 
 
 
 
 
2623b17
 
 
 
 
 
a465955
55953aa
 
 
 
a465955
55953aa
a465955
 
 
 
 
 
 
 
 
 
 
2623b17
 
 
 
a465955
 
 
 
 
 
 
 
2623b17
a465955
2623b17
a465955
2623b17
a465955
 
2623b17
a465955
2623b17
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
chain.py
Calls the LLM via HF Inference API with a strict RAG prompt.
Only answers from the retrieved context — never from general knowledge.

Upgrades vs original:
  • answer_stream() — yields token-by-token for real-time Gradio streaming
  • tenacity retry (3 attempts, exponential back-off) on transient API errors
  • Hard input length guard (query ≤ 2000 chars, history capped at 6 messages)
"""

from __future__ import annotations
import os
from typing import Generator

from huggingface_hub import InferenceClient
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type

SYSTEM_PROMPT = """You are an enterprise document assistant. Your ONLY job is to answer questions using the provided document context below.

STRICT RULES:
1. Answer ONLY using information explicitly found in the provided context.
2. Do NOT use any outside knowledge or assumptions.
3. If the answer is not found in the context, respond EXACTLY with: "I don't have that information in the uploaded documents."
4. Always cite the source document name(s) in your answer using [Source: <filename>].
5. Be concise and professional.

Context from uploaded documents:
---
{context}
---
"""

LLM_MODEL = os.environ.get("LLM_MODEL", "meta-llama/Llama-3.1-8B-Instruct")
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.1   # Low temperature for factual, grounded responses
MAX_QUERY_CHARS = 2000
MAX_HISTORY_TURNS = 6  # Keep last N messages (each turn = 1 user + 1 assistant)


def build_context(chunks: list[dict]) -> str:
    """Format retrieved chunks into a readable context block."""
    parts = []
    for i, chunk in enumerate(chunks, 1):
        parts.append(f"[{i}] (Source: {chunk['source']})\n{chunk['text']}")
    return "\n\n".join(parts)


def _build_messages(query: str, context_chunks: list[dict], chat_history: list[dict] | None) -> list[dict]:
    """Assemble the full message list for the LLM call."""
    context = build_context(context_chunks)
    system_msg = SYSTEM_PROMPT.format(context=context)

    messages: list[dict] = [{"role": "system", "content": system_msg}]
    if chat_history:
        # Cap history to avoid overflow
        for msg in chat_history[-MAX_HISTORY_TURNS:]:
            if msg.get("role") in ("user", "assistant") and msg.get("content"):
                messages.append({"role": msg["role"], "content": msg["content"]})

    # Guard: truncate excessively long queries
    query = query[:MAX_QUERY_CHARS]
    messages.append({"role": "user", "content": query})
    return messages


@retry(
    stop=stop_after_attempt(3),
    wait=wait_exponential(multiplier=1, min=2, max=10),
    retry=retry_if_exception_type(Exception),
    reraise=True,
)
def _open_stream(client: InferenceClient, messages: list[dict]):
    """
    Open a streaming connection to the LLM.
    The @retry decorator governs ONLY this connection phase (handshake / auth /
    transient 5xx).  Mid-stream token errors are handled separately in answer_stream().
    """
    return client.chat_completion(
        model=LLM_MODEL,
        messages=messages,
        max_tokens=MAX_NEW_TOKENS,
        temperature=TEMPERATURE,
        stream=True,
    )


def answer_stream(
    query: str,
    context_chunks: list[dict],
    hf_token: str,
    chat_history: list[dict] | None = None,
) -> Generator[str, None, None]:
    """
    Stream the LLM answer token-by-token.
    Yields the progressively-growing reply string so Gradio can update in real time.

    Error handling:
    • Connection failures → retried up to 3× before yielding an error message.
    • Mid-stream failures → partial response is preserved; error notice appended.
    """
    if not context_chunks:
        yield "I don't have that information in the uploaded documents."
        return

    messages = _build_messages(query, context_chunks, chat_history)
    client = InferenceClient(token=hf_token)

    # Phase 1: open stream (retried automatically by _open_stream)
    try:
        stream = _open_stream(client, messages)
    except Exception as e:
        yield f"❌ Could not reach the LLM after 3 attempts: {e}"
        return

    # Phase 2: consume the stream token-by-token
    accumulated = ""
    try:
        for chunk in stream:
            delta = chunk.choices[0].delta.content
            if delta:
                accumulated += delta
                yield accumulated
    except Exception as e:
        # Surface whatever was streamed so far alongside the error.
        error_notice = f"\n\n⚠️ *Streaming interrupted: {e}*"
        yield (accumulated + error_notice) if accumulated else f"❌ Streaming error: {e}"