| import os |
| os.environ["TRANSFORMERS_NO_TORCHVISION"] = "1" |
|
|
| import gradio as gr |
| import torch |
| import faiss |
| import numpy as np |
| import re |
| from pypdf import PdfReader |
| from sentence_transformers import SentenceTransformer |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
| embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") |
|
|
| model_name = "google/flan-t5-base" |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
| index = None |
| chunks = [] |
|
|
|
|
| def split_text(text, chunk_size=600, overlap=100): |
| pieces = [] |
| start = 0 |
| while start < len(text): |
| end = start + chunk_size |
| chunk = text[start:end] |
|
|
| |
| if len(chunk.strip()) > 50 and not chunk.strip().isdigit(): |
| pieces.append(chunk) |
|
|
| start = end - overlap |
| return pieces |
|
|
|
|
| def clean_text(text): |
| |
| text = re.split(r'References|REFERENCES', text)[0] |
|
|
| |
| text = re.sub(r'arXiv:\d+\.\d+', '', text) |
|
|
| |
| text = re.sub(r'\[\d+\]', '', text) |
|
|
| |
| text = re.sub(r'^\d+\s*$', '', text, flags=re.MULTILINE) |
|
|
| return text |
|
|
|
|
| def process_pdf(file): |
| global index, chunks |
|
|
| if file is None: |
| return "Please upload a PDF." |
|
|
| reader = PdfReader(file) |
| full_text = "" |
|
|
| for page in reader.pages: |
| text = page.extract_text() |
| if text: |
| full_text += text |
|
|
| if full_text.strip() == "": |
| return "PDF has no extractable text." |
|
|
| full_text = clean_text(full_text) |
| chunks = split_text(full_text) |
|
|
| embeddings = embedder.encode(chunks) |
| embeddings = np.array(embeddings).astype("float32") |
|
|
| dimension = embeddings.shape[1] |
| index = faiss.IndexFlatL2(dimension) |
| index.add(embeddings) |
|
|
| return "PDF processed successfully! Ask your question." |
|
|
|
|
| def ask_question(question): |
| global index, chunks |
|
|
| if index is None: |
| return "Please process a PDF first." |
|
|
| question_embedding = embedder.encode([question]).astype("float32") |
| D, I = index.search(question_embedding, k=3) |
|
|
| retrieved_chunks = [chunks[i] for i in I[0]] |
| context = "\n\n".join(retrieved_chunks) |
|
|
| prompt = f""" |
| You are a research assistant. |
| |
| Explain clearly what the paper is about. |
| Answer in 3-5 complete sentences. |
| Do not include citations or reference numbers. |
| If unclear, say the document does not clearly specify. |
| |
| Context: |
| {context} |
| |
| Question: |
| {question} |
| |
| Answer: |
| """ |
|
|
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True) |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=250, |
| temperature=0.5 |
| ) |
|
|
| answer = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| return answer.strip() |
|
|
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("# ๐ Clean RAG Paper QA") |
|
|
| file_input = gr.File(label="Upload Research PDF", file_types=[".pdf"]) |
| process_btn = gr.Button("Process PDF") |
| status_output = gr.Textbox(label="Status") |
|
|
| question_input = gr.Textbox(label="Ask a Question") |
| ask_btn = gr.Button("Get Answer") |
| answer_output = gr.Textbox(label="Answer") |
|
|
| process_btn.click(process_pdf, inputs=file_input, outputs=status_output) |
| ask_btn.click(ask_question, inputs=question_input, outputs=answer_output) |
|
|
| demo.launch() |