Group_5_Project / app.py
StevenMSAI's picture
Update app.py
02e51e6 verified
import os
import io
import gradio as gr
import faiss
import numpy as np
from pypdf import PdfReader
from docx import Document
from sentence_transformers import SentenceTransformer
from transformers import pipeline
# ---- Models (CPU-friendly) ----
# We're using Hugging Face's free tier, which is 2 virtual
# cores and 16gb ram only. So we need to keep these lightweight + cpu-only
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" # small & fast on CPU
GEN_MODEL_NAME = "MBZUAI/LaMini-Flan-T5-248M" # text2text model that runs on CPU
embedder = SentenceTransformer(EMBED_MODEL_NAME)
generator = pipeline("text2text-generation", model=GEN_MODEL_NAME)
# ---- PDF to text ----
def read_pdf_from_path_or_bytes(file_obj_or_path):
path = getattr(file_obj_or_path, "path", None)
if isinstance(file_obj_or_path, str) and os.path.exists(file_obj_or_path):
path = file_obj_or_path
if path and os.path.exists(path):
reader = PdfReader(path)
return "\n".join((p.extract_text() or "") for p in reader.pages)
data = None
if hasattr(file_obj_or_path, "read"):
data = file_obj_or_path.read()
elif hasattr(file_obj_or_path, "bytes"):
data = file_obj_or_path.bytes
if data:
reader = PdfReader(io.BytesIO(data))
return "\n".join((p.extract_text() or "") for p in reader.pages)
return ""
def read_docx_text(path):
doc = Document(path)
return "\n".join(p.text for p in doc.paragraphs)
def load_files_to_texts(files):
"""
Accepts mixed uploads (.pdf, .docx, .txt).
Returns a list[str] of raw texts (one per file).
"""
texts = []
for f in files or []:
path = getattr(f, "path", None) or getattr(f, "name", None)
name = (path or str(f)).lower()
if name.endswith(".pdf"):
texts.append(read_pdf_from_path_or_bytes(f if path is None else path))
elif name.endswith(".docx"):
if path:
texts.append(read_docx_text(path))
else:
# Need a real path for python-docx
data = f.read() if hasattr(f, "read") else getattr(f, "bytes", b"")
import tempfile
with tempfile.NamedTemporaryFile(delete=False, suffix=".docx") as tf:
tf.write(data)
tmp_path = tf.name
texts.append(read_docx_text(tmp_path))
os.unlink(tmp_path)
elif name.endswith(".txt"):
if path and os.path.exists(path):
with open(path, "r", errors="ignore") as fh:
texts.append(fh.read())
else:
data = f.read().decode("utf-8", errors="ignore") if hasattr(f, "read") else ""
texts.append(data)
else:
continue
return texts
# ---- Chunking ----
def chunk_text(text, chunk_size=600, overlap=120):
words = text.split()
chunks = []
i = 0
while i < len(words):
chunk = words[i:i+chunk_size]
chunks.append(" ".join(chunk))
i += chunk_size - overlap
return chunks
# ---- Build FAISS index from uploaded PDFs ----
index = None
corpus_chunks = []
def build_index(files, progress=gr.Progress()):
global index, corpus_chunks
try:
texts = load_files_to_texts(files)
corpus_chunks = []
for t in texts:
if t and t.strip():
corpus_chunks += chunk_text(t)
if not corpus_chunks:
return "No text extracted from files.", 0
progress(0.3, desc="Embedding chunks…")
embeddings = embedder.encode(corpus_chunks, convert_to_numpy=True, show_progress_bar=False)
d = embeddings.shape[1]
# Normalize for cosine sim with inner product
norms = np.linalg.norm(embeddings, axis=1, keepdims=True) + 1e-10
embeddings = embeddings / norms
progress(0.6, desc="Creating FAISS index…")
index = faiss.IndexFlatIP(d)
index.add(embeddings.astype(np.float32))
return f"Indexed {len(corpus_chunks)} chunks.", len(corpus_chunks)
except Exception as e:
return f"Build failed: {e}", 0
# ---- RAG query -> retrieve -> generate ----
def answer_question(question, top_k=5, max_new_tokens=256, progress=gr.Progress()):
if index is None or not corpus_chunks:
return "Index not built yet. Upload PDFs and click **Build Index** first."
# embed query (normalize for inner product)
q = embedder.encode([question], convert_to_numpy=True)
q = q / (np.linalg.norm(q, axis=1, keepdims=True) + 1e-10)
D, I = index.search(q.astype(np.float32), int(top_k))
retrieved = [corpus_chunks[i] for i in I[0] if i < len(corpus_chunks)]
context = "\n\n".join(retrieved)
prompt = (
"You are a helpful study assistant. Using ONLY the context, answer the question.\n"
"If the answer isn't in the context, say you don't have enough information.\n\n"
f"Context:\n{context}\n\nQuestion: {question}\nAnswer:"
)
out = generator(prompt, max_new_tokens=int(max_new_tokens), temperature=0.2)
return out[0]["generated_text"].strip()
# Everything is saved to RAM only and will reset when
# the model sleeps or restarts. Just incase a new user
# comes before that, adding a "reset" ability so they're
# not stuck with the old user's stuff
def reset_app():
"""Wipe in-memory state and return cleared UI values."""
global index, corpus_chunks
index = None
corpus_chunks = []
# status, chunk_count, answer, question, files
return "Reset: memory cleared. Ready.", 0, "", "", None
# ---- Gradio v5 UI (Blocks) ----
with gr.Blocks(title="Group 5 Study Helper (RAG)") as demo:
gr.Markdown("# Group 5 Study Helper (RAG)\nUpload PDFs → Build Index → Ask questions.")
with gr.Row():
file_in = gr.File(file_count="multiple", file_types=[".pdf", ".docx", ".txt"], label="Upload PDF/DOCX/TXT files")
with gr.Row():
build_btn = gr.Button("Build Index", variant="primary")
status = gr.Markdown()
chunk_count = gr.Number(label="Chunk count", interactive=False)
with gr.Row():
question = gr.Textbox(label="Your question")
with gr.Row():
topk = gr.Slider(1, 10, value=5, step=1, label="Top-K passages")
max_tokens = gr.Slider(64, 512, value=256, step=16, label="Max new tokens")
with gr.Row():
ask_btn = gr.Button("Ask", variant="primary")
with gr.Row():
answer = gr.Markdown(label="Answer")
with gr.Row():
reset_btn = gr.Button("Reset (clear memory & UI)")
# ClearButton clears UI components
gr.ClearButton([file_in, question, answer, status])
def _build(files):
msg, n = build_index(files)
return msg, n or 0
build_btn.click(_build, inputs=[file_in], outputs=[status, chunk_count])
evt = ask_btn.click(lambda: "⏳ Processing … this might take a minute (we're on the free tier)", inputs=None, outputs=answer)
evt.then(answer_question, inputs=[question, topk, max_tokens], outputs=answer)
reset_btn.click(
reset_app,
inputs=None,
outputs=[status, chunk_count, answer, question, file_in],
)
demo.launch()