File size: 4,429 Bytes
8a6a35d 4cd37a6 8a6a35d 4cd37a6 8a6a35d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | import os
import re
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# ==== настройки ====
# Можно задать через переменную окружения MODEL_ID в Settings → Repository secrets.
MODEL_ID = os.environ.get("MODEL_ID", "Spyspook/my-t5-medium-summarizer")
MAX_INPUT_LENGTH = 512 # вход в токенах
MAX_TARGET_LENGTH = 64 # длина заголовка
DEFAULT_NUM_TITLES = 3
DEFAULT_TEMPERATURE = 0.7
DEFAULT_BEAMS = 4 # для стабильности метрик можно 4; для разнообразия ставь do_sample=True
device = "cuda" if torch.cuda.is_available() else "cpu"
# ==== загрузка модели/токенизатора один раз ====
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device)
model.eval()
# Простая функция выделения первой фразы без NLTK (чтобы не тянуть ресурсы в Space)
_SENT_END_RE = re.compile(r"([.!?])\s+")
def first_sentence(text: str) -> str:
text = text.strip()
if not text:
return text
parts = _SENT_END_RE.split(text, maxsplit=1)
# parts = [before, sep, after] или просто [text]
if len(parts) >= 2:
return (parts[0] + parts[1]).strip()
return text
def generate_titles(article_text, num_titles, temperature, beams, do_sample):
if not article_text or not article_text.strip():
return []
# Префикс для T5
prefixed = "summarize: " + article_text.strip()
# Токенизация и обрезка по контексту
inputs = tokenizer(
prefixed,
return_tensors="pt",
truncation=True,
max_length=MAX_INPUT_LENGTH,
)
inputs = {k: v.to(device) for k, v in inputs.items()}
gen_kwargs = dict(
max_length=MAX_TARGET_LENGTH,
num_return_sequences=int(num_titles),
early_stopping=True,
)
# Логика генерации:
# - Если do_sample=True → семплирование (temperature, top_p),
# - иначе — детерминированный beam search.
if do_sample:
gen_kwargs.update(
dict(
do_sample=True,
temperature=float(temperature),
top_p=0.95,
num_beams=1,
)
)
else:
gen_kwargs.update(
dict(
do_sample=False,
num_beams=int(beams),
length_penalty=1.0,
)
)
with torch.no_grad():
outputs = model.generate(**inputs, **gen_kwargs)
decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
# Берем первую фразу и убираем дубликаты, сохраняя порядок
seen = set()
titles = []
for t in decoded:
t1 = first_sentence(t)
if t1 and t1 not in seen:
seen.add(t1)
titles.append(t1)
# Вернем как список списков для удобной таблицы
return [[t] for t in titles]
# ==== интерфейс Gradio ====
with gr.Blocks() as demo:
gr.Markdown("## T5 Article Title Generator")
with gr.Row():
text_in = gr.Textbox(
label="Article text",
placeholder="Paste article text here…",
lines=14,
)
with gr.Row():
num_titles = gr.Slider(1, 10, value=DEFAULT_NUM_TITLES, step=1, label="Number of titles")
temperature = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature (sampling)")
with gr.Row():
beams = gr.Slider(1, 8, value=DEFAULT_BEAMS, step=1, label="Beams (if sampling is OFF)")
do_sample = gr.Checkbox(value=True, label="Use sampling (ON) / Beam search (OFF)")
generate_btn = gr.Button("Generate")
out_table = gr.Dataframe(headers=["Title"], row_count=(0, "dynamic"), wrap=True)
generate_btn.click(
fn=generate_titles,
inputs=[text_in, num_titles, temperature, beams, do_sample],
outputs=out_table,
api_name="generate",
)
# Для HF Spaces достаточно экспортировать переменную приложения
app = demo
if __name__ == "__main__":
demo.launch()
|