Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| from typing import List, Tuple, Dict | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
| import gradio as gr | |
| import math | |
| MODEL_NAME = os.environ.get("SUMM_MODEL", "facebook/bart-large-cnn") | |
| MAX_NEW_TOKENS = 180 | |
| MIN_NEW_TOKENS = 40 | |
| CHUNK_TOKEN_TARGET = 900 | |
| CHUNK_OVERLAP = 120 | |
| TARGET_SENTENCES = (3, 5) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
| device = 0 if torch.cuda.is_available() else -1 | |
| summarizer = pipeline("summarization", model=model, tokenizer=tokenizer, device=device) | |
| _raw_model_max = getattr(tokenizer, "model_max_length", None) | |
| if not _raw_model_max or _raw_model_max > 100000: | |
| SAFE_MAX_INPUT_LEN = min(CHUNK_TOKEN_TARGET, 4096) | |
| else: | |
| SAFE_MAX_INPUT_LEN = min(CHUNK_TOKEN_TARGET, _raw_model_max) | |
| def sent_split(text: str) -> List[str]: | |
| return [s.strip() for s in re.split(r"(?<=[.!?])\s+", text) if s.strip()] | |
| def enforce_sentence_count(text: str, lo: int, hi: int) -> str: | |
| sents = sent_split(text) | |
| if len(sents) == 0: | |
| return text.strip() | |
| if len(sents) <= hi: | |
| return " ".join(sents) | |
| return " ".join(sents[:max(lo, min(hi, len(sents)))]) | |
| def chunk_by_tokens(text: str, token_budget: int, overlap: int) -> List[str]: | |
| input_ids = tokenizer.encode(text, add_special_tokens=False) | |
| chunks = [] | |
| i = 0 | |
| n = len(input_ids) | |
| while i < n: | |
| j = min(i + token_budget, n) | |
| chunk_ids = input_ids[i:j] | |
| chunks.append(tokenizer.decode(chunk_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)) | |
| if j >= n: | |
| break | |
| i = j - overlap | |
| if i < 0: | |
| i = 0 | |
| return chunks | |
| def safe_truncate_text(text: str, max_tokens: int) -> str: | |
| ids = tokenizer.encode(text, add_special_tokens=False) | |
| if len(ids) <= max_tokens: | |
| return text | |
| return tokenizer.decode(ids[:max_tokens], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| def safe_summarize_single_input(text: str) -> str: | |
| try: | |
| res = summarizer( | |
| text, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| min_new_tokens=MIN_NEW_TOKENS, | |
| do_sample=False, | |
| ) | |
| if not res: | |
| return "" | |
| return res[0].get("summary_text", "") or "" | |
| except Exception as e: | |
| return f"[Summarization error: {type(e).__name__}: {str(e)}]" | |
| def summarize_long(text: str) -> str: | |
| text = text.strip() | |
| if not text: | |
| return "" | |
| try: | |
| total_tokens = len(tokenizer.encode(text, add_special_tokens=False)) | |
| except Exception: | |
| total_tokens = math.inf | |
| if total_tokens <= CHUNK_TOKEN_TARGET: | |
| small_input = safe_truncate_text(text, SAFE_MAX_INPUT_LEN) | |
| out = safe_summarize_single_input(small_input) | |
| return enforce_sentence_count(out, *TARGET_SENTENCES) | |
| chunks = chunk_by_tokens(text, CHUNK_TOKEN_TARGET, CHUNK_OVERLAP) | |
| sub_summaries: List[str] = [] | |
| for c in chunks: | |
| c_trunc = safe_truncate_text(c, SAFE_MAX_INPUT_LEN) | |
| s = safe_summarize_single_input(c_trunc) | |
| sub_summaries.append(s) | |
| fused = "\n".join([s for s in sub_summaries if s]) | |
| fused_trunc = safe_truncate_text(fused, SAFE_MAX_INPUT_LEN) | |
| fused_sum = safe_summarize_single_input(fused_trunc) | |
| return enforce_sentence_count(fused_sum, *TARGET_SENTENCES) | |
| def summarize_batch(texts: List[str]) -> List[str]: | |
| results = [] | |
| for t in texts: | |
| results.append(summarize_long(t)) | |
| return results | |
| def summarize_multi(input_blob: str) -> str: | |
| parts = [p.strip() for p in re.split(r"\n-{3,}\n", input_blob.strip()) if p.strip()] | |
| outs = summarize_batch(parts) | |
| return "\n\n".join(outs) | |
| with gr.Blocks(title="AI Text Summarizer") as demo: | |
| gr.Markdown( | |
| "# AI Text Summarizer\nSummarizes long documents " | |
| ) | |
| with gr.Tab("Single"): | |
| inp = gr.Textbox(label="Input text to summarize", lines=12, placeholder="Paste an article or paper section…") | |
| out = gr.Textbox(label="Summarized text (3–5 sentences)", lines=6) | |
| btn = gr.Button("Summarize") | |
| btn.click(fn=lambda x: summarize_long(x), inputs=inp, outputs=out) | |
| with gr.Tab("Batch"): | |
| gr.Markdown("Paste multiple articles separated by a line with `---` on its own.") | |
| inp_b = gr.Textbox(label="Multiple articles", lines=18, placeholder="Article 1...\n---\nArticle 2...") | |
| out_b = gr.Textbox(label="Summaries (each 3–5 sentences)", lines=18) | |
| btn_b = gr.Button("Summarize Batch") | |
| btn_b.click(fn=summarize_multi, inputs=inp_b, outputs=out_b) | |
| with gr.Accordion("Debug / Warnings", open=False): | |
| gr.Markdown( | |
| f"- SAFE_MAX_INPUT_LEN={SAFE_MAX_INPUT_LEN}\n" | |
| f"- CHUNK_TOKEN_TARGET={CHUNK_TOKEN_TARGET}\n" | |
| f"- CHUNK_OVERLAP={CHUNK_OVERLAP}\n" | |
| "- Truncation warning: inputs longer than SAFE_MAX_INPUT_LEN are pre-truncated using the model tokenizer to avoid ambiguous pipeline truncation behavior.\n" | |
| "- Potential runtime issues: large inputs can still cause high memory usage; reduce CHUNK_TOKEN_TARGET or batch size, or run on a GPU with sufficient memory.\n\n" | |
| "---\n" | |
| "© **Vivek Reddy** \n" | |
| "🔗 [GitHub](https://github.com/vivekreddy1105) | [LinkedIn](https://linkedin.com/in/vivekreddy1105)" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |