Spyspook commited on
Commit
06a4696
·
verified ·
1 Parent(s): 6599a18

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +129 -0
app.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import torch
4
+ import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
+
7
+ # ==== настройки ====
8
+ # Можно задать через переменную окружения MODEL_ID в Settings → Repository secrets.
9
+ MODEL_ID = os.environ.get("MODEL_ID", "Spyspook/my-t5-medium-summarizer")
10
+
11
+ MAX_INPUT_LENGTH = 512 # вход в токенах
12
+ MAX_TARGET_LENGTH = 64 # длина заголовка
13
+ DEFAULT_NUM_TITLES = 3
14
+ DEFAULT_TEMPERATURE = 0.7
15
+ DEFAULT_BEAMS = 4 # для стабильности метрик можно 4; для разнообразия ставь do_sample=True
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # ==== загрузка модели/токенизатора один раз ====
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device)
22
+ model.eval()
23
+
24
+ # Простая функция выделения первой фразы без NLTK (чтобы не тянуть ресурсы в Space)
25
+ _SENT_END_RE = re.compile(r"([.!?])\s+")
26
+
27
+ def first_sentence(text: str) -> str:
28
+ text = text.strip()
29
+ if not text:
30
+ return text
31
+ parts = _SENT_END_RE.split(text, maxsplit=1)
32
+ # parts = [before, sep, after] или просто [text]
33
+ if len(parts) >= 2:
34
+ return (parts[0] + parts[1]).strip()
35
+ return text
36
+
37
+ def generate_titles(article_text, num_titles, temperature, beams, do_sample):
38
+ if not article_text or not article_text.strip():
39
+ return []
40
+
41
+ # Префикс для T5
42
+ prefixed = "summarize: " + article_text.strip()
43
+
44
+ # Токенизация и обрезка по контексту
45
+ inputs = tokenizer(
46
+ prefixed,
47
+ return_tensors="pt",
48
+ truncation=True,
49
+ max_length=MAX_INPUT_LENGTH,
50
+ )
51
+ inputs = {k: v.to(device) for k, v in inputs.items()}
52
+
53
+ gen_kwargs = dict(
54
+ max_length=MAX_TARGET_LENGTH,
55
+ num_return_sequences=int(num_titles),
56
+ early_stopping=True,
57
+ )
58
+
59
+ # Логика генерации:
60
+ # - Если do_sample=True → семплирование (temperature, top_p),
61
+ # - иначе — детерминированный beam search.
62
+ if do_sample:
63
+ gen_kwargs.update(
64
+ dict(
65
+ do_sample=True,
66
+ temperature=float(temperature),
67
+ top_p=0.95,
68
+ num_beams=1,
69
+ )
70
+ )
71
+ else:
72
+ gen_kwargs.update(
73
+ dict(
74
+ do_sample=False,
75
+ num_beams=int(beams),
76
+ length_penalty=1.0,
77
+ )
78
+ )
79
+
80
+ with torch.no_grad():
81
+ outputs = model.generate(**inputs, **gen_kwargs)
82
+
83
+ decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
84
+
85
+ # Берем первую фразу и убираем дубликаты, сохраняя порядок
86
+ seen = set()
87
+ titles = []
88
+ for t in decoded:
89
+ t1 = first_sentence(t)
90
+ if t1 and t1 not in seen:
91
+ seen.add(t1)
92
+ titles.append(t1)
93
+
94
+ # Вернем как список списков для удобной таблицы
95
+ return [[t] for t in titles]
96
+
97
+ # ==== интерфейс Gradio ====
98
+ with gr.Blocks() as demo:
99
+ gr.Markdown("## T5 Article Title Generator")
100
+
101
+ with gr.Row():
102
+ text_in = gr.Textbox(
103
+ label="Article text",
104
+ placeholder="Paste article text here…",
105
+ lines=14,
106
+ )
107
+
108
+ with gr.Row():
109
+ num_titles = gr.Slider(1, 10, value=DEFAULT_NUM_TITLES, step=1, label="Number of titles")
110
+ temperature = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature (sampling)")
111
+ with gr.Row():
112
+ beams = gr.Slider(1, 8, value=DEFAULT_BEAMS, step=1, label="Beams (if sampling is OFF)")
113
+ do_sample = gr.Checkbox(value=True, label="Use sampling (ON) / Beam search (OFF)")
114
+
115
+ generate_btn = gr.Button("Generate")
116
+ out_table = gr.Dataframe(headers=["Title"], row_count=(0, "dynamic"), wrap=True)
117
+
118
+ generate_btn.click(
119
+ fn=generate_titles,
120
+ inputs=[text_in, num_titles, temperature, beams, do_sample],
121
+ outputs=out_table,
122
+ api_name="generate",
123
+ )
124
+
125
+ # Для HF Spaces достаточно экспортировать переменную приложения
126
+ app = demo
127
+
128
+ if __name__ == "__main__":
129
+ demo.launch()