aankitdas's picture
initial clean commit
939a9f4
"""
LLM Module
----------
Purpose: Query Groq LLM with context for RAG answers
"""
from groq import Groq
from typing import List
import os
import logging
logging.basicConfig(level=logging.INFO)
from dotenv import load_dotenv
env_paths = [
os.path.join(os.path.dirname(__file__), '../..', '.env'), # Project root
os.path.join(os.path.dirname(__file__), '.env'), # Script directory
]
for env_path in env_paths:
if os.path.exists(env_path):
load_dotenv(env_path)
print(f"Loaded .env from: {env_path}")
break
logger = logging.getLogger(__name__)
class GroqLLMClient:
"""
Client for querying Groq LLM with context for RAG answers
Requires: Groq API key
Model: llama-3.1-8b-instant -> check available models using client.models.list()
"""
def __init__(
self,
api_key: str,
model_name: str = "llama-3.1-8b-instant",
max_tokens: int = 1024,
temperature: float = 0.7,
):
"""
Initialize Groq LLM client
Args:
api_key (str): Groq API key
model_name (str): Groq model name
max_tokens (int): Maximum number of tokens to generate
temperature (float): 0-1, higher for more creative shit
"""
self.api_key = api_key or os.getenv("GROQ_API_KEY")
if not self.api_key:
raise ValueError("GROQ_API_KEY not found in environment variables")
self.client = Groq(api_key=self.api_key)
self.model_name = model_name
self.max_tokens = max_tokens
self.temperature = temperature
logger.info(f"Groq LLM client initialized with model: {self.model_name}")
def _build_prompt(
self,
context: str,
question: str,
) -> str:
"""
Build the final prompt for LLM
Args:
context (str): Retrieved chunks
question (str): Question to ask
Returns:
str: Prompt for LLM
"""
prompt = f"""You are a helpful assistant. Answer the question based ONLY on the provided context.
If the context doesn't contain enough information to answer, say so explicitly.
Do not make up information.
Context: {context}
Question: {question}
Answer:"""
return prompt
def query(
self,
context: str,
query: str,
) -> str:
"""
Query the Groq LLM with context
Args:
context (str): Retrieved context from vector store
query: User's question
Returns:
LLM's answer as string
Raises:
RuntimeError: If Groq API fails
"""
try:
prompt = self._build_prompt(context, query)
logger.debug(f"Querying Groq with {len(context)} chars context")
response = self.client.chat.completions.create(
model=self.model_name,
messages=[
{"role": "user", "content": prompt}
],
max_tokens=self.max_tokens,
temperature=self.temperature,
)
answer = response.choices[0].message.content
logger.debug(f"Groq API response: {answer}")
return answer
except Exception as e:
logger.error(f"Groq query failed: {e}")
raise RuntimeError(f"LLM query failed: {e}")
def query_with_sources(
self,
context: str,
query: str,
sources: List[str] = None
) -> dict:
"""
Query LLM and return answer with source attribution.
Args:
context: Retrieved context
query: User's question
sources: Optional list of source identifiers (chunk IDs, URLs, etc.)
Returns:
Dict with 'answer' and 'sources' keys
Example:
>>> result = client.query_with_sources(
... context="...",
... query="What is ML?",
... sources=["doc1_chunk_0", "doc1_chunk_2"]
... )
>>> print(result["answer"])
>>> print(result["sources"])
"""
answer = self.query(context, query)
return {
"answer": answer,
"sources": sources or []
}
def build_context_string(
retrieved_results: List,
include_scores: bool = True
) -> str:
"""
Build a context string from retrieved results
Args:
retrieved_results: List of retrieved results
include_scores: Whether to include scores in the context string
Returns:
Context string
"""
context_parts = []
for i, result in enumerate(retrieved_results, 1):
if include_scores:
part = f"[Chunk {i} - Relevance: {result.similarity:.1%}]\n{result.text}"
else:
part = f"[Chunk {i}]\n{result.text}"
context_parts.append(part)
return "\n\n".join(context_parts)
# ============ TESTS ============
def test_build_context_string():
"""Test context string building."""
from .vector_store import RetrievalResult
results = [
RetrievalResult("chunk1", "Text 1", 0.95),
RetrievalResult("chunk2", "Text 2", 0.87)
]
context = build_context_string(results)
assert "Text 1" in context
assert "Text 2" in context
assert "95.0%" in context
if __name__ == "__main__":
try:
# Test Groq client
client = GroqLLMClient(api_key=os.getenv("GROQ_API_KEY"))
# Test context string
from .vector_store import RetrievalResult
results = [
RetrievalResult("chunk1", "Machine learning is AI", 0.95),
RetrievalResult("chunk2", "Deep learning uses neural networks", 0.87)
]
context = build_context_string(results)
# Query
answer = client.query(
context=context,
query="What is machine learning?"
)
print("✓ Groq query successful!")
print(f"Answer: {answer[:200]}...")
except Exception as e:
print(f"✗ Error: {e}")