PDF / app.py
oluinioluwa814's picture
Update app.py
b0cf3bb verified
# app.py
import os
import math
from typing import List
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import PyPDF2
import gradio as gr
from tqdm.auto import tqdm
# ---------- Configuration ----------
# Models
LONG_MODEL_ID = "allenai/led-base-16384" # good for very long docs (up to ~16k tokens)
SHORT_MODEL_ID = "sshleifer/distilbart-cnn-12-6" # faster, efficient summarizer for shorter texts
# Threshold (in tokens) when we'll switch to long-document model workflow
LONG_MODEL_SWITCH_TOKENS = 2000 # you can tweak this
# Per-chunk generation defaults
DEFAULT_SUMMARY_MIN_LENGTH = 60
DEFAULT_SUMMARY_MAX_LENGTH = 256
# Whether to prefer the long model whenever possible
prefer_long_model = True
# -----------------------------------
def extract_text_from_pdf(fileobj) -> str:
"""
Extract text from an uploaded PDF file object (a file-like object).
Uses PyPDF2 (works for many PDFs).
"""
reader = PyPDF2.PdfReader(fileobj)
all_text = []
for page in reader.pages:
try:
txt = page.extract_text()
except Exception:
txt = ""
if txt:
all_text.append(txt)
return "\n\n".join(all_text)
def choose_device():
return 0 if torch.cuda.is_available() else -1
def load_pipeline_for_model(model_id: str, device: int):
"""Load tokenizer, model and summarization pipeline for given model_id."""
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
# pipeline: specify device (0 for cuda, -1 for cpu)
summarizer = pipeline(
"summarization",
model=model,
tokenizer=tokenizer,
device=device
)
return tokenizer, summarizer
def chunk_text_by_tokens(text: str, tokenizer, max_tokens: int) -> List[str]:
"""
Chunk text into list of strings where each chunk's token length <= max_tokens.
Uses tokenizer.encode to count tokens approximately.
"""
tokens = tokenizer.encode(text, return_tensors="pt")[0].tolist()
# If tokenization of whole text is too big, we'll iterate by sentences/paragraphs
# Simpler approach: split by paragraphs and aggregate until token limit.
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
chunks = []
current = []
current_len = 0
for p in paragraphs:
p_tokens = len(tokenizer.encode(p, add_special_tokens=False))
if p_tokens > max_tokens:
# paragraph itself larger than limit: split by sentences (fallback)
sentences = p.split(". ")
for s in sentences:
s = s.strip()
if not s:
continue
s_tokens = len(tokenizer.encode(s, add_special_tokens=False))
if current_len + s_tokens <= max_tokens:
current.append(s)
current_len += s_tokens
else:
if current:
chunks.append(". ".join(current))
current = [s]
current_len = s_tokens
continue
if current_len + p_tokens <= max_tokens:
current.append(p)
current_len += p_tokens
else:
if current:
chunks.append("\n\n".join(current))
current = [p]
current_len = p_tokens
if current:
chunks.append("\n\n".join(current))
return chunks
def summarize_text_pipeline(text: str,
prefer_long: bool = True,
min_length: int = DEFAULT_SUMMARY_MIN_LENGTH,
max_length: int = DEFAULT_SUMMARY_MAX_LENGTH):
"""
Main logic:
- Tokenize the entire text to estimate length.
- If very long and prefer_long: use LED model with large max input length.
- Otherwise use the faster DistilBART.
- For models with small max_input_length, chunk and summarize each chunk then
combine chunk summaries and optionally summarize the combined summary.
"""
device = choose_device()
# quick decision: load short tokenizer first to estimate tokens
# We'll use the SHORT_MODEL tokenizer for a fast token estimate (it's lightweight).
short_tok = AutoTokenizer.from_pretrained(SHORT_MODEL_ID, use_fast=True)
total_tokens = len(short_tok.encode(text, add_special_tokens=False))
use_long_model = (prefer_long and total_tokens > LONG_MODEL_SWITCH_TOKENS)
if use_long_model:
model_id = LONG_MODEL_ID
else:
model_id = SHORT_MODEL_ID
tokenizer, summarizer = load_pipeline_for_model(model_id, device)
model_max = tokenizer.model_max_length if hasattr(tokenizer, "model_max_length") else 1024
# keep some headroom
chunk_input_limit = min(model_max - 64, 16000)
# If input fits, summarize directly
if total_tokens <= chunk_input_limit:
summary = summarizer(
text,
min_length=min_length,
max_length=max_length,
do_sample=False
)
return summary[0]["summary_text"]
# Otherwise chunk + summarize each part
chunks = chunk_text_by_tokens(text, tokenizer, chunk_input_limit)
chunk_summaries = []
for chunk in tqdm(chunks, desc="Summarizing chunks"):
out = summarizer(chunk, min_length=max(20, min_length//2), max_length=max_length, do_sample=False)
chunk_summaries.append(out[0]["summary_text"])
# Combine chunk summaries
combined = "\n\n".join(chunk_summaries)
# If combined summary is still long, produce a final summary
if len(tokenizer.encode(combined, add_special_tokens=False)) > chunk_input_limit:
# we need to chunk again or use short summarizer to condense
# load a short fast summarizer to condense (DistilBART)
short_tok2, short_summarizer = load_pipeline_for_model(SHORT_MODEL_ID, device)
final = short_summarizer(combined, min_length=min_length, max_length=max_length, do_sample=False)
return final[0]["summary_text"]
else:
# summarizer can condense combined
final = summarizer(combined, min_length=min_length, max_length=max_length, do_sample=False)
return final[0]["summary_text"]
# ---------- Gradio UI ----------
def summarize_pdf(file, min_len, max_len, prefer_long):
if file is None:
return "No file uploaded."
# file is a tempfile with .name, pass file.open() or file.name to PyPDF2
# Gradio provides a dictionary or temporary path; handle both.
try:
fpath = file.name
with open(fpath, "rb") as fh:
text = extract_text_from_pdf(fh)
except Exception:
# maybe file-like
try:
file.seek(0)
text = extract_text_from_pdf(file)
except Exception as e:
return f"Failed to read PDF: {e}"
if not text.strip():
return "No extractable text found in PDF."
# Run summarization (this may take a while depending on model & GPU)
summary = summarize_text_pipeline(text, prefer_long, int(min_len), int(max_len))
return summary
def create_ui():
with gr.Blocks() as demo:
gr.Markdown("# PDF Summarizer (Gradio + Transformers)\nUpload a PDF and get an abstractive summary. Uses LED for long docs and DistilBART for faster shorter summarization.")
with gr.Row():
file_in = gr.File(label="Upload PDF", file_types=[".pdf"])
with gr.Column():
min_len = gr.Slider(10, 400, value=DEFAULT_SUMMARY_MIN_LENGTH, step=10, label="Minimum summary length (tokens approx)")
max_len = gr.Slider(50, 1024, value=DEFAULT_SUMMARY_MAX_LENGTH, step=10, label="Maximum summary length (tokens approx)")
prefer_long_cb = gr.Checkbox(value=True, label="Prefer long-document model for long PDFs (recommended)")
run_btn = gr.Button("Summarize")
output = gr.Textbox(label="Summary", lines=20)
run_btn.click(fn=summarize_pdf, inputs=[file_in, min_len, max_len, prefer_long_cb], outputs=output)
return demo
if __name__ == "__main__":
ui = create_ui()
ui.launch( share=True)