File size: 7,077 Bytes
907bc53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3f86984
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
Pure RAG utilities – no Gradio code here.
"""
import re, requests, numpy as np, faiss, nltk, torch
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer, CrossEncoder
from transformers import AutoTokenizer, pipeline, set_seed
from sklearn.metrics.pairwise import cosine_similarity
from nltk.translate.bleu_score import sentence_bleu
from nltk.tokenize import word_tokenize
from rouge_score import rouge_scorer
import nltk
import os
nltk.download('punkt_tab')
# Force NLTK to use the bundled data
NLTK_LOCAL = os.path.join(os.path.dirname(__file__), "nltk_data")
nltk.data.path.insert(0, NLTK_LOCAL)

DEFAULT_LLMS = {
    "gpt2": "gpt2",
    "distilgpt2": "distilgpt2",
    "flan-t5-small": "google/flan-t5-small",
    "flan-t5-base": "google/flan-t5-base"
}

class RagEngine:
    def __init__(self):
        self.bi_encoder = SentenceTransformer("all-MiniLM-L6-v2")
        self.cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
        self.faiss_idx = None
        self.chunks = []
        self.embeddings = None
        self.pipes = {}  # cache for generation pipelines
        self.tok_t5 = None  # cache for t5 tokenizer

    # ---------- 1. scrape ----------
    def scrape_bbc_article(self, url: str) -> str:
        """Return plain text of a BBC Culture / Travel article."""
        html = requests.get(url, timeout=15).text
        soup = BeautifulSoup(html, "html.parser")
        # kill scripts / style
        for s in soup(["script", "style", "noscript"]):
            s.decompose()
        # BBC puts article body in <article> or <div class="ssrcss-...">
        article = soup.find("article") or soup.find("div", attrs={"data-testid": "article-body"})
        if not article:
            raise ValueError("Could not locate article body.")
        text = article.get_text(separator="\n")
        text = re.sub(r"\n+", "\n", text.strip())
        return text

    # ---------- 2. chunk ----------
    def chunk_text(self, text: str, chunk_size: int, overlap: int) -> list[str]:
        """Recursive character splitter mimic."""
        splitter = "\n"
        paragraphs = text.split(splitter)
        chunks, buf = [], []
        buf_len = 0
        for p in paragraphs:
            p_len = len(p.split())
            if buf_len + p_len > chunk_size and buf:
                chunks.append(" ".join(buf))
                # overlap
                overlap_words = int(overlap)
                buf = buf[-overlap_words:] if overlap_words else []
                buf_len = sum(len(x.split()) for x in buf)
            buf.append(p)
            buf_len += p_len
        if buf:
            chunks.append(" ".join(buf))
        return chunks

    # ---------- 3. embed + faiss ----------
    def build_index(self, chunks: list[str]):
        self.chunks = chunks
        self.embeddings = self.bi_encoder.encode(chunks, normalize_embeddings=True, show_progress_bar=True)
        d = self.embeddings.shape[1]
        self.faiss_idx = faiss.IndexFlatIP(d)  # inner-product = cosine with normed vecs
        self.faiss_idx.add(np.array(self.embeddings.astype("float32")))

    # ---------- 4. retrieve + re-rank ----------
    def retrieve(self, query: str, k: int = 5, rerank_top: int = 3):
        q_emb = self.bi_encoder.encode([query], normalize_embeddings=True)
        scores, idxs = self.faiss_idx.search(np.array(q_emb.astype("float32")), k)
        hits = [{"chunk": self.chunks[i], "score": float(scores[0][j]), "idx": int(idxs[0][j])}
                for j, i in enumerate(idxs[0])]
        # re-rank
        cross_in = [[query, h["chunk"]] for h in hits]
        cross_scores = self.cross_encoder.predict(cross_in)
        for h, sc in zip(hits, cross_scores):
            h["cross_score"] = float(sc)
        hits = sorted(hits, key=lambda x: x["cross_score"], reverse=True)[:rerank_top]
        return hits

    # ---------- 5. generation ----------
    def generate(self, query: str, retrieved: list[dict], llm_name: str,
             max_new_tokens: int = 120, temperature: float = 0.0) -> str:

        context = "\n".join(h["chunk"] for h in retrieved[:1])
        #prompt = (f"Use only the following passages to answer the question. "
        #        f"If the answer cannot be found, say 'I don't have enough information'.\n\n"
        #        f"Passages:\n{context}\n\nQuestion: {query}\nAnswer:")
        prompt = (f"Use only the following passages to answer the question. "
          f"If no single sentence gives a full answer, write what is stated.\n\n"
          f"Passages:\n{context}\n\nQuestion: {query}\nAnswer:")

        model_id = DEFAULT_LLMS.get(llm_name, llm_name)

        # ---------- choose task & truncate ----------
        if llm_name.startswith("flan-t5"):
            task = "text2text-generation"
            if self.tok_t5 is None:
                self.tok_t5 = AutoTokenizer.from_pretrained(model_id)
            tok = self.tok_t5
            tokens = tok.tokenize(prompt)
            limit = 512 - max_new_tokens - 5
            if len(tokens) > limit:
                prompt = tok.convert_tokens_to_string(tokens[:limit])
        else:                                     # gpt2
            task = "text-generation"
            prompt = self._clip_to_max_len(prompt, model_id, max_new_tokens)

        # ---------- cached pipeline ----------
        if (task, model_id) not in self.pipes:
            self.pipes[(task, model_id)] = pipeline(
                task, model=model_id,
                device=0 if torch.cuda.is_available() else -1)

        generator = self.pipes[(task, model_id)]
        gen_kw = dict(max_new_tokens=max_new_tokens,
                    do_sample=(temperature > 0),
                    temperature=temperature if temperature > 0 else None)
        if task == "text-generation":
            gen_kw["return_full_text"] = False

        out = generator(prompt, **gen_kw)[0]["generated_text"]
        return out.strip()
    
    # ---------- 6. metrics ----------
    def compute_metrics(self, reference: str, hypothesis: str):
        ref_tok = word_tokenize(reference.lower())
        hyp_tok = word_tokenize(hypothesis.lower())
        bleu = sentence_bleu([ref_tok], hyp_tok)

        scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
        rs = scorer.score(reference, hypothesis)
        return {
            "bleu": bleu,
            "rouge-1": rs["rouge1"].fmeasure,
            "rouge-2": rs["rouge2"].fmeasure,
            "rouge-l": rs["rougeL"].fmeasure,
        }
    @staticmethod
    def _clip_to_max_len(text: str, model_name: str, max_new_tokens: int) -> str:
        tok = AutoTokenizer.from_pretrained(model_name, use_fast=True)
        if tok.pad_token is None:
            tok.pad_token = tok.eos_token
        max_pos = tok.model_max_length          # 1024 for gpt2
        allowance = max_pos - max_new_tokens - 2
        tokens = tok.encode(text)
        if len(tokens) > allowance:
            tokens = tokens[:allowance]
        return tok.decode(tokens, skip_special_tokens=True)