Spyspook commited on
Commit
8a6a35d
·
verified ·
1 Parent(s): 4cd37a6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +127 -99
  2. requirements.txt +3 -2
app.py CHANGED
@@ -1,101 +1,129 @@
1
- import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import nltk
4
- import math
5
  import torch
 
 
6
 
7
- model_name = "AGIvan/t5-base-title-generation"
8
- max_input_length = 512
9
-
10
- st.header("Generate candidate titles for articles")
11
-
12
- st_model_load = st.text('Loading title generator model...')
13
-
14
- @st.cache_data()
15
- def load_model():
16
- print("Loading model...")
17
- tokenizer = AutoTokenizer.from_pretrained(model_name)
18
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
19
- nltk.download('punkt')
20
- print("Model loaded!")
21
- return tokenizer, model
22
-
23
- tokenizer, model = load_model()
24
- st.success('Model loaded!')
25
- st_model_load.text("")
26
-
27
- with st.sidebar:
28
- st.header("Model parameters")
29
- if 'num_titles' not in st.session_state:
30
- st.session_state.num_titles = 5
31
- def on_change_num_titles():
32
- st.session_state.num_titles = num_titles
33
- num_titles = st.slider("Number of titles to generate", min_value=1, max_value=10, value=1, step=1, on_change=on_change_num_titles)
34
- if 'temperature' not in st.session_state:
35
- st.session_state.temperature = 0.7
36
- def on_change_temperatures():
37
- st.session_state.temperature = temperature
38
- temperature = st.slider("Temperature", min_value=0.1, max_value=1.5, value=0.6, step=0.05, on_change=on_change_temperatures)
39
- st.markdown("_High temperature means that results are more random_")
40
-
41
- if 'text' not in st.session_state:
42
- st.session_state.text = ""
43
- st_text_area = st.text_area('Text to generate the title for', value=st.session_state.text, height=500)
44
-
45
- def generate_title():
46
- st.session_state.text = st_text_area
47
-
48
- # tokenize text
49
- inputs = ["summarize: " + st_text_area]
50
- inputs = tokenizer(inputs, return_tensors="pt")
51
-
52
- # compute span boundaries
53
- num_tokens = len(inputs["input_ids"][0])
54
- print(f"Input has {num_tokens} tokens")
55
- max_input_length = 512
56
- num_spans = math.ceil(num_tokens / max_input_length)
57
- print(f"Input has {num_spans} spans")
58
- overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
59
- spans_boundaries = []
60
- start = 0
61
- for i in range(num_spans):
62
- spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
63
- start -= overlap
64
- print(f"Span boundaries are {spans_boundaries}")
65
- spans_boundaries_selected = []
66
- j = 0
67
- for _ in range(num_titles):
68
- spans_boundaries_selected.append(spans_boundaries[j])
69
- j += 1
70
- if j == len(spans_boundaries):
71
- j = 0
72
- print(f"Selected span boundaries are {spans_boundaries_selected}")
73
-
74
- # transform input with spans
75
- tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
76
- tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
77
-
78
- inputs = {
79
- "input_ids": torch.stack(tensor_ids),
80
- "attention_mask": torch.stack(tensor_masks)
81
- }
82
-
83
- # compute predictions
84
- outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
85
- decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
86
- predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
87
-
88
- st.session_state.titles = predicted_titles
89
-
90
- # generate title button
91
- st_generate_button = st.button('Generate title', on_click=generate_title)
92
-
93
- # title generation labels
94
- if 'titles' not in st.session_state:
95
- st.session_state.titles = []
96
-
97
- if len(st.session_state.titles) > 0:
98
- with st.container():
99
- st.subheader("Generated titles")
100
- for title in st.session_state.titles:
101
- st.markdown("__" + title + "__")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
 
 
3
  import torch
4
+ import gradio as gr
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
 
