mbalvi commited on
Commit
60b19a4
·
verified ·
1 Parent(s): 7fc6c2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -149
app.py CHANGED
@@ -1,174 +1,186 @@
1
  """
2
- Simple Hugging Face Text Summarizer (Flask)
3
 
4
- Endpoints:
5
- - GET / -> basic info
6
- - POST /summarize -> JSON input: {"text": "...", "model": "...", "max_length": 130, "min_length": 30}
 
7
 
8
- How to run:
9
- 1. pip install -r requirements.txt
10
- 2. python app.py
11
- 3. POST JSON to http://127.0.0.1:8000/summarize
12
-
13
- Notes:
14
- - Default model: "facebook/bart-large-cnn". You may change to any summarization-capable HF model.
15
- - For very long texts, the app chunks the text and summarizes each chunk, then summarizes the concatenated chunk-summaries.
16
  """
17
- from flask import Flask, request, jsonify
18
- from transformers import pipeline, Pipeline
19
- from typing import List, Optional
20
- import threading
21
- import math
22
- import textwrap
23
- import os
24
 
25
- app = Flask(__name__)
26
-
27
- # Default model (good general-purpose summarizer)
28
- DEFAULT_MODEL = os.getenv("SUMMARIZER_MODEL", "facebook/bart-large-cnn")
29
 
30
- # Global pipeline cache to avoid reloading between requests
31
- _PIPELINES = {}
32
- _PIPELINES_LOCK = threading.Lock()
33
 
34
- def get_summarizer(model_name: str = DEFAULT_MODEL) -> Pipeline:
35
- """
36
- Return a cached summarization pipeline for model_name or create one.
37
- """
38
- with _PIPELINES_LOCK:
39
- if model_name not in _PIPELINES:
40
- # Create pipeline (use default device; if you have GPU and torch detects it, it'll use it)
41
- _PIPELINES[model_name] = pipeline("summarization", model=model_name)
42
- return _PIPELINES[model_name]
43
 
44
  def chunk_text(text: str, max_chars: int = 1000, overlap: int = 200) -> List[str]:
45
  """
46
- Chunk text into pieces of at most max_chars (approx) with specified overlap.
47
- This is a simple, robust chunker using whitespace boundaries.
48
  """
49
  if len(text) <= max_chars:
50
- return [text]
51
 
52
- words = text.split()
53
  chunks = []
54
- current = []
55
- current_len = 0
56
- i = 0
57
- while i < len(words):
58
- w = words[i]
59
- # +1 for a space when joined
60
- if current_len + len(w) + (1 if current_len > 0 else 0) <= max_chars:
61
- current.append(w)
62
- current_len += len(w) + (1 if current_len > 0 else 0)
63
- i += 1
64
- else:
65
- chunks.append(" ".join(current))
66
- # move pointer back by `overlap` words for overlapping context
67
- # calculate how many words correspond to overlap characters approx
68
- # (simple heuristic: take last K words)
69
- if overlap <= 0:
70
- current = []
71
- current_len = 0
72
- else:
73
- # take some words from the end as overlap
74
- overlap_chars = overlap
75
- ov = []
76
- ov_len = 0
77
- while current and ov_len + len(current[-1]) + (1 if ov_len > 0 else 0) <= overlap_chars:
78
- ov.insert(0, current.pop())
79
- ov_len += len(ov[0]) + (1 if ov_len > 0 else 0)
80
- current = ov
81
- current_len = ov_len
82
- if current:
83
- chunks.append(" ".join(current))
84
  return chunks
85
 
86
- def summarize_text(text: str, model_name: str = DEFAULT_MODEL,
87
- max_length: int = 130, min_length: int = 30,
88
- chunk_max_chars: int = 1000, chunk_overlap: int = 200) -> str:
89
- """
90
- Summarize a (possibly long) text by chunking -> summarizing chunks -> summarizing combined.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- Returns a final summary string.
93
- """
94
  summarizer = get_summarizer(model_name)
95
 
96
- # Chunk input
97
- chunks = chunk_text(text, max_chars=chunk_max_chars, overlap=chunk_overlap)
98
-
99
- # Summarize each chunk
100
  chunk_summaries = []
101
- for idx, chunk in enumerate(chunks):
102
- # The pipeline returns a list of dicts with 'summary_text'
 
 
 
 
 
 
103
  try:
104
- out = summarizer(chunk, max_length=max_length, min_length=min_length, truncation=True)
105
- s = out[0]["summary_text"].strip()
 
 
 
 
 
106
  except Exception as e:
107
- # Fallback simpler call without length constraints
108
- out = summarizer(chunk, truncation=True)
109
- s = out[0]["summary_text"].strip()
110
- chunk_summaries.append(s)
111
-
112
- # If multiple chunk summaries, summarize them again to produce final short summary
113
- if len(chunk_summaries) == 1:
114
- final = chunk_summaries[0]
115
- else:
116
- combined = "\n".join(chunk_summaries)
117
- # adjust lengths for final summary (shorter)
118
- final_out = summarizer(combined, max_length=min(max_length, 180), min_length=25, truncation=True)
119
- final = final_out[0]["summary_text"].strip()
120
- return final
121
-
122
- @app.route("/", methods=["GET"])
123
- def index():
124
- return jsonify({
125
- "service": "hf-text-summarizer",
126
- "endpoints": {
127
- "POST /summarize": {
128
- "json": {
129
- "text": "string (required)",
130
- "model": "optional HF model id (default facebook/bart-large-cnn)",
131
- "max_length": "optional int (summary max tokens, default 130)",
132
- "min_length": "optional int (summary min tokens, default 30)"
133
- }
134
- }
135
- }
136
- })
137
-
138
- @app.route("/summarize", methods=["POST"])
139
- def summarize_route():
140
- data = request.get_json(force=True, silent=True)
141
- if not data or "text" not in data:
142
- return jsonify({"error": "JSON body required with 'text' field"}), 400
143
-
144
- text = data["text"]
145
- model = data.get("model", DEFAULT_MODEL)
146
- max_length = int(data.get("max_length", 130))
147
- min_length = int(data.get("min_length", 30))
148
-
149
- # Basic input validation
150
- if not isinstance(text, str) or len(text.strip()) == 0:
151
- return jsonify({"error": "text must be a non-empty string"}), 400
152
- if max_length <= 0 or min_length < 0:
153
- return jsonify({"error": "invalid min_length/max_length"}), 400
154
-
155
- # Safety: prevent extremely huge single requests from crashing the server
156
- if len(text) > 500_000: # ~500k chars
157
- return jsonify({"error": "input text too large (limit 500k chars)"}), 413
158
 
