# app.py import os import math from typing import List from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM import torch import PyPDF2 import gradio as gr from tqdm.auto import tqdm # ---------- Configuration ---------- # Models LONG_MODEL_ID = "allenai/led-base-16384" # good for very long docs (up to ~16k tokens) SHORT_MODEL_ID = "sshleifer/distilbart-cnn-12-6" # faster, efficient summarizer for shorter texts # Threshold (in tokens) when we'll switch to long-document model workflow LONG_MODEL_SWITCH_TOKENS = 2000 # you can tweak this # Per-chunk generation defaults DEFAULT_SUMMARY_MIN_LENGTH = 60 DEFAULT_SUMMARY_MAX_LENGTH = 256 # Whether to prefer the long model whenever possible prefer_long_model = True # ----------------------------------- def extract_text_from_pdf(fileobj) -> str: """ Extract text from an uploaded PDF file object (a file-like object). Uses PyPDF2 (works for many PDFs). """ reader = PyPDF2.PdfReader(fileobj) all_text = [] for page in reader.pages: try: txt = page.extract_text() except Exception: txt = "" if txt: all_text.append(txt) return "\n\n".join(all_text) def choose_device(): return 0 if torch.cuda.is_available() else -1 def load_pipeline_for_model(model_id: str, device: int): """Load tokenizer, model and summarization pipeline for given model_id.""" tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) model = AutoModelForSeq2SeqLM.from_pretrained(model_id) # pipeline: specify device (0 for cuda, -1 for cpu) summarizer = pipeline( "summarization", model=model, tokenizer=tokenizer, device=device ) return tokenizer, summarizer def chunk_text_by_tokens(text: str, tokenizer, max_tokens: int) -> List[str]: """ Chunk text into list of strings where each chunk's token length <= max_tokens. Uses tokenizer.encode to count tokens approximately. """ tokens = tokenizer.encode(text, return_tensors="pt")[0].tolist() # If tokenization of whole text is too big, we'll iterate by sentences/paragraphs # Simpler approach: split by paragraphs and aggregate until token limit. paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] chunks = [] current = [] current_len = 0 for p in paragraphs: p_tokens = len(tokenizer.encode(p, add_special_tokens=False)) if p_tokens > max_tokens: # paragraph itself larger than limit: split by sentences (fallback) sentences = p.split(". ") for s in sentences: s = s.strip() if not s: continue s_tokens = len(tokenizer.encode(s, add_special_tokens=False)) if current_len + s_tokens <= max_tokens: current.append(s) current_len += s_tokens else: if current: chunks.append(". ".join(current)) current = [s] current_len = s_tokens continue if current_len + p_tokens <= max_tokens: current.append(p) current_len += p_tokens else: if current: chunks.append("\n\n".join(current)) current = [p] current_len = p_tokens if current: chunks.append("\n\n".join(current)) return chunks def summarize_text_pipeline(text: str, prefer_long: bool = True, min_length: int = DEFAULT_SUMMARY_MIN_LENGTH, max_length: int = DEFAULT_SUMMARY_MAX_LENGTH): """ Main logic: - Tokenize the entire text to estimate length. - If very long and prefer_long: use LED model with large max input length. - Otherwise use the faster DistilBART. - For models with small max_input_length, chunk and summarize each chunk then combine chunk summaries and optionally summarize the combined summary. """ device = choose_device() # quick decision: load short tokenizer first to estimate tokens # We'll use the SHORT_MODEL tokenizer for a fast token estimate (it's lightweight). short_tok = AutoTokenizer.from_pretrained(SHORT_MODEL_ID, use_fast=True) total_tokens = len(short_tok.encode(text, add_special_tokens=False)) use_long_model = (prefer_long and total_tokens > LONG_MODEL_SWITCH_TOKENS) if use_long_model: model_id = LONG_MODEL_ID else: model_id = SHORT_MODEL_ID tokenizer, summarizer = load_pipeline_for_model(model_id, device) model_max = tokenizer.model_max_length if hasattr(tokenizer, "model_max_length") else 1024 # keep some headroom chunk_input_limit = min(model_max - 64, 16000) # If input fits, summarize directly if total_tokens <= chunk_input_limit: summary = summarizer( text, min_length=min_length, max_length=max_length, do_sample=False ) return summary[0]["summary_text"] # Otherwise chunk + summarize each part chunks = chunk_text_by_tokens(text, tokenizer, chunk_input_limit) chunk_summaries = [] for chunk in tqdm(chunks, desc="Summarizing chunks"): out = summarizer(chunk, min_length=max(20, min_length//2), max_length=max_length, do_sample=False) chunk_summaries.append(out[0]["summary_text"]) # Combine chunk summaries combined = "\n\n".join(chunk_summaries) # If combined summary is still long, produce a final summary if len(tokenizer.encode(combined, add_special_tokens=False)) > chunk_input_limit: # we need to chunk again or use short summarizer to condense # load a short fast summarizer to condense (DistilBART) short_tok2, short_summarizer = load_pipeline_for_model(SHORT_MODEL_ID, device) final = short_summarizer(combined, min_length=min_length, max_length=max_length, do_sample=False) return final[0]["summary_text"] else: # summarizer can condense combined final = summarizer(combined, min_length=min_length, max_length=max_length, do_sample=False) return final[0]["summary_text"] # ---------- Gradio UI ---------- def summarize_pdf(file, min_len, max_len, prefer_long): if file is None: return "No file uploaded." # file is a tempfile with .name, pass file.open() or file.name to PyPDF2 # Gradio provides a dictionary or temporary path; handle both. try: fpath = file.name with open(fpath, "rb") as fh: text = extract_text_from_pdf(fh) except Exception: # maybe file-like try: file.seek(0) text = extract_text_from_pdf(file) except Exception as e: return f"Failed to read PDF: {e}" if not text.strip(): return "No extractable text found in PDF." # Run summarization (this may take a while depending on model & GPU) summary = summarize_text_pipeline(text, prefer_long, int(min_len), int(max_len)) return summary def create_ui(): with gr.Blocks() as demo: gr.Markdown("# PDF Summarizer (Gradio + Transformers)\nUpload a PDF and get an abstractive summary. Uses LED for long docs and DistilBART for faster shorter summarization.") with gr.Row(): file_in = gr.File(label="Upload PDF", file_types=[".pdf"]) with gr.Column(): min_len = gr.Slider(10, 400, value=DEFAULT_SUMMARY_MIN_LENGTH, step=10, label="Minimum summary length (tokens approx)") max_len = gr.Slider(50, 1024, value=DEFAULT_SUMMARY_MAX_LENGTH, step=10, label="Maximum summary length (tokens approx)") prefer_long_cb = gr.Checkbox(value=True, label="Prefer long-document model for long PDFs (recommended)") run_btn = gr.Button("Summarize") output = gr.Textbox(label="Summary", lines=20) run_btn.click(fn=summarize_pdf, inputs=[file_in, min_len, max_len, prefer_long_cb], outputs=output) return demo if __name__ == "__main__": ui = create_ui() ui.launch( share=True)