Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import nltk | |
| import math | |
| import torch | |
| import re | |
| model_name = "AGIvan/t5-base-title-generation" | |
| max_input_length = 512 | |
| def extract_first_sentence(text: str) -> str: | |
| """Извлекает первое предложение из текста без NLTK.""" | |
| text = text.strip() | |
| if not text: | |
| return text | |
| # Ищем конец первого предложения: точка/воскл./вопр. знак + пробел или конец строки | |
| match = re.search(r'(.*?[.!?])(?:\s|$)', text) | |
| return match.group(1) if match else text.split('\n')[0] | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| def generate_titles(text, num_titles, temperature): | |
| if not text.strip(): | |
| return "Пожалуйста, введите текст для генерации заголовков" | |
| inputs = ["summarize: " + text] | |
| inputs = tokenizer(inputs, return_tensors="pt") | |
| num_tokens = inputs["input_ids"].shape[1] | |
| num_spans = math.ceil(num_tokens / max_input_length) | |
| overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1)) if num_spans > 1 else 0 | |
| spans_boundaries = [] | |
| start = 0 | |
| for i in range(num_spans): | |
| end = start + max_input_length | |
| spans_boundaries.append([start, end]) | |
| start = end - overlap | |
| spans_boundaries_selected = [] | |
| j = 0 | |
| for _ in range(num_titles): | |
| spans_boundaries_selected.append(spans_boundaries[j % len(spans_boundaries)]) | |
| j += 1 | |
| tensor_ids = [] | |
| tensor_masks = [] | |
| for boundary in spans_boundaries_selected: | |
| span_ids = inputs["input_ids"][0][boundary[0]:boundary[1]].unsqueeze(0) | |
| span_mask = inputs["attention_mask"][0][boundary[0]:boundary[1]].unsqueeze(0) | |
| tensor_ids.append(span_ids) | |
| tensor_masks.append(span_mask) | |
| model_inputs = { | |
| "input_ids": torch.cat(tensor_ids, dim=0), | |
| "attention_mask": torch.cat(tensor_masks, dim=0) | |
| } | |
| outputs = model.generate( | |
| **model_inputs, | |
| do_sample=True, | |
| temperature=temperature, | |
| max_length=50 | |
| ) | |
| decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True) | |
| predicted_titles = [] | |
| for output in decoded_outputs: | |
| # Заменяем NLTK на кастомную функцию | |
| title = extract_first_sentence(output) | |
| predicted_titles.append(title) | |
| html_output = "<div style='font-size:16px; padding:10px;'>" | |
| for i, title in enumerate(predicted_titles, 1): | |
| html_output += f"<p><b>Вариант {i}:</b> {title}</p>" | |
| html_output += "</div>" | |
| return html_output | |
| # Создание интерфейса | |
| with gr.Blocks(title="Генератор заголовков статей") as interface: | |
| gr.Markdown("# 🗞 Генератор заголовков для статей") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## ⚙ Настройки модели") | |
| num_titles = gr.Slider(1, 10, value=5, step=1, label="Количество заголовков") | |
| temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Случайность (temperature)") | |
| gr.Markdown("🔺 Большее значение = более креативные результаты") | |
| with gr.Column(scale=3): | |
| text_input = gr.Textbox(label="Введите текст статьи", lines=15, placeholder="Вставьте текст статьи здесь...") | |
| generate_btn = gr.Button("Сгенерировать заголовки", variant="primary") | |
| output = gr.HTML(label="Результаты генерации") | |
| generate_btn.click( | |
| generate_titles, | |
| inputs=[text_input, num_titles, temperature], | |
| outputs=output | |
| ) | |
| # Запуск приложения | |
| if __name__ == "__main__": | |
| interface.launch() |