File size: 14,562 Bytes
884f65a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8223ae9
 
 
 
 
 
 
 
 
 
 
 
 
884f65a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
#!/usr/bin/env python3

import json
import numpy as np
from typing import List, Dict, Tuple
import pickle
from pathlib import Path
import logging
from dataclasses import dataclass
import re

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

@dataclass
class Document:
    id: str
    title: str
    content: str
    url: str
    source: str
    embedding: np.ndarray = None

class SimpleVectorDB:
    
    def __init__(self, model_name: str = "paraphrase-MiniLM-L3-v2"):
        self.model_name = model_name
        self.model = None
        self.documents: List[Document] = []
        self.embeddings: np.ndarray = None
        self.db_file = "atlan_vector_db.pkl"
        
    def _load_embedding_model(self):
        """Load the sentence transformer model"""
        try:
            from sentence_transformers import SentenceTransformer
            self.model = SentenceTransformer(self.model_name)
            logger.info(f"Loaded embedding model: {self.model_name}")
        except ImportError:
            logger.error("sentence-transformers not installed. Using fallback TF-IDF method.")
            self._init_tfidf_fallback()
    
    def _init_tfidf_fallback(self):
        """Fallback to TF-IDF if sentence-transformers is not available"""
        try:
            from sklearn.feature_extraction.text import TfidfVectorizer
            from sklearn.metrics.pairwise import cosine_similarity
            self.tfidf_vectorizer = TfidfVectorizer(
                max_features=1000,
                stop_words='english',
                ngram_range=(1, 2)
            )
            self.use_tfidf = True
            logger.info("Using TF-IDF fallback for embeddings")
        except ImportError:
            logger.error("scikit-learn not available. Using simple text matching.")
            self.use_simple_search = True
    
    def chunk_text(self, text: str, chunk_size: int = 500, overlap: int = 50) -> List[str]:
        """Split text into overlapping chunks for better retrieval"""
        if len(text) <= chunk_size:
            return [text]
        
        chunks = []
        start = 0
        
        while start < len(text):
            end = start + chunk_size
            
            # Try to break at sentence boundary
            if end < len(text):
                # Look for sentence ending within the next 100 chars
                sentence_end = text.rfind('.', end, min(end + 100, len(text)))
                if sentence_end > start:
                    end = sentence_end + 1
            
            chunk = text[start:end].strip()
            if chunk:
                chunks.append(chunk)
            
            start = end - overlap
            
            # Avoid infinite loop
            if start >= len(text):
                break
        
        return chunks
    
    def load_knowledge_base(self, filename: str = "atlan_knowledge_base.json") -> bool:
        """Load knowledge base and create document chunks"""
        try:
            with open(filename, 'r', encoding='utf-8') as f:
                kb_data = json.load(f)
            
            logger.info(f"Loading {len(kb_data)} pages from knowledge base...")
            
            # Process each page and create document chunks
            doc_id = 0
            for page in kb_data:
                title = page.get('title', 'Untitled')
                content = page.get('content', '')
                url = page.get('url', '')
                source = page.get('source', 'unknown')
                
                if not content:
                    continue
                
                # Split content into chunks for better retrieval
                chunks = self.chunk_text(content)
                
                for i, chunk in enumerate(chunks):
                    if len(chunk.strip()) < 100:  # Skip very short chunks
                        continue
                        
                    doc = Document(
                        id=f"{doc_id}_{i}",
                        title=f"{title} (Part {i+1})" if len(chunks) > 1 else title,
                        content=chunk,
                        url=url,
                        source=source
                    )
                    self.documents.append(doc)
                
                doc_id += 1
            
            logger.info(f"Created {len(self.documents)} document chunks")
            return True
            
        except FileNotFoundError:
            logger.error(f"Knowledge base file {filename} not found")
            return False
        except Exception as e:
            logger.error(f"Error loading knowledge base: {str(e)}")
            return False
    
    def create_embeddings(self):
        """Create embeddings for all documents"""
        if not self.documents:
            logger.error("No documents loaded")
            return
        
        if not self.model:
            self._load_embedding_model()
        
        logger.info("Creating embeddings for documents...")
        
        texts = [doc.content for doc in self.documents]
        
        if hasattr(self, 'use_tfidf') and self.use_tfidf:
            # Use TF-IDF fallback
            self.embeddings = self.tfidf_vectorizer.fit_transform(texts)
            logger.info("Created TF-IDF embeddings")
        elif hasattr(self, 'use_simple_search'):
            # Simple keyword matching fallback
            logger.info("Using simple keyword matching")
            return
        else:
            # Use sentence transformers
            embeddings = self.model.encode(texts, show_progress_bar=True)
            self.embeddings = np.array(embeddings)
            
            # Store embeddings in documents
            for i, doc in enumerate(self.documents):
                doc.embedding = embeddings[i]
            
            logger.info(f"Created {self.embeddings.shape[0]} embeddings with dimension {self.embeddings.shape[1]}")
    
    def save_database(self):
        """Save the vector database to disk"""
        db_data = {
            'documents': self.documents,
            'embeddings': self.embeddings,
            'model_name': self.model_name
        }
        
        with open(self.db_file, 'wb') as f:
            pickle.dump(db_data, f)
        
        logger.info(f"Vector database saved to {self.db_file}")
    
    def load_database(self) -> bool:
        """Load the vector database from disk"""
        try:
            with open(self.db_file, 'rb') as f:
                db_data = pickle.load(f)
            
            self.documents = db_data['documents']
            self.embeddings = db_data['embeddings']
            # Keep the current model_name (don't overwrite with old saved model)
            # This allows us to use a different model than what was saved
            saved_model = db_data.get('model_name', 'unknown')
            logger.info(f"Loaded vector database with {len(self.documents)} documents (original model: {saved_model}, using: {self.model_name})")
            
            # If the saved model is different from current, regenerate embeddings
            if saved_model != self.model_name:
                logger.warning(f"Model mismatch: saved={saved_model}, current={self.model_name}. Regenerating embeddings with new model.")
                # Force regeneration of embeddings with new model
                self._load_embedding_model()
                self.create_embeddings()
                self.save_database()  # Save with new model
                logger.info(f"Embeddings regenerated and saved with new model: {self.model_name}")
            
            return True
            
        except FileNotFoundError:
            logger.warning(f"Vector database file {self.db_file} not found")
            return False
        except Exception as e:
            logger.error(f"Error loading vector database: {str(e)}")
            return False
    
    def simple_keyword_search(self, query: str, top_k: int = 5) -> List[Tuple[Document, float]]:
        """Fallback keyword-based search"""
        query_words = set(query.lower().split())
        results = []
        
        for doc in self.documents:
            content_words = set(doc.content.lower().split())
            title_words = set(doc.title.lower().split())
            
            # Calculate simple overlap score
            content_overlap = len(query_words.intersection(content_words))
            title_overlap = len(query_words.intersection(title_words)) * 2  # Weight title higher
            
            score = (content_overlap + title_overlap) / len(query_words)
            
            if score > 0:
                results.append((doc, score))
        
        # Sort by score and return top k
        results.sort(key=lambda x: x[1], reverse=True)
        return results[:top_k]
    
    def search(self, query: str, top_k: int = 5, source_filter: str = None) -> List[Tuple[Document, float]]:
        """Search for relevant documents"""
        if not self.documents:
            logger.error("No documents in database")
            return []
        
        # Fallback to simple search if no embeddings
        if hasattr(self, 'use_simple_search'):
            return self.simple_keyword_search(query, top_k)
        
        # Load model if not loaded
        if not self.model and not hasattr(self, 'use_tfidf'):
            self._load_embedding_model()
        
        # Create query embedding
        if hasattr(self, 'use_tfidf') and self.use_tfidf:
            query_embedding = self.tfidf_vectorizer.transform([query])
            from sklearn.metrics.pairwise import cosine_similarity
            similarities = cosine_similarity(query_embedding, self.embeddings).flatten()
        else:
            query_embedding = self.model.encode([query])
            # Calculate cosine similarity
            similarities = np.dot(self.embeddings, query_embedding.T).flatten()
            norms = np.linalg.norm(self.embeddings, axis=1) * np.linalg.norm(query_embedding)
            similarities = similarities / norms
        
        # Get top k results
        top_indices = np.argsort(similarities)[::-1][:top_k * 2]  # Get more to filter
        
        results = []
        for idx in top_indices:
            doc = self.documents[idx]
            score = similarities[idx]
            
            # Apply source filter if specified
            if source_filter and doc.source != source_filter:
                continue
                
            results.append((doc, float(score)))
            
            if len(results) >= top_k:
                break
        
        return results
    
    def get_context_for_query(self, query: str, max_chars: int = 3000) -> Tuple[str, List[str]]:
        """Get relevant context for a query with source URLs"""
        
        # Determine source filter based on query content
        source_filter = None
        query_lower = query.lower()
        
        if any(keyword in query_lower for keyword in ['api', 'sdk', 'endpoint', 'programming', 'code']):
            source_filter = 'developer'
        elif any(keyword in query_lower for keyword in ['how to', 'setup', 'configure', 'guide', 'tutorial']):
            source_filter = 'docs'
        
        # Search for relevant documents
        results = self.search(query, top_k=10, source_filter=source_filter)
        
        if not results:
            return "", []
        
        # Combine relevant content
        context_parts = []
        sources = []
        total_chars = 0
        
        for doc, score in results:
            # Only include high-relevance results
            if score < 0.1:  # Threshold for relevance
                continue
                
            content = f"Title: {doc.title}\nContent: {doc.content}\n\n"
            
            if total_chars + len(content) > max_chars:
                # Add partial content if we're near the limit
                remaining_chars = max_chars - total_chars
                if remaining_chars > 200:  # Only if we have reasonable space left
                    content = content[:remaining_chars] + "..."
                    context_parts.append(content)
                break
            
            context_parts.append(content)
            if doc.url not in sources:
                sources.append(doc.url)
            
            total_chars += len(content)
        
        context = "".join(context_parts)
        return context, sources

