File size: 11,341 Bytes
b1748d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6506759
b1748d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import os, io, json, math, pickle, textwrap, shutil, re
from typing import List, Dict, Any, Tuple
import numpy as np, faiss, fitz  # pymupdf
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer
import gradio as gr
from groq import Groq

# ---------- Config ----------
EMBED_MODEL_NAME = "intfloat/multilingual-e5-small"
CHUNK_SIZE = 1200
CHUNK_OVERLAP = 200
TOP_K_DEFAULT = 5
MAX_CONTEXT_CHARS = 12000

INDEX_PATH = "rag_index.faiss"
STORE_PATH = "rag_store.pkl"

MODEL_CHOICES = [
    "llama-3.3-70b-versatile",
    "llama-3.1-8b-instant",
    "mixtral-8x7b-32768",
]

device = "cuda" if torch.cuda.is_available() else "cpu"
embedder = None
faiss_index = None
docstore: List[Dict[str, Any]] = []

# ---------- PDF utils ----------
def extract_text_from_pdf(pdf_path: str) -> List[Tuple[int, str]]:
    pages = []
    with fitz.open(pdf_path) as doc:
        for i, page in enumerate(doc, start=1):
            txt = page.get_text("text") or ""
            if not txt.strip():
                blocks = page.get_text("blocks")
                if isinstance(blocks, list):
                    txt = "\n".join(b[4] for b in blocks if isinstance(b, (list, tuple)) and len(b) > 4)
            pages.append((i, txt or ""))
    return pages

def chunk_text(text: str, chunk_size=CHUNK_SIZE, overlap=CHUNK_OVERLAP) -> List[str]:
    text = text.replace("\x00", " ").strip()
    if len(text) <= chunk_size:
        return [text] if text else []
    out, start = [], 0
    while start < len(text):
        end = start + chunk_size
        out.append(text[start:end])
        start = max(end - overlap, start + 1)
    return out

# ---------- Embeddings / FAISS ----------
def load_embedder():
    global embedder
    if embedder is None:
        embedder = SentenceTransformer(EMBED_MODEL_NAME, device=device)
    return embedder

def _normalize(vecs: np.ndarray) -> np.ndarray:
    norms = np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-12
    return (vecs / norms).astype("float32")

def embed_passages(texts: List[str]) -> np.ndarray:
    model = load_embedder()
    inputs = [f"passage: {t}" for t in texts]
    embs = model.encode(inputs, batch_size=64, show_progress_bar=False, convert_to_numpy=True)
    return _normalize(embs)

def embed_query(q: str) -> np.ndarray:
    model = load_embedder()
    embs = model.encode([f"query: {q}"], convert_to_numpy=True)
    return _normalize(embs)

def build_faiss(embs: np.ndarray):
    index = faiss.IndexFlatIP(embs.shape[1])
    index.add(embs)
    return index

def save_index(index, store_list: List[Dict[str, Any]]):
    faiss.write_index(index, INDEX_PATH)
    with open(STORE_PATH, "wb") as f:
        pickle.dump({"docstore": store_list, "embed_model": EMBED_MODEL_NAME}, f)

def load_index() -> bool:
    global faiss_index, docstore
    if os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH):
        faiss_index = faiss.read_index(INDEX_PATH)
        with open(STORE_PATH, "rb") as f:
            data = pickle.load(f)
        docstore = data["docstore"]
        load_embedder()
        return True
    return False

