Spaces:
Sleeping
Sleeping
File size: 7,245 Bytes
4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 4012cea 60b19a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
"""
Simple Gradio app for text summarization.
- Supports two pretrained summarization models by Hugging Face.
- Performs chunking for long inputs and then optionally performs a final
summary pass over chunk summaries to improve coherence.
- Suitable for local use or deployment on Hugging Face Spaces.
Save as: app.py
"""
from functools import lru_cache
from typing import List
import math
import time
import gradio as gr
from transformers import pipeline, Pipeline
import torch
# -------------------------
# Utilities
# -------------------------
def has_gpu():
try:
return torch.cuda.is_available()
except Exception:
return False
def chunk_text(text: str, max_chars: int = 1000, overlap: int = 200) -> List[str]:
"""
Split text into chunks of roughly max_chars with given overlap.
Splits at whitespace boundaries for nicer chunks.
"""
if len(text) <= max_chars:
return [text.strip()]
chunks = []
start = 0
n = len(text)
while start < n:
end = start + max_chars
if end >= n:
chunk = text[start:n].strip()
if chunk:
chunks.append(chunk)
break
# try to back up to nearest space for cleaner boundary
backup = text.rfind(" ", start, end)
if backup <= start:
backup = end # no space found, hard cut
chunk = text[start:backup].strip()
if chunk:
chunks.append(chunk)
# move start forward with overlap
start = backup - overlap
if start < 0:
start = 0
return chunks
# -------------------------
# Model loading (cached)
# -------------------------
@lru_cache(maxsize=4)
def get_summarizer(model_name: str) -> Pipeline:
device = 0 if has_gpu() else -1
# Create pipeline
summarizer = pipeline("summarization", model=model_name, device=device)
return summarizer
# -------------------------
# Summarization logic
# -------------------------
def summarize_text(
text: str,
model_name: str = "facebook/bart-large-cnn",
min_length: int = 30,
max_length: int = 200,
chunk_max_chars: int = 1000,
do_final_pass: bool = True,
):
if not text or not text.strip():
return "No input text provided."
summarizer = get_summarizer(model_name)
# Chunk text if long
chunks = chunk_text(text, max_chars=chunk_max_chars, overlap=200)
chunk_summaries = []
for i, ch in enumerate(chunks, start=1):
# Each chunk summarized individually
# We pass conservative lengths proportional to chunk size
proportion = min(1.0, len(ch) / chunk_max_chars)
min_l = max(5, int(min_length * proportion))
max_l = max(20, int(max_length * proportion))
try:
res = summarizer(
ch,
min_length=min_l,
max_length=max_l,
truncation=True,
)
summary_text = res[0]["summary_text"].strip()
except Exception as e:
summary_text = f"[Error summarizing chunk {i}: {str(e)}]"
chunk_summaries.append(summary_text)
# If multiple chunk summaries, optionally summarize them together again
if do_final_pass and len(chunk_summaries) > 1:
joined = " ".join(chunk_summaries)
try:
final = summarizer(
joined,
min_length=max(20, min_length // 2),
max_length=max_length,
truncation=True,
)
final_summary = final[0]["summary_text"].strip()
except Exception as e:
final_summary = " ".join(chunk_summaries)
final_summary += f"\n\n[Final-pass error: {str(e)}]"
return final_summary
else:
# if single chunk or not doing final pass, return joined chunk summaries
return "\n\n".join(chunk_summaries)
# -------------------------
# Gradio UI
# -------------------------
model_choices = [
("Facebook BART (cnn) — good general summarizer", "facebook/bart-large-cnn"),
("DistilBART (faster) — light-weight", "sshleifer/distilbart-cnn-12-6"),
]
examples = [
["In 2023, the world of AI advanced rapidly. Companies released larger and more capable language models, while researchers focused on safety, alignment, and practical applications. Governments started to craft regulations for responsible deployment. Meanwhile, startups found new ways to apply summarization, code generation, and retrieval-augmented systems. The long-term effects of these developments remain to be seen, but short-term productivity gains were highly visible across many industries."],
["Machine learning models require careful tuning of hyperparameters. Learning rate, batch size, and optimizer choice can dramatically affect convergence and final performance. Regularization techniques such as dropout, weight decay, and data augmentation improve generalization. Practitioners routinely combine validation curves and cross-validation to find the best configuration."],
]
with gr.Blocks(title="Text Summarizer (Hugging Face)") as demo:
gr.Markdown("# 🧾 Text Summarizer\nSimple Gradio app using Hugging Face summarization pipelines.\n\nEnter text on the left and press **Summarize**.")
with gr.Row():
with gr.Column(scale=2):
inp = gr.Textbox(lines=12, label="Input Text", placeholder="Paste article, long text, or notes here...", value=examples[0][0])
model = gr.Dropdown([m[0] for m in model_choices], label="Model", value=model_choices[0][0])
min_len = gr.Slider(5, 200, value=30, step=1, label="Min summary length (tokens / words approx.)")
max_len = gr.Slider(20, 600, value=150, step=1, label="Max summary length (tokens / words approx.)")
chunk_size = gr.Slider(500, 4000, value=1000, step=100, label="Chunk size (characters) — for long texts")
final_pass = gr.Checkbox(value=True, label="Do final-pass summarization (recommended for long inputs)")
btn = gr.Button("Summarize")
with gr.Column(scale=1):
out = gr.Textbox(lines=12, label="Summary")
gr.Markdown("### Examples")
ex = gr.Examples(examples=examples, inputs=inp, examples_per_page=6)
def _wrap_and_run(text, selected_model_label, min_length, max_length, chunk_size, do_final):
# map label to model name
model_map = {m[0]: m[1] for m in model_choices}
model_name = model_map.get(selected_model_label, model_choices[0][1])
start = time.time()
result = summarize_text(
text=text,
model_name=model_name,
min_length=min_length,
max_length=max_length,
chunk_max_chars=chunk_size,
do_final_pass=do_final,
)
took = time.time() - start
footer = f"\n\n---\nModel: {model_name} — Time: {took:.1f}s"
return result + footer
btn.click(
_wrap_and_run,
inputs=[inp, model, min_len, max_len, chunk_size, final_pass],
outputs=[out],
)
if __name__ == "__main__":
# Launch locally on port 7860
demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |