Spaces:
Sleeping
Sleeping
File size: 8,227 Bytes
e537c46 b0cf3bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
# app.py
import os
import math
from typing import List
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import PyPDF2
import gradio as gr
from tqdm.auto import tqdm
# ---------- Configuration ----------
# Models
LONG_MODEL_ID = "allenai/led-base-16384" # good for very long docs (up to ~16k tokens)
SHORT_MODEL_ID = "sshleifer/distilbart-cnn-12-6" # faster, efficient summarizer for shorter texts
# Threshold (in tokens) when we'll switch to long-document model workflow
LONG_MODEL_SWITCH_TOKENS = 2000 # you can tweak this
# Per-chunk generation defaults
DEFAULT_SUMMARY_MIN_LENGTH = 60
DEFAULT_SUMMARY_MAX_LENGTH = 256
# Whether to prefer the long model whenever possible
prefer_long_model = True
# -----------------------------------
def extract_text_from_pdf(fileobj) -> str:
"""
Extract text from an uploaded PDF file object (a file-like object).
Uses PyPDF2 (works for many PDFs).
"""
reader = PyPDF2.PdfReader(fileobj)
all_text = []
for page in reader.pages:
try:
txt = page.extract_text()
except Exception:
txt = ""
if txt:
all_text.append(txt)
return "\n\n".join(all_text)
def choose_device():
return 0 if torch.cuda.is_available() else -1
def load_pipeline_for_model(model_id: str, device: int):
"""Load tokenizer, model and summarization pipeline for given model_id."""
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
# pipeline: specify device (0 for cuda, -1 for cpu)
summarizer = pipeline(
"summarization",
model=model,
tokenizer=tokenizer,
device=device
)
return tokenizer, summarizer
def chunk_text_by_tokens(text: str, tokenizer, max_tokens: int) -> List[str]:
"""
Chunk text into list of strings where each chunk's token length <= max_tokens.
Uses tokenizer.encode to count tokens approximately.
"""
tokens = tokenizer.encode(text, return_tensors="pt")[0].tolist()
# If tokenization of whole text is too big, we'll iterate by sentences/paragraphs
# Simpler approach: split by paragraphs and aggregate until token limit.
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
chunks = []
current = []
current_len = 0
for p in paragraphs:
p_tokens = len(tokenizer.encode(p, add_special_tokens=False))
if p_tokens > max_tokens:
# paragraph itself larger than limit: split by sentences (fallback)
sentences = p.split(". ")
for s in sentences:
s = s.strip()
if not s:
continue
s_tokens = len(tokenizer.encode(s, add_special_tokens=False))
if current_len + s_tokens <= max_tokens:
current.append(s)
current_len += s_tokens
else:
if current:
chunks.append(". ".join(current))
current = [s]
current_len = s_tokens
continue
if current_len + p_tokens <= max_tokens:
current.append(p)
current_len += p_tokens
else:
if current:
chunks.append("\n\n".join(current))
current = [p]
current_len = p_tokens
if current:
chunks.append("\n\n".join(current))
return chunks
def summarize_text_pipeline(text: str,
prefer_long: bool = True,
min_length: int = DEFAULT_SUMMARY_MIN_LENGTH,
max_length: int = DEFAULT_SUMMARY_MAX_LENGTH):
"""
Main logic:
- Tokenize the entire text to estimate length.
- If very long and prefer_long: use LED model with large max input length.
- Otherwise use the faster DistilBART.
- For models with small max_input_length, chunk and summarize each chunk then
combine chunk summaries and optionally summarize the combined summary.
"""
device = choose_device()
# quick decision: load short tokenizer first to estimate tokens
# We'll use the SHORT_MODEL tokenizer for a fast token estimate (it's lightweight).
short_tok = AutoTokenizer.from_pretrained(SHORT_MODEL_ID, use_fast=True)
total_tokens = len(short_tok.encode(text, add_special_tokens=False))
use_long_model = (prefer_long and total_tokens > LONG_MODEL_SWITCH_TOKENS)
if use_long_model:
model_id = LONG_MODEL_ID
else:
model_id = SHORT_MODEL_ID
tokenizer, summarizer = load_pipeline_for_model(model_id, device)
model_max = tokenizer.model_max_length if hasattr(tokenizer, "model_max_length") else 1024
# keep some headroom
chunk_input_limit = min(model_max - 64, 16000)
# If input fits, summarize directly
if total_tokens <= chunk_input_limit:
summary = summarizer(
text,
min_length=min_length,
max_length=max_length,
do_sample=False
)
return summary[0]["summary_text"]
# Otherwise chunk + summarize each part
chunks = chunk_text_by_tokens(text, tokenizer, chunk_input_limit)
chunk_summaries = []
for chunk in tqdm(chunks, desc="Summarizing chunks"):
out = summarizer(chunk, min_length=max(20, min_length//2), max_length=max_length, do_sample=False)
chunk_summaries.append(out[0]["summary_text"])
# Combine chunk summaries
combined = "\n\n".join(chunk_summaries)
# If combined summary is still long, produce a final summary
if len(tokenizer.encode(combined, add_special_tokens=False)) > chunk_input_limit:
# we need to chunk again or use short summarizer to condense
# load a short fast summarizer to condense (DistilBART)
short_tok2, short_summarizer = load_pipeline_for_model(SHORT_MODEL_ID, device)
final = short_summarizer(combined, min_length=min_length, max_length=max_length, do_sample=False)
return final[0]["summary_text"]
else:
# summarizer can condense combined
final = summarizer(combined, min_length=min_length, max_length=max_length, do_sample=False)
return final[0]["summary_text"]
# ---------- Gradio UI ----------
def summarize_pdf(file, min_len, max_len, prefer_long):
if file is None:
return "No file uploaded."
# file is a tempfile with .name, pass file.open() or file.name to PyPDF2
# Gradio provides a dictionary or temporary path; handle both.
try:
fpath = file.name
with open(fpath, "rb") as fh:
text = extract_text_from_pdf(fh)
except Exception:
# maybe file-like
try:
file.seek(0)
text = extract_text_from_pdf(file)
except Exception as e:
return f"Failed to read PDF: {e}"
if not text.strip():
return "No extractable text found in PDF."
# Run summarization (this may take a while depending on model & GPU)
summary = summarize_text_pipeline(text, prefer_long, int(min_len), int(max_len))
return summary
def create_ui():
with gr.Blocks() as demo:
gr.Markdown("# PDF Summarizer (Gradio + Transformers)\nUpload a PDF and get an abstractive summary. Uses LED for long docs and DistilBART for faster shorter summarization.")
with gr.Row():
file_in = gr.File(label="Upload PDF", file_types=[".pdf"])
with gr.Column():
min_len = gr.Slider(10, 400, value=DEFAULT_SUMMARY_MIN_LENGTH, step=10, label="Minimum summary length (tokens approx)")
max_len = gr.Slider(50, 1024, value=DEFAULT_SUMMARY_MAX_LENGTH, step=10, label="Maximum summary length (tokens approx)")
prefer_long_cb = gr.Checkbox(value=True, label="Prefer long-document model for long PDFs (recommended)")
run_btn = gr.Button("Summarize")
output = gr.Textbox(label="Summary", lines=20)
run_btn.click(fn=summarize_pdf, inputs=[file_in, min_len, max_len, prefer_long_cb], outputs=output)
return demo
if __name__ == "__main__":
ui = create_ui()
ui.launch( share=True) |