Spaces:
Sleeping
Sleeping
| # rag_engine.py | |
| import os | |
| import pickle | |
| import numpy as np | |
| import faiss | |
| import requests | |
| import trafilatura | |
| from bs4 import BeautifulSoup | |
| from sentence_transformers import SentenceTransformer | |
| from flashrank import Ranker, RerankRequest | |
| import logging | |
| import time | |
| import re | |
| # BM25 for keyword-based lexical search | |
| from rank_bm25 import BM25Okapi | |
| # Playwright for SPA/JavaScript-rendered pages | |
| try: | |
| from playwright.sync_api import sync_playwright | |
| PLAYWRIGHT_AVAILABLE = True | |
| except ImportError: | |
| PLAYWRIGHT_AVAILABLE = False | |
| logging.warning("β οΈ Playwright not installed. SPA scraping will be limited.") | |
| # Setup basic logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def tokenize(text: str) -> list: | |
| """Simple tokenizer for BM25: lowercase, alphanumeric only.""" | |
| return re.findall(r'\w+', text.lower()) | |
| class KnowledgeBase: | |
| """ | |
| High-Performance Hybrid RAG Engine. | |
| Combines: FAISS (Semantic) + BM25 (Lexical) + RRF (Fusion) + FlashRank (Reranking). | |
| """ | |
| def __init__(self, index_path="faiss_index.bin", metadata_path="metadata.pkl", bm25_path="bm25_corpus.pkl"): | |
| self.index_path = index_path | |
| self.metadata_path = metadata_path | |
| self.bm25_path = bm25_path | |
| self.metadata = [] | |
| self.tokenized_corpus = [] # For BM25 | |
| self.bm25 = None | |
| logger.info("π Initializing Hybrid Knowledge Base (FAISS + BM25 + RRF)...") | |
| # 1. Embedding Model (384 dim, Fast) - Verified MiniLM | |
| self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu') | |
| logger.info(f"β Loaded SentenceTransformer: all-MiniLM-L6-v2 (dim={self.embedder.get_sentence_embedding_dimension()})") | |
| # 2. Reranker (Lightweight MiniLM for cross-encoder reranking) | |
| self.ranker = Ranker(model_name="ms-marco-MiniLM-L-12-v2", cache_dir="./flashrank_cache") | |
| # 3. Load/Create Indexes (FAISS + BM25) | |
| if os.path.exists(self.index_path) and os.path.exists(self.metadata_path): | |
| try: | |
| self.index = faiss.read_index(self.index_path) | |
| with open(self.metadata_path, "rb") as f: | |
| self.metadata = pickle.load(f) | |
| # Load BM25 corpus if exists | |
| if os.path.exists(self.bm25_path): | |
| with open(self.bm25_path, "rb") as f: | |
| self.tokenized_corpus = pickle.load(f) | |
| self.bm25 = BM25Okapi(self.tokenized_corpus) if self.tokenized_corpus else None | |
| logger.info(f"β Loaded BM25 Index ({len(self.tokenized_corpus)} docs).") | |
| logger.info(f"β Loaded FAISS Index ({self.index.ntotal} chunks).") | |
| except Exception as e: | |
| logger.error(f"β Index load failed: {e}. Resetting.") | |
| self.create_new_index() | |
| else: | |
| self.create_new_index() | |
| def create_new_index(self): | |
| self.index = faiss.IndexFlatL2(384) | |
| self.metadata = [] | |
| self.tokenized_corpus = [] | |
| self.bm25 = None | |
| logger.info("π Created new empty Hybrid Index (FAISS + BM25).") | |
| def rebuild_bm25(self): | |
| """Rebuild BM25 index from tokenized corpus.""" | |
| if self.tokenized_corpus: | |
| self.bm25 = BM25Okapi(self.tokenized_corpus) | |
| logger.info(f"π Rebuilt BM25 index with {len(self.tokenized_corpus)} documents.") | |
| def fetch_page_content(self, url): | |
| """ | |
| Robust Fetcher with 3-Stage Fallback: | |
| 1. Trafilatura (fast, for static pages) | |
| 2. Requests (fallback for simple pages) | |
| 3. Playwright (headless browser for SPAs/JS-rendered pages) | |
| """ | |
| # Method A: Trafilatura Fetch (fastest) | |
| downloaded = trafilatura.fetch_url(url) | |
| if downloaded and len(downloaded) > 500: # Must have substantial content | |
| return downloaded | |
| # Method B: Requests Fallback (User-Agent Spoofing) | |
| try: | |
| logger.info(f"βοΈ Trafilatura insufficient, trying requests...") | |
| headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"} | |
| resp = requests.get(url, timeout=15, headers=headers) | |
| if resp.status_code == 200 and len(resp.text) > 500: | |
| return resp.text | |
| except Exception as e: | |
| logger.warning(f"β οΈ Requests failed for {url}: {e}") | |
| # Method C: Playwright Headless Browser (for SPAs like React/Vue/Angular) | |
| if PLAYWRIGHT_AVAILABLE: | |
| logger.info(f"π Using Playwright for JavaScript-rendered page: {url}") | |
| try: | |
| with sync_playwright() as p: | |
| browser = p.chromium.launch(headless=True) | |
| page = browser.new_page() | |
| page.goto(url, wait_until="networkidle", timeout=30000) | |
| # Wait for React to render | |
| page.wait_for_timeout(2000) | |
| html_content = page.content() | |
| browser.close() | |
| if html_content and len(html_content) > 500: | |
| logger.info(f"β Playwright successfully rendered page ({len(html_content)} chars)") | |
| return html_content | |
| except Exception as e: | |
| logger.error(f"β Playwright failed for {url}: {e}") | |
| else: | |
| logger.warning("β οΈ Playwright not available. Cannot render SPA.") | |
| return None | |
| def ingest_url(self, url: str): | |
| """ | |
| Extracts text with multi-stage fallback: | |
| 1. Trafilatura fetch + extract | |
| 2. If extraction < 100 chars β Playwright headless browser + extract | |
| 3. BeautifulSoup greedy as final fallback | |
| """ | |
| logger.info(f"π·οΈ Scraping: {url}") | |
| text = None | |
| html_content = None | |
| # Stage 1: Try Trafilatura (fast, for static pages) | |
| html_content = trafilatura.fetch_url(url) | |
| if html_content: | |
| text = trafilatura.extract(html_content, include_comments=False, include_tables=True, no_fallback=False) | |
| # Stage 2: If extraction failed, use Playwright for SPA rendering | |
| if (not text or len(text) < 100) and PLAYWRIGHT_AVAILABLE: | |
| logger.info(f"π Trafilatura extraction empty. Using Playwright for SPA: {url}") | |
| try: | |
| with sync_playwright() as p: | |
| browser = p.chromium.launch(headless=True) | |
| page = browser.new_page() | |
| page.goto(url, wait_until="networkidle", timeout=30000) | |
| # Wait for React/Vue/Angular to render | |
| page.wait_for_timeout(3000) | |
| html_content = page.content() | |
| browser.close() | |
| logger.info(f"β Playwright rendered {len(html_content)} chars of HTML") | |
| # Try extraction again | |
| text = trafilatura.extract(html_content, include_comments=False, include_tables=True, no_fallback=False) | |
| except Exception as e: | |
| logger.error(f"β Playwright failed: {e}") | |
| elif not text or len(text) < 100: | |
| if not PLAYWRIGHT_AVAILABLE: | |
| logger.warning("β οΈ Playwright not available. SPA content cannot be rendered.") | |
| # Stage 3: BeautifulSoup greedy fallback | |
| if (not text or len(text) < 100) and html_content: | |
| logger.info("β οΈ Extraction still empty. Using Greedy BeautifulSoup.") | |
| soup = BeautifulSoup(html_content, 'html.parser') | |
| for element in soup(['script', 'style', 'noscript', 'svg', 'header', 'footer', 'nav']): | |
| element.decompose() | |
| text = soup.get_text(separator='\n\n', strip=True) | |
| if not text or len(text) < 50: | |
| return "No readable text found after all extraction methods." | |
| logger.info(f"π Extracted {len(text)} chars.") | |
| # 3. Chunking Strategy (User's Logic: Paragraph Split) | |
| # We split by double newline to preserve paragraph structure. | |
| raw_chunks = [c.strip() for c in text.split('\n\n') if len(c.strip()) > 50] | |
| # Additional processing: If a paragraph is HUGE (>1000 chars), split it further | |
| final_chunks = [] | |
| for chunk in raw_chunks: | |
| if len(chunk) > 1000: | |
| # Simple split for massive blocks | |
| for i in range(0, len(chunk), 800): | |
| sub_chunk = chunk[i:i+800] | |
| final_chunks.append(f"Source: {url} | Content: {sub_chunk}") | |
| else: | |
| final_chunks.append(f"Source: {url} | Content: {chunk}") | |
| if not final_chunks: | |
| return "Text was too short to chunk." | |
| # 4. Vectorize & Store (FAISS + BM25) | |
| try: | |
| # FAISS: Semantic embeddings | |
| embeddings = self.embedder.encode(final_chunks) | |
| faiss.normalize_L2(embeddings) | |
| self.index.add(np.array(embeddings).astype('float32')) | |
| self.metadata.extend(final_chunks) | |
| # BM25: Tokenized corpus for lexical search | |
| for chunk in final_chunks: | |
| self.tokenized_corpus.append(tokenize(chunk)) | |
| self.rebuild_bm25() | |
| # Save all indexes to disk | |
| faiss.write_index(self.index, self.index_path) | |
| with open(self.metadata_path, "wb") as f: | |
| pickle.dump(self.metadata, f) | |
| with open(self.bm25_path, "wb") as f: | |
| pickle.dump(self.tokenized_corpus, f) | |
| return f"β Ingested {len(final_chunks)} chunks (FAISS + BM25)." | |
| except Exception as e: | |
| return f"Error vectorizing: {e}" | |
| def bm25_search(self, query: str, top_k: int = 15) -> list: | |
| """ | |
| BM25 lexical search for keyword precision. | |
| Returns list of (doc_idx, score) tuples. | |
| """ | |
| if not self.bm25 or not self.tokenized_corpus: | |
| return [] | |
| tokenized_query = tokenize(query) | |
| scores = self.bm25.get_scores(tokenized_query) | |
| # Get top-k indices sorted by score descending | |
| top_indices = np.argsort(scores)[::-1][:top_k] | |
| return [(int(idx), float(scores[idx])) for idx in top_indices if scores[idx] > 0] | |
| def faiss_search(self, query: str, top_k: int = 15) -> list: | |
| """ | |
| FAISS semantic search for meaning-based retrieval. | |
| Returns list of (doc_idx, score) tuples. | |
| """ | |
| if self.index.ntotal == 0: | |
| return [] | |
| query_vec = self.embedder.encode([query]) | |
| faiss.normalize_L2(query_vec) | |
| distances, indices = self.index.search(np.array(query_vec).astype('float32'), top_k) | |
| # Convert L2 distance to similarity score (lower distance = higher score) | |
| results = [] | |
| for i, idx in enumerate(indices[0]): | |
| if idx != -1 and idx < len(self.metadata): | |
| # L2 distance to similarity: 1 / (1 + distance) | |
| score = 1.0 / (1.0 + distances[0][i]) | |
| results.append((int(idx), float(score))) | |
| return results | |
| def rrf_fusion(self, bm25_results: list, faiss_results: list, k: int = 60) -> list: | |
| """ | |
| Reciprocal Rank Fusion (RRF) to combine BM25 and FAISS results. | |
| RRF score = sum(1 / (k + rank)) for each result list. | |
| k=60 is the standard RRF constant. | |
| """ | |
| rrf_scores = {} | |
| # Score from BM25 rankings | |
| for rank, (doc_idx, _) in enumerate(bm25_results): | |
| if doc_idx not in rrf_scores: | |
| rrf_scores[doc_idx] = 0.0 | |
| rrf_scores[doc_idx] += 1.0 / (k + rank + 1) | |
| # Score from FAISS rankings | |
| for rank, (doc_idx, _) in enumerate(faiss_results): | |
| if doc_idx not in rrf_scores: | |
| rrf_scores[doc_idx] = 0.0 | |
| rrf_scores[doc_idx] += 1.0 / (k + rank + 1) | |
| # Sort by RRF score descending | |
| sorted_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True) | |
| return sorted_results | |
| def search(self, query: str, top_k: int = 15) -> str: | |
| """ | |
| Hybrid Search Pipeline: | |
| 1. BM25 for keyword/lexical matching | |
| 2. FAISS for semantic/meaning matching | |
| 3. RRF fusion to combine rankings | |
| 4. FlashRank cross-encoder reranking for final precision | |
| """ | |
| if self.index.ntotal == 0: | |
| return "" | |
| # Stage 1: Parallel retrieval from both indexes | |
| bm25_results = self.bm25_search(query, top_k) | |
| faiss_results = self.faiss_search(query, top_k) | |
| logger.info(f"π BM25: {len(bm25_results)} hits | FAISS: {len(faiss_results)} hits") | |
| # Stage 2: RRF Fusion | |
| if bm25_results or faiss_results: | |
| fused_results = self.rrf_fusion(bm25_results, faiss_results) | |
| else: | |
| return "" | |
| # Stage 3: Prepare candidates for reranking | |
| candidates = [] | |
| for doc_idx, rrf_score in fused_results[:top_k]: | |
| if doc_idx < len(self.metadata): | |
| candidates.append({"id": int(doc_idx), "text": self.metadata[doc_idx], "rrf_score": rrf_score}) | |
| if not candidates: | |
| return "" | |
| logger.info(f"π RRF Fusion: {len(candidates)} unique candidates") | |
| # Stage 4: Cross-encoder reranking (FlashRank) | |
| rerank_request = RerankRequest(query=query, passages=candidates) | |
| results = self.ranker.rerank(rerank_request) | |
| # Return Top 3 re-ranked chunks (balanced context) | |
| top_results = results[:3] | |
| # DETAILED LOGGING: Show actual results being sent to model | |
| logger.info(f"β Reranked: Returning top {len(top_results)} results") | |
| logger.info("=" * 60) | |
| logger.info("π TOP RANKED RESULTS BEING SENT TO MODEL CONTEXT:") | |
| logger.info("=" * 60) | |
| for i, result in enumerate(top_results, 1): | |
| text = result.get('text', '') | |
| score = result.get('score', 0) | |
| # Truncate for logging (first 200 chars) | |
| preview = text[:200].replace('\n', ' ').strip() | |
| if len(text) > 200: | |
| preview += "..." | |
| logger.info(f" [{i}] Score: {score:.4f} | Preview: {preview}") | |
| logger.info("=" * 60) | |
| return "\n\n".join([f"[Hybrid RAG] {r['text']}" for r in top_results]) | |
| local_kb = KnowledgeBase() | |
| # ========================================== | |
| # π οΈ INGESTION ZONE (RUN THIS TO BUILD DB) | |
| # ========================================== | |
| if __name__ == "__main__": | |
| kb = local_kb | |
| # Your Verified URLs | |
| urls = [ | |
| "https://www.azoneinstituteoftechnology.co.in/", | |
| "https://www.azoneinstituteoftechnology.co.in/courses", | |
| "https://www.azoneinstituteoftechnology.co.in/about", | |
| "https://www.azoneinstituteoftechnology.co.in/services", | |
| "https://www.azoneinstituteoftechnology.co.in/career", | |
| "https://www.azoneinstituteoftechnology.co.in/contact", | |
| "https://www.azoneinstituteoftechnology.co.in/contact#", | |
| ] | |
| print("\nπ Starting ROBUST Knowledge Ingestion...") | |
| print("="*50) | |
| for url in urls: | |
| result = kb.ingest_url(url) | |
| print(f"Result: {result}") | |
| time.sleep(1) # Polite delay | |
| print("="*50) | |
| # Test | |
| print("\nπ§ͺ Testing Retrieval...") | |
| test_query = "What is Azone Institute of Technology?" | |
| print(f"Query: {test_query}") | |
| answer = kb.search(test_query) | |
| print("-" * 20) | |
| print(answer) | |
| print("-" * 20) |