enterprise-rag-system / src /generation.py
Faraz618's picture
Update src/generation.py
d1e9091 verified
Raw
History Blame Contribute Delete
6.08 kB
"""
generation.py — LLM answer generation using Groq API (free tier).
Groq runs Llama 3 and Mixtral on custom LPU hardware.
Free tier: 14,400 requests/day, responses in ~200-400ms.
Model options (all free):
llama-3.1-8b-instant — fastest, good for most Q&A
llama-3.1-70b-versatile — best quality, use for complex documents
mixtral-8x7b-32768 — 32k context, good for long documents
gemma2-9b-it — Google model, good structured extraction
"""
import logging
from groq import Groq
from src.utils import get_required_env, get_env, count_tokens_estimate, Timer
logger = logging.getLogger("enterprise-rag.generation")
WEAK_CONTEXT_RESPONSE = (
"I was unable to find sufficient information in the uploaded document "
"to answer this question confidently. Please ensure the document contains "
"relevant content, or try rephrasing your question."
)
SYSTEM_PROMPT = """You are an enterprise document assistant. Answer questions strictly based on the provided document context.
Rules:
1. Answer ONLY using information from the provided context sections.
2. If the context does not contain enough information, say so clearly — do not guess or fabricate.
3. Be concise and precise.
4. Reproduce numbers, dates, and names exactly as they appear in the context.
5. Never use outside knowledge — only the provided context."""
def generate_answer(
query: str,
context_chunks: list,
scores: list,
is_relevant: bool,
max_new_tokens: int = 512,
) -> dict:
"""
Generate a grounded answer using Groq API.
Returns dict:
answer — generated text string
prompt_tokens — input token count
response_tokens — output token count
generation_latency_ms — time in milliseconds
model_used — model identifier
fallback_used — True if context was too weak
error — error message or None
"""
result = {
"answer": "",
"prompt_tokens": 0,
"response_tokens": 0,
"generation_latency_ms": 0,
"model_used": "",
"fallback_used": False,
"error": None,
}
# Return fallback immediately if retrieval quality is too low.
# Never send weak context to the LLM — it will fill gaps with hallucinations.
if not is_relevant or not context_chunks:
result["answer"] = WEAK_CONTEXT_RESPONSE
result["fallback_used"] = True
logger.warning("Fallback triggered: low retrieval relevance or empty chunks")
return result
try:
api_key = get_required_env("GROQ_API_KEY")
model_id = get_env("GROQ_MODEL", "llama-3.1-8b-instant")
result["model_used"] = model_id
client = Groq(api_key=api_key)
# Build numbered context block from retrieved chunks
context_block = "\n\n".join(
f"[Document Section {i + 1}]:\n{chunk.strip()}"
for i, chunk in enumerate(context_chunks)
)
full_prompt = SYSTEM_PROMPT + context_block + query
result["prompt_tokens"] = count_tokens_estimate(full_prompt)
messages = [
{
"role": "system",
"content": SYSTEM_PROMPT,
},
{
"role": "user",
"content": (
f"Here is the relevant document context:\n\n"
f"{context_block}\n\n"
f"Question: {query}\n\n"
f"Answer based only on the context above:"
),
},
]
with Timer() as t:
response = client.chat.completions.create(
model=model_id,
messages=messages,
max_tokens=max_new_tokens,
temperature=0.1,
top_p=0.9,
stream=False,
)
result["generation_latency_ms"] = round(t.elapsed_ms, 2)
answer_text = response.choices[0].message.content.strip()
if not answer_text:
result["answer"] = WEAK_CONTEXT_RESPONSE
result["fallback_used"] = True
else:
result["answer"] = answer_text
result["response_tokens"] = count_tokens_estimate(answer_text)
# Use Groq's actual token counts when available — more accurate than estimates
if hasattr(response, "usage") and response.usage:
if response.usage.prompt_tokens:
result["prompt_tokens"] = response.usage.prompt_tokens
if response.usage.completion_tokens:
result["response_tokens"] = response.usage.completion_tokens
logger.info(
f"Generated: {result['response_tokens']} tokens | "
f"{t.elapsed_ms:.0f}ms | model={model_id}"
)
except Exception as e:
error_msg = str(e)
logger.error(f"Groq generation error: {error_msg}")
if "rate_limit" in error_msg.lower() or "429" in error_msg:
result["answer"] = (
"⚠️ Groq API rate limit reached. "
"Free tier allows 14,400 requests/day and 6,000 tokens/minute. "
"Please wait a moment and try again."
)
elif "authentication" in error_msg.lower() or "401" in error_msg or "api_key" in error_msg.lower():
result["answer"] = (
"⚠️ Invalid GROQ_API_KEY. "
"Check your key in HF Space secrets. "
"Get a free key at console.groq.com → API Keys."
)
elif "model" in error_msg.lower() and "not found" in error_msg.lower():
result["answer"] = (
"⚠️ Model not found. Valid GROQ_MODEL options: "
"llama-3.1-8b-instant, llama-3.1-70b-versatile, "
"mixtral-8x7b-32768, gemma2-9b-it"
)
else:
result["answer"] = f"⚠️ Generation error: {error_msg}"
result["error"] = error_msg
result["fallback_used"] = True
return result