159
- try:
160
- summary = summarize_text(text, model_name=model,
161
- max_length=max_length, min_length=min_length)
162
- except Exception as e:
163
- # Try to present a helpful message (avoid leaking internals)
164
- return jsonify({"error": "failed to summarize text", "detail": str(e)}), 500
165
-
166
- return jsonify({
167
- "model": model,
168
- "summary": summary,
169
- "input_length": len(text),
170
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  if __name__ == "__main__":
173
- # Run Flask on port 8000
174
- app.run(host="0.0.0.0", port=8000, debug=False)
 
1
  """
2
+ Simple Gradio app for text summarization.
3
 
4
+ - Supports two pretrained summarization models by Hugging Face.
5
+ - Performs chunking for long inputs and then optionally performs a final
6
+ summary pass over chunk summaries to improve coherence.
7
+ - Suitable for local use or deployment on Hugging Face Spaces.
8
 
9
+ Save as: app.py
 
 
 
 
 
 
 
10
  """
 
 
 
 
 
 
 
11
 
12
+ from functools import lru_cache
13
+ from typing import List
14
+ import math
15
+ import time
16
 
17
+ import gradio as gr
18
+ from transformers import pipeline, Pipeline
19
+ import torch
20
 
21
+ # -------------------------
22
+ # Utilities
23
+ # -------------------------
24
+ def has_gpu():
25
+ try:
26
+ return torch.cuda.is_available()
27
+ except Exception:
28
+ return False
 
29
 
30
  def chunk_text(text: str, max_chars: int = 1000, overlap: int = 200) -> List[str]:
31
  """
32
+ Split text into chunks of roughly max_chars with given overlap.
33
+ Splits at whitespace boundaries for nicer chunks.
34
  """
35
  if len(text) <= max_chars:
36
+ return [text.strip()]
37
 
 
38
  chunks = []
39
+ start = 0
40
+ n = len(text)
41
+ while start < n:
42
+ end = start + max_chars
43
+ if end >= n:
44
+ chunk = text[start:n].strip()
45
+ if chunk:
46
+ chunks.append(chunk)
47
+ break
48
+
49
+ # try to back up to nearest space for cleaner boundary
50
+ backup = text.rfind(" ", start, end)
51
+ if backup <= start:
52
+ backup = end # no space found, hard cut
53
+ chunk = text[start:backup].strip()
54
+ if chunk:
55
+ chunks.append(chunk)
56
+ # move start forward with overlap
57
+ start = backup - overlap
58
+ if start < 0:
59
+ start = 0
60
+
 
 
 
 
 
 
 
 
61
  return chunks
62
 
63
+ # -------------------------
64
+ # Model loading (cached)
65
+ # -------------------------
66
+ @lru_cache(maxsize=4)
67
+ def get_summarizer(model_name: str) -> Pipeline:
68
+ device = 0 if has_gpu() else -1
69
+ # Create pipeline
70
+ summarizer = pipeline("summarization", model=model_name, device=device)
71
+ return summarizer
72
+
73
+ # -------------------------
74
+ # Summarization logic
75
+ # -------------------------
76
+ def summarize_text(
77
+ text: str,
78
+ model_name: str = "facebook/bart-large-cnn",
79
+ min_length: int = 30,
80
+ max_length: int = 200,
81
+ chunk_max_chars: int = 1000,
82
+ do_final_pass: bool = True,
83
+ ):
84
+ if not text or not text.strip():
85
+ return "No input text provided."
86
 
 
 
87
  summarizer = get_summarizer(model_name)
88
 
89
+ # Chunk text if long
90
+ chunks = chunk_text(text, max_chars=chunk_max_chars, overlap=200)
 
 
91
  chunk_summaries = []
92
+
93
+ for i, ch in enumerate(chunks, start=1):
94
+ # Each chunk summarized individually
95
+ # We pass conservative lengths proportional to chunk size
96
+ proportion = min(1.0, len(ch) / chunk_max_chars)
97
+ min_l = max(5, int(min_length * proportion))
98
+ max_l = max(20, int(max_length * proportion))
99
+
100
  try:
101
+ res = summarizer(
102
+ ch,
103
+ min_length=min_l,
104
+ max_length=max_l,
105
+ truncation=True,
106
+ )
107
+ summary_text = res[0]["summary_text"].strip()
108
  except Exception as e:
109
+ summary_text = f"[Error summarizing chunk {i}: {str(e)}]"
110
+ chunk_summaries.append(summary_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ # If multiple chunk summaries, optionally summarize them together again
113
+ if do_final_pass and len(chunk_summaries) > 1:
114
+ joined = " ".join(chunk_summaries)
115
+ try:
116
+ final = summarizer(
117
+ joined,
118
+ min_length=max(20, min_length // 2),
119
+ max_length=max_length,
120
+ truncation=True,
121
+ )
122
+ final_summary = final[0]["summary_text"].strip()
123
+ except Exception as e:
124
+ final_summary = " ".join(chunk_summaries)
125
+ final_summary += f"\n\n[Final-pass error: {str(e)}]"
126
+ return final_summary
127
+ else:
128
+ # if single chunk or not doing final pass, return joined chunk summaries
129
+ return "\n\n".join(chunk_summaries)
130
+
131
+
132
+ # -------------------------
133
+ # Gradio UI
134
+ # -------------------------
135
+ model_choices = [
136
+ ("Facebook BART (cnn) — good general summarizer", "facebook/bart-large-cnn"),
137
+ ("DistilBART (faster) — light-weight", "sshleifer/distilbart-cnn-12-6"),
138
+ ]
139
+
140
+ examples = [
141
+ ["In 2023, the world of AI advanced rapidly. Companies released larger and more capable language models, while researchers focused on safety, alignment, and practical applications. Governments started to craft regulations for responsible deployment. Meanwhile, startups found new ways to apply summarization, code generation, and retrieval-augmented systems. The long-term effects of these developments remain to be seen, but short-term productivity gains were highly visible across many industries."],
142
+ ["Machine learning models require careful tuning of hyperparameters. Learning rate, batch size, and optimizer choice can dramatically affect convergence and final performance. Regularization techniques such as dropout, weight decay, and data augmentation improve generalization. Practitioners routinely combine validation curves and cross-validation to find the best configuration."],
143
+ ]
144
+
145
+ with gr.Blocks(title="Text Summarizer (Hugging Face)") as demo:
146
+ gr.Markdown("# 🧾 Text Summarizer\nSimple Gradio app using Hugging Face summarization pipelines.\n\nEnter text on the left and press **Summarize**.")
147
+ with gr.Row():
148
+ with gr.Column(scale=2):
149
+ inp = gr.Textbox(lines=12, label="Input Text", placeholder="Paste article, long text, or notes here...", value=examples[0][0])
150
+ model = gr.Dropdown([m[0] for m in model_choices], label="Model", value=model_choices[0][0])
151
+ min_len = gr.Slider(5, 200, value=30, step=1, label="Min summary length (tokens / words approx.)")
152
+ max_len = gr.Slider(20, 600, value=150, step=1, label="Max summary length (tokens / words approx.)")
153
+ chunk_size = gr.Slider(500, 4000, value=1000, step=100, label="Chunk size (characters) — for long texts")
154
+ final_pass = gr.Checkbox(value=True, label="Do final-pass summarization (recommended for long inputs)")
155
+ btn = gr.Button("Summarize")
156
+ with gr.Column(scale=1):
157
+ out = gr.Textbox(lines=12, label="Summary")
158
+ gr.Markdown("### Examples")
159
+ ex = gr.Examples(examples=examples, inputs=inp, examples_per_page=6)
160
+
161
+ def _wrap_and_run(text, selected_model_label, min_length, max_length, chunk_size, do_final):
162
+ # map label to model name
163
+ model_map = {m[0]: m[1] for m in model_choices}
164
+ model_name = model_map.get(selected_model_label, model_choices[0][1])
165
+ start = time.time()
166
+ result = summarize_text(
167
+ text=text,
168
+ model_name=model_name,
169
+ min_length=min_length,
170
+ max_length=max_length,
171
+ chunk_max_chars=chunk_size,
172
+ do_final_pass=do_final,
173
+ )
174
+ took = time.time() - start
175
+ footer = f"\n\n---\nModel: {model_name} — Time: {took:.1f}s"
176
+ return result + footer
177
+
178
+ btn.click(
179
+ _wrap_and_run,
180
+ inputs=[inp, model, min_len, max_len, chunk_size, final_pass],
181
+ outputs=[out],
182
+ )
183
 
184
  if __name__ == "__main__":
185
+ # Launch locally on port 7860
186
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)