File size: 5,454 Bytes
40075ee
 
 
d874c9e
40075ee
d874c9e
796f2d6
40075ee
d874c9e
40075ee
 
 
 
 
 
d874c9e
40075ee
 
 
 
d874c9e
40075ee
 
 
 
 
d874c9e
f77e72b
d874c9e
40075ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59b89ed
40075ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import re
from typing import List, Tuple, Dict
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import gradio as gr

import math

MODEL_NAME = os.environ.get("SUMM_MODEL", "facebook/bart-large-cnn")
MAX_NEW_TOKENS = 180
MIN_NEW_TOKENS = 40
CHUNK_TOKEN_TARGET = 900
CHUNK_OVERLAP = 120
TARGET_SENTENCES = (3, 5)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
device = 0 if torch.cuda.is_available() else -1
summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=device)

_raw_model_max = getattr(tokenizer, "model_max_length", None)
if not _raw_model_max or _raw_model_max > 100000:
    SAFE_MAX_INPUT_LEN = min(CHUNK_TOKEN_TARGET, 4096)
else:
    SAFE_MAX_INPUT_LEN = min(CHUNK_TOKEN_TARGET, _raw_model_max)




def sent_split(text: str) -> List[str]:
    return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()]


def enforce_sentence_count(text: str, lo: int, hi: int) -> str:
    sents = sent_split(text)
    if len(sents) == 0:
        return text.strip()
    if len(sents) <= hi:
        return " ".join(sents)
    return " ".join(sents[:max(lo, min(hi, len(sents)))])


def chunk_by_tokens(text: str, token_budget: int, overlap: int) -> List[str]:
    input_ids = tokenizer.encode(text, add_special_tokens=False)
    chunks = []
    i = 0
    n = len(input_ids)
    while i < n:
        j = min(i + token_budget, n)
        chunk_ids = input_ids[i:j]
        chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True))
        if j >= n:
            break
        i = j - overlap
        if i < 0:
            i = 0
    return chunks


def safe_truncate_text(text: str, max_tokens: int) -> str:
    ids = tokenizer.encode(text, add_special_tokens=False)
    if len(ids) <= max_tokens:
        return text
    return tokenizer.decode(ids[:max_tokens], skip_special_tokens=True, clean_up_tokenization_spaces=True)


def safe_summarize_single_input(text: str) -> str:
    try:
        res = summarizer(
            text,
            max_new_tokens=MAX_NEW_TOKENS,
            min_new_tokens=MIN_NEW_TOKENS,
            do_sample=False,
        )
        if not res:
            return ""
        return res[0].get("summary_text", "") or ""
    except Exception as e:
        return f"[Summarization error: {type(e).__name__}: {str(e)}]"


def summarize_long(text: str) -> str:
    text = text.strip()
    if not text:
        return ""
    try:
        total_tokens = len(tokenizer.encode(text, add_special_tokens=False))
    except Exception:
        total_tokens = math.inf
    if total_tokens <= CHUNK_TOKEN_TARGET:
        small_input = safe_truncate_text(text, SAFE_MAX_INPUT_LEN)
        out = safe_summarize_single_input(small_input)
        return enforce_sentence_count(out, *TARGET_SENTENCES)
    chunks = chunk_by_tokens(text, CHUNK_TOKEN_TARGET, CHUNK_OVERLAP)
    sub_summaries: List[str] = []
    for c in chunks:
        c_trunc = safe_truncate_text(c, SAFE_MAX_INPUT_LEN)
        s = safe_summarize_single_input(c_trunc)
        sub_summaries.append(s)
    fused = "\n".join([s for s in sub_summaries if s])
    fused_trunc = safe_truncate_text(fused, SAFE_MAX_INPUT_LEN)
    fused_sum = safe_summarize_single_input(fused_trunc)
    return enforce_sentence_count(fused_sum, *TARGET_SENTENCES)


def summarize_batch(texts: List[str]) -> List[str]:
    results = []
    for t in texts:
        results.append(summarize_long(t))
    return results


def summarize_multi(input_blob: str) -> str:
    parts = [p.strip() for p in re.split(r"\n-{3,}\n", input_blob.strip()) if p.strip()]
    outs = summarize_batch(parts)
    return "\n\n".join(outs)


with gr.Blocks(title="AI Text Summarizer") as demo:
    gr.Markdown(
        "# AI Text Summarizer\nSummarizes long documents "
    )
    with gr.Tab("Single"):
        inp = gr.Textbox(label="Input text to summarize", lines=12, placeholder="Paste an article or paper section…")
        out = gr.Textbox(label="Summarized text (3–5 sentences)", lines=6)
        btn = gr.Button("Summarize")
        btn.click(fn=lambda x: summarize_long(x), inputs=inp, outputs=out)
    with gr.Tab("Batch"):
        gr.Markdown("Paste multiple articles separated by a line with `---` on its own.")
        inp_b = gr.Textbox(label="Multiple articles", lines=18, placeholder="Article 1...\n---\nArticle 2...")
        out_b = gr.Textbox(label="Summaries (each 3–5 sentences)", lines=18)
        btn_b = gr.Button("Summarize Batch")
        btn_b.click(fn=summarize_multi, inputs=inp_b, outputs=out_b)
    
    with gr.Accordion("Debug / Warnings", open=False):
        gr.Markdown(
        f"- SAFE_MAX_INPUT_LEN={SAFE_MAX_INPUT_LEN}\n"
        f"- CHUNK_TOKEN_TARGET={CHUNK_TOKEN_TARGET}\n"
        f"- CHUNK_OVERLAP={CHUNK_OVERLAP}\n"
        "- Truncation warning: inputs longer than SAFE_MAX_INPUT_LEN are pre-truncated using the model tokenizer to avoid ambiguous pipeline truncation behavior.\n"
        "- Potential runtime issues: large inputs can still cause high memory usage; reduce CHUNK_TOKEN_TARGET or batch size, or run on a GPU with sufficient memory.\n\n"
        "---\n"
        "© **Vivek Reddy**  \n"
        "🔗 [GitHub](https://github.com/vivekreddy1105) | [LinkedIn](https://linkedin.com/in/vivekreddy1105)"
    )


if __name__ == "__main__":
    demo.launch()