Azidan commited on
Commit
487a5d4
·
verified ·
1 Parent(s): e7e069e

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +250 -0
main.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import math
3
+ from typing import List, Tuple, Optional
4
+
5
+ import gradio as gr
6
+ from transformers import AutoTokenizer, pipeline
7
+ import PyPDF2
8
+ import docx
9
+
10
+ # -----------------------------
11
+ # Configuration
12
+ # -----------------------------
13
+ MODEL_NAME = "sshleifer/distilbart-cnn-12-6" # lightweight, works on free tier
14
+ DEVICE = -1 # force CPU (Spaces free tier)
15
+ CHUNK_STRIDE = 128 # overlap tokens between chunks (keeps context)
16
+ SECOND_PASS = True # run final summarization on joined chunk summaries
17
+
18
+ # Summary length presets (max tokens in generated summary)
19
+ SUMMARY_PRESETS = {
20
+ "short": {"max_length": 60, "min_length": 20},
21
+ "medium": {"max_length": 120, "min_length": 40},
22
+ "long": {"max_length": 200, "min_length": 80},
23
+ }
24
+
25
+ # -----------------------------
26
+ # Load tokenizer & pipeline
27
+ # -----------------------------
28
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29
+ summarizer = pipeline("summarization", model=MODEL_NAME, tokenizer=tokenizer, device=DEVICE)
30
+
31
+
32
+ # -----------------------------
33
+ # Helpers: file reading
34
+ # -----------------------------
35
+ def read_pdf_bytes(file_bytes: bytes) -> str:
36
+ try:
37
+ reader = PyPDF2.PdfReader(io.BytesIO(file_bytes))
38
+ pages = []
39
+ for p in reader.pages:
40
+ text = p.extract_text()
41
+ if text:
42
+ pages.append(text)
43
+ return "\n".join(pages)
44
+ except Exception as e:
45
+ return f"[Error reading PDF: {e}]"
46
+
47
+
48
+ def read_docx_bytes(file_bytes: bytes) -> str:
49
+ try:
50
+ doc = docx.Document(io.BytesIO(file_bytes))
51
+ paragraphs = [p.text for p in doc.paragraphs if p.text and p.text.strip()]
52
+ return "\n".join(paragraphs)
53
+ except Exception as e:
54
+ return f"[Error reading DOCX: {e}]"
55
+
56
+
57
+ # -----------------------------
58
+ # Helpers: token-aware chunking
59
+ # -----------------------------
60
+ def chunk_text_by_tokens(text: str, max_tokens: Optional[int] = None, stride: int = CHUNK_STRIDE) -> List[str]:
61
+ """
62
+ Split text into chunks no longer than `max_tokens` tokens each.
63
+ Use overlap `stride` to preserve context between chunks.
64
+ Returns list of chunk strings (decoded).
65
+ """
66
+ if not text or not text.strip():
67
+ return []
68
+
69
+ if max_tokens is None:
70
+ max_tokens = tokenizer.model_max_length # typically 1024 for this model
71
+
72
+ # encode without special tokens to control slicing precisely
73
+ token_ids = tokenizer.encode(text, add_special_tokens=False)
74
+ n = len(token_ids)
75
+ if n <= max_tokens:
76
+ return [text.strip()]
77
+
78
+ chunks = []
79
+ start = 0
80
+ while start < n:
81
+ end = min(start + max_tokens, n)
82
+ chunk_ids = token_ids[start:end]
83
+ chunk_text = tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
84
+ chunks.append(chunk_text.strip())
85
+ if end == n:
86
+ break
87
+ start = end - stride # overlap
88
+ return chunks
89
+
90
+
91
+ # -----------------------------
92
+ # Summarization logic
93
+ # -----------------------------
94
+ def summarize_chunks(chunks: List[str], preset: str, progress: Optional[gr.Progress] = None) -> Tuple[List[str], str]:
95
+ """
96
+ Summarize each chunk and return (list_of_chunk_summaries, final_summary).
97
+ If SECOND_PASS is True and >1 chunk, perform a second summarization of the concatenated chunk summaries.
98
+ """
99
+
100
+ if preset not in SUMMARY_PRESETS:
101
+ preset = "medium"
102
+ max_len = SUMMARY_PRESETS[preset]["max_length"]
103
+ min_len = SUMMARY_PRESETS[preset]["min_length"]
104
+
105
+ chunk_summaries = []
106
+ total = len(chunks)
107
+ for idx, chunk in enumerate(chunks, start=1):
108
+ # call summarizer safely (each chunk within token limit)
109
+ try:
110
+ out = summarizer(
111
+ chunk,
112
+ max_length=max_len,
113
+ min_length=min_len,
114
+ do_sample=False,
115
+ truncation=True
116
+ )
117
+ summary_text = out[0]["summary_text"].strip()
118
+ except Exception as e:
119
+ summary_text = f"[Chunk summarization error: {e}]"
120
+ chunk_summaries.append(summary_text)
121
+
122
+ if progress:
123
+ progress((idx / total) * 0.7, desc=f"Summarizing chunk {idx}/{total}...")
124
+
125
+ # Second pass: summarize combined chunk summaries to produce final summary
126
+ final_summary = ""
127
+ if SECOND_PASS and len(chunk_summaries) > 1:
128
+ joined = "\n\n".join(chunk_summaries)
129
+ # ensure joined fits token limit for model input by chunking again if needed
130
+ joined_chunks = chunk_text_by_tokens(joined, max_tokens=tokenizer.model_max_length, stride=CHUNK_STRIDE)
131
+ try:
132
+ # if single joined chunk, summarize directly; otherwise summarize the joined chunks sequentially then join and summarize once more
133
+ if len(joined_chunks) == 1:
134
+ out = summarizer(
135
+ joined_chunks[0],
136
+ max_length=max_len,
137
+ min_length=min_len,
138
+ do_sample=False,
139
+ truncation=True
140
+ )
141
+ final_summary = out[0]["summary_text"].strip()
142
+ else:
143
+ # reduce: summarize each joined_chunk into short pieces, then join and summarize final
144
+ intermediate = []
145
+ for jc in joined_chunks:
146
+ out = summarizer(jc, max_length=max_len, min_length=min_len, do_sample=False, truncation=True)
147
+ intermediate.append(out[0]["summary_text"].strip())
148
+ # final compression
149
+ final_text = "\n\n".join(intermediate)
150
+ out = summarizer(final_text, max_length=max_len, min_length=min_len, do_sample=False, truncation=True)
151
+ final_summary = out[0]["summary_text"].strip()
152
+ except Exception as e:
153
+ final_summary = f"[Final summarization error: {e}]"
154
+ else:
155
+ # if only one chunk or second pass disabled, final = join of chunk_summaries or the first chunk summary
156
+ final_summary = "\n\n".join(chunk_summaries) if len(chunk_summaries) > 1 else (chunk_summaries[0] if chunk_summaries else "")
157
+
158
+ if progress:
159
+ progress(1.0, desc="Done")
160
+
161
+ return chunk_summaries, final_summary
162
+
163
+
164
+ # -----------------------------
165
+ # Gradio processing function
166
+ # -----------------------------
167
+ def process(text_input: str, uploaded_file, preset: str, show_intermediate: bool, progress=gr.Progress()):
168
+ progress(0.0, desc="Extracting text...")
169
+
170
+ # Extract text
171
+ extracted = ""
172
+ if uploaded_file is not None:
173
+ try:
174
+ file_bytes = uploaded_file.read()
175
+ fname = uploaded_file.name.lower()
176
+ if fname.endswith(".pdf"):
177
+ extracted = read_pdf_bytes(file_bytes)
178
+ elif fname.endswith(".docx"):
179
+ extracted = read_docx_bytes(file_bytes)
180
+ else:
181
+ # fallback: try to decode as text
182
+ try:
183
+ extracted = file_bytes.decode("utf-8", errors="replace")
184
+ except Exception:
185
+ extracted = "[Unsupported file type]"
186
+ except Exception as e:
187
+ return f"[File read error: {e}]", "", ""
188
+ # combine pasted text with file text (file first)
189
+ if text_input and text_input.strip():
190
+ combined = (extracted + "\n\n" + text_input.strip()).strip()
191
+ else:
192
+ combined = extracted.strip()
193
+
194
+ if not combined:
195
+ return "No text found. Paste text or upload a PDF/DOCX file.", "", ""
196
+
197
+ # chunk text by tokens
198
+ progress(0.05, desc="Splitting into chunks...")
199
+ max_tokens = tokenizer.model_max_length # model input limit
200
+ chunks = chunk_text_by_tokens(combined, max_tokens=max_tokens, stride=CHUNK_STRIDE)
201
+
202
+ # safety: if still empty
203
+ if not chunks:
204
+ return "No text extracted from the file or input.", "", ""
205
+
206
+ # Summarize chunks (progress updates included)
207
+ chunk_summaries, final_summary = summarize_chunks(chunks, preset, progress=progress)
208
+
209
+ # Prepare intermediate summary output
210
+ intermediate_md_lines = []
211
+ for i, s in enumerate(chunk_summaries, start=1):
212
+ intermediate_md_lines.append(f"### Chunk {i} Summary\n\n{s}\n")
213
+ intermediate_md = "\n".join(intermediate_md_lines)
214
+
215
+ stats = f"Input tokens (approx): {sum(len(tokenizer.encode(c, add_special_tokens=False)) for c in chunks)} | Chunks: {len(chunks)}"
216
+
217
+ if show_intermediate:
218
+ return final_summary, intermediate_md, stats
219
+ else:
220
+ return final_summary, "", stats
221
+
222
+
223
+ # -----------------------------
224
+ # Gradio UI
225
+ # -----------------------------
226
+ demo = gr.Interface(
227
+ fn=process,
228
+ inputs=[
229
+ gr.Textbox(lines=12, placeholder="Paste text here (optional)...", label="Paste text (optional)"),
230
+ gr.File(label="Upload PDF or DOCX (optional)"),
231
+ gr.Radio(choices=["short", "medium", "long"], value="medium", label="Summary length (preset)"),
232
+ gr.Checkbox(value=False, label="Show intermediate chunk summaries")
233
+ ],
234
+ outputs=[
235
+ gr.Textbox(label="Final Summary"),
236
+ gr.Markdown(label="Intermediate Chunk Summaries (if enabled)"),
237
+ gr.Textbox(label="Stats")
238
+ ],
239
+ title="Hierarchical Long-Text Summarizer (token-aware, free-tier)",
240
+ description=(
241
+ "Paste text or upload a PDF/DOCX. The system splits long input by tokens, summarizes each chunk,"
242
+ " then optionally performs a 2nd-pass summarization to produce a concise final summary."
243
+ ),
244
+ allow_flagging="never",
245
+ examples=[],
246
+ )
247
+
248
+ if __name__ == "__main__":
249
+ # on Spaces this will be ignored and Gradio will serve automatically
250
+ demo.launch()