Text-Summarizer / app.py
mbalvi's picture
Update app.py
60b19a4 verified
"""
Simple Gradio app for text summarization.
- Supports two pretrained summarization models by Hugging Face.
- Performs chunking for long inputs and then optionally performs a final
summary pass over chunk summaries to improve coherence.
- Suitable for local use or deployment on Hugging Face Spaces.
Save as: app.py
"""
from functools import lru_cache
from typing import List
import math
import time
import gradio as gr
from transformers import pipeline, Pipeline
import torch
# -------------------------
# Utilities
# -------------------------
def has_gpu():
try:
return torch.cuda.is_available()
except Exception:
return False
def chunk_text(text: str, max_chars: int = 1000, overlap: int = 200) -> List[str]:
"""
Split text into chunks of roughly max_chars with given overlap.
Splits at whitespace boundaries for nicer chunks.
"""
if len(text) <= max_chars:
return [text.strip()]
chunks = []
start = 0
n = len(text)
while start < n:
end = start + max_chars
if end >= n:
chunk = text[start:n].strip()
if chunk:
chunks.append(chunk)
break
# try to back up to nearest space for cleaner boundary
backup = text.rfind(" ", start, end)
if backup <= start:
backup = end # no space found, hard cut
chunk = text[start:backup].strip()
if chunk:
chunks.append(chunk)
# move start forward with overlap
start = backup - overlap
if start < 0:
start = 0
return chunks
# -------------------------
# Model loading (cached)
# -------------------------
@lru_cache(maxsize=4)
def get_summarizer(model_name: str) -> Pipeline:
device = 0 if has_gpu() else -1
# Create pipeline
summarizer = pipeline("summarization", model=model_name, device=device)
return summarizer
# -------------------------
# Summarization logic
# -------------------------
def summarize_text(
text: str,
model_name: str = "facebook/bart-large-cnn",
min_length: int = 30,
max_length: int = 200,
chunk_max_chars: int = 1000,
do_final_pass: bool = True,
):
if not text or not text.strip():
return "No input text provided."
summarizer = get_summarizer(model_name)
# Chunk text if long
chunks = chunk_text(text, max_chars=chunk_max_chars, overlap=200)
chunk_summaries = []
for i, ch in enumerate(chunks, start=1):
# Each chunk summarized individually
# We pass conservative lengths proportional to chunk size
proportion = min(1.0, len(ch) / chunk_max_chars)
min_l = max(5, int(min_length * proportion))
max_l = max(20, int(max_length * proportion))
try:
res = summarizer(
ch,
min_length=min_l,
max_length=max_l,
truncation=True,
)
summary_text = res[0]["summary_text"].strip()
except Exception as e:
summary_text = f"[Error summarizing chunk {i}: {str(e)}]"
chunk_summaries.append(summary_text)
# If multiple chunk summaries, optionally summarize them together again
if do_final_pass and len(chunk_summaries) > 1:
joined = " ".join(chunk_summaries)
try:
final = summarizer(
joined,
min_length=max(20, min_length // 2),
max_length=max_length,
truncation=True,
)
final_summary = final[0]["summary_text"].strip()
except Exception as e:
final_summary = " ".join(chunk_summaries)
final_summary += f"\n\n[Final-pass error: {str(e)}]"
return final_summary
else:
# if single chunk or not doing final pass, return joined chunk summaries
return "\n\n".join(chunk_summaries)
# -------------------------
# Gradio UI
# -------------------------
model_choices = [
("Facebook BART (cnn) — good general summarizer", "facebook/bart-large-cnn"),
("DistilBART (faster) — light-weight", "sshleifer/distilbart-cnn-12-6"),
]
examples = [
["In 2023, the world of AI advanced rapidly. Companies released larger and more capable language models, while researchers focused on safety, alignment, and practical applications. Governments started to craft regulations for responsible deployment. Meanwhile, startups found new ways to apply summarization, code generation, and retrieval-augmented systems. The long-term effects of these developments remain to be seen, but short-term productivity gains were highly visible across many industries."],
["Machine learning models require careful tuning of hyperparameters. Learning rate, batch size, and optimizer choice can dramatically affect convergence and final performance. Regularization techniques such as dropout, weight decay, and data augmentation improve generalization. Practitioners routinely combine validation curves and cross-validation to find the best configuration."],
]
with gr.Blocks(title="Text Summarizer (Hugging Face)") as demo:
gr.Markdown("# 🧾 Text Summarizer\nSimple Gradio app using Hugging Face summarization pipelines.\n\nEnter text on the left and press **Summarize**.")
with gr.Row():
with gr.Column(scale=2):
inp = gr.Textbox(lines=12, label="Input Text", placeholder="Paste article, long text, or notes here...", value=examples[0][0])
model = gr.Dropdown([m[0] for m in model_choices], label="Model", value=model_choices[0][0])
min_len = gr.Slider(5, 200, value=30, step=1, label="Min summary length (tokens / words approx.)")
max_len = gr.Slider(20, 600, value=150, step=1, label="Max summary length (tokens / words approx.)")
chunk_size = gr.Slider(500, 4000, value=1000, step=100, label="Chunk size (characters) — for long texts")
final_pass = gr.Checkbox(value=True, label="Do final-pass summarization (recommended for long inputs)")
btn = gr.Button("Summarize")
with gr.Column(scale=1):
out = gr.Textbox(lines=12, label="Summary")
gr.Markdown("### Examples")
ex = gr.Examples(examples=examples, inputs=inp, examples_per_page=6)
def _wrap_and_run(text, selected_model_label, min_length, max_length, chunk_size, do_final):
# map label to model name
model_map = {m[0]: m[1] for m in model_choices}
model_name = model_map.get(selected_model_label, model_choices[0][1])
start = time.time()
result = summarize_text(
text=text,
model_name=model_name,
min_length=min_length,
max_length=max_length,
chunk_max_chars=chunk_size,
do_final_pass=do_final,
)
took = time.time() - start
footer = f"\n\n---\nModel: {model_name} — Time: {took:.1f}s"
return result + footer
btn.click(
_wrap_and_run,
inputs=[inp, model, min_len, max_len, chunk_size, final_pass],
outputs=[out],
)
if __name__ == "__main__":
# Launch locally on port 7860
demo.launch(server_name="0.0.0.0", server_port=7860, share=False)