Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import nltk | |
| MODEL_ID = "Ilyakk/t5-summarization" | |
| MAX_INPUT_LEN = 512 | |
| GEN_MAX_LEN = 64 | |
| def ensure_nltk(): | |
| try: | |
| nltk.data.find("tokenizers/punkt") | |
| except LookupError: | |
| nltk.download("punkt") | |
| try: | |
| nltk.data.find("tokenizers/punkt_tab") | |
| except LookupError: | |
| nltk.download("punkt_tab") | |
| ensure_nltk() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model.to(device) | |
| def _first_sentence(text: str) -> str: | |
| text = (text or "").strip() | |
| if not text: | |
| return "" | |
| try: | |
| sents = nltk.sent_tokenize(text) | |
| return sents[0].strip() if sents else text | |
| except Exception: | |
| for sep in [".", "!", "?"]: | |
| if sep in text: | |
| return text.split(sep)[0].strip() | |
| return text | |
| def generate_titles(text: str, num_titles: int = 3, temperature: float = 0.7): | |
| text = (text or "").strip() | |
| if not text: | |
| return ["Введите текст статьи выше."] | |
| enc = tokenizer(["summarize: " + text], return_tensors="pt", truncation=False) | |
| ids = enc["input_ids"][0] | |
| mask = enc["attention_mask"][0] | |
| num_tokens = len(ids) | |
| num_spans = max(1, math.ceil(num_tokens / MAX_INPUT_LEN)) | |
| overlap = math.ceil((num_spans * MAX_INPUT_LEN - num_tokens) / max(num_spans - 1, 1)) if num_spans > 1 else 0 | |
| spans = [] | |
| start = 0 | |
| for i in range(num_spans): | |
| b0 = start + MAX_INPUT_LEN * i | |
| b1 = start + MAX_INPUT_LEN * (i + 1) | |
| spans.append([max(0, b0), min(num_tokens, b1)]) | |
| start -= overlap | |
| chosen = [spans[i % len(spans)] for i in range(num_titles)] | |
| batch_ids = [ids[b0:b1] for (b0, b1) in chosen] | |
| batch_mask = [mask[b0:b1] for (b0, b1) in chosen] | |
| batch = { | |
| "input_ids": torch.stack(batch_ids).to(device), | |
| "attention_mask": torch.stack(batch_mask).to(device), | |
| } | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **batch, | |
| do_sample=True, | |
| temperature=float(temperature), | |
| max_length=GEN_MAX_LEN, | |
| num_beams=1 | |
| ) | |
| decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| titles = [_first_sentence(d) for d in decoded] | |
| return titles | |
| demo = gr.Interface( | |
| fn=generate_titles, | |
| inputs=[ | |
| gr.Textbox(label="Article text", lines=10, placeholder="Paste your article text here"), | |
| gr.Slider(1, 10, value=3, step=1, label="Number of titles"), | |
| gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature"), | |
| ], | |
| outputs=gr.List(label="Generated titles"), | |
| title="T5 Title Generator", | |
| description="Generate candidate titles for articles using your fine-tuned T5 model." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |