Spaces:
Sleeping
Sleeping
| """ | |
| Pure RAG utilities – no Gradio code here. | |
| """ | |
| import re, requests, numpy as np, faiss, nltk, torch | |
| from bs4 import BeautifulSoup | |
| from sentence_transformers import SentenceTransformer, CrossEncoder | |
| from transformers import AutoTokenizer, pipeline, set_seed | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| from nltk.translate.bleu_score import sentence_bleu | |
| from nltk.tokenize import word_tokenize | |
| from rouge_score import rouge_scorer | |
| import nltk | |
| import os | |
| nltk.download('punkt_tab') | |
| # Force NLTK to use the bundled data | |
| NLTK_LOCAL = os.path.join(os.path.dirname(__file__), "nltk_data") | |
| nltk.data.path.insert(0, NLTK_LOCAL) | |
| DEFAULT_LLMS = { | |
| "gpt2": "gpt2", | |
| "distilgpt2": "distilgpt2", | |
| "flan-t5-small": "google/flan-t5-small", | |
| "flan-t5-base": "google/flan-t5-base" | |
| } | |
| class RagEngine: | |
| def __init__(self): | |
| self.bi_encoder = SentenceTransformer("all-MiniLM-L6-v2") | |
| self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| self.faiss_idx = None | |
| self.chunks = [] | |
| self.embeddings = None | |
| self.pipes = {} # cache for generation pipelines | |
| self.tok_t5 = None # cache for t5 tokenizer | |
| # ---------- 1. scrape ---------- | |
| def scrape_bbc_article(self, url: str) -> str: | |
| """Return plain text of a BBC Culture / Travel article.""" | |
| html = requests.get(url, timeout=15).text | |
| soup = BeautifulSoup(html, "html.parser") | |
| # kill scripts / style | |
| for s in soup(["script", "style", "noscript"]): | |
| s.decompose() | |
| # BBC puts article body in <article> or <div class="ssrcss-..."> | |
| article = soup.find("article") or soup.find("div", attrs={"data-testid": "article-body"}) | |
| if not article: | |
| raise ValueError("Could not locate article body.") | |
| text = article.get_text(separator="\n") | |
| text = re.sub(r"\n+", "\n", text.strip()) | |
| return text | |
| # ---------- 2. chunk ---------- | |
| def chunk_text(self, text: str, chunk_size: int, overlap: int) -> list[str]: | |
| """Recursive character splitter mimic.""" | |
| splitter = "\n" | |
| paragraphs = text.split(splitter) | |
| chunks, buf = [], [] | |
| buf_len = 0 | |
| for p in paragraphs: | |
| p_len = len(p.split()) | |
| if buf_len + p_len > chunk_size and buf: | |
| chunks.append(" ".join(buf)) | |
| # overlap | |
| overlap_words = int(overlap) | |
| buf = buf[-overlap_words:] if overlap_words else [] | |
| buf_len = sum(len(x.split()) for x in buf) | |
| buf.append(p) | |
| buf_len += p_len | |
| if buf: | |
| chunks.append(" ".join(buf)) | |
| return chunks | |
| # ---------- 3. embed + faiss ---------- | |
| def build_index(self, chunks: list[str]): | |
| self.chunks = chunks | |
| self.embeddings = self.bi_encoder.encode(chunks, normalize_embeddings=True, show_progress_bar=True) | |
| d = self.embeddings.shape[1] | |
| self.faiss_idx = faiss.IndexFlatIP(d) # inner-product = cosine with normed vecs | |
| self.faiss_idx.add(np.array(self.embeddings.astype("float32"))) | |
| # ---------- 4. retrieve + re-rank ---------- | |
| def retrieve(self, query: str, k: int = 5, rerank_top: int = 3): | |
| q_emb = self.bi_encoder.encode([query], normalize_embeddings=True) | |
| scores, idxs = self.faiss_idx.search(np.array(q_emb.astype("float32")), k) | |
| hits = [{"chunk": self.chunks[i], "score": float(scores[0][j]), "idx": int(idxs[0][j])} | |
| for j, i in enumerate(idxs[0])] | |
| # re-rank | |
| cross_in = [[query, h["chunk"]] for h in hits] | |
| cross_scores = self.cross_encoder.predict(cross_in) | |
| for h, sc in zip(hits, cross_scores): | |
| h["cross_score"] = float(sc) | |
| hits = sorted(hits, key=lambda x: x["cross_score"], reverse=True)[:rerank_top] | |
| return hits | |
| # ---------- 5. generation ---------- | |
| def generate(self, query: str, retrieved: list[dict], llm_name: str, | |
| max_new_tokens: int = 120, temperature: float = 0.0) -> str: | |
| context = "\n".join(h["chunk"] for h in retrieved[:1]) | |
| #prompt = (f"Use only the following passages to answer the question. " | |
| # f"If the answer cannot be found, say 'I don't have enough information'.\n\n" | |
| # f"Passages:\n{context}\n\nQuestion: {query}\nAnswer:") | |
| prompt = (f"Use only the following passages to answer the question. " | |
| f"If no single sentence gives a full answer, write what is stated.\n\n" | |
| f"Passages:\n{context}\n\nQuestion: {query}\nAnswer:") | |
| model_id = DEFAULT_LLMS.get(llm_name, llm_name) | |
| # ---------- choose task & truncate ---------- | |
| if llm_name.startswith("flan-t5"): | |
| task = "text2text-generation" | |
| if self.tok_t5 is None: | |
| self.tok_t5 = AutoTokenizer.from_pretrained(model_id) | |
| tok = self.tok_t5 | |
| tokens = tok.tokenize(prompt) | |
| limit = 512 - max_new_tokens - 5 | |
| if len(tokens) > limit: | |
| prompt = tok.convert_tokens_to_string(tokens[:limit]) | |
| else: # gpt2 | |
| task = "text-generation" | |
| prompt = self._clip_to_max_len(prompt, model_id, max_new_tokens) | |
| # ---------- cached pipeline ---------- | |
| if (task, model_id) not in self.pipes: | |
| self.pipes[(task, model_id)] = pipeline( | |
| task, model=model_id, | |
| device=0 if torch.cuda.is_available() else -1) | |
| generator = self.pipes[(task, model_id)] | |
| gen_kw = dict(max_new_tokens=max_new_tokens, | |
| do_sample=(temperature > 0), | |
| temperature=temperature if temperature > 0 else None) | |
| if task == "text-generation": | |
| gen_kw["return_full_text"] = False | |
| out = generator(prompt, **gen_kw)[0]["generated_text"] | |
| return out.strip() | |
| # ---------- 6. metrics ---------- | |
| def compute_metrics(self, reference: str, hypothesis: str): | |
| ref_tok = word_tokenize(reference.lower()) | |
| hyp_tok = word_tokenize(hypothesis.lower()) | |
| bleu = sentence_bleu([ref_tok], hyp_tok) | |
| scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True) | |
| rs = scorer.score(reference, hypothesis) | |
| return { | |
| "bleu": bleu, | |
| "rouge-1": rs["rouge1"].fmeasure, | |
| "rouge-2": rs["rouge2"].fmeasure, | |
| "rouge-l": rs["rougeL"].fmeasure, | |
| } | |
| def _clip_to_max_len(text: str, model_name: str, max_new_tokens: int) -> str: | |
| tok = AutoTokenizer.from_pretrained(model_name, use_fast=True) | |
| if tok.pad_token is None: | |
| tok.pad_token = tok.eos_token | |
| max_pos = tok.model_max_length # 1024 for gpt2 | |
| allowance = max_pos - max_new_tokens - 2 | |
| tokens = tok.encode(text) | |
| if len(tokens) > allowance: | |
| tokens = tokens[:allowance] | |
| return tok.decode(tokens, skip_special_tokens=True) |