File size: 10,041 Bytes
6c21523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
"""
RAG engine: retrieves relevant context from vector store,
builds a strict prompt, and queries the LLM.

Backend priority:
1. GROQ_API_KEY set β†’ Groq API (fast, 100K tokens/day free tier)
2. USE_HF_LLM=1 + HF_TOKEN set β†’ HuggingFace Inference API
3. Otherwise β†’ Ollama (local)
"""
import os
import logging
from typing import List, Dict, Any, Generator

from utils.vector_store import VectorStoreManager
from utils.memory import ConversationMemory

logger = logging.getLogger(__name__)

DEFAULT_OLLAMA_MODEL = "llama3.2"
DEFAULT_GROQ_MODEL   = "llama-3.3-70b-versatile"
DEFAULT_HF_MODEL     = "meta-llama/Llama-3.1-8B-Instruct"

HF_TOKEN     = os.environ.get("HF_TOKEN") or os.environ.get("MultiModalRag_Token", "")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
USE_HF_LLM   = os.environ.get("USE_HF_LLM", "").lower() in ("1", "true", "yes")

# Pick backend: Groq first (fast), then HF Inference (only if USE_HF_LLM=1), then Ollama
if GROQ_API_KEY:
    BACKEND = "groq"
elif USE_HF_LLM and HF_TOKEN:
    BACKEND = "hf"
else:
    BACKEND = "ollama"

OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://localhost:11434")

SYSTEM_PROMPT = """You are a document assistant. Answer questions using ONLY the [CONTEXT] provided.
Rules:
1. If the context contains information relevant to the question, answer from it β€” even if only partially relevant.
2. Combine information from multiple context chunks if needed.
3. Only say "I DON'T KNOW" if the context truly contains NO relevant information at all.
4. Be concise. Cite source and page when available.
5. Do NOT make up information that is not in the context.
6. When answering questions about tables or structured data, apply ALL filter conditions from the question. Only include rows that match every condition β€” do not display or reference rows that do not match.
7. Give ONLY the final answer. Do NOT show reasoning steps, intermediate calculations, excluded rows, or any explanation of how you arrived at the answer.
"""

GENERAL_PROMPT = """You are a helpful AI assistant. Answer directly and concisely β€” final answer only, no reasoning steps.
If you don't know the answer, say so honestly.
"""

# Cosine distance threshold (0=identical, 2=opposite).
RELEVANCE_THRESHOLD = 1.2


def _make_hf_client():
    from huggingface_hub import InferenceClient
    return InferenceClient(token=HF_TOKEN)


def _make_groq_client():
    from groq import Groq
    return Groq(api_key=GROQ_API_KEY)


def _make_ollama_client():
    import ollama
    return ollama.Client(host=OLLAMA_HOST)


def _is_rate_limit(exc: Exception) -> bool:
    msg = str(exc).lower()
    return "429" in msg or "rate_limit" in msg or "rate limit" in msg