def build_vector_database():
    """Build the vector database from scraped knowledge base"""
    print("๐Ÿ”ง Building Vector Database...")
    print("=" * 40)
    
    # Initialize vector database
    vector_db = SimpleVectorDB()
    
    # Check if database already exists
    if vector_db.load_database():
        print(f"โœ… Loaded existing vector database with {len(vector_db.documents)} documents")
        response = input("Do you want to rebuild? (y/N): ").strip().lower()
        if response != 'y':
            return vector_db
    
    # Load knowledge base
    if not vector_db.load_knowledge_base():
        print("โŒ Failed to load knowledge base. Run scraper first.")
        return None
    
    # Create embeddings
    print("๐Ÿงฎ Creating embeddings...")
    vector_db.create_embeddings()
    
    # Save database
    vector_db.save_database()
    
    print(f"โœ… Vector database built successfully!")
    print(f"๐Ÿ“Š Documents: {len(vector_db.documents)}")
    
    return vector_db

def test_search(vector_db: SimpleVectorDB):
    """Test the search functionality"""
    print("\n๐Ÿ” Testing Search Functionality...")
    print("=" * 40)
    
    test_queries = [
        "How to connect Snowflake to Atlan?",
        "API documentation for creating assets",
        "Data lineage configuration",
        "SSO setup with SAML",
        "Troubleshooting connector issues"
    ]
    
    for query in test_queries:
        print(f"\nQuery: {query}")
        context, sources = vector_db.get_context_for_query(query, max_chars=500)
        print(f"Context length: {len(context)} characters")
        print(f"Sources: {len(sources)}")
        for i, source in enumerate(sources[:3]):
            print(f"  {i+1}. {source}")

if __name__ == "__main__":
    # Build vector database
    vector_db = build_vector_database()
    
    if vector_db:
        # Test search
        test_search(vector_db)
        
        print(f"\n๐ŸŽ‰ Vector database ready for RAG pipeline!")
    else:
        print("โŒ Failed to build vector database")