Ilyakk commited on
Commit
296eba9
·
verified ·
1 Parent(s): df6c3a8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -38
app.py CHANGED
@@ -1,53 +1,64 @@
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import nltk
4
- import torch
5
- import math
6
 
7
- model_name = "Ilyakk/t5-summarization"
8
- tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
10
- nltk.download("punkt")
11
 
12
- def generate_titles(text, num_titles=3, temperature=0.7):
13
- # tokenize text
14
- inputs = ["summarize: " + text]
15
- inputs = tokenizer(inputs, return_tensors="pt")
16
 
17
- num_tokens = len(inputs["input_ids"][0])
18
- max_input_length = 512
19
- num_spans = math.ceil(num_tokens / max_input_length)
20
- overlap = math.ceil((num_spans * max_input_length - num_tokens) / max(num_spans - 1, 1))
21
 
22
- spans_boundaries = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  start = 0
24
  for i in range(num_spans):
25
- spans_boundaries.append([start + max_input_length * i, start + max_input_length * (i + 1)])
 
 
26
  start -= overlap
27
 
28
- spans_boundaries_selected = []
29
- j = 0
30
- for _ in range(num_titles):
31
- spans_boundaries_selected.append(spans_boundaries[j])
32
- j += 1
33
- if j == len(spans_boundaries):
34
- j = 0
35
-
36
- tensor_ids = [inputs["input_ids"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
37
- tensor_masks = [inputs["attention_mask"][0][boundary[0]:boundary[1]] for boundary in spans_boundaries_selected]
38
 
39
- inputs = {
40
- "input_ids": torch.stack(tensor_ids),
41
- "attention_mask": torch.stack(tensor_masks),
42
- }
43
 
44
- outputs = model.generate(**inputs, do_sample=True, temperature=temperature)
45
- decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
46
- predicted_titles = [nltk.sent_tokenize(decoded_output.strip())[0] for decoded_output in decoded_outputs]
 
 
 
 
 
47
 
48
- return predicted_titles
 
 
49
 
50
- # Gradio interface
51
  demo = gr.Interface(
52
  fn=generate_titles,
53
  inputs=[
@@ -56,8 +67,8 @@ demo = gr.Interface(
56
  gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature"),
57
  ],
58
  outputs=gr.List(label="Generated titles"),
59
- title="📰 T5 Title Generator",
60
- description="Generate candidate titles for articles using a fine-tuned T5 model."
61
  )
62
 
63
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
+ import nltk, math, torch
 
 
4
 
5
+ MODEL_ID = "ilyakk/t5-summarization" \
6
+ MAX_INPUT_LEN = 512
 
 
7
 
8
+ \
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
10
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
 
11
 
 
 
 
 
12
 
13
+ try:
14
+ nltk.data.find("tokenizers/punkt")
15
+ except LookupError:
16
+ nltk.download("punkt")
17
+
18
+ def generate_titles(text: str, num_titles: int = 3, temperature: float = 0.7):
19
+ text = (text or "").strip()
20
+ if not text:
21
+ return ["Введите текст статьи выше."]
22
+
23
+
24
+ enc = tokenizer(["summarize: " + text], return_tensors="pt")
25
+ ids = enc["input_ids"][0]
26
+ mask = enc["attention_mask"][0]
27
+
28
+
29
+ num_tokens = len(ids)
30
+ num_spans = max(1, math.ceil(num_tokens / MAX_INPUT_LEN))
31
+ overlap = math.ceil((num_spans * MAX_INPUT_LEN - num_tokens) / max(num_spans - 1, 1)) if num_spans > 1 else 0
32
+
33
+
34
+ spans = []
35
  start = 0
36
  for i in range(num_spans):
37
+ b0 = start + MAX_INPUT_LEN * i
38
+ b1 = start + MAX_INPUT_LEN * (i + 1)
39
+ spans.append([max(0, b0), min(num_tokens, b1)])
40
  start -= overlap
41
 
42
+
43
+ chosen = [spans[i % len(spans)] for i in range(num_titles)]
 
 
 
 
 
 
 
 
44
 
45
+ batch_ids = [ids[b0:b1] for (b0, b1) in chosen]
46
+ batch_mask = [mask[b0:b1] for (b0, b1) in chosen]
47
+ batch = {"input_ids": torch.stack(batch_ids), "attention_mask": torch.stack(batch_mask)}
 
48
 
49
+ with torch.no_grad():
50
+ outputs = model.generate(
51
+ **batch,
52
+ do_sample=True,
53
+ temperature=float(temperature),
54
+ max_length=64,
55
+ num_beams=1
56
+ )
57
 
58
+ decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
59
+ titles = [ (nltk.sent_tokenize(d.strip())[0] if d.strip() else "").strip() for d in decoded ]
60
+ return titles
61
 
 
62
  demo = gr.Interface(
63
  fn=generate_titles,
64
  inputs=[
 
67
  gr.Slider(0.1, 1.5, value=0.7, step=0.05, label="Temperature"),
68
  ],
69
  outputs=gr.List(label="Generated titles"),
70
+ title="T5 Title Generator",
71
+ description="Generate candidate titles for articles using your fine-tuned T5 model."
72
  )
73
 
74
  if __name__ == "__main__":