Update rag_engine.py
Browse files- rag_engine.py +26 -91
rag_engine.py
CHANGED
|
@@ -1,119 +1,54 @@
|
|
| 1 |
import re
|
| 2 |
-
import textwrap
|
| 3 |
import faiss
|
| 4 |
import numpy as np
|
| 5 |
import requests
|
| 6 |
import spacy
|
| 7 |
-
from typing import Any
|
| 8 |
from bs4 import BeautifulSoup
|
| 9 |
from huggingface_hub import InferenceClient
|
| 10 |
from pypdf import PdfReader
|
| 11 |
-
from sentence_transformers import
|
| 12 |
-
|
| 13 |
-
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
-
EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
|
| 15 |
-
LLM_MODEL = "HuggingFaceH4/zephyr-7b-beta"
|
| 16 |
-
CHUNK_SIZE = 400
|
| 17 |
-
CHUNK_OVERLAP = 80
|
| 18 |
-
TOP_K = 4
|
| 19 |
-
|
| 20 |
-
INTENT_MAP = {
|
| 21 |
-
"summarise": ["summarise", "summarize", "summary", "overview", "brief", "key points"],
|
| 22 |
-
"explain": ["explain", "what is", "what are", "define", "describe", "tell me about"],
|
| 23 |
-
"list": ["list", "enumerate", "give me a list", "what are the types"],
|
| 24 |
-
}
|
| 25 |
|
| 26 |
class RAGEngine:
|
| 27 |
def __init__(self):
|
| 28 |
-
self.embed_model = SentenceTransformer(
|
| 29 |
self.hf_client = InferenceClient()
|
| 30 |
-
|
| 31 |
-
# Robust model loading for HuggingFace
|
| 32 |
try:
|
| 33 |
self.nlp = spacy.load("en_core_web_sm")
|
| 34 |
except:
|
| 35 |
-
|
| 36 |
-
os.system("python -m spacy download en_core_web_sm")
|
| 37 |
-
self.nlp = spacy.load("en_core_web_sm")
|
| 38 |
-
|
| 39 |
self.reset()
|
| 40 |
|
| 41 |
def reset(self):
|
| 42 |
-
self.chunks = []
|
| 43 |
-
self.index = None
|
| 44 |
-
self.ready = False
|
| 45 |
|
| 46 |
-
def
|
| 47 |
text = re.sub(r'\s+', ' ', text).strip()
|
| 48 |
-
|
| 49 |
-
for i in range(0, len(text), CHUNK_SIZE - CHUNK_OVERLAP):
|
| 50 |
-
chunk = text[i : i + CHUNK_SIZE]
|
| 51 |
-
if len(chunk) > 20:
|
| 52 |
-
new_chunks.append(chunk)
|
| 53 |
-
|
| 54 |
-
self.chunks = new_chunks
|
| 55 |
embeddings = self.embed_model.encode(self.chunks)
|
| 56 |
self.index = faiss.IndexFlatL2(embeddings.shape[1])
|
| 57 |
self.index.add(np.array(embeddings).astype("float32"))
|
| 58 |
self.ready = True
|
| 59 |
|
| 60 |
-
def load_pdf(self, path
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
self.
|
| 64 |
-
return f"β
Loaded PDF: {len(self.chunks)} chunks indexed."
|
| 65 |
|
| 66 |
-
def load_url(self, url
|
| 67 |
res = requests.get(url, timeout=10)
|
| 68 |
soup = BeautifulSoup(res.text, "html.parser")
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
self.
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
return "
|
| 84 |
-
|
| 85 |
-
def extract_entities(self, text: str) -> dict:
|
| 86 |
-
doc = self.nlp(text)
|
| 87 |
-
entities = {}
|
| 88 |
-
for ent in doc.ents:
|
| 89 |
-
if ent.label_ not in entities:
|
| 90 |
-
entities[ent.label_] = []
|
| 91 |
-
if ent.text not in entities[ent.label_]:
|
| 92 |
-
entities[ent.label_].append(ent.text)
|
| 93 |
-
return entities
|
| 94 |
-
|
| 95 |
-
def _retrieve(self, query: str) -> list:
|
| 96 |
-
query_vec = self.embed_model.encode([query]).astype("float32")
|
| 97 |
-
_, indices = self.index.search(query_vec, TOP_K)
|
| 98 |
-
return [self.chunks[i] for i in indices[0]]
|
| 99 |
-
|
| 100 |
-
def answer(self, query: str) -> str:
|
| 101 |
-
if not self.ready:
|
| 102 |
-
return "β οΈ No knowledge source loaded."
|
| 103 |
-
chunks = self._retrieve(query)
|
| 104 |
-
context = "\n".join([f"- {c}" for c in chunks])
|
| 105 |
-
prompt = f"<|system|>\nUse context to answer accurately.\n</s>\n<|user|>\nContext: {context}\nQuestion: {query}\n</s>\n<|assistant|>"
|
| 106 |
-
|
| 107 |
-
try:
|
| 108 |
-
response = self.hf_client.text_generation(prompt, model=LLM_MODEL, max_new_tokens=512)
|
| 109 |
-
return response.strip()
|
| 110 |
-
except Exception as e:
|
| 111 |
-
return f"β οΈ LLM Error. Top match: {chunks[0]}"
|
| 112 |
-
|
| 113 |
-
def answer_with_nlp(self, query: str) -> tuple:
|
| 114 |
-
ans = self.answer(query)
|
| 115 |
-
info = {
|
| 116 |
-
"intent": self.detect_intent(query),
|
| 117 |
-
"entities_found": self.extract_entities(ans)
|
| 118 |
-
}
|
| 119 |
-
return ans, info
|
|
|
|
| 1 |
import re
|
|
|
|
| 2 |
import faiss
|
| 3 |
import numpy as np
|
| 4 |
import requests
|
| 5 |
import spacy
|
|
|
|
| 6 |
from bs4 import BeautifulSoup
|
| 7 |
from huggingface_hub import InferenceClient
|
| 8 |
from pypdf import PdfReader
|
| 9 |
+
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class RAGEngine:
|
| 12 |
def __init__(self):
|
| 13 |
+
self.embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 14 |
self.hf_client = InferenceClient()
|
|
|
|
|
|
|
| 15 |
try:
|
| 16 |
self.nlp = spacy.load("en_core_web_sm")
|
| 17 |
except:
|
| 18 |
+
self.nlp = None # Fallback if model fails to load
|
|
|
|
|
|
|
|
|
|
| 19 |
self.reset()
|
| 20 |
|
| 21 |
def reset(self):
|
| 22 |
+
self.chunks, self.index, self.ready = [], None, False
|
|
|
|
|
|
|
| 23 |
|
| 24 |
+
def _index_text(self, text):
|
| 25 |
text = re.sub(r'\s+', ' ', text).strip()
|
| 26 |
+
self.chunks = [text[i:i+400] for i in range(0, len(text), 320) if len(text[i:i+400]) > 20]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
embeddings = self.embed_model.encode(self.chunks)
|
| 28 |
self.index = faiss.IndexFlatL2(embeddings.shape[1])
|
| 29 |
self.index.add(np.array(embeddings).astype("float32"))
|
| 30 |
self.ready = True
|
| 31 |
|
| 32 |
+
def load_pdf(self, path):
|
| 33 |
+
text = "".join([p.extract_text() for p in PdfReader(path).pages])
|
| 34 |
+
self._index_text(text)
|
| 35 |
+
return f"β
Indexed {len(self.chunks)} PDF sections."
|
|
|
|
| 36 |
|
| 37 |
+
def load_url(self, url):
|
| 38 |
res = requests.get(url, timeout=10)
|
| 39 |
soup = BeautifulSoup(res.text, "html.parser")
|
| 40 |
+
self._index_text(soup.get_text())
|
| 41 |
+
return f"β
Indexed {len(self.chunks)} URL sections."
|
| 42 |
+
|
| 43 |
+
def answer_with_nlp(self, query):
|
| 44 |
+
if not self.ready: return "Please upload data first.", {}
|
| 45 |
+
# Retrieval
|
| 46 |
+
q_vec = self.embed_model.encode([query]).astype("float32")
|
| 47 |
+
_, idx = self.index.search(q_vec, 3)
|
| 48 |
+
context = "\n".join([self.chunks[i] for i in idx[0]])
|
| 49 |
+
# Generation
|
| 50 |
+
prompt = f"<|system|>Answer based on context.</s><|user|>Context: {context}\nQuestion: {query}</s><|assistant|>"
|
| 51 |
+
ans = self.hf_client.text_generation(prompt, model="HuggingFaceH4/zephyr-7b-beta", max_new_tokens=512).strip()
|
| 52 |
+
# NLP Analysis
|
| 53 |
+
entities = {ent.label_: ent.text for ent in self.nlp(ans).ents} if self.nlp else {}
|
| 54 |
+
return ans, {"entities": entities}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|