Spaces:
Sleeping
Sleeping
| """ | |
| LLM Client using LiteLLM for Groq and other providers | |
| """ | |
| from typing import Optional, Dict, Any, List | |
| import os | |
| from dotenv import load_dotenv | |
| from litellm import completion | |
| load_dotenv() | |
| class LLMClient: | |
| """LLM client using LiteLLM""" | |
| def __init__( | |
| self, | |
| model: str = "groq/llama-3.1-8b-instant", | |
| api_key: Optional[str] = None, | |
| temperature: float = 0.1 | |
| ): | |
| """ | |
| Initialize LLM client | |
| Args: | |
| model: Model identifier (e.g., "groq/llama-3.1-8b-instant") | |
| api_key: API key (if None, uses GROQ_API_KEY env var) | |
| temperature: Sampling temperature | |
| """ | |
| self.model = model | |
| self.temperature = temperature | |
| if api_key: | |
| os.environ["GROQ_API_KEY"] = api_key | |
| elif "GROQ_API_KEY" not in os.environ: | |
| api_key = os.getenv("GROQ_API_KEY") | |
| if not api_key: | |
| raise ValueError( | |
| "GROQ_API_KEY not found. Please set it as environment variable " | |
| "or pass as api_key parameter. Get free key from https://console.groq.com/" | |
| ) | |
| def generate( | |
| self, | |
| prompt: str, | |
| max_tokens: int = 512, | |
| system_prompt: Optional[str] = None | |
| ) -> str: | |
| """ | |
| Generate text using LLM | |
| Args: | |
| prompt: User prompt | |
| max_tokens: Maximum tokens to generate | |
| system_prompt: Optional system prompt | |
| Returns: | |
| Generated text | |
| """ | |
| messages = [] | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| messages.append({"role": "user", "content": prompt}) | |
| try: | |
| response = completion( | |
| model=self.model, | |
| messages=messages, | |
| temperature=self.temperature, | |
| max_tokens=max_tokens | |
| ) | |
| return response.choices[0].message.content | |
| except Exception as e: | |
| raise Exception(f"Error calling LLM: {str(e)}") | |
| def answer_question( | |
| self, | |
| question: str, | |
| context_chunks: List[str], | |
| use_citations: bool = True | |
| ) -> Dict[str, Any]: | |
| """ | |
| Answer question using RAG context | |
| Args: | |
| question: User question | |
| context_chunks: List of relevant context chunks | |
| use_citations: Whether to add citations | |
| Returns: | |
| Dict with 'answer' and optionally 'citations' | |
| """ | |
| # Build context | |
| context = "\n\n".join([ | |
| f"[{i+1}] {chunk}" for i, chunk in enumerate(context_chunks) | |
| ]) | |
| # Build prompt | |
| system_prompt = ( | |
| "You are a helpful assistant that answers questions based on the provided context. " | |
| "Use only the information from the context to answer. " | |
| "If the context doesn't contain enough information, say so. " | |
| ) | |
| if use_citations: | |
| system_prompt += ( | |
| "When referencing information from the context, cite the source using " | |
| "the format [1], [2], etc. corresponding to the chunk numbers." | |
| ) | |
| user_prompt = f"""Context: | |
| {context} | |
| Question: {question} | |
| Answer:""" | |
| answer = self.generate(user_prompt, system_prompt=system_prompt) | |
| result = {"answer": answer} | |
| if use_citations: | |
| # Extract citation numbers from answer | |
| import re | |
| citations = list(set(re.findall(r'\[(\d+)\]', answer))) | |
| citations = [int(c) for c in citations if int(c) <= len(context_chunks)] | |
| result["citations"] = sorted(citations) | |
| result["cited_chunks"] = [context_chunks[i-1] for i in citations if 1 <= i <= len(context_chunks)] | |
| return result | |