# ---------- Ingest ----------
def ingest_pdfs(paths: List[str]) -> Tuple[Any, List[Dict[str, Any]]]:
    entries: List[Dict[str, Any]] = []
    for pdf in tqdm(paths, total=len(paths), desc="Parsing PDFs"):
        try:
            pages = extract_text_from_pdf(pdf)
            base = os.path.basename(pdf)
            for pno, ptxt in pages:
                if not ptxt.strip():
                    continue
                for ci, ch in enumerate(chunk_text(ptxt)):
                    entries.append({
                        "text": ch,
                        "source": base,
                        "page_start": pno,
                        "page_end": pno,
                        "chunk_id": f"{base}::p{pno}::c{ci}",
                    })
        except Exception as e:
            print(f"[WARN] Failed to parse {pdf}: {e}")
    if not entries:
        raise RuntimeError("No text extracted. If PDFs are scanned images, run OCR before indexing.")
    texts = [e["text"] for e in entries]
    embs = embed_passages(texts)
    index = build_faiss(embs)
    return index, entries

# ---------- Retrieval (supports required keywords) ----------
def retrieve(query: str, top_k=5, must_contain: str = ""):
    global faiss_index, docstore
    if faiss_index is None or not docstore:
        raise RuntimeError("Index not built or loaded. Use 'Build Index' or 'Reload Saved Index' first.")
    k = int(top_k) if top_k else TOP_K_DEFAULT

    pool = min(max(10 * k, 200), len(docstore))
    qemb = embed_query(query)
    D, I = faiss_index.search(qemb, pool)
    pairs = [(int(i), float(s)) for i, s in zip(I[0], D[0]) if i >= 0]

    must_words = [w.strip().lower() for w in must_contain.split(",") if w.strip()]
    if must_words:
        filtered = []
        for idx, score in pairs:
            t = docstore[idx]["text"].lower()
            if all(w in t for w in must_words):
                filtered.append((idx, score))
        if filtered:
            pairs = filtered

    pairs = pairs[:k]
    hits = []
    for idx, score in pairs:
        item = docstore[idx].copy()
        item["score"] = float(score)
        hits.append(item)
    return hits

# ---------- Groq LLM ----------
def groq_answer(query: str, contexts, model_name="llama-3.1-70b-versatile", temperature=0.2, max_tokens=1000):
    try:
        if not os.environ.get("GROQ_API_KEY"):
            return "GROQ_API_KEY is not set. Add it in your host's environment/secrets."
        client = Groq(api_key=os.environ["GROQ_API_KEY"])

        packed, used = [], 0
        for c in contexts:
            tag = f"[{c['source']} p.{c['page_start']}]"
            piece = f"{tag}\n{c['text'].strip()}\n"
            if used + len(piece) > MAX_CONTEXT_CHARS:
                break
            packed.append(piece); used += len(piece)
        context_str = "\n---\n".join(packed)

        system_prompt = (
            "You are a scholarly assistant. Answer using ONLY the provided context. "
            "If the answer is not present, say so. Always include a 'References' section with sources and page numbers."
        )
        user_prompt = (
            f"Question:\n{query}\n\n"
            f"Context snippets (use these only):\n{context_str}\n\n"
            "Write a precise answer. Keep claims traceable to the snippets."
        )

        resp = client.chat.completions.create(
            model=model_name,
            temperature=float(temperature),
            max_tokens=int(max_tokens),
            messages=[{"role":"system","content":system_prompt},{"role":"user","content":user_prompt}],
        )
        return resp.choices[0].message.content.strip()
    except Exception as e:
        import traceback
        return f"Groq API error: {e}\n```\n{traceback.format_exc()}\n```"

# ---------- Helpers for UI ----------
def build_index_from_uploads(paths: List[str]) -> str:
    global faiss_index, docstore
    if not paths: return "Please upload at least one PDF."
    if len(paths) > 120: return "Please limit to ~100 PDFs per build."

    faiss_index, entries = ingest_pdfs(paths)
    save_index(faiss_index, entries)
    docstore = entries
    return f"Index built with {len(entries)} chunks from {len(paths)} PDFs. Saved to disk."

def reload_index() -> str:
    ok = load_index()
    return f"Index reloaded. Chunks: {len(docstore)}" if ok else "No saved index found."

