mbalvi commited on
Commit
4012cea
·
verified ·
1 Parent(s): 20bd76a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +174 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Hugging Face Text Summarizer (Flask)
3
+
4
+ Endpoints:
5
+ - GET / -> basic info
6
+ - POST /summarize -> JSON input: {"text": "...", "model": "...", "max_length": 130, "min_length": 30}
7
+
8
+ How to run:
9
+ 1. pip install -r requirements.txt
10
+ 2. python app.py
11
+ 3. POST JSON to http://127.0.0.1:8000/summarize
12
+
13
+ Notes:
14
+ - Default model: "facebook/bart-large-cnn". You may change to any summarization-capable HF model.
15
+ - For very long texts, the app chunks the text and summarizes each chunk, then summarizes the concatenated chunk-summaries.
16
+ """
17
+ from flask import Flask, request, jsonify
18
+ from transformers import pipeline, Pipeline
19
+ from typing import List, Optional
20
+ import threading
21
+ import math
22
+ import textwrap
23
+ import os
24
+
25
+ app = Flask(__name__)
26
+
27
+ # Default model (good general-purpose summarizer)
28
+ DEFAULT_MODEL = os.getenv("SUMMARIZER_MODEL", "facebook/bart-large-cnn")
29
+
30
+ # Global pipeline cache to avoid reloading between requests
31
+ _PIPELINES = {}
32
+ _PIPELINES_LOCK = threading.Lock()
33
+
34
+ def get_summarizer(model_name: str = DEFAULT_MODEL) -> Pipeline:
35
+ """
36
+ Return a cached summarization pipeline for model_name or create one.
37
+ """
38
+ with _PIPELINES_LOCK:
39
+ if model_name not in _PIPELINES:
40
+ # Create pipeline (use default device; if you have GPU and torch detects it, it'll use it)
41
+ _PIPELINES[model_name] = pipeline("summarization", model=model_name)
42
+ return _PIPELINES[model_name]
43
+
44
+ def chunk_text(text: str, max_chars: int = 1000, overlap: int = 200) -> List[str]:
45
+ """
46
+ Chunk text into pieces of at most max_chars (approx) with specified overlap.
47
+ This is a simple, robust chunker using whitespace boundaries.
48
+ """
49
+ if len(text) <= max_chars:
50
+ return [text]
51
+
52
+ words = text.split()
53
+ chunks = []
54
+ current = []
55
+ current_len = 0
56
+ i = 0
57
+ while i < len(words):
58
+ w = words[i]
59
+ # +1 for a space when joined
60
+ if current_len + len(w) + (1 if current_len > 0 else 0) <= max_chars:
61
+ current.append(w)
62
+ current_len += len(w) + (1 if current_len > 0 else 0)
63
+ i += 1
64
+ else:
65
+ chunks.append(" ".join(current))
66
+ # move pointer back by `overlap` words for overlapping context
67
+ # calculate how many words correspond to overlap characters approx
68
+ # (simple heuristic: take last K words)
69
+ if overlap <= 0:
70
+ current = []
71
+ current_len = 0
72
+ else:
73
+ # take some words from the end as overlap
74
+ overlap_chars = overlap
75
+ ov = []
76
+ ov_len = 0
77
+ while current and ov_len + len(current[-1]) + (1 if ov_len > 0 else 0) <= overlap_chars:
78
+ ov.insert(0, current.pop())
79
+ ov_len += len(ov[0]) + (1 if ov_len > 0 else 0)
80
+ current = ov
81
+ current_len = ov_len
82
+ if current:
83
+ chunks.append(" ".join(current))
84
+ return chunks
85
+
86
+ def summarize_text(text: str, model_name: str = DEFAULT_MODEL,
87
+ max_length: int = 130, min_length: int = 30,
88
+ chunk_max_chars: int = 1000, chunk_overlap: int = 200) -> str:
89
+ """
90
+ Summarize a (possibly long) text by chunking -> summarizing chunks -> summarizing combined.
91
+
92
+ Returns a final summary string.
93
+ """
94
+ summarizer = get_summarizer(model_name)
95
+
96
+ # Chunk input
97
+ chunks = chunk_text(text, max_chars=chunk_max_chars, overlap=chunk_overlap)
98
+
99
+ # Summarize each chunk
100
+ chunk_summaries = []
101
+ for idx, chunk in enumerate(chunks):
102
+ # The pipeline returns a list of dicts with 'summary_text'
103
+ try:
104
+ out = summarizer(chunk, max_length=max_length, min_length=min_length, truncation=True)
105
+ s = out[0]["summary_text"].strip()
106
+ except Exception as e:
107
+ # Fallback simpler call without length constraints
108
+ out = summarizer(chunk, truncation=True)
109
+ s = out[0]["summary_text"].strip()
110
+ chunk_summaries.append(s)
111
+
112
+ # If multiple chunk summaries, summarize them again to produce final short summary
113
+ if len(chunk_summaries) == 1:
114
+ final = chunk_summaries[0]
115
+ else:
116
+ combined = "\n".join(chunk_summaries)
117
+ # adjust lengths for final summary (shorter)
118
+ final_out = summarizer(combined, max_length=min(max_length, 180), min_length=25, truncation=True)
119
+ final = final_out[0]["summary_text"].strip()
120
+ return final
121
+
122
+ @app.route("/", methods=["GET"])
123
+ def index():
124
+ return jsonify({
125
+ "service": "hf-text-summarizer",
126
+ "endpoints": {
127
+ "POST /summarize": {
128
+ "json": {
129
+ "text": "string (required)",
130
+ "model": "optional HF model id (default facebook/bart-large-cnn)",
131
+ "max_length": "optional int (summary max tokens, default 130)",
132
+ "min_length": "optional int (summary min tokens, default 30)"
133
+ }
134
+ }
135
+ }
136
+ })
137
+
138
+ @app.route("/summarize", methods=["POST"])
139
+ def summarize_route():
140
+ data = request.get_json(force=True, silent=True)
141
+ if not data or "text" not in data:
142
+ return jsonify({"error": "JSON body required with 'text' field"}), 400
143
+
144
+ text = data["text"]
145
+ model = data.get("model", DEFAULT_MODEL)
146
+ max_length = int(data.get("max_length", 130))
147
+ min_length = int(data.get("min_length", 30))
148
+
149
+ # Basic input validation
150
+ if not isinstance(text, str) or len(text.strip()) == 0:
151
+ return jsonify({"error": "text must be a non-empty string"}), 400
152
+ if max_length <= 0 or min_length < 0:
153
+ return jsonify({"error": "invalid min_length/max_length"}), 400
154
+
155
+ # Safety: prevent extremely huge single requests from crashing the server
156
+ if len(text) > 500_000: # ~500k chars
157
+ return jsonify({"error": "input text too large (limit 500k chars)"}), 413
158
+
159
+ try:
160
+ summary = summarize_text(text, model_name=model,
161
+ max_length=max_length, min_length=min_length)
162
+ except Exception as e:
163
+ # Try to present a helpful message (avoid leaking internals)
164
+ return jsonify({"error": "failed to summarize text", "detail": str(e)}), 500
165
+
166
+ return jsonify({
167
+ "model": model,
168
+ "summary": summary,
169
+ "input_length": len(text),
170
+ })
171
+
172
+ if __name__ == "__main__":
173
+ # Run Flask on port 8000
174
+ app.run(host="0.0.0.0", port=8000, debug=False)