t5_demo / app.py
Степан
fix
604a021
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import nltk
import math
import torch
import re
model_name = "AGIvan/t5-base-title-generation"
max_input_length = 512
def extract_first_sentence(text: str) -> str:
"""Извлекает первое предложение из текста без NLTK."""
text = text.strip()
if not text:
return text
# Ищем конец первого предложения: точка/воскл./вопр. знак + пробел или конец строки
match = re.search(r'(.*?[.!?])(?:\s|$)', text)
return match.group(1) if match else text.split('\n')[0]
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
def generate_titles(text, num_titles, temperature):
if not text.strip():
return "Пожалуйста, введите текст для генерации заголовков"
inputs = ["summarize: " + text]
inputs = tokenizer(inputs, return_tensors="pt")
num_tokens = inputs["input_ids"].shape[1]
num_spans = math.ceil(num_tokens / max_input_length)
overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1)) if num_spans > 1 else 0
spans_boundaries = []
start = 0
for i in range(num_spans):
end = start + max_input_length
spans_boundaries.append([start, end])
start = end - overlap
spans_boundaries_selected = []
j = 0
for _ in range(num_titles):
spans_boundaries_selected.append(spans_boundaries[j % len(spans_boundaries)])
j += 1
tensor_ids = []
tensor_masks = []
for boundary in spans_boundaries_selected:
span_ids = inputs["input_ids"][0][boundary[0]:boundary[1]].unsqueeze(0)
span_mask = inputs["attention_mask"][0][boundary[0]:boundary[1]].unsqueeze(0)
tensor_ids.append(span_ids)
tensor_masks.append(span_mask)
model_inputs = {
"input_ids": torch.cat(tensor_ids, dim=0),
"attention_mask": torch.cat(tensor_masks, dim=0)
}
outputs = model.generate(
**model_inputs,
do_sample=True,
temperature=temperature,
max_length=50
)
decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
predicted_titles = []
for output in decoded_outputs:
# Заменяем NLTK на кастомную функцию
title = extract_first_sentence(output)
predicted_titles.append(title)
html_output = "<div style='font-size:16px; padding:10px;'>"
for i, title in enumerate(predicted_titles, 1):
html_output += f"<p><b>Вариант {i}:</b> {title}</p>"
html_output += "</div>"
return html_output
# Создание интерфейса
with gr.Blocks(title="Генератор заголовков статей") as interface:
gr.Markdown("# 🗞 Генератор заголовков для статей")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## ⚙ Настройки модели")
num_titles = gr.Slider(1, 10, value=5, step=1, label="Количество заголовков")
temperature = gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Случайность (temperature)")
gr.Markdown("🔺 Большее значение = более креативные результаты")
with gr.Column(scale=3):
text_input = gr.Textbox(label="Введите текст статьи", lines=15, placeholder="Вставьте текст статьи здесь...")
generate_btn = gr.Button("Сгенерировать заголовки", variant="primary")
output = gr.HTML(label="Результаты генерации")
generate_btn.click(
generate_titles,
inputs=[text_input, num_titles, temperature],
outputs=output
)
# Запуск приложения
if __name__ == "__main__":
interface.launch()