Ilyakk commited on
Commit
24911b1
·
verified ·
1 Parent(s): 6dbe5b7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +45 -20
  2. requirements.txt +1 -1
app.py CHANGED
@@ -1,36 +1,58 @@
 
 
1
  import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
- import nltk, math, torch
4
 
5
- MODEL_ID = "Ilyakk/t5-summarization"
6
- MAX_INPUT_LEN = 512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
 
9
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
10
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
 
 
11
 
12
-
13
- try:
14
- nltk.data.find("tokenizers/punkt")
15
- except LookupError:
16
- nltk.download("punkt")
 
 
 
 
 
 
 
17
 
18
  def generate_titles(text: str, num_titles: int = 3, temperature: float = 0.7):
19
  text = (text or "").strip()
20
  if not text:
21
  return ["Введите текст статьи выше."]
22
 
23
-
24
- enc = tokenizer(["summarize: " + text], return_tensors="pt")
25
- ids = enc["input_ids"][0]
26
  mask = enc["attention_mask"][0]
27
-
28
-
29
  num_tokens = len(ids)
 
30
  num_spans = max(1, math.ceil(num_tokens / MAX_INPUT_LEN))
31
  overlap = math.ceil((num_spans * MAX_INPUT_LEN - num_tokens) / max(num_spans - 1, 1)) if num_spans > 1 else 0
32
 
33
-
34
  spans = []
35
  start = 0
36
  for i in range(num_spans):
@@ -39,24 +61,26 @@ def generate_titles(text: str, num_titles: int = 3, temperature: float = 0.7):
39
  spans.append([max(0, b0), min(num_tokens, b1)])
40
  start -= overlap
41
 
42
-
43
  chosen = [spans[i % len(spans)] for i in range(num_titles)]
44
 
45
- batch_ids = [ids[b0:b1] for (b0, b1) in chosen]
46
  batch_mask = [mask[b0:b1] for (b0, b1) in chosen]
47
- batch = {"input_ids": torch.stack(batch_ids), "attention_mask": torch.stack(batch_mask)}
 
 
 
48
 
49
  with torch.no_grad():
50
  outputs = model.generate(
51
  **batch,
52
  do_sample=True,
53
  temperature=float(temperature),
54
- max_length=64,
55
- num_beams=1
56
  )
57
 
58
  decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
59
- titles = [ (nltk.sent_tokenize(d.strip())[0] if d.strip() else "").strip() for d in decoded ]
60
  return titles
61
 
62
  demo = gr.Interface(
@@ -72,4 +96,5 @@ demo = gr.Interface(
72
  )
73
 
74
  if __name__ == "__main__":
 
75
  demo.launch()
 
1
+ import math
2
+ import torch
3
  import gradio as gr
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
+ import nltk
6
 
7
+
8
+ MODEL_ID = "Ilyakk/t5-summarization"
9
+ MAX_INPUT_LEN = 512
10
+ GEN_MAX_LEN = 64
11
+
12
+ def ensure_nltk():
13
+ try:
14
+ nltk.data.find("tokenizers/punkt")
15
+ except LookupError:
16
+ nltk.download("punkt")
17
+ try:
18
+ nltk.data.find("tokenizers/punkt_tab")
19
+ except LookupError:
20
+ nltk.download("punkt_tab")
21
+
22
+ ensure_nltk()
23
 
24
 
25
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
26
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+ model.to(device)
29
 
30
+ def _first_sentence(text: str) -> str:
31
+ text = (text or "").strip()
32
+ if not text:
33
+ return ""
34
+ try:
35
+ sents = nltk.sent_tokenize(text)
36
+ return sents[0].strip() if sents else text
37
+ except Exception:
38
+ for sep in [".", "!", "?"]:
39
+ if sep in text:
40
+ return text.split(sep)[0].strip()
41
+ return text
42
 
43
  def generate_titles(text: str, num_titles: int = 3, temperature: float = 0.7):
44
  text = (text or "").strip()
45
  if not text:
46
  return ["Введите текст статьи выше."]
47
 
48
+ enc = tokenizer(["summarize: " + text], return_tensors="pt", truncation=False)
49
+ ids = enc["input_ids"][0]
 
50
  mask = enc["attention_mask"][0]
 
 
51
  num_tokens = len(ids)
52
+
53
  num_spans = max(1, math.ceil(num_tokens / MAX_INPUT_LEN))
54
  overlap = math.ceil((num_spans * MAX_INPUT_LEN - num_tokens) / max(num_spans - 1, 1)) if num_spans > 1 else 0
55
 
 
56
  spans = []
57
  start = 0
58
  for i in range(num_spans):
 
61
  spans.append([max(0, b0), min(num_tokens, b1)])
62
  start -= overlap
63
 
 
64
  chosen = [spans[i % len(spans)] for i in range(num_titles)]
65
 
66
+ batch_ids = [ids[b0:b1] for (b0, b1) in chosen]
67
  batch_mask = [mask[b0:b1] for (b0, b1) in chosen]
68
+ batch = {
69
+ "input_ids": torch.stack(batch_ids).to(device),
70
+ "attention_mask": torch.stack(batch_mask).to(device),
71
+ }
72
 
73
  with torch.no_grad():
74
  outputs = model.generate(
75
  **batch,
76
  do_sample=True,
77
  temperature=float(temperature),
78
+ max_length=GEN_MAX_LEN,
79
+ num_beams=1
80
  )
81
 
82
  decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
83
+ titles = [_first_sentence(d) for d in decoded]
84
  return titles
85
 
86
  demo = gr.Interface(
 
96
  )
97
 
98
  if __name__ == "__main__":
99
+
100
  demo.launch()
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  transformers
2
  torch
3
- nltk
4
  gradio
5
  sentencepiece
 
 
1
  transformers
2
  torch
 
3
  gradio
4
  sentencepiece
5
+ nltk>=3.8.1