pdfSumAndQnA / app.py
Aroy1997's picture
Update app.py
3ea0da6 verified
# KEEPING YOUR ORIGINAL IMPORTS
import gradio as gr
import PyPDF2
import io
from transformers import pipeline, AutoTokenizer
import torch
import re
from typing import List, Tuple
import warnings
warnings.filterwarnings("ignore")
# QUESTION-ANSWERING ADDITION
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2")
# === SUMMARIZER CLASS ===
class PDFSummarizer:
def __init__(self):
self.model_name = "sshleifer/distilbart-cnn-12-6"
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {self.device}")
try:
self.summarizer = pipeline(
"summarization",
model=self.model_name,
device=0 if self.device == "cuda" else -1,
framework="pt",
model_kwargs={"torch_dtype": torch.float16 if self.device == "cuda" else torch.float32}
)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading model: {e}")
self.model_name = "facebook/bart-large-cnn"
self.summarizer = pipeline("summarization", model=self.model_name, device=-1)
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
print("Fallback model loaded")
def extract_text_from_pdf(self, pdf_file) -> str:
try:
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_file))
text = ""
for page_num, page in enumerate(pdf_reader.pages):
page_text = page.extract_text()
if page_text.strip():
text += f"\n--- Page {page_num + 1} ---\n"
text += page_text
return text.strip()
except Exception as e:
raise Exception(f"Error extracting text from PDF: {str(e)}")
def clean_text(self, text: str) -> str:
text = re.sub(r'\s+', ' ', text)
text = re.sub(r'[^\w\s.,!?;:()\-"]', ' ', text)
text = re.sub(r'--- Page \d+ ---', '', text)
return text.strip()
def chunk_text(self, text: str, max_chunk_length: int = 512) -> List[str]:
sentences = text.split('. ')
chunks = []
current_chunk = ""
for sentence in sentences:
potential_chunk = current_chunk + sentence + ". "
if len(potential_chunk.split()) <= max_chunk_length:
current_chunk = potential_chunk
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + ". "
if current_chunk:
chunks.append(current_chunk.strip())
return chunks[:5]
def summarize_chunk(self, chunk: str, max_length: int = 100, min_length: int = 30) -> str:
try:
summary = self.summarizer(
chunk,
max_length=max_length,
min_length=min_length,
do_sample=False,
truncation=True,
early_stopping=True,
num_beams=2
)
return summary[0]['summary_text']
except Exception as e:
return f"Error summarizing chunk: {str(e)}"
def process_pdf(self, pdf_file, summary_type: str) -> Tuple[str, str, str]:
try:
raw_text = self.extract_text_from_pdf(pdf_file)
if not raw_text.strip():
return "❌ Error: No text could be extracted from the PDF.", "", ""
cleaned_text = self.clean_text(raw_text)
word_count = len(cleaned_text.split())
char_count = len(cleaned_text)
if word_count < 50:
return "❌ Error: PDF contains too little text to summarize.", "", ""
chunks = self.chunk_text(cleaned_text)
if summary_type == "Brief (Quick)":
max_len, min_len = 60, 20
elif summary_type == "Detailed":
max_len, min_len = 100, 40
else:
max_len, min_len = 150, 60
chunk_summaries = []
for i, chunk in enumerate(chunks):
print(f"Processing chunk {i+1}/{len(chunks)}")
summary = self.summarize_chunk(chunk, max_len, min_len)
chunk_summaries.append(summary)
combined_summary = " ".join(chunk_summaries)
if len(chunks) <= 2:
final_summary = combined_summary
else:
final_summary = self.summarize_chunk(
combined_summary,
max_length=min(200, max_len * 1.5),
min_length=min_len
)
summary_stats = f"""
πŸ“Š **Document Statistics:**
- Original word count: {word_count:,}
- Original character count: {char_count:,}
- Pages processed: {len(chunks)}
- Summary word count: {len(final_summary.split()):,}
- Compression ratio: {word_count / len(final_summary.split()):.1f}:1
"""
return final_summary, summary_stats, "βœ… Summary generated successfully!"
except Exception as e:
return f"❌ Error processing PDF: {str(e)}", "", ""
pdf_summarizer = PDFSummarizer()
global_pdf_text = "" # used for QA
def summarize_pdf_interface(pdf_file, summary_type):
global global_pdf_text
if pdf_file is None:
return "❌ Please upload a PDF file.", "", ""
try:
with open(pdf_file, 'rb') as f:
pdf_content = f.read()
global_pdf_text = pdf_summarizer.clean_text(pdf_summarizer.extract_text_from_pdf(pdf_content))
summary, stats, status = pdf_summarizer.process_pdf(pdf_content, summary_type)
return summary, stats, status
except Exception as e:
return f"❌ Error: {str(e)}", "", ""
# === NEW: QA FUNCTION ===
def answer_question_interface(question):
if not global_pdf_text:
return "❌ Please upload and summarize a PDF first."
try:
answer = qa_pipeline(question=question, context=global_pdf_text)
return answer["answer"]
except Exception as e:
return f"❌ Error: {str(e)}"
# === GRADIO INTERFACE ===
def create_interface():
with gr.Blocks(title="πŸ“„ AI PDF Summarizer & QA", theme=gr.themes.Soft()) as interface:
gr.Markdown("# πŸ“„ PDF Summarizer + πŸ’¬ Question Answering")
with gr.Row():
with gr.Column(scale=1):
pdf_input = gr.File(label="πŸ“ Upload PDF", file_types=[".pdf"], type="filepath")
summary_type = gr.Radio(
choices=["Brief (Quick)", "Detailed", "Comprehensive"],
value="Detailed",
label="πŸ“ Summary Length"
)
summarize_btn = gr.Button("πŸš€ Generate Summary", variant="primary")
status_output = gr.Textbox(label="πŸ“‹ Status", interactive=False, max_lines=2)
with gr.Column(scale=2):
summary_output = gr.Textbox(label="πŸ“ Summary", lines=15, interactive=False)
stats_output = gr.Markdown(label="πŸ“Š Document Statistics")
summarize_btn.click(
fn=summarize_pdf_interface,
inputs=[pdf_input, summary_type],
outputs=[summary_output, stats_output, status_output]
)
pdf_input.change(
fn=summarize_pdf_interface,
inputs=[pdf_input, summary_type],
outputs=[summary_output, stats_output, status_output]
)
gr.Markdown("## πŸ’¬ Ask a Question About the PDF")
with gr.Row():
question_input = gr.Textbox(label="❓ Your Question", placeholder="e.g. What is the main finding?")
answer_output = gr.Textbox(label="πŸ’‘ Answer", interactive=False)
question_input.submit(fn=answer_question_interface, inputs=question_input, outputs=answer_output)
return interface
# === MAIN ===
if __name__ == "__main__":
interface = create_interface()
interface.launch()