qwe / rag_engine.py
Rajhuggingface4253's picture
Update rag_engine.py
6ae06e3 verified
# rag_engine.py
import os
import pickle
import numpy as np
import faiss
import requests
import trafilatura
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer
from flashrank import Ranker, RerankRequest
import logging
import time
# Setup basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class KnowledgeBase:
"""
High-Performance Local RAG Engine.
Fusion Logic: User's Robust Extraction + FAISS Vector Storage.
"""
def __init__(self, index_path="faiss_index.bin", metadata_path="metadata.pkl"):
self.index_path = index_path
self.metadata_path = metadata_path
self.metadata = []
logger.info("πŸ“š Initializing Knowledge Base (RAG Engine)...")
# 1. Embedding Model (384 dim, Fast)
self.embedder = SentenceTransformer('all-MiniLM-L6-v2', device='cpu')
# 2. Reranker (Lightweight)
self.ranker = Ranker(model_name="ms-marco-MiniLM-L-12-v2", cache_dir="./flashrank_cache")
# 3. Load/Create FAISS Index
if os.path.exists(self.index_path) and os.path.exists(self.metadata_path):
try:
self.index = faiss.read_index(self.index_path)
with open(self.metadata_path, "rb") as f:
self.metadata = pickle.load(f)
logger.info(f"βœ… Loaded RAG Index ({self.index.ntotal} chunks).")
except Exception as e:
logger.error(f"❌ Index load failed: {e}. Resetting.")
self.create_new_index()
else:
self.create_new_index()
def create_new_index(self):
self.index = faiss.IndexFlatL2(384)
self.metadata = []
logger.info("πŸ†• Created new empty RAG Index.")
def fetch_page_content(self, url):
"""
Robust Fetcher: Tries Trafilatura first, falls back to Requests.
Logic borrowed from user's successful script.
"""
# Method A: Trafilatura Fetch
downloaded = trafilatura.fetch_url(url)
if downloaded:
return downloaded
# Method B: Requests Fallback (User-Agent Spoofing)
try:
logger.warning(f"⚠️ Trafilatura fetch failed for {url}, trying requests...")
headers = {"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) ToolboxesAI-Bot/1.0"}
resp = requests.get(url, timeout=15, headers=headers)
if resp.status_code == 200:
return resp.text
except Exception as e:
logger.error(f"❌ Requests failed for {url}: {e}")
return None
def ingest_url(self, url: str):
"""
Extracts text using Trafilatura (Primary) -> BeautifulSoup (Backup).
Chunks by PARAGRAPH (\n\n).
"""
logger.info(f"πŸ•·οΈ Scraping: {url}")
# 1. Fetch
html_content = self.fetch_page_content(url)
if not html_content:
return f"Failed to fetch {url}"
# 2. Extract Text
# Strategy A: Trafilatura (Best for articles/docs)
text = trafilatura.extract(html_content, include_comments=False, include_tables=True, no_fallback=False)
# Strategy B: BeautifulSoup Greedy (If Trafilatura thinks page is empty/nav-only)
if not text or len(text) < 100:
logger.info("⚠️ Trafilatura extraction empty. Using Greedy BeautifulSoup.")
soup = BeautifulSoup(html_content, 'html.parser')
for element in soup(['script', 'style', 'noscript', 'svg']):
element.decompose()
text = soup.get_text(separator='\n\n', strip=True)
if not text:
return "No readable text found."
logger.info(f"πŸ“„ Extracted {len(text)} chars.")
# 3. Chunking Strategy (User's Logic: Paragraph Split)
# We split by double newline to preserve paragraph structure.
raw_chunks = [c.strip() for c in text.split('\n\n') if len(c.strip()) > 50]
# Additional processing: If a paragraph is HUGE (>1000 chars), split it further
final_chunks = []
for chunk in raw_chunks:
if len(chunk) > 1000:
# Simple split for massive blocks
for i in range(0, len(chunk), 800):
sub_chunk = chunk[i:i+800]
final_chunks.append(f"Source: {url} | Content: {sub_chunk}")
else:
final_chunks.append(f"Source: {url} | Content: {chunk}")
if not final_chunks:
return "Text was too short to chunk."
# 4. Vectorize & Store
try:
embeddings = self.embedder.encode(final_chunks)
faiss.normalize_L2(embeddings)
self.index.add(np.array(embeddings).astype('float32'))
self.metadata.extend(final_chunks)
# Save to disk
faiss.write_index(self.index, self.index_path)
with open(self.metadata_path, "wb") as f:
pickle.dump(self.metadata, f)
return f"βœ… Ingested {len(final_chunks)} chunks."
except Exception as e:
return f"Error vectorizing: {e}"
def search(self, query: str, top_k: int = 10) -> str:
"""
Retrieves docs using FAISS -> Reranks using FlashRank.
"""
if self.index.ntotal == 0:
return ""
# 1. Coarse Search (FAISS)
query_vec = self.embedder.encode([query])
faiss.normalize_L2(query_vec)
distances, indices = self.index.search(np.array(query_vec).astype('float32'), top_k)
candidates = []
for idx in indices[0]:
if idx != -1 and idx < len(self.metadata):
candidates.append({"id": int(idx), "text": self.metadata[idx]})
if not candidates:
return ""
# 2. Rerank (FlashRank)
rerank_request = RerankRequest(query=query, passages=candidates)
results = self.ranker.rerank(rerank_request)
# Return Top 5 re-ranked chunks
top_results = results[:1]
return "\n\n".join([f"[Local RAG] {r['text']}" for r in top_results])
local_kb = KnowledgeBase()
# ==========================================
# πŸ› οΈ INGESTION ZONE (RUN THIS TO BUILD DB)
# ==========================================
if __name__ == "__main__":
kb = local_kb
# Your Verified URLs
urls = [
"https://toolboxesai.com",
"https://toolboxesai.com/hub",
"https://toolboxesai.com/app",
"https://toolboxesai.com/app-guide",
"https://toolboxesai.com/privacy",
"https://toolboxesai.com/terms",
"https://toolboxesai.com/contact",
"https://toolboxesai.com/about",
"https://compressorpro.toolboxesai.com",
"https://compressorpro.toolboxesai.com/resizer",
"https://compressorpro.toolboxesai.com/enhancer",
"https://compressorpro.toolboxesai.com/color-grader",
"https://compressorpro.toolboxesai.com/print",
"https://compressorpro.toolboxesai.com/user-guide"
]
print("\nπŸš€ Starting ROBUST Knowledge Ingestion...")
print("="*50)
for url in urls:
result = kb.ingest_url(url)
print(f"Result: {result}")
time.sleep(1) # Polite delay
print("="*50)
# Test
print("\nπŸ§ͺ Testing Retrieval...")
test_query = "What is toolboxesai?"
print(f"Query: {test_query}")
answer = kb.search(test_query)
print("-" * 20)
print(answer)
print("-" * 20)