crazy_bot / rag_engine.py
Wall06's picture
Update rag_engine.py
65562e0 verified
raw
history blame
3.95 kB
from __future__ import annotations
import os
import re
import textwrap
from pathlib import Path
from typing import Any
import faiss
import numpy as np
import requests
import spacy
from bs4 import BeautifulSoup
from huggingface_hub import InferenceClient
from pypdf import PdfReader
from sentence_transformers import SentenceTransformer
# ── Config ─────────────────────────────────────────────
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL = "HuggingFaceH4/zephyr-7b-beta"
CHUNK_SIZE = 400
CHUNK_OVERLAP = 80
TOP_K = 4
INTENT_MAP = {
"summarise": ["summarise", "summary"],
"explain": ["explain", "what is"],
"list": ["list"],
"compare": ["compare"],
"find": ["find"],
}
# ── Engine ─────────────────────────────────────────────
class RAGEngine:
def __init__(self):
print("Loading embedding model...")
self.embedder = SentenceTransformer(EMBED_MODEL)
self.hf_client = InferenceClient(token=os.getenv("HF_TOKEN"))
self._load_spacy()
self.reset()
def _load_spacy(self):
try:
self.nlp = spacy.load("en_core_web_sm")
except:
import subprocess, sys
subprocess.run([sys.executable, "-m", "spacy", "download", "en_core_web_sm"])
self.nlp = spacy.load("en_core_web_sm")
def reset(self):
self.chunks = []
self.index = None
@property
def ready(self):
return self.index is not None and len(self.chunks) > 0
# ── Loaders ─────────────────────────────────────
def load_pdf(self, path):
reader = PdfReader(path)
text = " ".join(p.extract_text() or "" for p in reader.pages)
self._build_index(text)
return f"βœ… Loaded PDF ({len(self.chunks)} chunks)"
def load_url(self, url):
r = requests.get(url)
soup = BeautifulSoup(r.text, "html.parser")
text = soup.get_text()
self._build_index(text)
return f"βœ… Loaded URL ({len(self.chunks)} chunks)"
def load_text(self, text):
self._build_index(text)
return f"βœ… Loaded text ({len(self.chunks)} chunks)"
# ── Chunk + Index ───────────────────────────────
def _chunk(self, text):
text = re.sub(r"\s+", " ", text)
chunks, i = [], 0
while i < len(text):
chunks.append(text[i:i+CHUNK_SIZE])
i += CHUNK_SIZE - CHUNK_OVERLAP
return chunks
def _build_index(self, text):
self.chunks = self._chunk(text)
emb = self.embedder.encode(self.chunks)
emb = np.array(emb).astype("float32")
faiss.normalize_L2(emb)
self.index = faiss.IndexFlatIP(emb.shape[1])
self.index.add(emb)
# ── Query ───────────────────────────────────────
def _retrieve(self, q):
emb = self.embedder.encode([q])
emb = np.array(emb).astype("float32")
faiss.normalize_L2(emb)
_, idx = self.index.search(emb, TOP_K)
return [self.chunks[i] for i in idx[0]]
def answer(self, query):
if not self.ready:
return "⚠️ Load data first."
chunks = self._retrieve(query)
prompt = f"""
Answer using only this context:
{chunks}
Question: {query}
"""
try:
res = self.hf_client.text_generation(
prompt,
model=LLM_MODEL,
max_new_tokens=300,
)
return res.strip()
except Exception as e:
return f"⚠️ API error: {e}\n\n{chunks[0]}"