TextSummarizer / app.py
vivekreddy1105's picture
Update app.py
59b89ed verified
import os
import re
from typing import List, Tuple, Dict
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import gradio as gr
import math
MODEL_NAME = os.environ.get("SUMM_MODEL", "facebook/bart-large-cnn")
MAX_NEW_TOKENS = 180
MIN_NEW_TOKENS = 40
CHUNK_TOKEN_TARGET = 900
CHUNK_OVERLAP = 120
TARGET_SENTENCES = (3, 5)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
device = 0 if torch.cuda.is_available() else -1
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=device)
_raw_model_max = getattr(tokenizer, "model_max_length", None)
if not _raw_model_max or _raw_model_max > 100000:
SAFE_MAX_INPUT_LEN = min(CHUNK_TOKEN_TARGET, 4096)
else:
SAFE_MAX_INPUT_LEN = min(CHUNK_TOKEN_TARGET, _raw_model_max)
def sent_split(text: str) -> List[str]:
return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]
def enforce_sentence_count(text: str, lo: int, hi: int) -> str:
sents = sent_split(text)
if len(sents) == 0:
return text.strip()
if len(sents) <= hi:
return " ".join(sents)
return " ".join(sents[:max(lo, min(hi, len(sents)))])
def chunk_by_tokens(text: str, token_budget: int, overlap: int) -> List[str]:
input_ids = tokenizer.encode(text, add_special_tokens=False)
chunks = []
i = 0
n = len(input_ids)
while i < n:
j = min(i + token_budget, n)
chunk_ids = input_ids[i:j]
chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True))
if j >= n:
break
i = j - overlap
if i < 0:
i = 0
return chunks
def safe_truncate_text(text: str, max_tokens: int) -> str:
ids = tokenizer.encode(text, add_special_tokens=False)
if len(ids) <= max_tokens:
return text
return tokenizer.decode(ids[:max_tokens], skip_special_tokens=True, clean_up_tokenization_spaces=True)
def safe_summarize_single_input(text: str) -> str:
try:
res = summarizer(
text,
max_new_tokens=MAX_NEW_TOKENS,
min_new_tokens=MIN_NEW_TOKENS,
do_sample=False,
)
if not res:
return ""
return res[0].get("summary_text", "") or ""
except Exception as e:
return f"[Summarization error: {type(e).__name__}: {str(e)}]"
def summarize_long(text: str) -> str:
text = text.strip()
if not text:
return ""
try:
total_tokens = len(tokenizer.encode(text, add_special_tokens=False))
except Exception:
total_tokens = math.inf
if total_tokens <= CHUNK_TOKEN_TARGET:
small_input = safe_truncate_text(text, SAFE_MAX_INPUT_LEN)
out = safe_summarize_single_input(small_input)
return enforce_sentence_count(out, *TARGET_SENTENCES)
chunks = chunk_by_tokens(text, CHUNK_TOKEN_TARGET, CHUNK_OVERLAP)
sub_summaries: List[str] = []
for c in chunks:
c_trunc = safe_truncate_text(c, SAFE_MAX_INPUT_LEN)
s = safe_summarize_single_input(c_trunc)
sub_summaries.append(s)
fused = "\n".join([s for s in sub_summaries if s])
fused_trunc = safe_truncate_text(fused, SAFE_MAX_INPUT_LEN)
fused_sum = safe_summarize_single_input(fused_trunc)
return enforce_sentence_count(fused_sum, *TARGET_SENTENCES)
def summarize_batch(texts: List[str]) -> List[str]:
results = []
for t in texts:
results.append(summarize_long(t))
return results
def summarize_multi(input_blob: str) -> str:
parts = [p.strip() for p in re.split(r"\n-{3,}\n", input_blob.strip()) if p.strip()]
outs = summarize_batch(parts)
return "\n\n".join(outs)
with gr.Blocks(title="AI Text Summarizer") as demo:
gr.Markdown(
"# AI Text Summarizer\nSummarizes long documents "
)
with gr.Tab("Single"):
inp = gr.Textbox(label="Input text to summarize", lines=12, placeholder="Paste an article or paper section…")
out = gr.Textbox(label="Summarized text (3–5 sentences)", lines=6)
btn = gr.Button("Summarize")
btn.click(fn=lambda x: summarize_long(x), inputs=inp, outputs=out)
with gr.Tab("Batch"):
gr.Markdown("Paste multiple articles separated by a line with `---` on its own.")
inp_b = gr.Textbox(label="Multiple articles", lines=18, placeholder="Article 1...\n---\nArticle 2...")
out_b = gr.Textbox(label="Summaries (each 3–5 sentences)", lines=18)
btn_b = gr.Button("Summarize Batch")
btn_b.click(fn=summarize_multi, inputs=inp_b, outputs=out_b)
with gr.Accordion("Debug / Warnings", open=False):
gr.Markdown(
f"- SAFE_MAX_INPUT_LEN={SAFE_MAX_INPUT_LEN}\n"
f"- CHUNK_TOKEN_TARGET={CHUNK_TOKEN_TARGET}\n"
f"- CHUNK_OVERLAP={CHUNK_OVERLAP}\n"
"- Truncation warning: inputs longer than SAFE_MAX_INPUT_LEN are pre-truncated using the model tokenizer to avoid ambiguous pipeline truncation behavior.\n"
"- Potential runtime issues: large inputs can still cause high memory usage; reduce CHUNK_TOKEN_TARGET or batch size, or run on a GPU with sufficient memory.\n\n"
"---\n"
"© **Vivek Reddy** \n"
"🔗 [GitHub](https://github.com/vivekreddy1105) | [LinkedIn](https://linkedin.com/in/vivekreddy1105)"
)
if __name__ == "__main__":
demo.launch()