Spaces:
Sleeping
Sleeping
File size: 7,077 Bytes
907bc53 3f86984 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
"""
Pure RAG utilities – no Gradio code here.
"""
import re, requests, numpy as np, faiss, nltk, torch
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import AutoTokenizer, pipeline, set_seed
from sklearn.metrics.pairwise import cosine_similarity
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize
from rouge_score import rouge_scorer
import nltk
import os
nltk.download('punkt_tab')
# Force NLTK to use the bundled data
NLTK_LOCAL = os.path.join(os.path.dirname(__file__), "nltk_data")
nltk.data.path.insert(0, NLTK_LOCAL)
DEFAULT_LLMS = {
"gpt2": "gpt2",
"distilgpt2": "distilgpt2",
"flan-t5-small": "google/flan-t5-small",
"flan-t5-base": "google/flan-t5-base"
}
class RagEngine:
def __init__(self):
self.bi_encoder = SentenceTransformer("all-MiniLM-L6-v2")
self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
self.faiss_idx = None
self.chunks = []
self.embeddings = None
self.pipes = {} # cache for generation pipelines
self.tok_t5 = None # cache for t5 tokenizer
# ---------- 1. scrape ----------
def scrape_bbc_article(self, url: str) -> str:
"""Return plain text of a BBC Culture / Travel article."""
html = requests.get(url, timeout=15).text
soup = BeautifulSoup(html, "html.parser")
# kill scripts / style
for s in soup(["script", "style", "noscript"]):
s.decompose()
# BBC puts article body in <article> or <div class="ssrcss-...">
article = soup.find("article") or soup.find("div", attrs={"data-testid": "article-body"})
if not article:
raise ValueError("Could not locate article body.")
text = article.get_text(separator="\n")
text = re.sub(r"\n+", "\n", text.strip())
return text
# ---------- 2. chunk ----------
def chunk_text(self, text: str, chunk_size: int, overlap: int) -> list[str]:
"""Recursive character splitter mimic."""
splitter = "\n"
paragraphs = text.split(splitter)
chunks, buf = [], []
buf_len = 0
for p in paragraphs:
p_len = len(p.split())
if buf_len + p_len > chunk_size and buf:
chunks.append(" ".join(buf))
# overlap
overlap_words = int(overlap)
buf = buf[-overlap_words:] if overlap_words else []
buf_len = sum(len(x.split()) for x in buf)
buf.append(p)
buf_len += p_len
if buf:
chunks.append(" ".join(buf))
return chunks
# ---------- 3. embed + faiss ----------
def build_index(self, chunks: list[str]):
self.chunks = chunks
self.embeddings = self.bi_encoder.encode(chunks, normalize_embeddings=True, show_progress_bar=True)
d = self.embeddings.shape[1]
self.faiss_idx = faiss.IndexFlatIP(d) # inner-product = cosine with normed vecs
self.faiss_idx.add(np.array(self.embeddings.astype("float32")))
# ---------- 4. retrieve + re-rank ----------
def retrieve(self, query: str, k: int = 5, rerank_top: int = 3):
q_emb = self.bi_encoder.encode([query], normalize_embeddings=True)
scores, idxs = self.faiss_idx.search(np.array(q_emb.astype("float32")), k)
hits = [{"chunk": self.chunks[i], "score": float(scores[0][j]), "idx": int(idxs[0][j])}
for j, i in enumerate(idxs[0])]
# re-rank
cross_in = [[query, h["chunk"]] for h in hits]
cross_scores = self.cross_encoder.predict(cross_in)
for h, sc in zip(hits, cross_scores):
h["cross_score"] = float(sc)
hits = sorted(hits, key=lambda x: x["cross_score"], reverse=True)[:rerank_top]
return hits
# ---------- 5. generation ----------
def generate(self, query: str, retrieved: list[dict], llm_name: str,
max_new_tokens: int = 120, temperature: float = 0.0) -> str:
context = "\n".join(h["chunk"] for h in retrieved[:1])
#prompt = (f"Use only the following passages to answer the question. "
# f"If the answer cannot be found, say 'I don't have enough information'.\n\n"
# f"Passages:\n{context}\n\nQuestion: {query}\nAnswer:")
prompt = (f"Use only the following passages to answer the question. "
f"If no single sentence gives a full answer, write what is stated.\n\n"
f"Passages:\n{context}\n\nQuestion: {query}\nAnswer:")
model_id = DEFAULT_LLMS.get(llm_name, llm_name)
# ---------- choose task & truncate ----------
if llm_name.startswith("flan-t5"):
task = "text2text-generation"
if self.tok_t5 is None:
self.tok_t5 = AutoTokenizer.from_pretrained(model_id)
tok = self.tok_t5
tokens = tok.tokenize(prompt)
limit = 512 - max_new_tokens - 5
if len(tokens) > limit:
prompt = tok.convert_tokens_to_string(tokens[:limit])
else: # gpt2
task = "text-generation"
prompt = self._clip_to_max_len(prompt, model_id, max_new_tokens)
# ---------- cached pipeline ----------
if (task, model_id) not in self.pipes:
self.pipes[(task, model_id)] = pipeline(
task, model=model_id,
device=0 if torch.cuda.is_available() else -1)
generator = self.pipes[(task, model_id)]
gen_kw = dict(max_new_tokens=max_new_tokens,
do_sample=(temperature > 0),
temperature=temperature if temperature > 0 else None)
if task == "text-generation":
gen_kw["return_full_text"] = False
out = generator(prompt, **gen_kw)[0]["generated_text"]
return out.strip()
# ---------- 6. metrics ----------
def compute_metrics(self, reference: str, hypothesis: str):
ref_tok = word_tokenize(reference.lower())
hyp_tok = word_tokenize(hypothesis.lower())
bleu = sentence_bleu([ref_tok], hyp_tok)
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
rs = scorer.score(reference, hypothesis)
return {
"bleu": bleu,
"rouge-1": rs["rouge1"].fmeasure,
"rouge-2": rs["rouge2"].fmeasure,
"rouge-l": rs["rougeL"].fmeasure,
}
@staticmethod
def _clip_to_max_len(text: str, model_name: str, max_new_tokens: int) -> str:
tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
max_pos = tok.model_max_length # 1024 for gpt2
allowance = max_pos - max_new_tokens - 2
tokens = tok.encode(text)
if len(tokens) > allowance:
tokens = tokens[:allowance]
return tok.decode(tokens, skip_special_tokens=True) |