hrezaei commited on
Commit
4095fba
·
verified ·
1 Parent(s): b5eb10e

Upload t5la_adapter.py

Browse files
Files changed (1) hide show
  1. t5la_adapter.py +360 -0
t5la_adapter.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ from dataclasses import dataclass
3
+ from typing import Optional, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers import T5ForConditionalGeneration, T5Config, Cache
9
+ from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
10
+
11
+
12
+ class T5LaAdapterConfig(T5Config):
13
+ model_type = "t5la_adapter"
14
+ keys_to_ignore_at_inference = ["past_key_values"]
15
+ attribute_map = {
16
+ "hidden_size": "d_model",
17
+ "num_attention_heads": "num_heads",
18
+ "num_hidden_layers": "num_layers",
19
+ "head_dim": "d_kv",
20
+ }
21
+ auto_map = {
22
+ "AutoConfig": "t5la_adapter.T5LaAdapterConfig",
23
+ "AutoModel": "t5la_adapter.T5LaAdapterForConditionalGeneration",
24
+ "AutoModelForSeq2SeqLM": "t5la_adapter.T5LaAdapterForConditionalGeneration",
25
+ "AutoTokenizer": [
26
+ "transformers.T5TokenizerFast",
27
+ "transformers.T5Tokenizer"
28
+ ]
29
+ }
30
+
31
+ def __init__(
32
+ self,
33
+ is_encoder_decoder=True,
34
+ pad_token_id=0,
35
+ eos_token_id=1,
36
+ lookahead_type="la",
37
+ lookahead_size=0,
38
+ freeze_base=True,
39
+ **kwargs,
40
+ ):
41
+ self.lookahead_type = lookahead_type
42
+ self.lookahead_size = lookahead_size
43
+ self.freeze_base = freeze_base
44
+ super().__init__(
45
+ pad_token_id=pad_token_id,
46
+ eos_token_id=eos_token_id,
47
+ is_encoder_decoder=is_encoder_decoder,
48
+ **kwargs,
49
+ )
50
+ self.auto_map = {
51
+ "AutoConfig": "t5la_adapter.T5LaAdapterConfig",
52
+ "AutoModel": "t5la_adapter.T5LaAdapterForConditionalGeneration",
53
+ "AutoModelForSeq2SeqLM": "t5la_adapter.T5LaAdapterForConditionalGeneration",
54
+ "AutoTokenizer": [
55
+ "transformers.T5TokenizerFast",
56
+ "transformers.T5Tokenizer"
57
+ ]
58
+ }
59
+
60
+ @dataclass
61
+ class Seq2SeqLMOutputLA(Seq2SeqLMOutput):
62
+ lookahead_logits: torch.FloatTensor = None
63
+ lookahead_loss: Optional[torch.FloatTensor] = None
64
+ base_loss: Optional[torch.FloatTensor] = None
65
+ decoder_last_hidden_state: Optional[tuple[torch.FloatTensor, ...]] = None
66
+
67
+
68
+ class LookAheadHeads(nn.Module):
69
+ def __init__(self, config: T5LaAdapterConfig, k: int) -> None:
70
+ super().__init__()
71
+ self.k = k
72
+ self.heads = nn.ModuleList(
73
+ [
74
+ # K heads for LA positions:
75
+ nn.Linear(config.d_model, config.vocab_size, bias=False)
76
+ for _ in range(self.k)
77
+ ]
78
+ )
79
+
80
+ def forward(self, x):
81
+ # ModuleList can act as an iterable, or be indexed using ints
82
+ # Apply each head to the shared features
83
+ logits = [head(x) for head in self.heads]
84
+
85
+ # Stack logits along a new dimension to create a tensor of shape [batch_size, num_heads, output_size]
86
+ if self.k > 0:
87
+ logits = torch.stack(logits, dim=1)
88
+ else:
89
+ logits = logits[0]
90
+ return logits
91
+
92
+
93
+ class T5LaAdapterForConditionalGeneration(T5ForConditionalGeneration):
94
+ config_class = T5LaAdapterConfig
95
+ def __init__(self, config: T5LaAdapterConfig):
96
+ super().__init__(config)
97
+ if config.lookahead_type == "la":
98
+ self.la_heads = LookAheadHeads(config, config.lookahead_size)
99
+ elif config.lookahead_type in ["laa", "laa2"]:
100
+ self.la_heads = LookAheadHeads(config, 1)
101
+
102
+ # Freeze all parameters except the new head
103
+ if config.freeze_base:
104
+ for param in self.parameters():
105
+ param.requires_grad = False
106
+ for param in self.la_heads.parameters():
107
+ param.requires_grad = True # unfreeze the extra head
108
+
109
+ def freeze_base(self):
110
+ # Freeze all parameters except the new head
111
+ for param in self.parameters():
112
+ param.requires_grad = False
113
+ for param in self.la_heads.parameters():
114
+ param.requires_grad = True # unfreeze the extra head
115
+
116
+ def forward(
117
+ self,
118
+ input_ids: Optional[torch.LongTensor] = None,
119
+ attention_mask: Optional[torch.FloatTensor] = None,
120
+ decoder_input_ids: Optional[torch.LongTensor] = None,
121
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
122
+ head_mask: Optional[torch.FloatTensor] = None,
123
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
124
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
125
+ encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
126
+ past_key_values: Optional[Cache] = None,
127
+ inputs_embeds: Optional[torch.FloatTensor] = None,
128
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
129
+ labels: Optional[torch.LongTensor] = None,
130
+ use_cache: Optional[bool] = None,
131
+ output_attentions: Optional[bool] = None,
132
+ output_hidden_states: Optional[bool] = None,
133
+ return_dict: Optional[bool] = None,
134
+ cache_position: Optional[torch.LongTensor] = None,
135
+ lookahead_targets: Optional[torch.LongTensor] = None,
136
+ ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutputLA]:
137
+ r"""
138
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
139
+ Indices of input sequence tokens in the vocabulary. T5LA is a model with relative position embeddings so you
140
+ should be able to pad the inputs on both the right and the left.
141
+
142
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
143
+ [`PreTrainedTokenizer.__call__`] for detail.
144
+
145
+ [What are input IDs?](../glossary#input-ids)
146
+
147
+ To know more on how to prepare `input_ids` for pretraining take a look a [T5LA Training](./t5la#training).
148
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
149
+ Indices of decoder input sequence tokens in the vocabulary.
150
+
151
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
152
+ [`PreTrainedTokenizer.__call__`] for details.
153
+
154
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
155
+
156
+ T5LA uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
157
+ is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
158
+
159
+ To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5LA
160
+ Training](./t5la#training).
161
+ decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
162
+ Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
163
+ be used by default.
164
+ decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
165
+ Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
166
+ 1]`:
167
+
168
+ - 1 indicates the head is **not masked**,
169
+ - 0 indicates the head is **masked**.
170
+ cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
171
+ Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
172
+ `[0, 1]`:
173
+
174
+ - 1 indicates the head is **not masked**,
175
+ - 0 indicates the head is **masked**.
176
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
177
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
178
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
179
+ labels in `[0, ..., config.vocab_size]`
180
+ lookahead_targets (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
181
+ Labels for computing the loss of the LA heads or positions (models of type la, laa, and laa2 have
182
+ LA heads and lae has LA positions)
183
+
184
+ Examples:
185
+
186
+ ```python
187
+ >>> from transformers import AutoTokenizer
188
+
189
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
190
+ >>> config = T5LaAdapterConfig.from_pretrained("google-t5/t5-small", lookahead_size=2)
191
+ >>> model = T5LaAdapterForConditionalGeneration.from_pretrained("google-t5/t5-small", config=config)
192
+
193
+ >>> # training
194
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
195
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
196
+ >>> outputs = model(input_ids=input_ids, labels=labels)
197
+ >>> loss = outputs.loss
198
+ >>> logits = outputs.logits
199
+
200
+ >>> # inference
201
+ >>> input_ids = tokenizer(
202
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
203
+ ... ).input_ids # Batch size 1
204
+ >>> outputs = model.generate(input_ids)
205
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
206
+ >>> # studies have shown that owning a dog is good for you.
207
+ ```"""
208
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
209
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
210
+
211
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
212
+ if head_mask is not None and decoder_head_mask is None:
213
+ if self.config.num_layers == self.config.num_decoder_layers:
214
+ warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
215
+ decoder_head_mask = head_mask
216
+
217
+ # Encode if needed (training, first prediction pass)
218
+ if encoder_outputs is None:
219
+ # Convert encoder inputs in embeddings if needed
220
+ encoder_outputs = self.encoder(
221
+ input_ids=input_ids,
222
+ attention_mask=attention_mask,
223
+ inputs_embeds=inputs_embeds,
224
+ head_mask=head_mask,
225
+ output_attentions=output_attentions,
226
+ output_hidden_states=output_hidden_states,
227
+ return_dict=return_dict,
228
+ )
229
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
230
+ encoder_outputs = BaseModelOutput(
231
+ last_hidden_state=encoder_outputs[0],
232
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
233
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
234
+ )
235
+
236
+ hidden_states = encoder_outputs[0]
237
+
238
+ if self.model_parallel:
239
+ torch.cuda.set_device(self.decoder.first_device)
240
+
241
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
242
+ # get decoder inputs from shifting lm labels to the right
243
+ decoder_input_ids = self._shift_right(labels)
244
+
245
+ if self.config.lookahead_type == "lae":
246
+ # Extend decoder input with lookahead_size extra positions filled by zero as especial tokens:
247
+ zeros_to_add = torch.zeros(
248
+ decoder_input_ids.shape[0],
249
+ self.config.lookahead_size,
250
+ device=decoder_input_ids.device,
251
+ dtype=decoder_input_ids.dtype,
252
+ )
253
+ decoder_input_ids = torch.cat((decoder_input_ids, zeros_to_add), dim=1)
254
+ if decoder_attention_mask is not None:
255
+ ones_to_add = torch.ones(
256
+ decoder_attention_mask.shape[0],
257
+ self.config.lookahead_size,
258
+ device=decoder_attention_mask.device,
259
+ dtype=decoder_attention_mask.dtype,
260
+ )
261
+ decoder_attention_mask = torch.cat((decoder_attention_mask, ones_to_add), dim=1)
262
+ # Set device for model parallelism
263
+ if self.model_parallel:
264
+ torch.cuda.set_device(self.decoder.first_device)
265
+ hidden_states = hidden_states.to(self.decoder.first_device)
266
+ if decoder_input_ids is not None:
267
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
268
+ if attention_mask is not None:
269
+ attention_mask = attention_mask.to(self.decoder.first_device)
270
+ if decoder_attention_mask is not None:
271
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
272
+
273
+ # Decode
274
+ decoder_outputs = self.decoder(
275
+ input_ids=decoder_input_ids,
276
+ attention_mask=decoder_attention_mask,
277
+ inputs_embeds=decoder_inputs_embeds,
278
+ past_key_values=past_key_values,
279
+ encoder_hidden_states=hidden_states,
280
+ encoder_attention_mask=attention_mask,
281
+ head_mask=decoder_head_mask,
282
+ cross_attn_head_mask=cross_attn_head_mask,
283
+ use_cache=use_cache,
284
+ output_attentions=output_attentions,
285
+ output_hidden_states=output_hidden_states,
286
+ return_dict=return_dict,
287
+ cache_position=cache_position,
288
+ )
289
+
290
+ sequence_output = decoder_outputs[0]
291
+
292
+ # Set device for model parallelism
293
+ if self.model_parallel:
294
+ torch.cuda.set_device(self.encoder.first_device)
295
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
296
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
297
+
298
+ if self.config.tie_word_embeddings:
299
+ # Rescale output before projecting on vocab
300
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
301
+ sequence_output = sequence_output * (self.model_dim**-0.5)
302
+
303
+ lm_logits = self.lm_head(sequence_output)
304
+
305
+ lookahead_logits = None
306
+ if self.config.lookahead_type == "la":
307
+ lookahead_logits = self.la_heads(sequence_output)
308
+ elif self.config.lookahead_type == "laa":
309
+ la_input = torch.repeat_interleave(hidden_states[:, [-1]], self.config.lookahead_size, dim=1)
310
+ lookahead_logits = self.la_heads(la_input)
311
+ elif self.config.lookahead_type == "laa2":
312
+ lookahead_logits = self.la_heads(hidden_states[:, -self.config.lookahead_size :])
313
+ elif self.config.lookahead_type == "lae":
314
+ lookahead_logits = lm_logits[:, -self.config.lookahead_size :].contiguous()
315
+ lm_logits = lm_logits[:, : -self.config.lookahead_size].contiguous()
316
+
317
+ lookahead_loss = None
318
+ loss = None
319
+ base_loss = None
320
+ if labels is not None:
321
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
322
+ # move labels to correct device to enable PP
323
+ labels = labels.to(lm_logits.device)
324
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
325
+ base_loss = loss.clone()
326
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
327
+ if self.config.lookahead_size > 0 and lookahead_targets is not None:
328
+ lookahead_loss = loss_fct(
329
+ lookahead_logits.reshape(-1, lookahead_logits.size(-1)),
330
+ lookahead_targets.view(-1),
331
+ # vocab_size=self.config.vocab_size,
332
+ )
333
+ if self.config.lookahead_type == "la":
334
+ # If we simply add, the loss will be larger than a non-LA T5 model because
335
+ # in a normal T5, the number of tokens is much lower:
336
+ loss = (loss + lookahead_loss) / (1 + self.config.lookahead_size)
337
+ else:
338
+ loss = (loss * lm_logits.shape[1] + lookahead_loss * self.config.lookahead_size) / (
339
+ lm_logits.shape[1] + self.config.lookahead_size
340
+ )
341
+
342
+ if not return_dict:
343
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
344
+ return ((loss,) + output) if loss is not None else output
345
+
346
+ return Seq2SeqLMOutputLA(
347
+ loss=loss,
348
+ base_loss=base_loss,
349
+ logits=lm_logits,
350
+ past_key_values=decoder_outputs.past_key_values,
351
+ decoder_hidden_states=decoder_outputs.hidden_states,
352
+ decoder_last_hidden_state=decoder_outputs.last_hidden_state,
353
+ decoder_attentions=decoder_outputs.attentions,
354
+ cross_attentions=decoder_outputs.cross_attentions,
355
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
356
+ encoder_hidden_states=encoder_outputs.hidden_states,
357
+ encoder_attentions=encoder_outputs.attentions,
358
+ lookahead_logits=lookahead_logits,
359
+ lookahead_loss=lookahead_loss,
360
+ )