File size: 8,227 Bytes
e537c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0cf3bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# app.py
import os
import math
from typing import List
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import PyPDF2
import gradio as gr
from tqdm.auto import tqdm

# ---------- Configuration ----------
# Models
LONG_MODEL_ID = "allenai/led-base-16384"           # good for very long docs (up to ~16k tokens)
SHORT_MODEL_ID = "sshleifer/distilbart-cnn-12-6"  # faster, efficient summarizer for shorter texts

# Threshold (in tokens) when we'll switch to long-document model workflow
LONG_MODEL_SWITCH_TOKENS = 2000  # you can tweak this

# Per-chunk generation defaults
DEFAULT_SUMMARY_MIN_LENGTH = 60
DEFAULT_SUMMARY_MAX_LENGTH = 256

# Whether to prefer the long model whenever possible
prefer_long_model = True
# -----------------------------------

def extract_text_from_pdf(fileobj) -> str:
    """
    Extract text from an uploaded PDF file object (a file-like object).
    Uses PyPDF2 (works for many PDFs).
    """
    reader = PyPDF2.PdfReader(fileobj)
    all_text = []
    for page in reader.pages:
        try:
            txt = page.extract_text()
        except Exception:
            txt = ""
        if txt:
            all_text.append(txt)
    return "\n\n".join(all_text)

def choose_device():
    return 0 if torch.cuda.is_available() else -1

def load_pipeline_for_model(model_id: str, device: int):
    """Load tokenizer, model and summarization pipeline for given model_id."""
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
    # pipeline: specify device (0 for cuda, -1 for cpu)
    summarizer = pipeline(
        "summarization",
        model=model,
        tokenizer=tokenizer,
        device=device
    )
    return tokenizer, summarizer

def chunk_text_by_tokens(text: str, tokenizer, max_tokens: int) -> List[str]:
    """
    Chunk text into list of strings where each chunk's token length <= max_tokens.
    Uses tokenizer.encode to count tokens approximately.
    """
    tokens = tokenizer.encode(text, return_tensors="pt")[0].tolist()
    # If tokenization of whole text is too big, we'll iterate by sentences/paragraphs
    # Simpler approach: split by paragraphs and aggregate until token limit.
    paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
    chunks = []
    current = []
    current_len = 0
    for p in paragraphs:
        p_tokens = len(tokenizer.encode(p, add_special_tokens=False))
        if p_tokens > max_tokens:
            # paragraph itself larger than limit: split by sentences (fallback)
            sentences = p.split(". ")
            for s in sentences:
                s = s.strip()
                if not s:
                    continue
                s_tokens = len(tokenizer.encode(s, add_special_tokens=False))
                if current_len + s_tokens <= max_tokens:
                    current.append(s)
                    current_len += s_tokens
                else:
                    if current:
                        chunks.append(". ".join(current))
                    current = [s]
                    current_len = s_tokens
            continue
        if current_len + p_tokens <= max_tokens:
            current.append(p)
            current_len += p_tokens
        else:
            if current:
                chunks.append("\n\n".join(current))
            current = [p]
            current_len = p_tokens
    if current:
        chunks.append("\n\n".join(current))
    return chunks