class RAGEngine:
    def __init__(self, vector_store: VectorStoreManager, model: str = None):
        self.vs = vector_store
        if BACKEND == "hf":
            self.model = os.environ.get("HF_MODEL", DEFAULT_HF_MODEL)
            self._client = _make_hf_client()
            logger.info(f"LLM backend: HuggingFace Inference ({self.model})")
        elif BACKEND == "groq":
            self.model = os.environ.get("GROQ_MODEL", DEFAULT_GROQ_MODEL)
            self._client = _make_groq_client()
            logger.info(f"LLM backend: Groq ({self.model})")
        else:
            self.model = model or os.environ.get("OLLAMA_MODEL", DEFAULT_OLLAMA_MODEL)
            self._client = _make_ollama_client()
            logger.info(f"LLM backend: Ollama ({self.model})")

    def _build_context(self, results: List[Dict[str, Any]]) -> str:
        if not results:
            return "No relevant documents found."
        parts = []
        for i, r in enumerate(results, 1):
            meta = r["metadata"]
            source   = meta.get("source", "unknown")
            page     = meta.get("page", "")
            doc_type = meta.get("type", "text")
            page_str = f", Page {page}" if page else ""
            type_str = f" [{doc_type}]" if doc_type != "text" else ""
            parts.append(f"[Doc {i} β€” {source}{page_str}{type_str}]\n{r['text']}")
        return "\n\n---\n\n".join(parts)

    def _build_messages(self, question: str, context: str, memory: ConversationMemory):
        user_message = (
            f"[CONTEXT]\n{context}\n\n"
            f"[QUESTION]\n{question}\n\n"
            "Remember: Answer ONLY from the context above. If not found, say \"I DON'T KNOW\"."
        )
        return [
            {"role": "system", "content": SYSTEM_PROMPT},
            *memory.get_history_for_prompt(),
            {"role": "user", "content": user_message},
        ]

    def _is_off_topic(self, results: List[Dict[str, Any]]) -> bool:
        if not results:
            return True
        return all(r.get("distance", 1.0) > RELEVANCE_THRESHOLD for r in results)

    def _build_general_messages(self, question: str, memory: ConversationMemory):
        return [
            {"role": "system", "content": GENERAL_PROMPT},
            *memory.get_history_for_prompt(),
            {"role": "user", "content": question},
        ]

    def query(
        self,
        question: str,
        memory: ConversationMemory,
        n_results: int = 8,
        temperature: float = 0.0,
        stream: bool = False,
        source_filter: list = None,
        pre_fetched_results: list = None,
    ) -> Generator[str, None, None]:
        results = pre_fetched_results if pre_fetched_results is not None else self.vs.query(question, n_results=n_results, source_filter=source_filter)

        if self._is_off_topic(results):
            logger.info(f"Off-topic query (no relevant chunks): '{question[:60]}'")
            messages = self._build_general_messages(question, memory)
        else:
            context  = self._build_context(results)
            messages = self._build_messages(question, context, memory)

        try:
            if BACKEND == "hf":
                yield from self._query_hf(messages, memory, question, temperature, stream)
            elif BACKEND == "groq":
                yield from self._query_groq(messages, memory, question, temperature, stream)
            else:
                yield from self._query_ollama(messages, memory, question, temperature, stream)
        except Exception as e:
            error_msg = f"Error: {str(e)}"
            logger.error(error_msg, exc_info=True)
            yield error_msg

    def _query_hf(self, messages, memory, question, temperature, stream):
        # HuggingFace Inference API β€” OpenAI-compatible chat completions
        resp = self._client.chat_completion(
            model=self.model,
            messages=messages,
            temperature=max(temperature, 0.01),
            max_tokens=2048,
        )
        answer = resp.choices[0].message.content
        memory.add("user", question)
        memory.add("assistant", answer)
        yield answer

    def _hf_fallback(self, messages, memory, question, temperature):
        from huggingface_hub import InferenceClient
        model = os.environ.get("HF_MODEL", DEFAULT_HF_MODEL)
        client = InferenceClient(token=HF_TOKEN)
        resp = client.chat_completion(
            model=model,
            messages=messages,
            temperature=max(temperature, 0.01),
            max_tokens=2048,
        )
        answer = resp.choices[0].message.content
        memory.add("user", question)
        memory.add("assistant", answer)
        yield answer

    def _query_groq(self, messages, memory, question, temperature, stream):
        try:
            if stream:
                response_text = ""
                with self._client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    temperature=temperature,
                    stream=True,
                ) as stream_resp:
                    for chunk in stream_resp:
                        token = chunk.choices[0].delta.content or ""
                        response_text += token
                        yield token
                memory.add("user", question)
                memory.add("assistant", response_text)
            else:
                resp = self._client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    temperature=temperature,
                )
                answer = resp.choices[0].message.content
                memory.add("user", question)
                memory.add("assistant", answer)
                yield answer
        except Exception as e:
            if _is_rate_limit(e):
                if HF_TOKEN:
                    logger.warning("Groq rate limit reached β€” falling back to HF Inference")
                    yield from self._hf_fallback(messages, memory, question, temperature)
                else:
                    yield "⚠️ Groq daily token limit reached (100K/day free tier). Please try again in a few hours, or upgrade at https://console.groq.com/settings/billing"
            else:
                raise

    def _query_ollama(self, messages, memory, question, temperature, stream):
        if stream:
            response_text = ""
            stream_resp = self._client.chat(
                model=self.model,
                messages=messages,
                stream=True,
                options={"temperature": temperature},
            )
            for chunk in stream_resp:
                token = chunk["message"]["content"]
                response_text += token
                yield token
            memory.add("user", question)
            memory.add("assistant", response_text)
        else:
            response = self._client.chat(
                model=self.model,
                messages=messages,
                options={"temperature": temperature},
            )
            answer = response["message"]["content"]
            memory.add("user", question)
            memory.add("assistant", answer)
            yield answer

    def list_available_models(self) -> List[str]:
        return [self.model]