File size: 8,071 Bytes
068aa4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""
Multi-Source Retrieval Agent
Intelligently decides which sources to query based on query type
"""
import os
from dotenv import load_dotenv
from groq import Groq
from hybrid_search import HybridSearch
from sentence_transformers import SentenceTransformer

load_dotenv()

class RetrievalAgent:
    def __init__(self, chromadb_collection, groq_api_key=None):
        """Initialize Retrieval Agent"""
        print("πŸ” Initializing Multi-Source Retrieval Agent...\n")
        
        self.groq_client = Groq(api_key=groq_api_key)
        self.model_name = "llama-3.3-70b-versatile"
        self.collection = chromadb_collection
        self.embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
        
        # Initialize retrieval sources
        all_docs = self._get_all_documents()
        self.hybrid_search = HybridSearch(all_docs)
        
        self.classification_prompt = """Analyze this query and classify it:

QUERY: "{query}"

Determine:
1. Query Type: factual, conceptual, procedural, comparative
2. Information Need: general knowledge, specific details, step-by-step guide, comparison
3. Search Strategy: broad (many results), narrow (specific results), mixed

Respond in this format ONLY:
TYPE: [type]
NEED: [need]
STRATEGY: [strategy]"""
        
        print("βœ… Retrieval Agent ready!\n")
    
    def _get_all_documents(self):
        """Get all documents from ChromaDB collection"""
        try:
            results = self.collection.get()
            docs = []
            for i, doc in enumerate(results['documents']):
                docs.append(doc)
            return docs
        except:
            return []
    
    def classify_query(self, query):
        """Use LLM to classify query for optimal retrieval strategy"""
        print(f"πŸ“Š Classifying query: '{query}'")
        
        try:
            response = self.groq_client.chat.completions.create(
                messages=[
                    {
                        "role": "user",
                        "content": self.classification_prompt.format(query=query)
                    }
                ],
                model=self.model_name,
                temperature=0.3,
                max_tokens=100
            )
            
            classification = response.choices[0].message.content.strip()
            print(f"βœ… Classification:\n{classification}\n")
            
            return classification
        
        except Exception as e:
            print(f"❌ Classification error: {e}\n")
            return "TYPE: mixed\nNEED: general\nSTRATEGY: mixed"
    
    def vector_search(self, query, top_k=5):
        """Search using vector embeddings (semantic similarity)"""
        print(f"  πŸ“Œ Performing vector search...")
        
        try:
            query_embedding = self.embedding_model.encode([query])[0]
            
            results = self.collection.query(
                query_embeddings=[query_embedding.tolist()],
                n_results=top_k
            )
            
            vector_results = []
            if results and results['documents']:
                for i, doc in enumerate(results['documents'][0]):
                    vector_results.append({
                        'index': i,
                        'content': doc,
                        'source': results['metadatas'][0][i]['source_file'],
                        'score': 1 - results['distances'][0][i],
                        'method': 'vector_search'
                    })
            
            print(f"     βœ“ Found {len(vector_results)} results via vector search")
            return vector_results
        
        except Exception as e:
            print(f"     βœ— Vector search error: {e}")
            return []
    
    def bm25_search(self, query, top_k=5):
        """Search using BM25 (keyword matching)"""
        print(f"  πŸ“Œ Performing BM25 search...")
        
        try:
            bm25_results = self.hybrid_search.bm25_search(query, top_k)
            
            # Get all documents to find sources
            all_results = self.collection.get()
            doc_to_source = {}
            if all_results and all_results['metadatas']:
                for i, metadata in enumerate(all_results['metadatas']):
                    if i < len(all_results['documents']):
                        doc_text = all_results['documents'][i][:50]  # First 50 chars as key
                        doc_to_source[doc_text] = metadata.get('source_file', 'unknown')
            
            formatted_results = []
            for result in bm25_results:
                # Normalize BM25 score (typically 0-100, divide by 100)
                normalized_score = min(result['score'] / 100.0, 1.0)
                
                # Find source
                doc_preview = result['content'][:50]
                source = 'unknown'
                for key, val in doc_to_source.items():
                    if key in result['content']:
                        source = val
                        break
                
                formatted_results.append({
                    'index': result['index'],
                    'content': result['content'],
                    'source': source,
                    'score': normalized_score,
                    'method': 'bm25_search'
                })
            
            print(f"     βœ“ Found {len(formatted_results)} results via BM25")
            return formatted_results
        
        except Exception as e:
            print(f"     βœ— BM25 search error: {e}")
            return []
    
    def retrieve(self, query, top_k=5):
        """
        Main retrieval method: intelligently combines multiple sources
        """
        print(f"\nπŸ” RETRIEVING FOR QUERY: '{query}'")
        print("-" * 70)
        
        # Step 1: Classify query
        classification = self.classify_query(query)
        
        # Step 2: Decide which sources to use
        use_vector = True  # Always use vector
        use_bm25 = True    # Always use BM25
        
        all_results = []
        
        print(f"πŸ”Ž Searching sources:")
        
        # Step 3: Search vector database
        if use_vector:
            vector_results = self.vector_search(query, top_k)
            all_results.extend(vector_results)
        
        # Step 4: Search BM25
        if use_bm25:
            bm25_results = self.bm25_search(query, top_k)
            all_results.extend(bm25_results)
        
        # Step 5: Deduplicate and rank
        seen = set()
        unique_results = []
        
        for result in all_results:
            content_hash = hash(result['content'][:100])
            if content_hash not in seen:
                seen.add(content_hash)
                unique_results.append(result)
        
        # Sort by score (descending)
        unique_results.sort(key=lambda x: x['score'], reverse=True)
        final_results = unique_results[:top_k]
        
        print(f"\nβœ… Retrieved {len(final_results)} unique documents")
        print("-" * 70 + "\n")
        
        return final_results


# Test the agent
if __name__ == "__main__":
    import chromadb
    from dotenv import load_dotenv
    import os
    
    load_dotenv()
    api_key = os.getenv("GROQ_API_KEY")
    
    # Connect to ChromaDB
    client = chromadb.PersistentClient(path="data/vectordb")
    collection = client.get_collection(name="documents")
    
    # Initialize agent
    agent = RetrievalAgent(collection, groq_api_key=api_key)
    
    # Test queries
    test_queries = [
        "How do I create a FastAPI endpoint?",
        "What is the leave policy?",
        "Remote work guidelines"
    ]
    
    print("=" * 70)
    print("πŸ” MULTI-SOURCE RETRIEVAL AGENT TEST")
    print("=" * 70)
    
    for query in test_queries:
        results = agent.retrieve(query, top_k=3)
        print(f"Results for '{query}':")
        for i, result in enumerate(results, 1):
            print(f"  {i}. [{result['method']}] Score: {result['score']:.2f}")
        print()