7
+ # ==== настройки ====
8
+ # Можно задать через переменную окружения MODEL_ID в Settings → Repository secrets.
9
+ MODEL_ID = os.environ.get("MODEL_ID", "Spyspook/my-t5-medium-summarizer")
10
+
11
+ MAX_INPUT_LENGTH = 512 # вход в токенах
12
+ MAX_TARGET_LENGTH = 64 # длина заголовка
13
+ DEFAULT_NUM_TITLES = 3
14
+ DEFAULT_TEMPERATURE = 0.7
15
+ DEFAULT_BEAMS = 4 # для стабильности метрик можно 4; для разнообразия ставь do_sample=True
16
+
17
+ device = "cuda" if torch.cuda.is_available() else "cpu"
18
+
19
+ # ==== загрузка модели/токенизатора один раз ====
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID).to(device)
22
+ model.eval()
23
+
24
+ # Простая функция выделения первой фразы без NLTK (чтобы не тянуть ресурсы в Space)
25
+ _SENT_END_RE = re.compile(r"([.!?])\s+")
26
+
27
+ def first_sentence(text: str) -> str:
28
+ text = text.strip()
29
+ if not text:
30
+ return text
31
+ parts = _SENT_END_RE.split(text, maxsplit=1)
32
+ # parts = [before, sep, after] или просто [text]
33
+ if len(parts) >= 2:
34
+ return (parts[0] + parts[1]).strip()
35
+ return text
36
+
37
+ def generate_titles(article_text, num_titles, temperature, beams, do_sample):
38
+ if not article_text or not article_text.strip():
39
+ return []
40
+
41
+ # Префикс для T5
42
+ prefixed = "summarize: " + article_text.strip()
43
+
44
+ # Токенизация и обрезка по контексту
45
+ inputs = tokenizer(
46
+ prefixed,
47
+ return_tensors="pt",
48
+ truncation=True,
49
+ max_length=MAX_INPUT_LENGTH,
50
+ )
51
+ inputs = {k: v.to(device) for k, v in inputs.items()}
52
+
53
+ gen_kwargs = dict(
54
+ max_length=MAX_TARGET_LENGTH,
55
+ num_return_sequences=int(num_titles),
56
+ early_stopping=True,
57
+ )
58
+
59
+ # Логика генерации:
60
+ # - Если do_sample=True → семплирование (temperature, top_p),
61
+ # - иначе — детерминированный beam search.
62
+ if do_sample:
63
+ gen_kwargs.update(
64
+ dict(
65
+ do_sample=True,
66
+ temperature=float(temperature),
67
+ top_p=0.95,
68
+ num_beams=1,
69
+ )
70
+ )
71
+ else:
72
+ gen_kwargs.update(
73
+ dict(
74
+ do_sample=False,
75
+ num_beams=int(beams),
76
+ length_penalty=1.0,
77
+ )
78
+ )
79
+
80
+ with torch.no_grad():
81
+ outputs = model.generate(**inputs, **gen_kwargs)
82
+
83
+ decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
84
+
85
+ # Берем первую фразу и убираем дубликаты, сохраняя порядок
86
+ seen = set()
87
+ titles = []
88
+ for t in decoded:
89
+ t1 = first_sentence(t)
90
+ if t1 and t1 not in seen:
91
+ seen.add(t1)
92
+ titles.append(t1)
93
+
94
+ # Вернем как список списков для удобной таблицы
95
+ return [[t] for t in titles]
96
+
97
+ # ==== интерфейс Gradio ====
98
+ with gr.Blocks() as demo:
99
+ gr.Markdown("## T5 Article Title Generator")
100
+
101
+ with gr.Row():
102
+ text_in = gr.Textbox(
103
+ label="Article text",
104
+ placeholder="Paste article text here…",
105
+ lines=14,
106
+ )
107
+
108
+ with gr.Row():
109
+ num_titles = gr.Slider(1, 10, value=DEFAULT_NUM_TITLES, step=1, label="Number of titles")
110
+ temperature = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, step=0.05, label="Temperature (sampling)")
111
+ with gr.Row():
112
+ beams = gr.Slider(1, 8, value=DEFAULT_BEAMS, step=1, label="Beams (if sampling is OFF)")
113
+ do_sample = gr.Checkbox(value=True, label="Use sampling (ON) / Beam search (OFF)")
114
+
115
+ generate_btn = gr.Button("Generate")
116
+ out_table = gr.Dataframe(headers=["Title"], row_count=(0, "dynamic"), wrap=True)
117
+
118
+ generate_btn.click(
119
+ fn=generate_titles,
120
+ inputs=[text_in, num_titles, temperature, beams, do_sample],
121
+ outputs=out_table,
122
+ api_name="generate",
123
+ )
124
+
125
+ # Для HF Spaces достаточно экспортировать переменную приложения
126
+ app = demo
127
+
128
+ if __name__ == "__main__":
129
+ demo.launch()
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- nltk
2
  torch
3
- transformers
 
 
 
 
1
  torch
2
+ transformers
3
+ gradio
4
+ sentencepiece