|
|
import os |
|
|
import io |
|
|
import gradio as gr |
|
|
import faiss |
|
|
import numpy as np |
|
|
from pypdf import PdfReader |
|
|
from docx import Document |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
|
|
GEN_MODEL_NAME = "MBZUAI/LaMini-Flan-T5-248M" |
|
|
|
|
|
embedder = SentenceTransformer(EMBED_MODEL_NAME) |
|
|
generator = pipeline("text2text-generation", model=GEN_MODEL_NAME) |
|
|
|
|
|
|
|
|
def read_pdf_from_path_or_bytes(file_obj_or_path): |
|
|
|
|
|
path = getattr(file_obj_or_path, "path", None) |
|
|
if isinstance(file_obj_or_path, str) and os.path.exists(file_obj_or_path): |
|
|
path = file_obj_or_path |
|
|
if path and os.path.exists(path): |
|
|
reader = PdfReader(path) |
|
|
return "\n".join((p.extract_text() or "") for p in reader.pages) |
|
|
|
|
|
data = None |
|
|
if hasattr(file_obj_or_path, "read"): |
|
|
data = file_obj_or_path.read() |
|
|
elif hasattr(file_obj_or_path, "bytes"): |
|
|
data = file_obj_or_path.bytes |
|
|
if data: |
|
|
reader = PdfReader(io.BytesIO(data)) |
|
|
return "\n".join((p.extract_text() or "") for p in reader.pages) |
|
|
|
|
|
return "" |
|
|
|
|
|
|
|
|
def read_docx_text(path): |
|
|
doc = Document(path) |
|
|
return "\n".join(p.text for p in doc.paragraphs) |
|
|
|
|
|
|
|
|
def load_files_to_texts(files): |
|
|
""" |
|
|
Accepts mixed uploads (.pdf, .docx, .txt). |
|
|
Returns a list[str] of raw texts (one per file). |
|
|
""" |
|
|
texts = [] |
|
|
for f in files or []: |
|
|
path = getattr(f, "path", None) or getattr(f, "name", None) |
|
|
name = (path or str(f)).lower() |
|
|
|
|
|
if name.endswith(".pdf"): |
|
|
texts.append(read_pdf_from_path_or_bytes(f if path is None else path)) |
|
|
|
|
|
elif name.endswith(".docx"): |
|
|
if path: |
|
|
texts.append(read_docx_text(path)) |
|
|
else: |
|
|
|
|
|
data = f.read() if hasattr(f, "read") else getattr(f, "bytes", b"") |
|
|
import tempfile |
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as tf: |
|
|
tf.write(data) |
|
|
tmp_path = tf.name |
|
|
texts.append(read_docx_text(tmp_path)) |
|
|
os.unlink(tmp_path) |
|
|
|
|
|
elif name.endswith(".txt"): |
|
|
if path and os.path.exists(path): |
|
|
with open(path, "r", errors="ignore") as fh: |
|
|
texts.append(fh.read()) |
|
|
else: |
|
|
data = f.read().decode("utf-8", errors="ignore") if hasattr(f, "read") else "" |
|
|
texts.append(data) |
|
|
else: |
|
|
continue |
|
|
return texts |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def chunk_text(text, chunk_size=600, overlap=120): |
|
|
words = text.split() |
|
|
chunks = [] |
|
|
i = 0 |
|
|
while i < len(words): |
|
|
chunk = words[i:i+chunk_size] |
|
|
chunks.append(" ".join(chunk)) |
|
|
i += chunk_size - overlap |
|
|
return chunks |
|
|
|
|
|
|
|
|
|
|
|
index = None |
|
|
corpus_chunks = [] |
|
|
|
|
|
def build_index(files, progress=gr.Progress()): |
|
|
global index, corpus_chunks |
|
|
try: |
|
|
texts = load_files_to_texts(files) |
|
|
corpus_chunks = [] |
|
|
for t in texts: |
|
|
if t and t.strip(): |
|
|
corpus_chunks += chunk_text(t) |
|
|
|
|
|
if not corpus_chunks: |
|
|
return "No text extracted from files.", 0 |
|
|
|
|
|
progress(0.3, desc="Embedding chunks…") |
|
|
embeddings = embedder.encode(corpus_chunks, convert_to_numpy=True, show_progress_bar=False) |
|
|
d = embeddings.shape[1] |
|
|
|
|
|
|
|
|
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-10 |
|
|
embeddings = embeddings / norms |
|
|
|
|
|
progress(0.6, desc="Creating FAISS index…") |
|
|
index = faiss.IndexFlatIP(d) |
|
|
index.add(embeddings.astype(np.float32)) |
|
|
|
|
|
return f"Indexed {len(corpus_chunks)} chunks.", len(corpus_chunks) |
|
|
except Exception as e: |
|
|
return f"Build failed: {e}", 0 |
|
|
|
|
|
|
|
|
|
|
|
def answer_question(question, top_k=5, max_new_tokens=256, progress=gr.Progress()): |
|
|
|
|
|
if index is None or not corpus_chunks: |
|
|
return "Index not built yet. Upload PDFs and click **Build Index** first." |
|
|
|
|
|
|
|
|
q = embedder.encode([question], convert_to_numpy=True) |
|
|
q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-10) |
|
|
|
|
|
D, I = index.search(q.astype(np.float32), int(top_k)) |
|
|
retrieved = [corpus_chunks[i] for i in I[0] if i < len(corpus_chunks)] |
|
|
|
|
|
context = "\n\n".join(retrieved) |
|
|
prompt = ( |
|
|
"You are a helpful study assistant. Using ONLY the context, answer the question.\n" |
|
|
"If the answer isn't in the context, say you don't have enough information.\n\n" |
|
|
f"Context:\n{context}\n\nQuestion: {question}\nAnswer:" |
|
|
) |
|
|
|
|
|
out = generator(prompt, max_new_tokens=int(max_new_tokens), temperature=0.2) |
|
|
return out[0]["generated_text"].strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_app(): |
|
|
"""Wipe in-memory state and return cleared UI values.""" |
|
|
global index, corpus_chunks |
|
|
index = None |
|
|
corpus_chunks = [] |
|
|
|
|
|
return "Reset: memory cleared. Ready.", 0, "", "", None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Group 5 Study Helper (RAG)") as demo: |
|
|
gr.Markdown("# Group 5 Study Helper (RAG)\nUpload PDFs → Build Index → Ask questions.") |
|
|
|
|
|
with gr.Row(): |
|
|
file_in = gr.File(file_count="multiple", file_types=[".pdf", ".docx", ".txt"], label="Upload PDF/DOCX/TXT files") |
|
|
with gr.Row(): |
|
|
build_btn = gr.Button("Build Index", variant="primary") |
|
|
status = gr.Markdown() |
|
|
chunk_count = gr.Number(label="Chunk count", interactive=False) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
question = gr.Textbox(label="Your question") |
|
|
with gr.Row(): |
|
|
topk = gr.Slider(1, 10, value=5, step=1, label="Top-K passages") |
|
|
max_tokens = gr.Slider(64, 512, value=256, step=16, label="Max new tokens") |
|
|
with gr.Row(): |
|
|
ask_btn = gr.Button("Ask", variant="primary") |
|
|
with gr.Row(): |
|
|
answer = gr.Markdown(label="Answer") |
|
|
|
|
|
with gr.Row(): |
|
|
reset_btn = gr.Button("Reset (clear memory & UI)") |
|
|
|
|
|
gr.ClearButton([file_in, question, answer, status]) |
|
|
|
|
|
def _build(files): |
|
|
msg, n = build_index(files) |
|
|
return msg, n or 0 |
|
|
|
|
|
build_btn.click(_build, inputs=[file_in], outputs=[status, chunk_count]) |
|
|
evt = ask_btn.click(lambda: "⏳ Processing … this might take a minute (we're on the free tier)", inputs=None, outputs=answer) |
|
|
evt.then(answer_question, inputs=[question, topk, max_tokens], outputs=answer) |
|
|
|
|
|
reset_btn.click( |
|
|
reset_app, |
|
|
inputs=None, |
|
|
outputs=[status, chunk_count, answer, question, file_in], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo.launch() |
|
|
|