MultiModalRag / utils /rag_engine.py
irajkoohi's picture
chore: update app [space deploy]
6c21523
Raw
History Blame Contribute Delete
10 kB
"""
RAG engine: retrieves relevant context from vector store,
builds a strict prompt, and queries the LLM.
Backend priority:
1. GROQ_API_KEY set → Groq API (fast, 100K tokens/day free tier)
2. USE_HF_LLM=1 + HF_TOKEN set → HuggingFace Inference API
3. Otherwise → Ollama (local)
"""
import os
import logging
from typing import List, Dict, Any, Generator
from utils.vector_store import VectorStoreManager
from utils.memory import ConversationMemory
logger = logging.getLogger(__name__)
DEFAULT_OLLAMA_MODEL = "llama3.2"
DEFAULT_GROQ_MODEL = "llama-3.3-70b-versatile"
DEFAULT_HF_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("MultiModalRag_Token", "")
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
USE_HF_LLM = os.environ.get("USE_HF_LLM", "").lower() in ("1", "true", "yes")
# Pick backend: Groq first (fast), then HF Inference (only if USE_HF_LLM=1), then Ollama
if GROQ_API_KEY:
BACKEND = "groq"
elif USE_HF_LLM and HF_TOKEN:
BACKEND = "hf"
else:
BACKEND = "ollama"
OLLAMA_HOST = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
SYSTEM_PROMPT = """You are a document assistant. Answer questions using ONLY the [CONTEXT] provided.
Rules:
1. If the context contains information relevant to the question, answer from it — even if only partially relevant.
2. Combine information from multiple context chunks if needed.
3. Only say "I DON'T KNOW" if the context truly contains NO relevant information at all.
4. Be concise. Cite source and page when available.
5. Do NOT make up information that is not in the context.
6. When answering questions about tables or structured data, apply ALL filter conditions from the question. Only include rows that match every condition — do not display or reference rows that do not match.
7. Give ONLY the final answer. Do NOT show reasoning steps, intermediate calculations, excluded rows, or any explanation of how you arrived at the answer.
"""
GENERAL_PROMPT = """You are a helpful AI assistant. Answer directly and concisely — final answer only, no reasoning steps.
If you don't know the answer, say so honestly.
"""
# Cosine distance threshold (0=identical, 2=opposite).
RELEVANCE_THRESHOLD = 1.2
def _make_hf_client():
from huggingface_hub import InferenceClient
return InferenceClient(token=HF_TOKEN)
def _make_groq_client():
from groq import Groq
return Groq(api_key=GROQ_API_KEY)
def _make_ollama_client():
import ollama
return ollama.Client(host=OLLAMA_HOST)
def _is_rate_limit(exc: Exception) -> bool:
msg = str(exc).lower()
return "429" in msg or "rate_limit" in msg or "rate limit" in msg
class RAGEngine:
def __init__(self, vector_store: VectorStoreManager, model: str = None):
self.vs = vector_store
if BACKEND == "hf":
self.model = os.environ.get("HF_MODEL", DEFAULT_HF_MODEL)
self._client = _make_hf_client()
logger.info(f"LLM backend: HuggingFace Inference ({self.model})")
elif BACKEND == "groq":
self.model = os.environ.get("GROQ_MODEL", DEFAULT_GROQ_MODEL)
self._client = _make_groq_client()
logger.info(f"LLM backend: Groq ({self.model})")
else:
self.model = model or os.environ.get("OLLAMA_MODEL", DEFAULT_OLLAMA_MODEL)
self._client = _make_ollama_client()
logger.info(f"LLM backend: Ollama ({self.model})")
def _build_context(self, results: List[Dict[str, Any]]) -> str:
if not results:
return "No relevant documents found."
parts = []
for i, r in enumerate(results, 1):
meta = r["metadata"]
source = meta.get("source", "unknown")
page = meta.get("page", "")
doc_type = meta.get("type", "text")
page_str = f", Page {page}" if page else ""
type_str = f" [{doc_type}]" if doc_type != "text" else ""
parts.append(f"[Doc {i}{source}{page_str}{type_str}]\n{r['text']}")
return "\n\n---\n\n".join(parts)
def _build_messages(self, question: str, context: str, memory: ConversationMemory):
user_message = (
f"[CONTEXT]\n{context}\n\n"
f"[QUESTION]\n{question}\n\n"
"Remember: Answer ONLY from the context above. If not found, say \"I DON'T KNOW\"."
)
return [
{"role": "system", "content": SYSTEM_PROMPT},
*memory.get_history_for_prompt(),
{"role": "user", "content": user_message},
]
def _is_off_topic(self, results: List[Dict[str, Any]]) -> bool:
if not results:
return True
return all(r.get("distance", 1.0) > RELEVANCE_THRESHOLD for r in results)
def _build_general_messages(self, question: str, memory: ConversationMemory):
return [
{"role": "system", "content": GENERAL_PROMPT},
*memory.get_history_for_prompt(),
{"role": "user", "content": question},
]
def query(
self,
question: str,
memory: ConversationMemory,
n_results: int = 8,
temperature: float = 0.0,
stream: bool = False,
source_filter: list = None,
pre_fetched_results: list = None,
) -> Generator[str, None, None]:
results = pre_fetched_results if pre_fetched_results is not None else self.vs.query(question, n_results=n_results, source_filter=source_filter)
if self._is_off_topic(results):
logger.info(f"Off-topic query (no relevant chunks): '{question[:60]}'")
messages = self._build_general_messages(question, memory)
else:
context = self._build_context(results)
messages = self._build_messages(question, context, memory)
try:
if BACKEND == "hf":
yield from self._query_hf(messages, memory, question, temperature, stream)
elif BACKEND == "groq":
yield from self._query_groq(messages, memory, question, temperature, stream)
else:
yield from self._query_ollama(messages, memory, question, temperature, stream)
except Exception as e:
error_msg = f"Error: {str(e)}"
logger.error(error_msg, exc_info=True)
yield error_msg
def _query_hf(self, messages, memory, question, temperature, stream):
# HuggingFace Inference API — OpenAI-compatible chat completions
resp = self._client.chat_completion(
model=self.model,
messages=messages,
temperature=max(temperature, 0.01),
max_tokens=2048,
)
answer = resp.choices[0].message.content
memory.add("user", question)
memory.add("assistant", answer)
yield answer
def _hf_fallback(self, messages, memory, question, temperature):
from huggingface_hub import InferenceClient
model = os.environ.get("HF_MODEL", DEFAULT_HF_MODEL)
client = InferenceClient(token=HF_TOKEN)
resp = client.chat_completion(
model=model,
messages=messages,
temperature=max(temperature, 0.01),
max_tokens=2048,
)
answer = resp.choices[0].message.content
memory.add("user", question)
memory.add("assistant", answer)
yield answer
def _query_groq(self, messages, memory, question, temperature, stream):
try:
if stream:
response_text = ""
with self._client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
stream=True,
) as stream_resp:
for chunk in stream_resp:
token = chunk.choices[0].delta.content or ""
response_text += token
yield token
memory.add("user", question)
memory.add("assistant", response_text)
else:
resp = self._client.chat.completions.create(
model=self.model,
messages=messages,
temperature=temperature,
)
answer = resp.choices[0].message.content
memory.add("user", question)
memory.add("assistant", answer)
yield answer
except Exception as e:
if _is_rate_limit(e):
if HF_TOKEN:
logger.warning("Groq rate limit reached — falling back to HF Inference")
yield from self._hf_fallback(messages, memory, question, temperature)
else:
yield "⚠️ Groq daily token limit reached (100K/day free tier). Please try again in a few hours, or upgrade at https://console.groq.com/settings/billing"
else:
raise
def _query_ollama(self, messages, memory, question, temperature, stream):
if stream:
response_text = ""
stream_resp = self._client.chat(
model=self.model,
messages=messages,
stream=True,
options={"temperature": temperature},
)
for chunk in stream_resp:
token = chunk["message"]["content"]
response_text += token
yield token
memory.add("user", question)
memory.add("assistant", response_text)
else:
response = self._client.chat(
model=self.model,
messages=messages,
options={"temperature": temperature},
)
answer = response["message"]["content"]
memory.add("user", question)
memory.add("assistant", answer)
yield answer
def list_available_models(self) -> List[str]:
return [self.model]