IlyaGusev commited on
Commit
69f1662
·
1 Parent(s): d771c97

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +158 -20
README.md CHANGED
@@ -18,48 +18,32 @@ Example model for [Headline generation competition](https://competitions.codalab
18
  #### How to use
19
 
20
  ```python
21
-
22
- model_name = "IlyaGusev/rubert_telegram_headlines"
23
-
24
  from transformers import AutoTokenizer, EncoderDecoderModel
25
 
 
26
  tokenizer = AutoTokenizer.from_pretrained(model_name)
27
-
28
  hg_model = EncoderDecoderModel.from_pretrained(model_name)
29
 
30
  article_text = "..."
31
 
32
  input_ids = tokenizer.prepare_seq2seq_batch(
33
-
34
  [article_text],
35
-
36
  return_tensors="pt",
37
-
38
  padding="max_length",
39
-
40
  truncation=True,
41
-
42
  max_length=256
43
-
44
  )["input_ids"]
45
 
46
  output_ids = hg_model.generate(
47
-
48
  input_ids=input_ids,
49
-
50
  max_length=64,
51
-
52
  no_repeat_ngram_size=3,
53
-
54
  num_beams=10,
55
-
56
  top_p=0.95
57
- )
58
-
59
- headline = tokenizer.decode(output_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)
60
 
 
61
  print(headline)
62
-
63
  ```
64
 
65
  ## Training data
@@ -68,4 +52,158 @@ print(headline)
68
 
69
  ## Training procedure
70
 
71
- TBA
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  #### How to use
19
 
20
  ```python
 
 
 
21
  from transformers import AutoTokenizer, EncoderDecoderModel
22
 
23
+ model_name = "IlyaGusev/rubert_telegram_headlines"
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
25
  hg_model = EncoderDecoderModel.from_pretrained(model_name)
26
 
27
  article_text = "..."
28
 
29
  input_ids = tokenizer.prepare_seq2seq_batch(
 
30
  [article_text],
 
31
  return_tensors="pt",
 
32
  padding="max_length",
 
33
  truncation=True,
 
34
  max_length=256
 
35
  )["input_ids"]
36
 
37
  output_ids = hg_model.generate(
 
38
  input_ids=input_ids,
 
39
  max_length=64,
 
40
  no_repeat_ngram_size=3,
 
41
  num_beams=10,
 
42
  top_p=0.95
43
+ )[0]
 
 
44
 
45
+ headline = tokenizer.decode(output_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
46
  print(headline)
 
47
  ```
48
 
49
  ## Training data
 
52
 
53
  ## Training procedure
54
 
55
+ ```python
56
+ import json
57
+ import os
58
+ import random
59
+ import shutil
60
+ import torch
61
+ from torch.utils.data import Dataset
62
+
63
+ from tqdm import tqdm
64
+ from transformers import BertTokenizer, EncoderDecoderModel, Trainer, TrainingArguments, logging
65
+
66
+
67
+ def convert_to_tensors(
68
+ tokenizer,
69
+ text,
70
+ max_text_tokens_count,
71
+ max_title_tokens_count = None,
72
+ title = None
73
+ ):
74
+ inputs = tokenizer(
75
+ text,
76
+ add_special_tokens=True,
77
+ max_length=max_text_tokens_count,
78
+ padding="max_length",
79
+ truncation=True
80
+ )
81
+ result = {
82
+ "input_ids": torch.tensor(inputs["input_ids"]),
83
+ "attention_mask": torch.tensor(inputs["attention_mask"]),
84
+ }
85
+
86
+ if title is not None:
87
+ outputs = tokenizer(
88
+ title,
89
+ add_special_tokens=True,
90
+ max_length=max_title_tokens_count,
91
+ padding="max_length",
92
+ truncation=True
93
+ )
94
+
95
+ decoder_input_ids = torch.tensor(outputs["input_ids"])
96
+ decoder_attention_mask = torch.tensor(outputs["attention_mask"])
97
+ labels = decoder_input_ids.clone()
98
+ labels[decoder_attention_mask == 0] = -100
99
+ result.update({
100
+ "labels": labels,
101
+ "decoder_input_ids": decoder_input_ids,
102
+ "decoder_attention_mask": decoder_attention_mask
103
+ })
104
+ return result
105
+
106
+
107
+ class GetTitleDataset(Dataset):
108
+ def __init__(
109
+ self,
110
+ original_records,
111
+ sample_rate,
112
+ tokenizer,
113
+ max_text_tokens_count,
114
+ max_title_tokens_count
115
+ ):
116
+ self.original_records = original_records
117
+ self.sample_rate = sample_rate
118
+ self.tokenizer = tokenizer
119
+ self.max_text_tokens_count = max_text_tokens_count
120
+ self.max_title_tokens_count = max_title_tokens_count
121
+ self.records = []
122
+ for record in tqdm(original_records):
123
+ if random.random() > self.sample_rate:
124
+ continue
125
+ tensors = convert_to_tensors(
126
+ tokenizer=tokenizer,
127
+ title=record["title"],
128
+ text=record["text"],
129
+ max_title_tokens_count=self.max_title_tokens_count,
130
+ max_text_tokens_count=self.max_text_tokens_count
131
+ )
132
+ self.records.append(tensors)
133
+
134
+ def __len__(self):
135
+ return len(self.records)
136
+
137
+ def __getitem__(self, index):
138
+ return self.records[index]
139
+
140
+
141
+ def train(
142
+ config_file,
143
+ train_records,
144
+ val_records,
145
+ pretrained_model_path,
146
+ train_sample_rate=1.0,
147
+ val_sample_rate=1.0,
148
+ output_model_path="models",
149
+ checkpoint=None,
150
+ max_text_tokens_count=256,
151
+ max_title_tokens_count=64,
152
+ batch_size=8,
153
+ logging_steps=1000,
154
+ eval_steps=10000,
155
+ save_steps=10000,
156
+ learning_rate=0.00003,
157
+ warmup_steps=2000,
158
+ num_train_epochs=3
159
+ ):
160
+ logging.set_verbosity_info()
161
+ tokenizer = BertTokenizer.from_pretrained(
162
+ pretrained_model_path,
163
+ do_lower_case=False,
164
+ do_basic_tokenize=False,
165
+ strip_accents=False
166
+ )
167
+ train_dataset = GetTitleDataset(
168
+ train_records,
169
+ train_sample_rate,
170
+ tokenizer,
171
+ max_text_tokens_count=max_text_tokens_count,
172
+ max_title_tokens_count=max_title_tokens_count
173
+ )
174
+ val_dataset = GetTitleDataset(
175
+ val_records,
176
+ val_sample_rate,
177
+ tokenizer,
178
+ max_text_tokens_count=max_text_tokens_count,
179
+ max_title_tokens_count=max_title_tokens_count
180
+ )
181
+
182
+ model = EncoderDecoderModel.from_encoder_decoder_pretrained(pretrained_model_path, pretrained_model_path)
183
+ training_args = TrainingArguments(
184
+ output_dir=output_model_path,
185
+ per_device_train_batch_size=batch_size,
186
+ per_device_eval_batch_size=batch_size,
187
+ do_train=True,
188
+ do_eval=True,
189
+ overwrite_output_dir=False,
190
+ logging_steps=logging_steps,
191
+ eval_steps=eval_steps,
192
+ evaluation_strategy="steps",
193
+ save_steps=save_steps,
194
+ learning_rate=learning_rate,
195
+ warmup_steps=warmup_steps,
196
+ num_train_epochs=num_train_epochs,
197
+ max_steps=-1,
198
+ save_total_limit=1,
199
+ )
200
+
201
+ trainer = Trainer(
202
+ model=model,
203
+ args=training_args,
204
+ train_dataset=train_dataset,
205
+ eval_dataset=val_dataset
206
+ )
207
+ trainer.train(checkpoint)
208
+ model.save_pretrained(output_model_path)
209
+ ```