adityagofi's picture
Update title and Top k related query
82e5ffa verified
import PyPDF2
import numpy as np
import faiss
from transformers import BertTokenizer, BertForMaskedLM, BertForQuestionAnswering
import torch
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
qa_model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
from transformers import AutoTokenizer, AutoModelForCausalLM
gen_model_id = "distilgpt2"
gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_id)
gen_model = AutoModelForCausalLM.from_pretrained(gen_model_id).to("cuda" if torch.cuda.is_available() else "cpu")
def read_pdf(file):
reader = PyPDF2.PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text()
return text
def split_text(text, chunk_size=500):
return [text[i:i+chunk_size] for i in range(0, len(text), chunk_size)]
def encode_text(text):
inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
with torch.no_grad():
outputs = qa_model.bert(**inputs)
return outputs.last_hidden_state.mean(dim=1).numpy().astype(np.float32)
def create_faiss_index(chunks):
embeddings = np.vstack([encode_text(chunk) for chunk in chunks])
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
return index, embeddings
def search_faq(query, index, k=3):
query_emb = encode_text(query)
D, I = index.search(query_emb, k)
return I
def generate_distilgpt2_answer(context, question):
prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
inputs = gen_tokenizer(prompt, return_tensors="pt").to(gen_model.device)
with torch.no_grad():
outputs = gen_model.generate(**inputs, max_new_tokens=64,
pad_token_id=gen_tokenizer.eos_token_id,
eos_token_id=gen_tokenizer.eos_token_id,
repetition_penalty=1.3)
generated = gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
return generated.split("Answer:")[-1].strip()
def process_pdf_and_answer(pdf_file, query, top_k=1):
text = read_pdf(pdf_file)
chunks = split_text(text)
faiss_index, _ = create_faiss_index(chunks)
indices = search_faq(query, faiss_index, k=top_k)
answers = []
for idx in indices[0]:
context = chunks[idx]
answer = generate_distilgpt2_answer(context, query)
answers.append(answer)
return "\n\n---\n\n".join(answers)
import gradio as gr
interface = gr.Interface(
fn=process_pdf_and_answer,
inputs=[
gr.File(label="Upload PDF"),
gr.Textbox(label="Your Question"),
],
outputs=gr.Textbox(label="Generated Answer(s)"),
title="📄 PDF Question Answering",
description="Upload a PDF and ask a question about its content. The model will try to answer based on the most relevant chunks.",
)
interface.launch()