Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import math | |
| from typing import List | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM | |
| import torch | |
| import PyPDF2 | |
| import gradio as gr | |
| from tqdm.auto import tqdm | |
| # ---------- Configuration ---------- | |
| # Models | |
| LONG_MODEL_ID = "allenai/led-base-16384" # good for very long docs (up to ~16k tokens) | |
| SHORT_MODEL_ID = "sshleifer/distilbart-cnn-12-6" # faster, efficient summarizer for shorter texts | |
| # Threshold (in tokens) when we'll switch to long-document model workflow | |
| LONG_MODEL_SWITCH_TOKENS = 2000 # you can tweak this | |
| # Per-chunk generation defaults | |
| DEFAULT_SUMMARY_MIN_LENGTH = 60 | |
| DEFAULT_SUMMARY_MAX_LENGTH = 256 | |
| # Whether to prefer the long model whenever possible | |
| prefer_long_model = True | |
| # ----------------------------------- | |
| def extract_text_from_pdf(fileobj) -> str: | |
| """ | |
| Extract text from an uploaded PDF file object (a file-like object). | |
| Uses PyPDF2 (works for many PDFs). | |
| """ | |
| reader = PyPDF2.PdfReader(fileobj) | |
| all_text = [] | |
| for page in reader.pages: | |
| try: | |
| txt = page.extract_text() | |
| except Exception: | |
| txt = "" | |
| if txt: | |
| all_text.append(txt) | |
| return "\n\n".join(all_text) | |
| def choose_device(): | |
| return 0 if torch.cuda.is_available() else -1 | |
| def load_pipeline_for_model(model_id: str, device: int): | |
| """Load tokenizer, model and summarization pipeline for given model_id.""" | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_id) | |
| # pipeline: specify device (0 for cuda, -1 for cpu) | |
| summarizer = pipeline( | |
| "summarization", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device=device | |
| ) | |
| return tokenizer, summarizer | |
| def chunk_text_by_tokens(text: str, tokenizer, max_tokens: int) -> List[str]: | |
| """ | |
| Chunk text into list of strings where each chunk's token length <= max_tokens. | |
| Uses tokenizer.encode to count tokens approximately. | |
| """ | |
| tokens = tokenizer.encode(text, return_tensors="pt")[0].tolist() | |
| # If tokenization of whole text is too big, we'll iterate by sentences/paragraphs | |
| # Simpler approach: split by paragraphs and aggregate until token limit. | |
| paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] | |
| chunks = [] | |
| current = [] | |
| current_len = 0 | |
| for p in paragraphs: | |
| p_tokens = len(tokenizer.encode(p, add_special_tokens=False)) | |
| if p_tokens > max_tokens: | |
| # paragraph itself larger than limit: split by sentences (fallback) | |
| sentences = p.split(". ") | |
| for s in sentences: | |
| s = s.strip() | |
| if not s: | |
| continue | |
| s_tokens = len(tokenizer.encode(s, add_special_tokens=False)) | |
| if current_len + s_tokens <= max_tokens: | |
| current.append(s) | |
| current_len += s_tokens | |
| else: | |
| if current: | |
| chunks.append(". ".join(current)) | |
| current = [s] | |
| current_len = s_tokens | |
| continue | |
| if current_len + p_tokens <= max_tokens: | |
| current.append(p) | |
| current_len += p_tokens | |
| else: | |
| if current: | |
| chunks.append("\n\n".join(current)) | |
| current = [p] | |
| current_len = p_tokens | |
| if current: | |
| chunks.append("\n\n".join(current)) | |
| return chunks | |
| def summarize_text_pipeline(text: str, | |
| prefer_long: bool = True, | |
| min_length: int = DEFAULT_SUMMARY_MIN_LENGTH, | |
| max_length: int = DEFAULT_SUMMARY_MAX_LENGTH): | |
| """ | |
| Main logic: | |
| - Tokenize the entire text to estimate length. | |
| - If very long and prefer_long: use LED model with large max input length. | |
| - Otherwise use the faster DistilBART. | |
| - For models with small max_input_length, chunk and summarize each chunk then | |
| combine chunk summaries and optionally summarize the combined summary. | |
| """ | |
| device = choose_device() | |
| # quick decision: load short tokenizer first to estimate tokens | |
| # We'll use the SHORT_MODEL tokenizer for a fast token estimate (it's lightweight). | |
| short_tok = AutoTokenizer.from_pretrained(SHORT_MODEL_ID, use_fast=True) | |
| total_tokens = len(short_tok.encode(text, add_special_tokens=False)) | |
| use_long_model = (prefer_long and total_tokens > LONG_MODEL_SWITCH_TOKENS) | |
| if use_long_model: | |
| model_id = LONG_MODEL_ID | |
| else: | |
| model_id = SHORT_MODEL_ID | |
| tokenizer, summarizer = load_pipeline_for_model(model_id, device) | |
| model_max = tokenizer.model_max_length if hasattr(tokenizer, "model_max_length") else 1024 | |
| # keep some headroom | |
| chunk_input_limit = min(model_max - 64, 16000) | |
| # If input fits, summarize directly | |
| if total_tokens <= chunk_input_limit: | |
| summary = summarizer( | |
| text, | |
| min_length=min_length, | |
| max_length=max_length, | |
| do_sample=False | |
| ) | |
| return summary[0]["summary_text"] | |
| # Otherwise chunk + summarize each part | |
| chunks = chunk_text_by_tokens(text, tokenizer, chunk_input_limit) | |
| chunk_summaries = [] | |
| for chunk in tqdm(chunks, desc="Summarizing chunks"): | |
| out = summarizer(chunk, min_length=max(20, min_length//2), max_length=max_length, do_sample=False) | |
| chunk_summaries.append(out[0]["summary_text"]) | |
| # Combine chunk summaries | |
| combined = "\n\n".join(chunk_summaries) | |
| # If combined summary is still long, produce a final summary | |
| if len(tokenizer.encode(combined, add_special_tokens=False)) > chunk_input_limit: | |
| # we need to chunk again or use short summarizer to condense | |
| # load a short fast summarizer to condense (DistilBART) | |
| short_tok2, short_summarizer = load_pipeline_for_model(SHORT_MODEL_ID, device) | |
| final = short_summarizer(combined, min_length=min_length, max_length=max_length, do_sample=False) | |
| return final[0]["summary_text"] | |
| else: | |
| # summarizer can condense combined | |
| final = summarizer(combined, min_length=min_length, max_length=max_length, do_sample=False) | |
| return final[0]["summary_text"] | |
| # ---------- Gradio UI ---------- | |
| def summarize_pdf(file, min_len, max_len, prefer_long): | |
| if file is None: | |
| return "No file uploaded." | |
| # file is a tempfile with .name, pass file.open() or file.name to PyPDF2 | |
| # Gradio provides a dictionary or temporary path; handle both. | |
| try: | |
| fpath = file.name | |
| with open(fpath, "rb") as fh: | |
| text = extract_text_from_pdf(fh) | |
| except Exception: | |
| # maybe file-like | |
| try: | |
| file.seek(0) | |
| text = extract_text_from_pdf(file) | |
| except Exception as e: | |
| return f"Failed to read PDF: {e}" | |
| if not text.strip(): | |
| return "No extractable text found in PDF." | |
| # Run summarization (this may take a while depending on model & GPU) | |
| summary = summarize_text_pipeline(text, prefer_long, int(min_len), int(max_len)) | |
| return summary | |
| def create_ui(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# PDF Summarizer (Gradio + Transformers)\nUpload a PDF and get an abstractive summary. Uses LED for long docs and DistilBART for faster shorter summarization.") | |
| with gr.Row(): | |
| file_in = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
| with gr.Column(): | |
| min_len = gr.Slider(10, 400, value=DEFAULT_SUMMARY_MIN_LENGTH, step=10, label="Minimum summary length (tokens approx)") | |
| max_len = gr.Slider(50, 1024, value=DEFAULT_SUMMARY_MAX_LENGTH, step=10, label="Maximum summary length (tokens approx)") | |
| prefer_long_cb = gr.Checkbox(value=True, label="Prefer long-document model for long PDFs (recommended)") | |
| run_btn = gr.Button("Summarize") | |
| output = gr.Textbox(label="Summary", lines=20) | |
| run_btn.click(fn=summarize_pdf, inputs=[file_in, min_len, max_len, prefer_long_cb], outputs=output) | |
| return demo | |
| if __name__ == "__main__": | |
| ui = create_ui() | |
| ui.launch( share=True) |