Spaces:
Build error
Build error
| """ | |
| generation.py — LLM answer generation using Groq API (free tier). | |
| Groq runs Llama 3 and Mixtral on custom LPU hardware. | |
| Free tier: 14,400 requests/day, responses in ~200-400ms. | |
| Model options (all free): | |
| llama-3.1-8b-instant — fastest, good for most Q&A | |
| llama-3.1-70b-versatile — best quality, use for complex documents | |
| mixtral-8x7b-32768 — 32k context, good for long documents | |
| gemma2-9b-it — Google model, good structured extraction | |
| """ | |
| import logging | |
| from groq import Groq | |
| from src.utils import get_required_env, get_env, count_tokens_estimate, Timer | |
| logger = logging.getLogger("enterprise-rag.generation") | |
| WEAK_CONTEXT_RESPONSE = ( | |
| "I was unable to find sufficient information in the uploaded document " | |
| "to answer this question confidently. Please ensure the document contains " | |
| "relevant content, or try rephrasing your question." | |
| ) | |
| SYSTEM_PROMPT = """You are an enterprise document assistant. Answer questions strictly based on the provided document context. | |
| Rules: | |
| 1. Answer ONLY using information from the provided context sections. | |
| 2. If the context does not contain enough information, say so clearly — do not guess or fabricate. | |
| 3. Be concise and precise. | |
| 4. Reproduce numbers, dates, and names exactly as they appear in the context. | |
| 5. Never use outside knowledge — only the provided context.""" | |
| def generate_answer( | |
| query: str, | |
| context_chunks: list, | |
| scores: list, | |
| is_relevant: bool, | |
| max_new_tokens: int = 512, | |
| ) -> dict: | |
| """ | |
| Generate a grounded answer using Groq API. | |
| Returns dict: | |
| answer — generated text string | |
| prompt_tokens — input token count | |
| response_tokens — output token count | |
| generation_latency_ms — time in milliseconds | |
| model_used — model identifier | |
| fallback_used — True if context was too weak | |
| error — error message or None | |
| """ | |
| result = { | |
| "answer": "", | |
| "prompt_tokens": 0, | |
| "response_tokens": 0, | |
| "generation_latency_ms": 0, | |
| "model_used": "", | |
| "fallback_used": False, | |
| "error": None, | |
| } | |
| # Return fallback immediately if retrieval quality is too low. | |
| # Never send weak context to the LLM — it will fill gaps with hallucinations. | |
| if not is_relevant or not context_chunks: | |
| result["answer"] = WEAK_CONTEXT_RESPONSE | |
| result["fallback_used"] = True | |
| logger.warning("Fallback triggered: low retrieval relevance or empty chunks") | |
| return result | |
| try: | |
| api_key = get_required_env("GROQ_API_KEY") | |
| model_id = get_env("GROQ_MODEL", "llama-3.1-8b-instant") | |
| result["model_used"] = model_id | |
| client = Groq(api_key=api_key) | |
| # Build numbered context block from retrieved chunks | |
| context_block = "\n\n".join( | |
| f"[Document Section {i + 1}]:\n{chunk.strip()}" | |
| for i, chunk in enumerate(context_chunks) | |
| ) | |
| full_prompt = SYSTEM_PROMPT + context_block + query | |
| result["prompt_tokens"] = count_tokens_estimate(full_prompt) | |
| messages = [ | |
| { | |
| "role": "system", | |
| "content": SYSTEM_PROMPT, | |
| }, | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"Here is the relevant document context:\n\n" | |
| f"{context_block}\n\n" | |
| f"Question: {query}\n\n" | |
| f"Answer based only on the context above:" | |
| ), | |
| }, | |
| ] | |
| with Timer() as t: | |
| response = client.chat.completions.create( | |
| model=model_id, | |
| messages=messages, | |
| max_tokens=max_new_tokens, | |
| temperature=0.1, | |
| top_p=0.9, | |
| stream=False, | |
| ) | |
| result["generation_latency_ms"] = round(t.elapsed_ms, 2) | |
| answer_text = response.choices[0].message.content.strip() | |
| if not answer_text: | |
| result["answer"] = WEAK_CONTEXT_RESPONSE | |
| result["fallback_used"] = True | |
| else: | |
| result["answer"] = answer_text | |
| result["response_tokens"] = count_tokens_estimate(answer_text) | |
| # Use Groq's actual token counts when available — more accurate than estimates | |
| if hasattr(response, "usage") and response.usage: | |
| if response.usage.prompt_tokens: | |
| result["prompt_tokens"] = response.usage.prompt_tokens | |
| if response.usage.completion_tokens: | |
| result["response_tokens"] = response.usage.completion_tokens | |
| logger.info( | |
| f"Generated: {result['response_tokens']} tokens | " | |
| f"{t.elapsed_ms:.0f}ms | model={model_id}" | |
| ) | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.error(f"Groq generation error: {error_msg}") | |
| if "rate_limit" in error_msg.lower() or "429" in error_msg: | |
| result["answer"] = ( | |
| "⚠️ Groq API rate limit reached. " | |
| "Free tier allows 14,400 requests/day and 6,000 tokens/minute. " | |
| "Please wait a moment and try again." | |
| ) | |
| elif "authentication" in error_msg.lower() or "401" in error_msg or "api_key" in error_msg.lower(): | |
| result["answer"] = ( | |
| "⚠️ Invalid GROQ_API_KEY. " | |
| "Check your key in HF Space secrets. " | |
| "Get a free key at console.groq.com → API Keys." | |
| ) | |
| elif "model" in error_msg.lower() and "not found" in error_msg.lower(): | |
| result["answer"] = ( | |
| "⚠️ Model not found. Valid GROQ_MODEL options: " | |
| "llama-3.1-8b-instant, llama-3.1-70b-versatile, " | |
| "mixtral-8x7b-32768, gemma2-9b-it" | |
| ) | |
| else: | |
| result["answer"] = f"⚠️ Generation error: {error_msg}" | |
| result["error"] = error_msg | |
| result["fallback_used"] = True | |
| return result |