def ask_rag(query: str, top_k, model_name: str, temperature: float, must_contain: str):
    try:
        if not query.strip():
            return "Please enter a question.", []
        ctx = retrieve(query, top_k=int(top_k) if top_k else TOP_K_DEFAULT, must_contain=must_contain)
        ans = groq_answer(query, ctx, model_name=model_name, temperature=temperature)
        rows = []
        for c in ctx:
            preview = c["text"][:200].replace("\n"," ") + ("..." if len(c["text"])>200 else "")
            rows.append([c["source"], str(c["page_start"]), f"{c['score']:.3f}", preview])
        return ans, rows
    except Exception as e:
        import traceback
        return f"**Error:** {e}\n```\n{traceback.format_exc()}\n```", []

def set_api_key(k: str):
    if k and k.strip():
        os.environ["GROQ_API_KEY"] = k.strip()
        return "API key set in runtime."
    return "No key provided."

def download_index_zip():
    if not (os.path.exists(INDEX_PATH) and os.path.exists(STORE_PATH)):
        return None
    base = "rag_index_bundle"
    zip_path = shutil.make_archive(base, "zip", ".", ".")
    # workaround for shutil: package explicit files
    with shutil.make_archive("rag_index", "zip"):
        pass
    # build our own zip containing only index files
    import zipfile
    zp = "rag_index_bundle.zip"
    with zipfile.ZipFile(zp, "w", zipfile.ZIP_DEFLATED) as z:
        z.write(INDEX_PATH)
        z.write(STORE_PATH)
    return zp

# ---------- Gradio UI ----------
with gr.Blocks(title="RAG over PDFs (Groq)") as demo:
    gr.Markdown("## RAG over your PDFs using Groq\nUpload PDFs, build an index, then ask questions with cited answers.")
    with gr.Row():
        api_box = gr.Textbox(label="(Optional) Set GROQ_API_KEY for this session", type="password", placeholder="sk_...") 
        set_btn = gr.Button("Set Key")
        set_out = gr.Markdown()
        set_btn.click(set_api_key, inputs=[api_box], outputs=[set_out])

    with gr.Tab("1) Build or Load Index"):
        file_u = gr.Files(label="Upload PDFs", file_types=[".pdf"], type="filepath")
        with gr.Row():
            build_btn = gr.Button("Build Index")
            reload_btn = gr.Button("Reload Saved Index")
            download_btn = gr.Button("Download Index (.zip)")
        build_out = gr.Markdown()

        def on_build(paths, progress=gr.Progress(track_tqdm=True)):
            try:
                return build_index_from_uploads(paths)
            except Exception as e:
                import traceback
                return f"**Error while building index:** {e}\n\n```\n{traceback.format_exc()}\n```"

        build_btn.click(on_build, inputs=[file_u], outputs=[build_out])
        reload_btn.click(fn=reload_index, outputs=[build_out])
        zpath = gr.File(label="Index zip", interactive=False)
        download_btn.click(fn=download_index_zip, outputs=[zpath])

    with gr.Tab("2) Ask Questions"):
        q = gr.Textbox(label="Your question", lines=2, placeholder="Ask something present in the uploaded papers…")
        with gr.Row():
            topk = gr.Slider(1, 15, value=TOP_K_DEFAULT, step=1, label="Top-K passages")
            model_dd = gr.Dropdown(MODEL_CHOICES, value=MODEL_CHOICES[0], label="Groq model")
            temp = gr.Slider(0.0, 1.0, value=0.2, step=0.05, label="Temperature")
        must = gr.Textbox(label="Must contain (comma-separated keywords)", placeholder="camera, CMOS, frame rate")
        ask_btn = gr.Button("Answer")
        ans = gr.Markdown()
        src = gr.Dataframe(headers=["Source","Page","Score","Snippet"], wrap=True)
        ask_btn.click(ask_rag, inputs=[q, topk, model_dd, temp, must], outputs=[ans, src])

demo.queue()  # keep it simple for broad Gradio versions
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))