File size: 7,245 Bytes
4012cea
60b19a4
4012cea
60b19a4
 
 
 
4012cea
60b19a4
4012cea
 
60b19a4
 
 
 
4012cea
60b19a4
 
 
4012cea
60b19a4
 
 
 
 
 
 
 
4012cea
 
 
60b19a4
 
4012cea
 
60b19a4
4012cea
 
60b19a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4012cea
 
60b19a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4012cea
 
 
60b19a4
 
4012cea
60b19a4
 
 
 
 
 
 
 
4012cea
60b19a4
 
 
 
 
 
 
4012cea
60b19a4
 
4012cea
60b19a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4012cea
 
60b19a4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""
Simple Gradio app for text summarization.

- Supports two pretrained summarization models by Hugging Face.
- Performs chunking for long inputs and then optionally performs a final
  summary pass over chunk summaries to improve coherence.
- Suitable for local use or deployment on Hugging Face Spaces.

Save as: app.py
"""

from functools import lru_cache
from typing import List
import math
import time

import gradio as gr
from transformers import pipeline, Pipeline
import torch

# -------------------------
# Utilities
# -------------------------
def has_gpu():
    try:
        return torch.cuda.is_available()
    except Exception:
        return False

def chunk_text(text: str, max_chars: int = 1000, overlap: int = 200) -> List[str]:
    """
    Split text into chunks of roughly max_chars with given overlap.
    Splits at whitespace boundaries for nicer chunks.
    """
    if len(text) <= max_chars:
        return [text.strip()]

    chunks = []
    start = 0
    n = len(text)
    while start < n:
        end = start + max_chars
        if end >= n:
            chunk = text[start:n].strip()
            if chunk:
                chunks.append(chunk)
            break

        # try to back up to nearest space for cleaner boundary
        backup = text.rfind(" ", start, end)
        if backup <= start:
            backup = end  # no space found, hard cut
        chunk = text[start:backup].strip()
        if chunk:
            chunks.append(chunk)
        # move start forward with overlap
        start = backup - overlap
        if start < 0:
            start = 0

    return chunks

# -------------------------
# Model loading (cached)
# -------------------------
@lru_cache(maxsize=4)
def get_summarizer(model_name: str) -> Pipeline:
    device = 0 if has_gpu() else -1
    # Create pipeline
    summarizer = pipeline("summarization", model=model_name, device=device)
    return summarizer

# -------------------------
# Summarization logic
# -------------------------
def summarize_text(
    text: str,
    model_name: str = "facebook/bart-large-cnn",
    min_length: int = 30,
    max_length: int = 200,
    chunk_max_chars: int = 1000,
    do_final_pass: bool = True,
):
    if not text or not text.strip():
        return "No input text provided."

    summarizer = get_summarizer(model_name)

    # Chunk text if long
    chunks = chunk_text(text, max_chars=chunk_max_chars, overlap=200)
    chunk_summaries = []

    for i, ch in enumerate(chunks, start=1):
        # Each chunk summarized individually
        # We pass conservative lengths proportional to chunk size
        proportion = min(1.0, len(ch) / chunk_max_chars)
        min_l = max(5, int(min_length * proportion))
        max_l = max(20, int(max_length * proportion))

        try:
            res = summarizer(
                ch,
                min_length=min_l,
                max_length=max_l,
                truncation=True,
            )
            summary_text = res[0]["summary_text"].strip()
        except Exception as e:
            summary_text = f"[Error summarizing chunk {i}: {str(e)}]"
        chunk_summaries.append(summary_text)

    # If multiple chunk summaries, optionally summarize them together again
    if do_final_pass and len(chunk_summaries) > 1:
        joined = " ".join(chunk_summaries)
        try:
            final = summarizer(
                joined,
                min_length=max(20, min_length // 2),
                max_length=max_length,
                truncation=True,
            )
            final_summary = final[0]["summary_text"].strip()
        except Exception as e:
            final_summary = " ".join(chunk_summaries)
            final_summary += f"\n\n[Final-pass error: {str(e)}]"
        return final_summary
    else:
        # if single chunk or not doing final pass, return joined chunk summaries
        return "\n\n".join(chunk_summaries)


# -------------------------
# Gradio UI
# -------------------------
model_choices = [
    ("Facebook BART (cnn) — good general summarizer", "facebook/bart-large-cnn"),
    ("DistilBART (faster) — light-weight", "sshleifer/distilbart-cnn-12-6"),
]

examples = [
    ["In 2023, the world of AI advanced rapidly. Companies released larger and more capable language models, while researchers focused on safety, alignment, and practical applications. Governments started to craft regulations for responsible deployment. Meanwhile, startups found new ways to apply summarization, code generation, and retrieval-augmented systems. The long-term effects of these developments remain to be seen, but short-term productivity gains were highly visible across many industries."],
    ["Machine learning models require careful tuning of hyperparameters. Learning rate, batch size, and optimizer choice can dramatically affect convergence and final performance. Regularization techniques such as dropout, weight decay, and data augmentation improve generalization. Practitioners routinely combine validation curves and cross-validation to find the best configuration."],
]

with gr.Blocks(title="Text Summarizer (Hugging Face)") as demo:
    gr.Markdown("# 🧾 Text Summarizer\nSimple Gradio app using Hugging Face summarization pipelines.\n\nEnter text on the left and press **Summarize**.")
    with gr.Row():
        with gr.Column(scale=2):
            inp = gr.Textbox(lines=12, label="Input Text", placeholder="Paste article, long text, or notes here...", value=examples[0][0])
            model = gr.Dropdown([m[0] for m in model_choices], label="Model", value=model_choices[0][0])
            min_len = gr.Slider(5, 200, value=30, step=1, label="Min summary length (tokens / words approx.)")
            max_len = gr.Slider(20, 600, value=150, step=1, label="Max summary length (tokens / words approx.)")
            chunk_size = gr.Slider(500, 4000, value=1000, step=100, label="Chunk size (characters) — for long texts")
            final_pass = gr.Checkbox(value=True, label="Do final-pass summarization (recommended for long inputs)")
            btn = gr.Button("Summarize")
        with gr.Column(scale=1):
            out = gr.Textbox(lines=12, label="Summary")
            gr.Markdown("### Examples")
            ex = gr.Examples(examples=examples, inputs=inp, examples_per_page=6)

    def _wrap_and_run(text, selected_model_label, min_length, max_length, chunk_size, do_final):
        # map label to model name
        model_map = {m[0]: m[1] for m in model_choices}
        model_name = model_map.get(selected_model_label, model_choices[0][1])
        start = time.time()
        result = summarize_text(
            text=text,
            model_name=model_name,
            min_length=min_length,
            max_length=max_length,
            chunk_max_chars=chunk_size,
            do_final_pass=do_final,
        )
        took = time.time() - start
        footer = f"\n\n---\nModel: {model_name} — Time: {took:.1f}s"
        return result + footer

    btn.click(
        _wrap_and_run,
        inputs=[inp, model, min_len, max_len, chunk_size, final_pass],
        outputs=[out],
    )

if __name__ == "__main__":
    # Launch locally on port 7860
    demo.launch(server_name="0.0.0.0", server_port=7860, share=False)