textSum / main.py
Azidan's picture
Create main.py
487a5d4 verified
raw
history blame
9.55 kB
import io
import math
from typing import List, Tuple, Optional
import gradio as gr
from transformers import AutoTokenizer, pipeline
import PyPDF2
import docx
# -----------------------------
# Configuration
# -----------------------------
MODEL_NAME = "sshleifer/distilbart-cnn-12-6" # lightweight, works on free tier
DEVICE = -1 # force CPU (Spaces free tier)
CHUNK_STRIDE = 128 # overlap tokens between chunks (keeps context)
SECOND_PASS = True # run final summarization on joined chunk summaries
# Summary length presets (max tokens in generated summary)
SUMMARY_PRESETS = {
"short": {"max_length": 60, "min_length": 20},
"medium": {"max_length": 120, "min_length": 40},
"long": {"max_length": 200, "min_length": 80},
}
# -----------------------------
# Load tokenizer & pipeline
# -----------------------------
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
summarizer = pipeline("summarization", model=MODEL_NAME, tokenizer=tokenizer, device=DEVICE)
# -----------------------------
# Helpers: file reading
# -----------------------------
def read_pdf_bytes(file_bytes: bytes) -> str:
try:
reader = PyPDF2.PdfReader(io.BytesIO(file_bytes))
pages = []
for p in reader.pages:
text = p.extract_text()
if text:
pages.append(text)
return "\n".join(pages)
except Exception as e:
return f"[Error reading PDF: {e}]"
def read_docx_bytes(file_bytes: bytes) -> str:
try:
doc = docx.Document(io.BytesIO(file_bytes))
paragraphs = [p.text for p in doc.paragraphs if p.text and p.text.strip()]
return "\n".join(paragraphs)
except Exception as e:
return f"[Error reading DOCX: {e}]"
# -----------------------------
# Helpers: token-aware chunking
# -----------------------------
def chunk_text_by_tokens(text: str, max_tokens: Optional[int] = None, stride: int = CHUNK_STRIDE) -> List[str]:
"""
Split text into chunks no longer than `max_tokens` tokens each.
Use overlap `stride` to preserve context between chunks.
Returns list of chunk strings (decoded).
"""
if not text or not text.strip():
return []
if max_tokens is None:
max_tokens = tokenizer.model_max_length # typically 1024 for this model
# encode without special tokens to control slicing precisely
token_ids = tokenizer.encode(text, add_special_tokens=False)
n = len(token_ids)
if n <= max_tokens:
return [text.strip()]
chunks = []
start = 0
while start < n:
end = min(start + max_tokens, n)
chunk_ids = token_ids[start:end]
chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
chunks.append(chunk_text.strip())
if end == n:
break
start = end - stride # overlap
return chunks
# -----------------------------
# Summarization logic
# -----------------------------
def summarize_chunks(chunks: List[str], preset: str, progress: Optional[gr.Progress] = None) -> Tuple[List[str], str]:
"""
Summarize each chunk and return (list_of_chunk_summaries, final_summary).
If SECOND_PASS is True and >1 chunk, perform a second summarization of the concatenated chunk summaries.
"""
if preset not in SUMMARY_PRESETS:
preset = "medium"
max_len = SUMMARY_PRESETS[preset]["max_length"]
min_len = SUMMARY_PRESETS[preset]["min_length"]
chunk_summaries = []
total = len(chunks)
for idx, chunk in enumerate(chunks, start=1):
# call summarizer safely (each chunk within token limit)
try:
out = summarizer(
chunk,
max_length=max_len,
min_length=min_len,
do_sample=False,
truncation=True
)
summary_text = out[0]["summary_text"].strip()
except Exception as e:
summary_text = f"[Chunk summarization error: {e}]"
chunk_summaries.append(summary_text)
if progress:
progress((idx / total) * 0.7, desc=f"Summarizing chunk {idx}/{total}...")
# Second pass: summarize combined chunk summaries to produce final summary
final_summary = ""
if SECOND_PASS and len(chunk_summaries) > 1:
joined = "\n\n".join(chunk_summaries)
# ensure joined fits token limit for model input by chunking again if needed
joined_chunks = chunk_text_by_tokens(joined, max_tokens=tokenizer.model_max_length, stride=CHUNK_STRIDE)
try:
# if single joined chunk, summarize directly; otherwise summarize the joined chunks sequentially then join and summarize once more
if len(joined_chunks) == 1:
out = summarizer(
joined_chunks[0],
max_length=max_len,
min_length=min_len,
do_sample=False,
truncation=True
)
final_summary = out[0]["summary_text"].strip()
else:
# reduce: summarize each joined_chunk into short pieces, then join and summarize final
intermediate = []
for jc in joined_chunks:
out = summarizer(jc, max_length=max_len, min_length=min_len, do_sample=False, truncation=True)
intermediate.append(out[0]["summary_text"].strip())
# final compression
final_text = "\n\n".join(intermediate)
out = summarizer(final_text, max_length=max_len, min_length=min_len, do_sample=False, truncation=True)
final_summary = out[0]["summary_text"].strip()
except Exception as e:
final_summary = f"[Final summarization error: {e}]"
else:
# if only one chunk or second pass disabled, final = join of chunk_summaries or the first chunk summary
final_summary = "\n\n".join(chunk_summaries) if len(chunk_summaries) > 1 else (chunk_summaries[0] if chunk_summaries else "")
if progress:
progress(1.0, desc="Done")
return chunk_summaries, final_summary
# -----------------------------
# Gradio processing function
# -----------------------------
def process(text_input: str, uploaded_file, preset: str, show_intermediate: bool, progress=gr.Progress()):
progress(0.0, desc="Extracting text...")
# Extract text
extracted = ""
if uploaded_file is not None:
try:
file_bytes = uploaded_file.read()
fname = uploaded_file.name.lower()
if fname.endswith(".pdf"):
extracted = read_pdf_bytes(file_bytes)
elif fname.endswith(".docx"):
extracted = read_docx_bytes(file_bytes)
else:
# fallback: try to decode as text
try:
extracted = file_bytes.decode("utf-8", errors="replace")
except Exception:
extracted = "[Unsupported file type]"
except Exception as e:
return f"[File read error: {e}]", "", ""
# combine pasted text with file text (file first)
if text_input and text_input.strip():
combined = (extracted + "\n\n" + text_input.strip()).strip()
else:
combined = extracted.strip()
if not combined:
return "No text found. Paste text or upload a PDF/DOCX file.", "", ""
# chunk text by tokens
progress(0.05, desc="Splitting into chunks...")
max_tokens = tokenizer.model_max_length # model input limit
chunks = chunk_text_by_tokens(combined, max_tokens=max_tokens, stride=CHUNK_STRIDE)
# safety: if still empty
if not chunks:
return "No text extracted from the file or input.", "", ""
# Summarize chunks (progress updates included)
chunk_summaries, final_summary = summarize_chunks(chunks, preset, progress=progress)
# Prepare intermediate summary output
intermediate_md_lines = []
for i, s in enumerate(chunk_summaries, start=1):
intermediate_md_lines.append(f"### Chunk {i} Summary\n\n{s}\n")
intermediate_md = "\n".join(intermediate_md_lines)
stats = f"Input tokens (approx): {sum(len(tokenizer.encode(c, add_special_tokens=False)) for c in chunks)} | Chunks: {len(chunks)}"
if show_intermediate:
return final_summary, intermediate_md, stats
else:
return final_summary, "", stats
# -----------------------------
# Gradio UI
# -----------------------------
demo = gr.Interface(
fn=process,
inputs=[
gr.Textbox(lines=12, placeholder="Paste text here (optional)...", label="Paste text (optional)"),
gr.File(label="Upload PDF or DOCX (optional)"),
gr.Radio(choices=["short", "medium", "long"], value="medium", label="Summary length (preset)"),
gr.Checkbox(value=False, label="Show intermediate chunk summaries")
],
outputs=[
gr.Textbox(label="Final Summary"),
gr.Markdown(label="Intermediate Chunk Summaries (if enabled)"),
gr.Textbox(label="Stats")
],
title="Hierarchical Long-Text Summarizer (token-aware, free-tier)",
description=(
"Paste text or upload a PDF/DOCX. The system splits long input by tokens, summarizes each chunk,"
" then optionally performs a 2nd-pass summarization to produce a concise final summary."
),
allow_flagging="never",
examples=[],
)
if __name__ == "__main__":
# on Spaces this will be ignored and Gradio will serve automatically
demo.launch()