def summarize_text_pipeline(text: str,
                            prefer_long: bool = True,
                            min_length: int = DEFAULT_SUMMARY_MIN_LENGTH,
                            max_length: int = DEFAULT_SUMMARY_MAX_LENGTH):
    """
    Main logic:
    - Tokenize the entire text to estimate length.
    - If very long and prefer_long: use LED model with large max input length.
    - Otherwise use the faster DistilBART.
    - For models with small max_input_length, chunk and summarize each chunk then
      combine chunk summaries and optionally summarize the combined summary.
    """
    device = choose_device()
    # quick decision: load short tokenizer first to estimate tokens
    # We'll use the SHORT_MODEL tokenizer for a fast token estimate (it's lightweight).
    short_tok = AutoTokenizer.from_pretrained(SHORT_MODEL_ID, use_fast=True)
    total_tokens = len(short_tok.encode(text, add_special_tokens=False))
    use_long_model = (prefer_long and total_tokens > LONG_MODEL_SWITCH_TOKENS)

    if use_long_model:
        model_id = LONG_MODEL_ID
    else:
        model_id = SHORT_MODEL_ID

    tokenizer, summarizer = load_pipeline_for_model(model_id, device)

    model_max = tokenizer.model_max_length if hasattr(tokenizer, "model_max_length") else 1024
    # keep some headroom
    chunk_input_limit = min(model_max - 64, 16000)

    # If input fits, summarize directly
    if total_tokens <= chunk_input_limit:
        summary = summarizer(
            text,
            min_length=min_length,
            max_length=max_length,
            do_sample=False
        )
        return summary[0]["summary_text"]
    # Otherwise chunk + summarize each part
    chunks = chunk_text_by_tokens(text, tokenizer, chunk_input_limit)
    chunk_summaries = []
    for chunk in tqdm(chunks, desc="Summarizing chunks"):
        out = summarizer(chunk, min_length=max(20, min_length//2), max_length=max_length, do_sample=False)
        chunk_summaries.append(out[0]["summary_text"])

    # Combine chunk summaries
    combined = "\n\n".join(chunk_summaries)

    # If combined summary is still long, produce a final summary
    if len(tokenizer.encode(combined, add_special_tokens=False)) > chunk_input_limit:
        # we need to chunk again or use short summarizer to condense
        # load a short fast summarizer to condense (DistilBART)
        short_tok2, short_summarizer = load_pipeline_for_model(SHORT_MODEL_ID, device)
        final = short_summarizer(combined, min_length=min_length, max_length=max_length, do_sample=False)
        return final[0]["summary_text"]
    else:
        # summarizer can condense combined
        final = summarizer(combined, min_length=min_length, max_length=max_length, do_sample=False)
        return final[0]["summary_text"]

# ---------- Gradio UI ----------
def summarize_pdf(file, min_len, max_len, prefer_long):
    if file is None:
        return "No file uploaded."
    # file is a tempfile with .name, pass file.open() or file.name to PyPDF2
    # Gradio provides a dictionary or temporary path; handle both.
    try:
        fpath = file.name
        with open(fpath, "rb") as fh:
            text = extract_text_from_pdf(fh)
    except Exception:
        # maybe file-like
        try:
            file.seek(0)
            text = extract_text_from_pdf(file)
        except Exception as e:
            return f"Failed to read PDF: {e}"

    if not text.strip():
        return "No extractable text found in PDF."

    # Run summarization (this may take a while depending on model & GPU)
    summary = summarize_text_pipeline(text, prefer_long, int(min_len), int(max_len))
    return summary

def create_ui():
    with gr.Blocks() as demo:
        gr.Markdown("# PDF Summarizer (Gradio + Transformers)\nUpload a PDF and get an abstractive summary. Uses LED for long docs and DistilBART for faster shorter summarization.")
        with gr.Row():
            file_in = gr.File(label="Upload PDF", file_types=[".pdf"])
            with gr.Column():
                min_len = gr.Slider(10, 400, value=DEFAULT_SUMMARY_MIN_LENGTH, step=10, label="Minimum summary length (tokens approx)")
                max_len = gr.Slider(50, 1024, value=DEFAULT_SUMMARY_MAX_LENGTH, step=10, label="Maximum summary length (tokens approx)")
                prefer_long_cb = gr.Checkbox(value=True, label="Prefer long-document model for long PDFs (recommended)")
                run_btn = gr.Button("Summarize")
        output = gr.Textbox(label="Summary", lines=20)
        run_btn.click(fn=summarize_pdf, inputs=[file_in, min_len, max_len, prefer_long_cb], outputs=output)
    return demo

if __name__ == "__main__":
    ui = create_ui()
    ui.launch( share=True)