llm-project / rag_system /llm_client.py
=
Initial commit: RAG Q&A system for agricultural research
779b4bd
"""
LLM Client using LiteLLM for Groq and other providers
"""
from typing import Optional, Dict, Any, List
import os
from dotenv import load_dotenv
from litellm import completion
load_dotenv()
class LLMClient:
"""LLM client using LiteLLM"""
def __init__(
self,
model: str = "groq/llama-3.1-8b-instant",
api_key: Optional[str] = None,
temperature: float = 0.1
):
"""
Initialize LLM client
Args:
model: Model identifier (e.g., "groq/llama-3.1-8b-instant")
api_key: API key (if None, uses GROQ_API_KEY env var)
temperature: Sampling temperature
"""
self.model = model
self.temperature = temperature
if api_key:
os.environ["GROQ_API_KEY"] = api_key
elif "GROQ_API_KEY" not in os.environ:
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise ValueError(
"GROQ_API_KEY not found. Please set it as environment variable "
"or pass as api_key parameter. Get free key from https://console.groq.com/"
)
def generate(
self,
prompt: str,
max_tokens: int = 512,
system_prompt: Optional[str] = None
) -> str:
"""
Generate text using LLM
Args:
prompt: User prompt
max_tokens: Maximum tokens to generate
system_prompt: Optional system prompt
Returns:
Generated text
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
messages.append({"role": "user", "content": prompt})
try:
response = completion(
model=self.model,
messages=messages,
temperature=self.temperature,
max_tokens=max_tokens
)
return response.choices[0].message.content
except Exception as e:
raise Exception(f"Error calling LLM: {str(e)}")
def answer_question(
self,
question: str,
context_chunks: List[str],
use_citations: bool = True
) -> Dict[str, Any]:
"""
Answer question using RAG context
Args:
question: User question
context_chunks: List of relevant context chunks
use_citations: Whether to add citations
Returns:
Dict with 'answer' and optionally 'citations'
"""
# Build context
context = "\n\n".join([
f"[{i+1}] {chunk}" for i, chunk in enumerate(context_chunks)
])
# Build prompt
system_prompt = (
"You are a helpful assistant that answers questions based on the provided context. "
"Use only the information from the context to answer. "
"If the context doesn't contain enough information, say so. "
)
if use_citations:
system_prompt += (
"When referencing information from the context, cite the source using "
"the format [1], [2], etc. corresponding to the chunk numbers."
)
user_prompt = f"""Context:
{context}
Question: {question}
Answer:"""
answer = self.generate(user_prompt, system_prompt=system_prompt)
result = {"answer": answer}
if use_citations:
# Extract citation numbers from answer
import re
citations = list(set(re.findall(r'\[(\d+)\]', answer)))
citations = [int(c) for c in citations if int(c) <= len(context_chunks)]
result["citations"] = sorted(citations)
result["cited_chunks"] = [context_chunks[i-1] for i in citations if 1 <= i <= len(context_chunks)]
return result