File size: 4,429 Bytes
8a6a35d
 
4cd37a6
8a6a35d
 
4cd37a6
8a6a35d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import os
import re
import torch
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# ==== настройки ====
# Можно задать через переменную окружения MODEL_ID в Settings → Repository secrets.
MODEL_ID = os.environ.get("MODEL_ID", "Spyspook/my-t5-medium-summarizer")

MAX_INPUT_LENGTH = 512     # вход в токенах
MAX_TARGET_LENGTH = 64     # длина заголовка
DEFAULT_NUM_TITLES = 3
DEFAULT_TEMPERATURE = 0.7
DEFAULT_BEAMS = 4          # для стабильности метрик можно 4; для разнообразия ставь do_sample=True

device = "cuda" if torch.cuda.is_available() else "cpu"

# ==== загрузка модели/токенизатора один раз ====
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device)
model.eval()

# Простая функция выделения первой фразы без NLTK (чтобы не тянуть ресурсы в Space)
_SENT_END_RE = re.compile(r"([.!?])\s+")

def first_sentence(text: str) -> str:
    text = text.strip()
    if not text:
        return text
    parts = _SENT_END_RE.split(text, maxsplit=1)
    # parts = [before, sep, after] или просто [text]
    if len(parts) >= 2:
        return (parts[0] + parts[1]).strip()
    return text

def generate_titles(article_text, num_titles, temperature, beams, do_sample):
    if not article_text or not article_text.strip():
        return []

    # Префикс для T5
    prefixed = "summarize: " + article_text.strip()

    # Токенизация и обрезка по контексту
    inputs = tokenizer(
        prefixed,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_INPUT_LENGTH,
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    gen_kwargs = dict(
        max_length=MAX_TARGET_LENGTH,
        num_return_sequences=int(num_titles),
        early_stopping=True,
    )

    # Логика генерации:
    # - Если do_sample=True → семплирование (temperature, top_p),
    # - иначе — детерминированный beam search.
    if do_sample:
        gen_kwargs.update(
            dict(
                do_sample=True,
                temperature=float(temperature),
                top_p=0.95,
                num_beams=1,
            )
        )
    else:
        gen_kwargs.update(
            dict(
                do_sample=False,
                num_beams=int(beams),
                length_penalty=1.0,
            )
        )

    with torch.no_grad():
        outputs = model.generate(**inputs, **gen_kwargs)

    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)

    # Берем первую фразу и убираем дубликаты, сохраняя порядок
    seen = set()
    titles = []
    for t in decoded:
        t1 = first_sentence(t)
        if t1 and t1 not in seen:
            seen.add(t1)
            titles.append(t1)

    # Вернем как список списков для удобной таблицы
    return [[t] for t in titles]

# ==== интерфейс Gradio ====
with gr.Blocks() as demo:
    gr.Markdown("## T5 Article Title Generator")

    with gr.Row():
        text_in = gr.Textbox(
            label="Article text",
            placeholder="Paste article text here…",
            lines=14,
        )

    with gr.Row():
        num_titles = gr.Slider(1, 10, value=DEFAULT_NUM_TITLES, step=1, label="Number of titles")
        temperature = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature (sampling)")
    with gr.Row():
        beams = gr.Slider(1, 8, value=DEFAULT_BEAMS, step=1, label="Beams (if sampling is OFF)")
        do_sample = gr.Checkbox(value=True, label="Use sampling (ON) / Beam search (OFF)")

    generate_btn = gr.Button("Generate")
    out_table = gr.Dataframe(headers=["Title"], row_count=(0, "dynamic"), wrap=True)

    generate_btn.click(
        fn=generate_titles,
        inputs=[text_in, num_titles, temperature, beams, do_sample],
        outputs=out_table,
        api_name="generate",
    )

# Для HF Spaces достаточно экспортировать переменную приложения
app = demo

if __name__ == "__main__":
    demo.launch()