myraggradio / rag_engine.py
Soha85's picture
nltk handling
907bc53 verified
"""
Pure RAG utilities – no Gradio code here.
"""
import re, requests, numpy as np, faiss, nltk, torch
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import AutoTokenizer, pipeline, set_seed
from sklearn.metrics.pairwise import cosine_similarity
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize
from rouge_score import rouge_scorer
import nltk
import os
nltk.download('punkt_tab')
# Force NLTK to use the bundled data
NLTK_LOCAL = os.path.join(os.path.dirname(__file__), "nltk_data")
nltk.data.path.insert(0, NLTK_LOCAL)
DEFAULT_LLMS = {
"gpt2": "gpt2",
"distilgpt2": "distilgpt2",
"flan-t5-small": "google/flan-t5-small",
"flan-t5-base": "google/flan-t5-base"
}
class RagEngine:
def __init__(self):
self.bi_encoder = SentenceTransformer("all-MiniLM-L6-v2")
self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
self.faiss_idx = None
self.chunks = []
self.embeddings = None
self.pipes = {} # cache for generation pipelines
self.tok_t5 = None # cache for t5 tokenizer
# ---------- 1. scrape ----------
def scrape_bbc_article(self, url: str) -> str:
"""Return plain text of a BBC Culture / Travel article."""
html = requests.get(url, timeout=15).text
soup = BeautifulSoup(html, "html.parser")
# kill scripts / style
for s in soup(["script", "style", "noscript"]):
s.decompose()
# BBC puts article body in <article> or <div class="ssrcss-...">
article = soup.find("article") or soup.find("div", attrs={"data-testid": "article-body"})
if not article:
raise ValueError("Could not locate article body.")
text = article.get_text(separator="\n")
text = re.sub(r"\n+", "\n", text.strip())
return text
# ---------- 2. chunk ----------
def chunk_text(self, text: str, chunk_size: int, overlap: int) -> list[str]:
"""Recursive character splitter mimic."""
splitter = "\n"
paragraphs = text.split(splitter)
chunks, buf = [], []
buf_len = 0
for p in paragraphs:
p_len = len(p.split())
if buf_len + p_len > chunk_size and buf:
chunks.append(" ".join(buf))
# overlap
overlap_words = int(overlap)
buf = buf[-overlap_words:] if overlap_words else []
buf_len = sum(len(x.split()) for x in buf)
buf.append(p)
buf_len += p_len
if buf:
chunks.append(" ".join(buf))
return chunks
# ---------- 3. embed + faiss ----------
def build_index(self, chunks: list[str]):
self.chunks = chunks
self.embeddings = self.bi_encoder.encode(chunks, normalize_embeddings=True, show_progress_bar=True)
d = self.embeddings.shape[1]
self.faiss_idx = faiss.IndexFlatIP(d) # inner-product = cosine with normed vecs
self.faiss_idx.add(np.array(self.embeddings.astype("float32")))
# ---------- 4. retrieve + re-rank ----------
def retrieve(self, query: str, k: int = 5, rerank_top: int = 3):
q_emb = self.bi_encoder.encode([query], normalize_embeddings=True)
scores, idxs = self.faiss_idx.search(np.array(q_emb.astype("float32")), k)
hits = [{"chunk": self.chunks[i], "score": float(scores[0][j]), "idx": int(idxs[0][j])}
for j, i in enumerate(idxs[0])]
# re-rank
cross_in = [[query, h["chunk"]] for h in hits]
cross_scores = self.cross_encoder.predict(cross_in)
for h, sc in zip(hits, cross_scores):
h["cross_score"] = float(sc)
hits = sorted(hits, key=lambda x: x["cross_score"], reverse=True)[:rerank_top]
return hits
# ---------- 5. generation ----------
def generate(self, query: str, retrieved: list[dict], llm_name: str,
max_new_tokens: int = 120, temperature: float = 0.0) -> str:
context = "\n".join(h["chunk"] for h in retrieved[:1])
#prompt = (f"Use only the following passages to answer the question. "
# f"If the answer cannot be found, say 'I don't have enough information'.\n\n"
# f"Passages:\n{context}\n\nQuestion: {query}\nAnswer:")
prompt = (f"Use only the following passages to answer the question. "
f"If no single sentence gives a full answer, write what is stated.\n\n"
f"Passages:\n{context}\n\nQuestion: {query}\nAnswer:")
model_id = DEFAULT_LLMS.get(llm_name, llm_name)
# ---------- choose task & truncate ----------
if llm_name.startswith("flan-t5"):
task = "text2text-generation"
if self.tok_t5 is None:
self.tok_t5 = AutoTokenizer.from_pretrained(model_id)
tok = self.tok_t5
tokens = tok.tokenize(prompt)
limit = 512 - max_new_tokens - 5
if len(tokens) > limit:
prompt = tok.convert_tokens_to_string(tokens[:limit])
else: # gpt2
task = "text-generation"
prompt = self._clip_to_max_len(prompt, model_id, max_new_tokens)
# ---------- cached pipeline ----------
if (task, model_id) not in self.pipes:
self.pipes[(task, model_id)] = pipeline(
task, model=model_id,
device=0 if torch.cuda.is_available() else -1)
generator = self.pipes[(task, model_id)]
gen_kw = dict(max_new_tokens=max_new_tokens,
do_sample=(temperature > 0),
temperature=temperature if temperature > 0 else None)
if task == "text-generation":
gen_kw["return_full_text"] = False
out = generator(prompt, **gen_kw)[0]["generated_text"]
return out.strip()
# ---------- 6. metrics ----------
def compute_metrics(self, reference: str, hypothesis: str):
ref_tok = word_tokenize(reference.lower())
hyp_tok = word_tokenize(hypothesis.lower())
bleu = sentence_bleu([ref_tok], hyp_tok)
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
rs = scorer.score(reference, hypothesis)
return {
"bleu": bleu,
"rouge-1": rs["rouge1"].fmeasure,
"rouge-2": rs["rouge2"].fmeasure,
"rouge-l": rs["rougeL"].fmeasure,
}
@staticmethod
def _clip_to_max_len(text: str, model_name: str, max_new_tokens: int) -> str:
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
max_pos = tok.model_max_length # 1024 for gpt2
allowance = max_pos - max_new_tokens - 2
tokens = tok.encode(text)
if len(tokens) > allowance:
tokens = tokens[:allowance]
return tok.decode(tokens, skip_special_tokens=True)