import math import torch import gradio as gr from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import nltk MODEL_ID = "Ilyakk/t5-summarization" MAX_INPUT_LEN = 512 GEN_MAX_LEN = 64 def ensure_nltk(): try: nltk.data.find("tokenizers/punkt") except LookupError: nltk.download("punkt") try: nltk.data.find("tokenizers/punkt_tab") except LookupError: nltk.download("punkt_tab") ensure_nltk() tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) def _first_sentence(text: str) -> str: text = (text or "").strip() if not text: return "" try: sents = nltk.sent_tokenize(text) return sents[0].strip() if sents else text except Exception: for sep in [".", "!", "?"]: if sep in text: return text.split(sep)[0].strip() return text def generate_titles(text: str, num_titles: int = 3, temperature: float = 0.7): text = (text or "").strip() if not text: return ["Введите текст статьи выше."] enc = tokenizer(["summarize: " + text], return_tensors="pt", truncation=False) ids = enc["input_ids"][0] mask = enc["attention_mask"][0] num_tokens = len(ids) num_spans = max(1, math.ceil(num_tokens / MAX_INPUT_LEN)) overlap = math.ceil((num_spans * MAX_INPUT_LEN - num_tokens) / max(num_spans - 1, 1)) if num_spans > 1 else 0 spans = [] start = 0 for i in range(num_spans): b0 = start + MAX_INPUT_LEN * i b1 = start + MAX_INPUT_LEN * (i + 1) spans.append([max(0, b0), min(num_tokens, b1)]) start -= overlap chosen = [spans[i % len(spans)] for i in range(num_titles)] batch_ids = [ids[b0:b1] for (b0, b1) in chosen] batch_mask = [mask[b0:b1] for (b0, b1) in chosen] batch = { "input_ids": torch.stack(batch_ids).to(device), "attention_mask": torch.stack(batch_mask).to(device), } with torch.no_grad(): outputs = model.generate( **batch, do_sample=True, temperature=float(temperature), max_length=GEN_MAX_LEN, num_beams=1 ) decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True) titles = [_first_sentence(d) for d in decoded] return titles demo = gr.Interface( fn=generate_titles, inputs=[ gr.Textbox(label="Article text", lines=10, placeholder="Paste your article text here"), gr.Slider(1, 10, value=3, step=1, label="Number of titles"), gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature"), ], outputs=gr.List(label="Generated titles"), title="T5 Title Generator", description="Generate candidate titles for articles using your fine-tuned T5 model." ) if __name__ == "__main__": demo.launch()