File size: 6,328 Bytes
939a9f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
"""
LLM Module
----------
Purpose: Query Groq LLM with context for RAG answers
"""
from groq import Groq
from typing import List
import os
import logging 
logging.basicConfig(level=logging.INFO) 
from dotenv import load_dotenv

env_paths = [
    os.path.join(os.path.dirname(__file__), '../..', '.env'),  # Project root
    os.path.join(os.path.dirname(__file__), '.env'),  # Script directory
]

for env_path in env_paths:
    if os.path.exists(env_path):
        load_dotenv(env_path)
        print(f"Loaded .env from: {env_path}")
        break

logger = logging.getLogger(__name__)


class GroqLLMClient:
    """
    Client for querying Groq LLM with context for RAG answers
    Requires: Groq API key
    Model: llama-3.1-8b-instant -> check available models using client.models.list()
    """
    def __init__(
        self,
        api_key: str,
        model_name: str = "llama-3.1-8b-instant",
        max_tokens: int = 1024,
        temperature: float = 0.7,
    ):
        """
        Initialize Groq LLM client
        Args:
            api_key (str): Groq API key
            model_name (str): Groq model name
            max_tokens (int): Maximum number of tokens to generate
            temperature (float): 0-1, higher for more creative shit
        """
        self.api_key = api_key or os.getenv("GROQ_API_KEY")

        if not self.api_key:
            raise ValueError("GROQ_API_KEY not found in environment variables")

        self.client = Groq(api_key=self.api_key)
        self.model_name = model_name
        self.max_tokens = max_tokens
        self.temperature = temperature

        logger.info(f"Groq LLM client initialized with model: {self.model_name}")

    def _build_prompt(
        self,
        context: str,
        question: str,
    ) -> str:
        """
        Build the final prompt for LLM
        Args:
            context (str): Retrieved chunks
            question (str): Question to ask
        Returns:
            str: Prompt for LLM
        """
        prompt = f"""You are a helpful assistant. Answer the question based ONLY on the provided context.
                    If the context doesn't contain enough information to answer, say so explicitly.
                    Do not make up information.

                    Context: {context}

                    Question: {question}

                    Answer:"""
        return prompt

    def query(
        self,
        context: str,
        query: str,
    ) -> str:
        """
        Query the Groq LLM with context
        Args:
            context (str): Retrieved context from vector store
            query: User's question
        
        Returns:
            LLM's answer as string
        
        Raises:
            RuntimeError: If Groq API fails
        """
        try:
            prompt = self._build_prompt(context, query)
            logger.debug(f"Querying Groq with {len(context)} chars context")

            response = self.client.chat.completions.create(
                model=self.model_name,
                messages=[
                    {"role": "user", "content": prompt}
                ],
                max_tokens=self.max_tokens,
                temperature=self.temperature,
            )
            answer = response.choices[0].message.content
            logger.debug(f"Groq API response: {answer}")
            return answer
        except Exception as e:
            logger.error(f"Groq query failed: {e}")
            raise RuntimeError(f"LLM query failed: {e}")

    def query_with_sources(
        self,
        context: str,
        query: str,
        sources: List[str] = None
    ) -> dict:
        """
        Query LLM and return answer with source attribution.
        
        Args:
            context: Retrieved context
            query: User's question
            sources: Optional list of source identifiers (chunk IDs, URLs, etc.)
        
        Returns:
            Dict with 'answer' and 'sources' keys
        
        Example:
            >>> result = client.query_with_sources(
            ...     context="...",
            ...     query="What is ML?",
            ...     sources=["doc1_chunk_0", "doc1_chunk_2"]
            ... )
            >>> print(result["answer"])
            >>> print(result["sources"])
        """
        answer = self.query(context, query)
        
        return {
            "answer": answer,
            "sources": sources or []
        }

def build_context_string(
    retrieved_results: List,
    include_scores: bool = True
) -> str:
    """
    Build a context string from retrieved results
    Args:
        retrieved_results: List of retrieved results
        include_scores: Whether to include scores in the context string
    Returns:
        Context string
    """
    context_parts = []

    for i, result in enumerate(retrieved_results, 1):
        if include_scores:
            part = f"[Chunk {i} - Relevance: {result.similarity:.1%}]\n{result.text}"
        else:
            part = f"[Chunk {i}]\n{result.text}"
        
        context_parts.append(part)

    return "\n\n".join(context_parts)
            
# ============ TESTS ============

def test_build_context_string():
    """Test context string building."""
    from .vector_store import RetrievalResult
    
    results = [
        RetrievalResult("chunk1", "Text 1", 0.95),
        RetrievalResult("chunk2", "Text 2", 0.87)
    ]
    
    context = build_context_string(results)
    
    assert "Text 1" in context
    assert "Text 2" in context
    assert "95.0%" in context


if __name__ == "__main__":
    try:
        # Test Groq client
        client = GroqLLMClient(api_key=os.getenv("GROQ_API_KEY"))
        
        # Test context string
        from .vector_store import RetrievalResult
        
        results = [
            RetrievalResult("chunk1", "Machine learning is AI", 0.95),
            RetrievalResult("chunk2", "Deep learning uses neural networks", 0.87)
        ]
        
        context = build_context_string(results)
        
        # Query
        answer = client.query(
            context=context,
            query="What is machine learning?"
        )
        
        print("✓ Groq query successful!")
        print(f"Answer: {answer[:200]}...")
    
    except Exception as e:
        print(f"✗ Error: {e}")