ttt / rag_engine.py
Rajhuggingface4253's picture
Update rag_engine.py
dda6857 verified
# rag_engine.py
import os
import pickle
import numpy as np
import faiss
import requests
import trafilatura
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer
from flashrank import Ranker, RerankRequest
import logging
import time
import re
# BM25 for keyword-based lexical search
from rank_bm25 import BM25Okapi
# Playwright for SPA/JavaScript-rendered pages
try:
from playwright.sync_api import sync_playwright
PLAYWRIGHT_AVAILABLE = True
except ImportError:
PLAYWRIGHT_AVAILABLE = False
logging.warning("⚠️ Playwright not installed. SPA scraping will be limited.")
# Setup basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def tokenize(text: str) -> list:
"""Simple tokenizer for BM25: lowercase, alphanumeric only."""
return re.findall(r'\w+', text.lower())
class KnowledgeBase:
"""
High-Performance Hybrid RAG Engine.
Combines: FAISS (Semantic) + BM25 (Lexical) + RRF (Fusion) + FlashRank (Reranking).
"""
def __init__(self, index_path="faiss_index.bin", metadata_path="metadata.pkl", bm25_path="bm25_corpus.pkl"):
self.index_path = index_path
self.metadata_path = metadata_path
self.bm25_path = bm25_path
self.metadata = []
self.tokenized_corpus = [] # For BM25
self.bm25 = None
logger.info("πŸ“š Initializing Hybrid Knowledge Base (FAISS + BM25 + RRF)...")
# 1. Embedding Model (384 dim, Fast) - Verified MiniLM
self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
logger.info(f"βœ… Loaded SentenceTransformer: all-MiniLM-L6-v2 (dim={self.embedder.get_sentence_embedding_dimension()})")
# 2. Reranker (Lightweight MiniLM for cross-encoder reranking)
self.ranker = Ranker(model_name="ms-marco-MiniLM-L-12-v2", cache_dir="./flashrank_cache")
# 3. Load/Create Indexes (FAISS + BM25)
if os.path.exists(self.index_path) and os.path.exists(self.metadata_path):
try:
self.index = faiss.read_index(self.index_path)
with open(self.metadata_path, "rb") as f:
self.metadata = pickle.load(f)
# Load BM25 corpus if exists
if os.path.exists(self.bm25_path):
with open(self.bm25_path, "rb") as f:
self.tokenized_corpus = pickle.load(f)
self.bm25 = BM25Okapi(self.tokenized_corpus) if self.tokenized_corpus else None
logger.info(f"βœ… Loaded BM25 Index ({len(self.tokenized_corpus)} docs).")
logger.info(f"βœ… Loaded FAISS Index ({self.index.ntotal} chunks).")
except Exception as e:
logger.error(f"❌ Index load failed: {e}. Resetting.")
self.create_new_index()
else:
self.create_new_index()
def create_new_index(self):
self.index = faiss.IndexFlatL2(384)
self.metadata = []
self.tokenized_corpus = []
self.bm25 = None
logger.info("πŸ†• Created new empty Hybrid Index (FAISS + BM25).")
def rebuild_bm25(self):
"""Rebuild BM25 index from tokenized corpus."""
if self.tokenized_corpus:
self.bm25 = BM25Okapi(self.tokenized_corpus)
logger.info(f"πŸ”„ Rebuilt BM25 index with {len(self.tokenized_corpus)} documents.")
def fetch_page_content(self, url):
"""
Robust Fetcher with 3-Stage Fallback:
1. Trafilatura (fast, for static pages)
2. Requests (fallback for simple pages)
3. Playwright (headless browser for SPAs/JS-rendered pages)
"""
# Method A: Trafilatura Fetch (fastest)
downloaded = trafilatura.fetch_url(url)
if downloaded and len(downloaded) > 500: # Must have substantial content
return downloaded
# Method B: Requests Fallback (User-Agent Spoofing)
try:
logger.info(f"βš™οΈ Trafilatura insufficient, trying requests...")
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"}
resp = requests.get(url, timeout=15, headers=headers)
if resp.status_code == 200 and len(resp.text) > 500:
return resp.text
except Exception as e:
logger.warning(f"⚠️ Requests failed for {url}: {e}")
# Method C: Playwright Headless Browser (for SPAs like React/Vue/Angular)
if PLAYWRIGHT_AVAILABLE:
logger.info(f"🎭 Using Playwright for JavaScript-rendered page: {url}")
try:
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
page = browser.new_page()
page.goto(url, wait_until="networkidle", timeout=30000)
# Wait for React to render
page.wait_for_timeout(2000)
html_content = page.content()
browser.close()
if html_content and len(html_content) > 500:
logger.info(f"βœ… Playwright successfully rendered page ({len(html_content)} chars)")
return html_content
except Exception as e:
logger.error(f"❌ Playwright failed for {url}: {e}")
else:
logger.warning("⚠️ Playwright not available. Cannot render SPA.")
return None
def ingest_url(self, url: str):
"""
Extracts text with multi-stage fallback:
1. Trafilatura fetch + extract
2. If extraction < 100 chars β†’ Playwright headless browser + extract
3. BeautifulSoup greedy as final fallback
"""
logger.info(f"πŸ•·οΈ Scraping: {url}")
text = None
html_content = None
# Stage 1: Try Trafilatura (fast, for static pages)
html_content = trafilatura.fetch_url(url)
if html_content:
text = trafilatura.extract(html_content, include_comments=False, include_tables=True, no_fallback=False)
# Stage 2: If extraction failed, use Playwright for SPA rendering
if (not text or len(text) < 100) and PLAYWRIGHT_AVAILABLE:
logger.info(f"🎭 Trafilatura extraction empty. Using Playwright for SPA: {url}")
try:
with sync_playwright() as p:
browser = p.chromium.launch(headless=True)
page = browser.new_page()
page.goto(url, wait_until="networkidle", timeout=30000)
# Wait for React/Vue/Angular to render
page.wait_for_timeout(3000)
html_content = page.content()
browser.close()
logger.info(f"βœ… Playwright rendered {len(html_content)} chars of HTML")
# Try extraction again
text = trafilatura.extract(html_content, include_comments=False, include_tables=True, no_fallback=False)
except Exception as e:
logger.error(f"❌ Playwright failed: {e}")
elif not text or len(text) < 100:
if not PLAYWRIGHT_AVAILABLE:
logger.warning("⚠️ Playwright not available. SPA content cannot be rendered.")
# Stage 3: BeautifulSoup greedy fallback
if (not text or len(text) < 100) and html_content:
logger.info("⚠️ Extraction still empty. Using Greedy BeautifulSoup.")
soup = BeautifulSoup(html_content, 'html.parser')
for element in soup(['script', 'style', 'noscript', 'svg', 'header', 'footer', 'nav']):
element.decompose()
text = soup.get_text(separator='\n\n', strip=True)
if not text or len(text) < 50:
return "No readable text found after all extraction methods."
logger.info(f"πŸ“„ Extracted {len(text)} chars.")
# 3. Chunking Strategy (User's Logic: Paragraph Split)
# We split by double newline to preserve paragraph structure.
raw_chunks = [c.strip() for c in text.split('\n\n') if len(c.strip()) > 50]
# Additional processing: If a paragraph is HUGE (>1000 chars), split it further
final_chunks = []
for chunk in raw_chunks:
if len(chunk) > 1000:
# Simple split for massive blocks
for i in range(0, len(chunk), 800):
sub_chunk = chunk[i:i+800]
final_chunks.append(f"Source: {url} | Content: {sub_chunk}")
else:
final_chunks.append(f"Source: {url} | Content: {chunk}")
if not final_chunks:
return "Text was too short to chunk."
# 4. Vectorize & Store (FAISS + BM25)
try:
# FAISS: Semantic embeddings
embeddings = self.embedder.encode(final_chunks)
faiss.normalize_L2(embeddings)
self.index.add(np.array(embeddings).astype('float32'))
self.metadata.extend(final_chunks)
# BM25: Tokenized corpus for lexical search
for chunk in final_chunks:
self.tokenized_corpus.append(tokenize(chunk))
self.rebuild_bm25()
# Save all indexes to disk
faiss.write_index(self.index, self.index_path)
with open(self.metadata_path, "wb") as f:
pickle.dump(self.metadata, f)
with open(self.bm25_path, "wb") as f:
pickle.dump(self.tokenized_corpus, f)
return f"βœ… Ingested {len(final_chunks)} chunks (FAISS + BM25)."
except Exception as e:
return f"Error vectorizing: {e}"
def bm25_search(self, query: str, top_k: int = 15) -> list:
"""
BM25 lexical search for keyword precision.
Returns list of (doc_idx, score) tuples.
"""
if not self.bm25 or not self.tokenized_corpus:
return []
tokenized_query = tokenize(query)
scores = self.bm25.get_scores(tokenized_query)
# Get top-k indices sorted by score descending
top_indices = np.argsort(scores)[::-1][:top_k]
return [(int(idx), float(scores[idx])) for idx in top_indices if scores[idx] > 0]
def faiss_search(self, query: str, top_k: int = 15) -> list:
"""
FAISS semantic search for meaning-based retrieval.
Returns list of (doc_idx, score) tuples.
"""
if self.index.ntotal == 0:
return []
query_vec = self.embedder.encode([query])
faiss.normalize_L2(query_vec)
distances, indices = self.index.search(np.array(query_vec).astype('float32'), top_k)
# Convert L2 distance to similarity score (lower distance = higher score)
results = []
for i, idx in enumerate(indices[0]):
if idx != -1 and idx < len(self.metadata):
# L2 distance to similarity: 1 / (1 + distance)
score = 1.0 / (1.0 + distances[0][i])
results.append((int(idx), float(score)))
return results
def rrf_fusion(self, bm25_results: list, faiss_results: list, k: int = 60) -> list:
"""
Reciprocal Rank Fusion (RRF) to combine BM25 and FAISS results.
RRF score = sum(1 / (k + rank)) for each result list.
k=60 is the standard RRF constant.
"""
rrf_scores = {}
# Score from BM25 rankings
for rank, (doc_idx, _) in enumerate(bm25_results):
if doc_idx not in rrf_scores:
rrf_scores[doc_idx] = 0.0
rrf_scores[doc_idx] += 1.0 / (k + rank + 1)
# Score from FAISS rankings
for rank, (doc_idx, _) in enumerate(faiss_results):
if doc_idx not in rrf_scores:
rrf_scores[doc_idx] = 0.0
rrf_scores[doc_idx] += 1.0 / (k + rank + 1)
# Sort by RRF score descending
sorted_results = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
return sorted_results
def search(self, query: str, top_k: int = 15) -> str:
"""
Hybrid Search Pipeline:
1. BM25 for keyword/lexical matching
2. FAISS for semantic/meaning matching
3. RRF fusion to combine rankings
4. FlashRank cross-encoder reranking for final precision
"""
if self.index.ntotal == 0:
return ""
# Stage 1: Parallel retrieval from both indexes
bm25_results = self.bm25_search(query, top_k)
faiss_results = self.faiss_search(query, top_k)
logger.info(f"πŸ” BM25: {len(bm25_results)} hits | FAISS: {len(faiss_results)} hits")
# Stage 2: RRF Fusion
if bm25_results or faiss_results:
fused_results = self.rrf_fusion(bm25_results, faiss_results)
else:
return ""
# Stage 3: Prepare candidates for reranking
candidates = []
for doc_idx, rrf_score in fused_results[:top_k]:
if doc_idx < len(self.metadata):
candidates.append({"id": int(doc_idx), "text": self.metadata[doc_idx], "rrf_score": rrf_score})
if not candidates:
return ""
logger.info(f"πŸ”„ RRF Fusion: {len(candidates)} unique candidates")
# Stage 4: Cross-encoder reranking (FlashRank)
rerank_request = RerankRequest(query=query, passages=candidates)
results = self.ranker.rerank(rerank_request)
# Return Top 3 re-ranked chunks (balanced context)
top_results = results[:3]
# DETAILED LOGGING: Show actual results being sent to model
logger.info(f"βœ… Reranked: Returning top {len(top_results)} results")
logger.info("=" * 60)
logger.info("πŸ“‹ TOP RANKED RESULTS BEING SENT TO MODEL CONTEXT:")
logger.info("=" * 60)
for i, result in enumerate(top_results, 1):
text = result.get('text', '')
score = result.get('score', 0)
# Truncate for logging (first 200 chars)
preview = text[:200].replace('\n', ' ').strip()
if len(text) > 200:
preview += "..."
logger.info(f" [{i}] Score: {score:.4f} | Preview: {preview}")
logger.info("=" * 60)
return "\n\n".join([f"[Hybrid RAG] {r['text']}" for r in top_results])
local_kb = KnowledgeBase()
# ==========================================
# πŸ› οΈ INGESTION ZONE (RUN THIS TO BUILD DB)
# ==========================================
if __name__ == "__main__":
kb = local_kb
# Your Verified URLs
urls = [
"https://www.azoneinstituteoftechnology.co.in/",
"https://www.azoneinstituteoftechnology.co.in/courses",
"https://www.azoneinstituteoftechnology.co.in/about",
"https://www.azoneinstituteoftechnology.co.in/services",
"https://www.azoneinstituteoftechnology.co.in/career",
"https://www.azoneinstituteoftechnology.co.in/contact",
"https://www.azoneinstituteoftechnology.co.in/contact#",
]
print("\nπŸš€ Starting ROBUST Knowledge Ingestion...")
print("="*50)
for url in urls:
result = kb.ingest_url(url)
print(f"Result: {result}")
time.sleep(1) # Polite delay
print("="*50)
# Test
print("\nπŸ§ͺ Testing Retrieval...")
test_query = "What is Azone Institute of Technology?"
print(f"Query: {test_query}")
answer = kb.search(test_query)
print("-" * 20)
print(answer)
print("-" * 20)