Fraser commited on
Commit
74bc6c7
·
1 Parent(s): 92871f7

passing most tests

Browse files
model/config.py CHANGED
@@ -10,7 +10,7 @@ from model.utils import assertEqual, assertIn
10
  logger = logging.get_logger(__name__)
11
 
12
 
13
- class T5_VAE_Config(PretrainedConfig):
14
  r"""
15
  This is the configuration class to store the configuration of :class:`FlaxT5VAE`.
16
  It is used to instantiate a T5-VAE model according to the specified arguments, defining the model architecture.
@@ -22,8 +22,8 @@ class T5_VAE_Config(PretrainedConfig):
22
  outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
23
 
24
  Arguments:
25
- latent_size (:obj:`int`, `optional`, defaults to 1,000):
26
- Number of dimensions to use for the sequences latent code.
27
  t5_name (:obj:`str`, `optional`, defaults to t5-base):
28
  Name of the Transformer model to use as a decoder.
29
  block_size (:obj:`int`, `optional`, defaults to 60):
@@ -37,8 +37,8 @@ class T5_VAE_Config(PretrainedConfig):
37
  def __init__(
38
  self,
39
  t5_model_name_or_path="t5-base",
40
- n_latent_tokens=5, # set to -1 for full sequence
41
- latent_size=1_000,
42
  vae_encoder_model='',
43
  vae_decoder_model='',
44
  block_size=60,
@@ -51,7 +51,6 @@ class T5_VAE_Config(PretrainedConfig):
51
  num_layers=0,
52
  num_heads=0,
53
  tie_word_embeddings=True,
54
- skip_upsample=False,
55
  **kwargs,
56
  ):
57
  assertIn(vae_encoder_model, VAE_ENCODER_MODELS.keys(), "Unexpected VAE encoder.")
@@ -65,28 +64,35 @@ class T5_VAE_Config(PretrainedConfig):
65
  self.vae_encoder_model = vae_encoder_model
66
  self.vae_decoder_model = vae_decoder_model
67
 
 
68
  assert(n_latent_tokens <= self.set_seq_size, 'Cannot use more latent tokens than input tokens.')
69
- self.latent_size = latent_size
70
  self.n_latent_tokens = n_latent_tokens
71
- self.skip_upsample = skip_upsample
72
 
73
  # T5
74
  if 't5' not in kwargs:
75
  self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
 
76
  if num_layers:
77
  self.t5.num_layers = num_layers
78
  if num_heads:
79
  self.t5.num_heads = num_heads
80
  self.t5.decoder_start_token_id = decoder_start_token_id
81
  self.t5.n_positions = self.set_seq_size
82
- assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
83
  else:
84
- self.t5 = T5Config(**kwargs.pop('t5'))
 
 
 
 
 
 
 
85
 
86
- # misc
87
  self.tie_word_embeddings = tie_word_embeddings
88
  self.t5.tie_word_embeddings = self.tie_word_embeddings
89
- self.use_cache = getattr(self.t5, "use_cache", False)
 
90
 
91
  def to_dict(self):
92
  """
 
10
  logger = logging.get_logger(__name__)
11
 
12
 
13
+ class T5VaeConfig(PretrainedConfig):
14
  r"""
15
  This is the configuration class to store the configuration of :class:`FlaxT5VAE`.
16
  It is used to instantiate a T5-VAE model according to the specified arguments, defining the model architecture.
 
22
  outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
23
 
24
  Arguments:
25
+ latent_token_size (:obj:`int`, `optional`, defaults to 1,000):
26
+ Number of dimensions to use for each latent token.
27
  t5_name (:obj:`str`, `optional`, defaults to t5-base):
28
  Name of the Transformer model to use as a decoder.
29
  block_size (:obj:`int`, `optional`, defaults to 60):
 
37
  def __init__(
38
  self,
39
  t5_model_name_or_path="t5-base",
40
+ n_latent_tokens=6, # set to -1 for full sequence
41
+ latent_token_size=768,
42
  vae_encoder_model='',
43
  vae_decoder_model='',
44
  block_size=60,
 
51
  num_layers=0,
52
  num_heads=0,
53
  tie_word_embeddings=True,
 
54
  **kwargs,
55
  ):
56
  assertIn(vae_encoder_model, VAE_ENCODER_MODELS.keys(), "Unexpected VAE encoder.")
 
64
  self.vae_encoder_model = vae_encoder_model
65
  self.vae_decoder_model = vae_decoder_model
66
 
67
+ self.latent_token_size = latent_token_size
68
  assert(n_latent_tokens <= self.set_seq_size, 'Cannot use more latent tokens than input tokens.')
 
69
  self.n_latent_tokens = n_latent_tokens
 
70
 
71
  # T5
72
  if 't5' not in kwargs:
73
  self.t5 = AutoConfig.from_pretrained(t5_model_name_or_path, cache_dir=cache_dir)
74
+ assertEqual(self.t5.model_type, "t5", "Need t5 model type for transformer_decoder.")
75
  if num_layers:
76
  self.t5.num_layers = num_layers
77
  if num_heads:
78
  self.t5.num_heads = num_heads
79
  self.t5.decoder_start_token_id = decoder_start_token_id
80
  self.t5.n_positions = self.set_seq_size
 
81
  else:
82
+ self.t5 = T5Config(
83
+ num_layers=num_layers, num_heads=num_heads,
84
+ decoder_start_token_id=decoder_start_token_id,
85
+ n_positions=self.set_seq_size, **kwargs
86
+ )
87
+
88
+ if self.t5.d_model < self.latent_token_size:
89
+ raise Exception('Using larger latent token dimension then T5 hidden dimension.')
90
 
91
+ # Add t5 config options
92
  self.tie_word_embeddings = tie_word_embeddings
93
  self.t5.tie_word_embeddings = self.tie_word_embeddings
94
+ for attr in 'vocab_size hidden_size num_attention_heads num_hidden_layers use_cache'.split():
95
+ setattr(self, attr, getattr(self.t5, attr))
96
 
97
  def to_dict(self):
98
  """
model/encoders.py CHANGED
@@ -9,12 +9,12 @@ class Encoder(nn.Module):
9
  '''
10
  Converts N hidden tokens into N seperate latent codes.
11
  '''
12
- latent_size: int
13
  n_latent_tokens: int
14
 
15
  @nn.compact
16
  def __call__(self, encoding):
17
- latent_tokens = nn.Dense(self.latent_size)(encoding)
18
  raw_latent_code = latent_tokens[:, : self.n_latent_tokens, :]
19
  # TODO does this just apply tanh to each latent token? Or across the whole batch
20
  latent_code = jnp.tanh(raw_latent_code)
 
9
  '''
10
  Converts N hidden tokens into N seperate latent codes.
11
  '''
12
+ latent_token_size: int
13
  n_latent_tokens: int
14
 
15
  @nn.compact
16
  def __call__(self, encoding):
17
+ latent_tokens = nn.Dense(self.latent_token_size)(encoding)
18
  raw_latent_code = latent_tokens[:, : self.n_latent_tokens, :]
19
  # TODO does this just apply tanh to each latent token? Or across the whole batch
20
  latent_code = jnp.tanh(raw_latent_code)
model/outputs.py CHANGED
@@ -12,7 +12,7 @@ class TransformerVAE_Output(ModelOutput):
12
  Base class for a Transformer-VAE's outputs.
13
 
14
  Args:
15
- latent_codes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, latent_size)`):
16
  Latent codes representing encoded sequences.
17
  remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
18
  Reconstructed encoder hidden states representing sequences.
 
12
  Base class for a Transformer-VAE's outputs.
13
 
14
  Args:
15
+ latent_codes (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_latent_tokens, latent_token_size)`):
16
  Latent codes representing encoded sequences.
17
  remade_encoder_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, n_tokens, model_dim)`):
18
  Reconstructed encoder hidden states representing sequences.
model/t5_vae.py CHANGED
@@ -13,17 +13,23 @@ from transformers.models.t5.modeling_flax_t5 import FlaxT5ForConditionalGenerati
13
 
14
  from model.vae import VAE
15
  from model.outputs import TransformerVAE_Output
16
- from model.config import T5_VAE_Config
17
 
18
 
19
  @add_start_docstrings("""T5 Model with a `language modeling` head on top converted into a VAE.""")
20
- class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
21
- config: T5_VAE_Config
22
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
23
 
24
  def _get_encoder_module(self):
25
  return self.t5.encoder
26
 
 
 
 
 
 
 
27
  def _get_decoder_module(self):
28
  return self.t5.decoder
29
 
@@ -42,7 +48,6 @@ class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
42
  return_dict=None,
43
  deterministic: bool = True,
44
  ):
45
- # TODO should I use None args when everything has to be computed anyway?
46
  """
47
  Adapted from `FlaxT5ForConditionalGenerationModule`
48
  """
@@ -104,19 +109,19 @@ class FlaxT5_VAE_ForAutoencodingModule(nn.Module):
104
  )
105
 
106
 
107
- class FlaxT5_VAE_PreTrainedModel(FlaxPreTrainedModel):
108
  """
109
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
110
  models.
111
  """
112
 
113
- config_class = T5_VAE_Config
114
  base_model_prefix = "transformer"
115
  module_class: nn.Module = None
116
 
117
  def __init__(
118
  self,
119
- config: T5_VAE_Config,
120
  input_shape: Tuple[int] = (1, 1),
121
  seed: int = 0,
122
  dtype: jnp.dtype = jnp.float32,
@@ -208,19 +213,21 @@ class FlaxT5_VAE_PreTrainedModel(FlaxPreTrainedModel):
208
  decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
209
  decoder_attention_mask = jnp.ones_like(decoder_input_ids)
210
 
211
- def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
 
212
  decoder_module = module._get_decoder_module()
213
  return decoder_module(
214
  decoder_input_ids,
215
  decoder_attention_mask,
 
216
  **kwargs,
217
  )
218
 
219
  init_variables = self.module.init(
220
  jax.random.PRNGKey(0),
221
  decoder_input_ids=decoder_input_ids,
222
- decoder_attention_mask=decoder_attention_mask,
223
  latent_codes=latent_codes,
 
224
  init_cache=True,
225
  method=_decoder_forward, # we only need to call the decoder to init the cache
226
  )
@@ -256,8 +263,8 @@ class FlaxT5_VAE_PreTrainedModel(FlaxPreTrainedModel):
256
  raise NotImplementedError()
257
 
258
 
259
- class FlaxT5_VAE_ForAutoencoding(FlaxT5_VAE_PreTrainedModel):
260
- module_class = FlaxT5_VAE_ForAutoencodingModule
261
 
262
  def __call__(
263
  self,
@@ -308,18 +315,6 @@ class FlaxT5_VAE_ForAutoencoding(FlaxT5_VAE_PreTrainedModel):
308
  params: dict = None,
309
  dropout_rng: PRNGKey = None,
310
  ):
311
- r"""
312
- Returns:
313
-
314
- Example::
315
-
316
- >>> model = FlaxT5_VAE_ForAutoencoding.from_pretrained('t5-small')
317
- >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
318
-
319
- >>> text = "My friends are cool but they eat too many carbs."
320
- >>> inputs = tokenizer(text, max_length=512, return_tensors='jax')
321
- >>> latent_codes = model.encode(**inputs)
322
- """
323
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
324
  output_hidden_states = (
325
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -335,20 +330,9 @@ class FlaxT5_VAE_ForAutoencoding(FlaxT5_VAE_PreTrainedModel):
335
  rngs["dropout"] = dropout_rng
336
 
337
  def _encoder_forward(module, input_ids, attention_mask, **kwargs):
338
- # Encode
339
- encoder_outputs = self.t5.encoder(
340
- input_ids=input_ids,
341
- attention_mask=attention_mask,
342
- output_attentions=output_attentions,
343
- output_hidden_states=output_hidden_states,
344
- return_dict=return_dict,
345
- deterministic=not train,
346
- )
347
-
348
- hidden_states = encoder_outputs[0]
349
-
350
- # Autoencode
351
- return self.vae(hidden_states, kwargs.get('latent_codes'))
352
 
353
  return self.module.apply(
354
  {"params": params or self.params},
@@ -381,7 +365,7 @@ class FlaxT5_VAE_ForAutoencoding(FlaxT5_VAE_PreTrainedModel):
381
 
382
  Example::
383
 
384
- >>> model = FlaxT5_VAE_ForAutoencoding.from_pretrained('t5-small')
385
  >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
386
 
387
  >>> text = "My friends are cool but they eat too many carbs."
@@ -400,10 +384,8 @@ class FlaxT5_VAE_ForAutoencoding(FlaxT5_VAE_PreTrainedModel):
400
  )
401
  return_dict = return_dict if return_dict is not None else self.config.return_dict
402
 
403
- # TODO match latent_codes to encoder hidden states size
404
- encoder_hidden_states = latent_codes
405
  if encoder_attention_mask is None:
406
- batch_size, sequence_length = encoder_hidden_states.shape[:2]
407
  encoder_attention_mask = jnp.ones((batch_size, sequence_length))
408
 
409
  batch_size, sequence_length = decoder_input_ids.shape
@@ -426,14 +408,15 @@ class FlaxT5_VAE_ForAutoencoding(FlaxT5_VAE_PreTrainedModel):
426
  else:
427
  mutable = False
428
 
429
- def _decoder_forward(module, decoder_input_ids, decoder_attention_mask, **kwargs):
 
430
  decoder_module = module._get_decoder_module()
431
  decoder_outputs = decoder_module(
432
  decoder_input_ids,
433
  decoder_attention_mask,
 
434
  **kwargs,
435
  )
436
-
437
  sequence_output = decoder_outputs[0]
438
 
439
  if self.config.tie_word_embeddings:
@@ -442,18 +425,18 @@ class FlaxT5_VAE_ForAutoencoding(FlaxT5_VAE_PreTrainedModel):
442
  sequence_output = sequence_output * (self.config.d_model ** -0.5)
443
 
444
  if self.config.tie_word_embeddings:
445
- shared_embedding = module.shared.variables["params"]["embedding"]
446
- lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
447
  else:
448
- lm_logits = module.lm_head(sequence_output)
449
 
450
  return lm_logits, decoder_outputs
451
 
452
  outputs = self.module.apply(
453
  inputs,
454
  decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
 
455
  decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
456
- encoder_hidden_states=encoder_hidden_states,
457
  encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
458
  output_attentions=output_attentions,
459
  output_hidden_states=output_hidden_states,
 
13
 
14
  from model.vae import VAE
15
  from model.outputs import TransformerVAE_Output
16
+ from model.config import T5VaeConfig
17
 
18
 
19
  @add_start_docstrings("""T5 Model with a `language modeling` head on top converted into a VAE.""")
20
+ class FlaxT5VaeForAutoencodingModule(nn.Module):
21
+ config: T5VaeConfig
22
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
23
 
24
  def _get_encoder_module(self):
25
  return self.t5.encoder
26
 
27
+ def _get_vae_encoder_module(self):
28
+ return self.vae.encoder
29
+
30
+ def _get_vae_decoder_module(self):
31
+ return self.vae.decoder
32
+
33
  def _get_decoder_module(self):
34
  return self.t5.decoder
35
 
 
48
  return_dict=None,
49
  deterministic: bool = True,
50
  ):
 
51
  """
52
  Adapted from `FlaxT5ForConditionalGenerationModule`
53
  """
 
109
  )
110
 
111
 
112
+ class FlaxT5VaePreTrainedModel(FlaxPreTrainedModel):
113
  """
114
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
115
  models.
116
  """
117
 
118
+ config_class = T5VaeConfig
119
  base_model_prefix = "transformer"
120
  module_class: nn.Module = None
121
 
122
  def __init__(
123
  self,
124
+ config: T5VaeConfig,
125
  input_shape: Tuple[int] = (1, 1),
126
  seed: int = 0,
127
  dtype: jnp.dtype = jnp.float32,
 
213
  decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
214
  decoder_attention_mask = jnp.ones_like(decoder_input_ids)
215
 
216
+ def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs):
217
+ vae_decoder_module = module._get_vae_decoder_module()
218
  decoder_module = module._get_decoder_module()
219
  return decoder_module(
220
  decoder_input_ids,
221
  decoder_attention_mask,
222
+ encoder_hidden_states=vae_decoder_module(latent_codes),
223
  **kwargs,
224
  )
225
 
226
  init_variables = self.module.init(
227
  jax.random.PRNGKey(0),
228
  decoder_input_ids=decoder_input_ids,
 
229
  latent_codes=latent_codes,
230
+ decoder_attention_mask=decoder_attention_mask,
231
  init_cache=True,
232
  method=_decoder_forward, # we only need to call the decoder to init the cache
233
  )
 
263
  raise NotImplementedError()
264
 
265
 
266
+ class FlaxT5VaeForAutoencoding(FlaxT5VaePreTrainedModel):
267
+ module_class = FlaxT5VaeForAutoencodingModule
268
 
269
  def __call__(
270
  self,
 
315
  params: dict = None,
316
  dropout_rng: PRNGKey = None,
317
  ):
 
 
 
 
 
 
 
 
 
 
 
 
318
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
319
  output_hidden_states = (
320
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
330
  rngs["dropout"] = dropout_rng
331
 
332
  def _encoder_forward(module, input_ids, attention_mask, **kwargs):
333
+ encode_module = module._get_encoder_module()
334
+ vae_encoder_module = module._get_vae_encoder_module()
335
+ return vae_encoder_module(encode_module(input_ids, attention_mask, **kwargs)[0])
 
 
 
 
 
 
 
 
 
 
 
336
 
337
  return self.module.apply(
338
  {"params": params or self.params},
 
365
 
366
  Example::
367
 
368
+ >>> model = FlaxT5VaeForAutoencoding.from_pretrained('t5-small')
369
  >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
370
 
371
  >>> text = "My friends are cool but they eat too many carbs."
 
384
  )
385
  return_dict = return_dict if return_dict is not None else self.config.return_dict
386
 
 
 
387
  if encoder_attention_mask is None:
388
+ batch_size, sequence_length = latent_codes.shape[:2]
389
  encoder_attention_mask = jnp.ones((batch_size, sequence_length))
390
 
391
  batch_size, sequence_length = decoder_input_ids.shape
 
408
  else:
409
  mutable = False
410
 
411
+ def _decoder_forward(module, decoder_input_ids, latent_codes, decoder_attention_mask, **kwargs):
412
+ vae_decoder_module = module._get_vae_decoder_module()
413
  decoder_module = module._get_decoder_module()
414
  decoder_outputs = decoder_module(
415
  decoder_input_ids,
416
  decoder_attention_mask,
417
+ encoder_hidden_states=vae_decoder_module(latent_codes),
418
  **kwargs,
419
  )
 
420
  sequence_output = decoder_outputs[0]
421
 
422
  if self.config.tie_word_embeddings:
 
425
  sequence_output = sequence_output * (self.config.d_model ** -0.5)
426
 
427
  if self.config.tie_word_embeddings:
428
+ shared_embedding = module.t5.shared.variables["params"]["embedding"]
429
+ lm_logits = module.t5.lm_head.apply({"params": {"kernel": shared_embedding.T}}, sequence_output)
430
  else:
431
+ lm_logits = module.t5.lm_head(sequence_output)
432
 
433
  return lm_logits, decoder_outputs
434
 
435
  outputs = self.module.apply(
436
  inputs,
437
  decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
438
+ latent_codes=latent_codes,
439
  decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
 
440
  encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
441
  output_attentions=output_attentions,
442
  output_hidden_states=output_hidden_states,
model/vae.py CHANGED
@@ -3,7 +3,7 @@ import flax.linen as nn
3
 
4
  from model.encoders import VAE_ENCODER_MODELS
5
  from model.decoders import VAE_DECODER_MODELS
6
- from model.config import T5_VAE_Config
7
 
8
 
9
  class VAE(nn.Module):
@@ -12,11 +12,11 @@ class VAE(nn.Module):
12
  An MMD-VAE used with encoder-decoder models.
13
  Encodes all token encodings into a single latent & spits them back out.
14
  """
15
- config: T5_VAE_Config
16
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
17
 
18
  def setup(self):
19
- self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_size, self.config.n_latent_tokens)
20
  self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens)
21
 
22
  def __call__(self, encoding=None, latent_codes=None):
 
3
 
4
  from model.encoders import VAE_ENCODER_MODELS
5
  from model.decoders import VAE_DECODER_MODELS
6
+ from model.config import T5VaeConfig
7
 
8
 
9
  class VAE(nn.Module):
 
12
  An MMD-VAE used with encoder-decoder models.
13
  Encodes all token encodings into a single latent & spits them back out.
14
  """
15
+ config: T5VaeConfig
16
  dtype: jnp.dtype = jnp.float32 # the dtype of the computation
17
 
18
  def setup(self):
19
+ self.encoder = VAE_ENCODER_MODELS[self.config.vae_encoder_model](self.config.latent_token_size, self.config.n_latent_tokens)
20
  self.decoder = VAE_DECODER_MODELS[self.config.vae_decoder_model](self.config.t5.d_model, self.config.n_latent_tokens)
21
 
22
  def __call__(self, encoding=None, latent_codes=None):
run_clm_flax.py CHANGED
@@ -405,7 +405,7 @@ def main():
405
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
406
  for k, t in concatenated_examples.items()
407
  }
408
- result["label"] = result["input_ids"].copy()
409
  return result
410
 
411
  # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
@@ -421,8 +421,6 @@ def main():
421
  num_proc=data_args.preprocessing_num_workers,
422
  load_from_cache_file=not data_args.overwrite_cache,
423
  )
424
- import pdb
425
- pdb.set_trace()
426
 
427
  if training_args.do_train:
428
  if "train" not in tokenized_datasets:
@@ -624,7 +622,6 @@ def main():
624
 
625
  # Save metrics
626
  if has_tensorboard and jax.process_index() == 0:
627
- cur_step = epoch * (len(train_dataset) // train_batch_size)
628
  write_eval_metric(summary_writer, eval_metrics, cur_step)
629
 
630
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
 
405
  k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
406
  for k, t in concatenated_examples.items()
407
  }
408
+ result["labels"] = result["input_ids"].copy()
409
  return result
410
 
411
  # Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
 
421
  num_proc=data_args.preprocessing_num_workers,
422
  load_from_cache_file=not data_args.overwrite_cache,
423
  )
 
 
424
 
425
  if training_args.do_train:
426
  if "train" not in tokenized_datasets:
 
622
 
623
  # Save metrics
624
  if has_tensorboard and jax.process_index() == 0:
 
625
  write_eval_metric(summary_writer, eval_metrics, cur_step)
626
 
627
  if cur_step % training_args.save_steps == 0 and cur_step > 0:
tests/__init__.py ADDED
File without changes
tests/test_configuration_common.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+
17
+ import json
18
+ import os
19
+ import tempfile
20
+ import unittest
21
+
22
+ from huggingface_hub import HfApi
23
+ from requests.exceptions import HTTPError
24
+ from transformers import BertConfig, GPT2Config
25
+ from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
26
+
27
+
28
+ class ConfigTester(object):
29
+ def __init__(self, parent, config_class=None, has_text_modality=True, **kwargs):
30
+ self.parent = parent
31
+ self.config_class = config_class
32
+ self.has_text_modality = has_text_modality
33
+ self.inputs_dict = kwargs
34
+
35
+ def create_and_test_config_common_properties(self):
36
+ config = self.config_class(**self.inputs_dict)
37
+ if self.has_text_modality:
38
+ self.parent.assertTrue(hasattr(config, "vocab_size"))
39
+ self.parent.assertTrue(hasattr(config, "hidden_size"))
40
+ self.parent.assertTrue(hasattr(config, "num_attention_heads"))
41
+ self.parent.assertTrue(hasattr(config, "num_hidden_layers"))
42
+
43
+ def create_and_test_config_to_json_string(self):
44
+ config = self.config_class(**self.inputs_dict)
45
+ obj = json.loads(config.to_json_string())
46
+ for key, value in self.inputs_dict.items():
47
+ self.parent.assertEqual(obj[key], value)
48
+
49
+ def create_and_test_config_to_json_file(self):
50
+ config_first = self.config_class(**self.inputs_dict)
51
+
52
+ with tempfile.TemporaryDirectory() as tmpdirname:
53
+ json_file_path = os.path.join(tmpdirname, "config.json")
54
+ config_first.to_json_file(json_file_path)
55
+ config_second = self.config_class.from_json_file(json_file_path)
56
+
57
+ self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
58
+
59
+ def create_and_test_config_from_and_save_pretrained(self):
60
+ config_first = self.config_class(**self.inputs_dict)
61
+
62
+ with tempfile.TemporaryDirectory() as tmpdirname:
63
+ config_first.save_pretrained(tmpdirname)
64
+ config_second = self.config_class.from_pretrained(tmpdirname)
65
+
66
+ self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
67
+
68
+ def create_and_test_config_with_num_labels(self):
69
+ config = self.config_class(**self.inputs_dict, num_labels=5)
70
+ self.parent.assertEqual(len(config.id2label), 5)
71
+ self.parent.assertEqual(len(config.label2id), 5)
72
+
73
+ config.num_labels = 3
74
+ self.parent.assertEqual(len(config.id2label), 3)
75
+ self.parent.assertEqual(len(config.label2id), 3)
76
+
77
+ def check_config_can_be_init_without_params(self):
78
+ if self.config_class.is_composition:
79
+ return
80
+ config = self.config_class()
81
+ self.parent.assertIsNotNone(config)
82
+
83
+ def run_common_tests(self):
84
+ self.create_and_test_config_common_properties()
85
+ self.create_and_test_config_to_json_string()
86
+ self.create_and_test_config_to_json_file()
87
+ self.create_and_test_config_from_and_save_pretrained()
88
+ self.create_and_test_config_with_num_labels()
89
+ self.check_config_can_be_init_without_params()
90
+
91
+
92
+ @is_staging_test
93
+ class ConfigPushToHubTester(unittest.TestCase):
94
+ @classmethod
95
+ def setUpClass(cls):
96
+ cls._api = HfApi(endpoint=ENDPOINT_STAGING)
97
+ cls._token = cls._api.login(username=USER, password=PASS)
98
+
99
+ @classmethod
100
+ def tearDownClass(cls):
101
+ try:
102
+ cls._api.delete_repo(token=cls._token, name="test-config")
103
+ except HTTPError:
104
+ pass
105
+
106
+ try:
107
+ cls._api.delete_repo(token=cls._token, name="test-config-org", organization="valid_org")
108
+ except HTTPError:
109
+ pass
110
+
111
+ def test_push_to_hub(self):
112
+ config = BertConfig(
113
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
114
+ )
115
+ with tempfile.TemporaryDirectory() as tmp_dir:
116
+ config.save_pretrained(os.path.join(tmp_dir, "test-config"), push_to_hub=True, use_auth_token=self._token)
117
+
118
+ new_config = BertConfig.from_pretrained(f"{USER}/test-config")
119
+ for k, v in config.__dict__.items():
120
+ if k != "transformers_version":
121
+ self.assertEqual(v, getattr(new_config, k))
122
+
123
+ def test_push_to_hub_in_organization(self):
124
+ config = BertConfig(
125
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
126
+ )
127
+
128
+ with tempfile.TemporaryDirectory() as tmp_dir:
129
+ config.save_pretrained(
130
+ os.path.join(tmp_dir, "test-config-org"),
131
+ push_to_hub=True,
132
+ use_auth_token=self._token,
133
+ organization="valid_org",
134
+ )
135
+
136
+ new_config = BertConfig.from_pretrained("valid_org/test-config-org")
137
+ for k, v in config.__dict__.items():
138
+ if k != "transformers_version":
139
+ self.assertEqual(v, getattr(new_config, k))
140
+
141
+
142
+ class ConfigTestUtils(unittest.TestCase):
143
+ def test_config_from_string(self):
144
+ c = GPT2Config()
145
+
146
+ # attempt to modify each of int/float/bool/str config records and verify they were updated
147
+ n_embd = c.n_embd + 1 # int
148
+ resid_pdrop = c.resid_pdrop + 1.0 # float
149
+ scale_attn_weights = not c.scale_attn_weights # bool
150
+ summary_type = c.summary_type + "foo" # str
151
+ c.update_from_string(
152
+ f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
153
+ )
154
+ self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
155
+ self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
156
+ self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
157
+ self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
tests/test_generation_flax_utils.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2021 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+
17
+ import numpy as np
18
+
19
+ from transformers import is_flax_available
20
+ from transformers.testing_utils import require_flax
21
+
22
+
23
+ if is_flax_available():
24
+ import os
25
+
26
+ import jax
27
+ import jax.numpy as jnp
28
+ from jax import jit
29
+
30
+ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
31
+
32
+
33
+ def ids_tensor(shape, vocab_size, rng=None):
34
+ """Creates a random int32 tensor of the shape within the vocab size."""
35
+ if rng is None:
36
+ rng = random.Random()
37
+
38
+ total_dims = 1
39
+ for dim in shape:
40
+ total_dims *= dim
41
+
42
+ values = []
43
+ for _ in range(total_dims):
44
+ values.append(rng.randint(0, vocab_size - 1))
45
+
46
+ output = np.array(values, dtype=jnp.int32).reshape(shape)
47
+
48
+ return output
49
+
50
+
51
+ def random_attention_mask(shape, rng=None):
52
+ attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
53
+ # make sure that at least one token is attended to for each batch
54
+ attn_mask[:, -1] = 1
55
+ return attn_mask
56
+
57
+
58
+ @require_flax
59
+ class FlaxGenerationTesterMixin:
60
+ model_tester = None
61
+ all_generative_model_classes = ()
62
+
63
+ def _get_input_ids_and_config(self):
64
+ config, inputs = self.model_tester.prepare_config_and_inputs_for_common()
65
+
66
+ # cut to half length & take max batch_size 3
67
+ max_batch_size = 2
68
+ sequence_length = inputs["input_ids"].shape[-1] // 2
69
+ input_ids = inputs["input_ids"][:max_batch_size, :sequence_length]
70
+
71
+ attention_mask = jnp.ones_like(input_ids)
72
+ attention_mask = attention_mask[:max_batch_size, :sequence_length]
73
+
74
+ # generate max 5 tokens
75
+ max_length = input_ids.shape[-1] + 5
76
+ if config.eos_token_id is not None and config.pad_token_id is None:
77
+ # hack to allow generate for models such as GPT2 as is done in `generate()`
78
+ config.pad_token_id = config.eos_token_id
79
+ return config, input_ids, attention_mask, max_length
80
+
81
+ def test_greedy_generate(self):
82
+ config, input_ids, _, max_length = self._get_input_ids_and_config()
83
+ config.do_sample = False
84
+ config.max_length = max_length
85
+
86
+ for model_class in self.all_generative_model_classes:
87
+ model = model_class(config)
88
+
89
+ generation_outputs = model.generate(input_ids).sequences
90
+ self.assertEqual(generation_outputs.shape[-1], max_length)
91
+
92
+ jit_generate = jit(model.generate)
93
+ jit_generation_outputs = jit_generate(input_ids).sequences
94
+
95
+ self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
96
+
97
+ def test_sample_generate(self):
98
+ config, input_ids, _, max_length = self._get_input_ids_and_config()
99
+ config.do_sample = True
100
+ config.max_length = max_length
101
+
102
+ for model_class in self.all_generative_model_classes:
103
+ model = model_class(config)
104
+
105
+ generation_outputs = model.generate(input_ids).sequences
106
+ self.assertEqual(generation_outputs.shape[-1], max_length)
107
+
108
+ jit_generate = jit(model.generate)
109
+ jit_generation_outputs = jit_generate(input_ids).sequences
110
+
111
+ self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
112
+
113
+ def test_beam_search_generate(self):
114
+ config, input_ids, _, max_length = self._get_input_ids_and_config()
115
+ config.do_sample = False
116
+ config.max_length = max_length
117
+ config.num_beams = 2
118
+
119
+ for model_class in self.all_generative_model_classes:
120
+ model = model_class(config)
121
+
122
+ generation_outputs = model.generate(input_ids).sequences
123
+ self.assertEqual(generation_outputs.shape[-1], max_length)
124
+
125
+ jit_generate = jit(model.generate)
126
+ jit_generation_outputs = jit_generate(input_ids).sequences
127
+
128
+ self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
129
+
130
+ def test_sample_generate_logits_warper(self):
131
+ config, input_ids, _, max_length = self._get_input_ids_and_config()
132
+ config.do_sample = True
133
+ config.max_length = max_length
134
+ config.temperature = 0.8
135
+ config.top_k = 10
136
+ config.top_p = 0.3
137
+ config.min_length = 1
138
+ config.forced_bos_token_id = 8
139
+ config.forced_eos_token_id = 9
140
+
141
+ for model_class in self.all_generative_model_classes:
142
+ model = model_class(config)
143
+
144
+ generation_outputs = model.generate(input_ids).sequences
145
+ self.assertEqual(generation_outputs.shape[-1], max_length)
146
+
147
+ jit_generate = jit(model.generate)
148
+ jit_generation_outputs = jit_generate(input_ids).sequences
149
+
150
+ self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
151
+
152
+ def test_greedy_generate_logits_warper(self):
153
+ config, input_ids, _, max_length = self._get_input_ids_and_config()
154
+ config.max_length = max_length
155
+ config.min_length = 1
156
+ config.forced_bos_token_id = 8
157
+ config.forced_eos_token_id = 9
158
+
159
+ for model_class in self.all_generative_model_classes:
160
+ model = model_class(config)
161
+
162
+ generation_outputs = model.generate(input_ids).sequences
163
+ self.assertEqual(generation_outputs.shape[-1], max_length)
164
+
165
+ jit_generate = jit(model.generate)
166
+ jit_generation_outputs = jit_generate(input_ids).sequences
167
+
168
+ self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
169
+
170
+ def test_beam_search_generate_logits_warper(self):
171
+ config, input_ids, _, max_length = self._get_input_ids_and_config()
172
+ config.max_length = max_length
173
+ config.num_beams = 2
174
+ config.min_length = 1
175
+ config.forced_bos_token_id = 8
176
+ config.forced_eos_token_id = 9
177
+
178
+ for model_class in self.all_generative_model_classes:
179
+ model = model_class(config)
180
+
181
+ generation_outputs = model.generate(input_ids).sequences
182
+ self.assertEqual(generation_outputs.shape[-1], max_length)
183
+
184
+ jit_generate = jit(model.generate)
185
+ jit_generation_outputs = jit_generate(input_ids).sequences
186
+
187
+ self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
188
+
189
+ def test_greedy_generate_attn_mask(self):
190
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
191
+
192
+ # pad attention mask on the left
193
+ attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
194
+
195
+ config.do_sample = False
196
+ config.max_length = max_length
197
+
198
+ for model_class in self.all_generative_model_classes:
199
+ model = model_class(config)
200
+
201
+ generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
202
+ self.assertEqual(generation_outputs.shape[-1], max_length)
203
+
204
+ jit_generate = jit(model.generate)
205
+ jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
206
+
207
+ self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
208
+
209
+ def test_sample_generate_attn_mask(self):
210
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
211
+
212
+ # pad attention mask on the left
213
+ attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
214
+
215
+ config.do_sample = True
216
+ config.max_length = max_length
217
+
218
+ for model_class in self.all_generative_model_classes:
219
+ model = model_class(config)
220
+
221
+ generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
222
+ self.assertEqual(generation_outputs.shape[-1], max_length)
223
+
224
+ jit_generate = jit(model.generate)
225
+ jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
226
+
227
+ self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
228
+
229
+ def test_beam_search_generate_attn_mask(self):
230
+ config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
231
+
232
+ # pad attention mask on the left
233
+ attention_mask = jax.ops.index_update(attention_mask, (0, 0), 0)
234
+
235
+ config.num_beams = 2
236
+ config.max_length = max_length
237
+
238
+ for model_class in self.all_generative_model_classes:
239
+ model = model_class(config)
240
+
241
+ generation_outputs = model.generate(input_ids, attention_mask=attention_mask).sequences
242
+ self.assertEqual(generation_outputs.shape[-1], max_length)
243
+
244
+ jit_generate = jit(model.generate)
245
+ jit_generation_outputs = jit_generate(input_ids, attention_mask=attention_mask).sequences
246
+
247
+ self.assertListEqual(generation_outputs.tolist(), jit_generation_outputs.tolist())
tests/test_modeling_flax_common.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import copy
16
+ import inspect
17
+ import random
18
+ import tempfile
19
+ import unittest
20
+ from typing import List, Tuple
21
+
22
+ import numpy as np
23
+
24
+ import transformers
25
+ from huggingface_hub import HfApi
26
+ from requests.exceptions import HTTPError
27
+ from transformers import BertConfig, FlaxBertModel, is_flax_available, is_torch_available
28
+ from transformers.models.auto import get_values
29
+ from transformers.testing_utils import (
30
+ ENDPOINT_STAGING,
31
+ PASS,
32
+ USER,
33
+ is_pt_flax_cross_test,
34
+ is_staging_test,
35
+ require_flax,
36
+ slow,
37
+ )
38
+
39
+
40
+ if is_flax_available():
41
+ import os
42
+
43
+ import jax
44
+ import jax.numpy as jnp
45
+ import jaxlib.xla_extension as jax_xla
46
+ from flax.core.frozen_dict import unfreeze
47
+ from flax.traverse_util import flatten_dict
48
+ from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING
49
+ from transformers.modeling_flax_pytorch_utils import (
50
+ convert_pytorch_state_dict_to_flax,
51
+ load_flax_weights_in_pytorch_model,
52
+ )
53
+
54
+ os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
55
+
56
+ if is_torch_available():
57
+ import torch
58
+
59
+
60
+ def _config_zero_init(config):
61
+ configs_no_init = copy.deepcopy(config)
62
+ for key in configs_no_init.__dict__.keys():
63
+ if "_range" in key or "_std" in key or "initializer_factor" in key:
64
+ setattr(configs_no_init, key, 1e-10)
65
+ return configs_no_init
66
+
67
+
68
+ def ids_tensor(shape, vocab_size, rng=None):
69
+ """Creates a random int32 tensor of the shape within the vocab size."""
70
+ if rng is None:
71
+ rng = random.Random()
72
+
73
+ total_dims = 1
74
+ for dim in shape:
75
+ total_dims *= dim
76
+
77
+ values = []
78
+ for _ in range(total_dims):
79
+ values.append(rng.randint(0, vocab_size - 1))
80
+
81
+ output = np.array(values, dtype=jnp.int32).reshape(shape)
82
+
83
+ return output
84
+
85
+
86
+ def floats_tensor(shape, scale=1.0, rng=None, name=None):
87
+ """Creates a random float32 tensor"""
88
+ if rng is None:
89
+ rng = random.Random()
90
+
91
+ total_dims = 1
92
+ for dim in shape:
93
+ total_dims *= dim
94
+
95
+ values = []
96
+ for _ in range(total_dims):
97
+ values.append(rng.random() * scale)
98
+
99
+ return np.array(values, dtype=jnp.float32).reshape(shape)
100
+
101
+
102
+ def random_attention_mask(shape, rng=None):
103
+ attn_mask = ids_tensor(shape, vocab_size=2, rng=rng)
104
+ # make sure that at least one token is attended to for each batch
105
+ attn_mask[:, -1] = 1
106
+ return attn_mask
107
+
108
+
109
+ @require_flax
110
+ class FlaxModelTesterMixin:
111
+ model_tester = None
112
+ all_model_classes = ()
113
+ is_encoder_decoder = False
114
+
115
+ def _prepare_for_class(self, inputs_dict, model_class):
116
+ inputs_dict = copy.deepcopy(inputs_dict)
117
+
118
+ # hack for now until we have AutoModel classes
119
+ if "ForMultipleChoice" in model_class.__name__:
120
+ inputs_dict = {
121
+ k: jnp.broadcast_to(v[:, None], (v.shape[0], self.model_tester.num_choices, v.shape[-1]))
122
+ if isinstance(v, (jax_xla.DeviceArray, np.ndarray))
123
+ else v
124
+ for k, v in inputs_dict.items()
125
+ }
126
+
127
+ return inputs_dict
128
+
129
+ def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
130
+ diff = np.abs((a - b)).max()
131
+ self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).")
132
+
133
+ def test_model_outputs_equivalence(self):
134
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
135
+
136
+ def set_nan_tensor_to_zero(t):
137
+ t[t != t] = 0
138
+ return t
139
+
140
+ def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
141
+ tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
142
+ dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
143
+
144
+ def recursive_check(tuple_object, dict_object):
145
+ if isinstance(tuple_object, (List, Tuple)):
146
+ for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
147
+ recursive_check(tuple_iterable_value, dict_iterable_value)
148
+ elif tuple_object is None:
149
+ return
150
+ else:
151
+ self.assert_almost_equals(
152
+ set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), 1e-5
153
+ )
154
+
155
+ recursive_check(tuple_output, dict_output)
156
+
157
+ for model_class in self.all_model_classes:
158
+ model = model_class(config)
159
+
160
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
161
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
162
+ check_equivalence(model, tuple_inputs, dict_inputs)
163
+
164
+ tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
165
+ dict_inputs = self._prepare_for_class(inputs_dict, model_class)
166
+ check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
167
+
168
+ @is_pt_flax_cross_test
169
+ def test_equivalence_pt_to_flax(self):
170
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
171
+
172
+ for model_class in self.all_model_classes:
173
+ with self.subTest(model_class.__name__):
174
+ # prepare inputs
175
+ prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
176
+ pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
177
+
178
+ # load corresponding PyTorch class
179
+ pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
180
+ pt_model_class = getattr(transformers, pt_model_class_name)
181
+
182
+ pt_model = pt_model_class(config).eval()
183
+ # Flax models don't use the `use_cache` option and cache is not returned as a default.
184
+ # So we disable `use_cache` here for PyTorch model.
185
+ pt_model.config.use_cache = False
186
+ fx_model = model_class(config, dtype=jnp.float32)
187
+
188
+ fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
189
+ fx_model.params = fx_state
190
+
191
+ with torch.no_grad():
192
+ pt_outputs = pt_model(**pt_inputs).to_tuple()
193
+
194
+ fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
195
+ self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
196
+ for fx_output, pt_output in zip(fx_outputs, pt_outputs):
197
+ self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
198
+
199
+ with tempfile.TemporaryDirectory() as tmpdirname:
200
+ pt_model.save_pretrained(tmpdirname)
201
+ fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
202
+
203
+ fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
204
+ self.assertEqual(
205
+ len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
206
+ )
207
+ for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
208
+ self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
209
+
210
+ @is_pt_flax_cross_test
211
+ def test_equivalence_flax_to_pt(self):
212
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
213
+
214
+ for model_class in self.all_model_classes:
215
+ with self.subTest(model_class.__name__):
216
+ # prepare inputs
217
+ prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
218
+ pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
219
+
220
+ # load corresponding PyTorch class
221
+ pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
222
+ pt_model_class = getattr(transformers, pt_model_class_name)
223
+
224
+ pt_model = pt_model_class(config).eval()
225
+ # Flax models don't use the `use_cache` option and cache is not returned as a default.
226
+ # So we disable `use_cache` here for PyTorch model.
227
+ pt_model.config.use_cache = False
228
+ fx_model = model_class(config, dtype=jnp.float32)
229
+
230
+ pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params)
231
+
232
+ # make sure weights are tied in PyTorch
233
+ pt_model.tie_weights()
234
+
235
+ with torch.no_grad():
236
+ pt_outputs = pt_model(**pt_inputs).to_tuple()
237
+
238
+ fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
239
+ self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
240
+
241
+ for fx_output, pt_output in zip(fx_outputs, pt_outputs):
242
+ self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
243
+
244
+ with tempfile.TemporaryDirectory() as tmpdirname:
245
+ fx_model.save_pretrained(tmpdirname)
246
+ pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
247
+
248
+ with torch.no_grad():
249
+ pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
250
+
251
+ self.assertEqual(
252
+ len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
253
+ )
254
+ for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
255
+ self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
256
+
257
+ def test_from_pretrained_save_pretrained(self):
258
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
259
+
260
+ for model_class in self.all_model_classes:
261
+ with self.subTest(model_class.__name__):
262
+ model = model_class(config)
263
+
264
+ prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
265
+ outputs = model(**prepared_inputs_dict).to_tuple()
266
+
267
+ # verify that normal save_pretrained works as expected
268
+ with tempfile.TemporaryDirectory() as tmpdirname:
269
+ model.save_pretrained(tmpdirname)
270
+ model_loaded = model_class.from_pretrained(tmpdirname)
271
+
272
+ outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
273
+ for output_loaded, output in zip(outputs_loaded, outputs):
274
+ self.assert_almost_equals(output_loaded, output, 1e-3)
275
+
276
+ # verify that save_pretrained for distributed training
277
+ # with `params=params` works as expected
278
+ with tempfile.TemporaryDirectory() as tmpdirname:
279
+ model.save_pretrained(tmpdirname, params=model.params)
280
+ model_loaded = model_class.from_pretrained(tmpdirname)
281
+
282
+ outputs_loaded = model_loaded(**prepared_inputs_dict).to_tuple()
283
+ for output_loaded, output in zip(outputs_loaded, outputs):
284
+ self.assert_almost_equals(output_loaded, output, 1e-3)
285
+
286
+ def test_save_load_from_base(self):
287
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
288
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
289
+
290
+ for model_class in self.all_model_classes:
291
+ if model_class == base_class:
292
+ continue
293
+
294
+ model = base_class(config)
295
+ base_params = flatten_dict(unfreeze(model.params))
296
+
297
+ # check that all base model weights are loaded correctly
298
+ with tempfile.TemporaryDirectory() as tmpdirname:
299
+ model.save_pretrained(tmpdirname)
300
+ head_model = model_class.from_pretrained(tmpdirname)
301
+
302
+ base_param_from_head = flatten_dict(unfreeze(head_model.params[head_model.base_model_prefix]))
303
+
304
+ for key in base_param_from_head.keys():
305
+ max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
306
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
307
+
308
+ def test_save_load_to_base(self):
309
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
310
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
311
+
312
+ for model_class in self.all_model_classes:
313
+ if model_class == base_class:
314
+ continue
315
+
316
+ model = model_class(config)
317
+ base_params_from_head = flatten_dict(unfreeze(model.params[model.base_model_prefix]))
318
+
319
+ # check that all base model weights are loaded correctly
320
+ with tempfile.TemporaryDirectory() as tmpdirname:
321
+ model.save_pretrained(tmpdirname)
322
+ base_model = base_class.from_pretrained(tmpdirname)
323
+
324
+ base_params = flatten_dict(unfreeze(base_model.params))
325
+
326
+ for key in base_params_from_head.keys():
327
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
328
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
329
+
330
+ @slow
331
+ def test_jit_compilation(self):
332
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
333
+
334
+ for model_class in self.all_model_classes:
335
+ with self.subTest(model_class.__name__):
336
+ prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
337
+ model = model_class(config)
338
+
339
+ @jax.jit
340
+ def model_jitted(input_ids, attention_mask=None, **kwargs):
341
+ return model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
342
+
343
+ with self.subTest("JIT Enabled"):
344
+ jitted_outputs = model_jitted(**prepared_inputs_dict).to_tuple()
345
+
346
+ with self.subTest("JIT Disabled"):
347
+ with jax.disable_jit():
348
+ outputs = model_jitted(**prepared_inputs_dict).to_tuple()
349
+
350
+ self.assertEqual(len(outputs), len(jitted_outputs))
351
+ for jitted_output, output in zip(jitted_outputs, outputs):
352
+
353
+ self.assertEqual(jitted_output.shape, output.shape)
354
+
355
+ def test_forward_signature(self):
356
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
357
+
358
+ for model_class in self.all_model_classes:
359
+ model = model_class(config)
360
+ signature = inspect.signature(model.__call__)
361
+ # signature.parameters is an OrderedDict => so arg_names order is deterministic
362
+ arg_names = [*signature.parameters.keys()]
363
+
364
+ if model.config.is_encoder_decoder:
365
+ expected_arg_names = [
366
+ "input_ids",
367
+ "attention_mask",
368
+ "decoder_input_ids",
369
+ "decoder_attention_mask",
370
+ ]
371
+ self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
372
+ else:
373
+ expected_arg_names = ["input_ids", "attention_mask"]
374
+ self.assertListEqual(arg_names[:2], expected_arg_names)
375
+
376
+ def test_naming_convention(self):
377
+ for model_class in self.all_model_classes:
378
+ model_class_name = model_class.__name__
379
+ module_class_name = (
380
+ model_class_name[:-5] + "Module" if model_class_name[-5:] == "Model" else model_class_name + "Module"
381
+ )
382
+ bert_modeling_flax_module = __import__(model_class.__module__, fromlist=[module_class_name])
383
+ module_cls = getattr(bert_modeling_flax_module, module_class_name)
384
+
385
+ self.assertIsNotNone(module_cls)
386
+
387
+ def test_hidden_states_output(self):
388
+ def check_hidden_states_output(inputs_dict, config, model_class):
389
+ model = model_class(config)
390
+
391
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
392
+ hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
393
+
394
+ expected_num_layers = getattr(
395
+ self.model_tester, "expected_num_hidden_layers", self.model_tester.num_hidden_layers + 1
396
+ )
397
+ self.assertEqual(len(hidden_states), expected_num_layers)
398
+
399
+ if hasattr(self.model_tester, "encoder_seq_length"):
400
+ seq_length = self.model_tester.encoder_seq_length
401
+ else:
402
+ seq_length = self.model_tester.seq_length
403
+
404
+ self.assertListEqual(
405
+ list(hidden_states[0].shape[-2:]),
406
+ [seq_length, self.model_tester.hidden_size],
407
+ )
408
+
409
+ if config.is_encoder_decoder:
410
+ hidden_states = outputs.decoder_hidden_states
411
+
412
+ self.assertIsInstance(hidden_states, (list, tuple))
413
+ self.assertEqual(len(hidden_states), expected_num_layers)
414
+ seq_len = getattr(self.model_tester, "seq_length", None)
415
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
416
+
417
+ self.assertListEqual(
418
+ list(hidden_states[0].shape[-2:]),
419
+ [decoder_seq_length, self.model_tester.hidden_size],
420
+ )
421
+
422
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
423
+
424
+ for model_class in self.all_model_classes:
425
+ inputs_dict["output_hidden_states"] = True
426
+ check_hidden_states_output(inputs_dict, config, model_class)
427
+
428
+ # check that output_hidden_states also work using config
429
+ del inputs_dict["output_hidden_states"]
430
+ config.output_hidden_states = True
431
+
432
+ check_hidden_states_output(inputs_dict, config, model_class)
433
+
434
+ def test_attention_outputs(self):
435
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
436
+ config.return_dict = True
437
+
438
+ seq_length = getattr(self.model_tester, "seq_length", None)
439
+ decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_length)
440
+ encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_length)
441
+ decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
442
+ encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
443
+
444
+ for model_class in self.all_model_classes:
445
+ inputs_dict["output_attentions"] = True
446
+ inputs_dict["output_hidden_states"] = False
447
+ model = model_class(config)
448
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
449
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
450
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
451
+
452
+ # check that output_attentions also work using config
453
+ del inputs_dict["output_attentions"]
454
+ config.output_attentions = True
455
+ model = model_class(config)
456
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
457
+ attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
458
+ self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
459
+
460
+ self.assertListEqual(
461
+ list(attentions[0].shape[-3:]),
462
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
463
+ )
464
+ out_len = len(outputs)
465
+
466
+ if self.is_encoder_decoder:
467
+ correct_outlen = 5
468
+
469
+ # Question Answering model returns start_logits and end_logits
470
+ if model_class in get_values(FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING):
471
+ correct_outlen += 1 # start_logits and end_logits instead of only 1 output
472
+
473
+ self.assertEqual(out_len, correct_outlen)
474
+
475
+ # decoder attentions
476
+ decoder_attentions = outputs.decoder_attentions
477
+ self.assertIsInstance(decoder_attentions, (list, tuple))
478
+ self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
479
+ self.assertListEqual(
480
+ list(decoder_attentions[0].shape[-3:]),
481
+ [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
482
+ )
483
+
484
+ # cross attentions
485
+ cross_attentions = outputs.cross_attentions
486
+ self.assertIsInstance(cross_attentions, (list, tuple))
487
+ self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
488
+ self.assertListEqual(
489
+ list(cross_attentions[0].shape[-3:]),
490
+ [
491
+ self.model_tester.num_attention_heads,
492
+ decoder_seq_length,
493
+ encoder_key_length,
494
+ ],
495
+ )
496
+
497
+ # Check attention is always last and order is fine
498
+ inputs_dict["output_attentions"] = True
499
+ inputs_dict["output_hidden_states"] = True
500
+ model = model_class(config)
501
+ outputs = model(**self._prepare_for_class(inputs_dict, model_class))
502
+
503
+ if hasattr(self.model_tester, "num_hidden_states_types"):
504
+ added_hidden_states = self.model_tester.num_hidden_states_types
505
+ elif self.is_encoder_decoder:
506
+ added_hidden_states = 2
507
+ else:
508
+ added_hidden_states = 1
509
+ self.assertEqual(out_len + added_hidden_states, len(outputs))
510
+
511
+ self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
512
+ self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
513
+
514
+ self.assertListEqual(
515
+ list(self_attentions[0].shape[-3:]),
516
+ [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
517
+ )
518
+
519
+
520
+ @require_flax
521
+ @is_staging_test
522
+ class FlaxModelPushToHubTester(unittest.TestCase):
523
+ @classmethod
524
+ def setUpClass(cls):
525
+ cls._api = HfApi(endpoint=ENDPOINT_STAGING)
526
+ cls._token = cls._api.login(username=USER, password=PASS)
527
+
528
+ @classmethod
529
+ def tearDownClass(cls):
530
+ try:
531
+ cls._api.delete_repo(token=cls._token, name="test-model-flax")
532
+ except HTTPError:
533
+ pass
534
+
535
+ try:
536
+ cls._api.delete_repo(token=cls._token, name="test-model-flax-org", organization="valid_org")
537
+ except HTTPError:
538
+ pass
539
+
540
+ def test_push_to_hub(self):
541
+ config = BertConfig(
542
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
543
+ )
544
+ model = FlaxBertModel(config)
545
+ with tempfile.TemporaryDirectory() as tmp_dir:
546
+ model.save_pretrained(
547
+ os.path.join(tmp_dir, "test-model-flax"), push_to_hub=True, use_auth_token=self._token
548
+ )
549
+
550
+ new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
551
+
552
+ base_params = flatten_dict(unfreeze(model.params))
553
+ new_params = flatten_dict(unfreeze(new_model.params))
554
+
555
+ for key in base_params.keys():
556
+ max_diff = (base_params[key] - new_params[key]).sum().item()
557
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
558
+
559
+ def test_push_to_hub_in_organization(self):
560
+ config = BertConfig(
561
+ vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
562
+ )
563
+ model = FlaxBertModel(config)
564
+ with tempfile.TemporaryDirectory() as tmp_dir:
565
+ model.save_pretrained(
566
+ os.path.join(tmp_dir, "test-model-flax-org"),
567
+ push_to_hub=True,
568
+ use_auth_token=self._token,
569
+ organization="valid_org",
570
+ )
571
+
572
+ new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
573
+
574
+ base_params = flatten_dict(unfreeze(model.params))
575
+ new_params = flatten_dict(unfreeze(new_model.params))
576
+
577
+ for key in base_params.keys():
578
+ max_diff = (base_params[key] - new_params[key]).sum().item()
579
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
tests/test_t5_vae.py ADDED
@@ -0,0 +1,491 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import unittest
3
+
4
+ import numpy as np
5
+
6
+ from transformers import is_flax_available
7
+ from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
8
+ from transformers.testing_utils import require_flax
9
+
10
+ from tests.test_configuration_common import ConfigTester
11
+ from tests.test_generation_flax_utils import FlaxGenerationTesterMixin
12
+ from tests.test_modeling_flax_common import FlaxModelTesterMixin, ids_tensor
13
+
14
+
15
+ if is_flax_available():
16
+ import os
17
+
18
+ # The slow tests are often failing with OOM error on GPU
19
+ # This makes JAX allocate exactly what is needed on demand, and deallocate memory that is no longer needed
20
+ # but will be slower as stated here https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html
21
+ os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
22
+
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import unfreeze
26
+ from flax.traverse_util import flatten_dict
27
+ from transformers import FLAX_MODEL_MAPPING
28
+ from model.t5_vae import FlaxT5VaeForAutoencoding, T5VaeConfig
29
+
30
+
31
+ class FlaxVaeModelTester:
32
+ def __init__(
33
+ self,
34
+ parent,
35
+ vocab_size=99,
36
+ batch_size=13,
37
+ seq_length=7,
38
+ latent_token_size=10,
39
+ n_latent_tokens=3,
40
+ # For common tests
41
+ is_training=True,
42
+ use_attention_mask=True,
43
+ use_labels=True,
44
+ hidden_size=32,
45
+ num_hidden_layers=5,
46
+ num_attention_heads=4,
47
+ d_ff=37,
48
+ relative_attention_num_buckets=8,
49
+ dropout_rate=0.1,
50
+ initializer_factor=0.002,
51
+ eos_token_id=1,
52
+ pad_token_id=0,
53
+ decoder_start_token_id=0,
54
+ scope=None,
55
+ decoder_layers=None,
56
+ ):
57
+
58
+ self.parent = parent
59
+ self.batch_size = batch_size
60
+ self.latent_token_size = latent_token_size
61
+ self.n_latent_tokens = n_latent_tokens
62
+ # For common tests
63
+ self.seq_length = seq_length
64
+ self.is_training = is_training
65
+ self.use_attention_mask = use_attention_mask
66
+ self.use_labels = use_labels
67
+ self.vocab_size = vocab_size
68
+ self.hidden_size = hidden_size
69
+ self.num_hidden_layers = num_hidden_layers
70
+ self.num_attention_heads = num_attention_heads
71
+ self.d_ff = d_ff
72
+ self.relative_attention_num_buckets = relative_attention_num_buckets
73
+ self.dropout_rate = dropout_rate
74
+ self.initializer_factor = initializer_factor
75
+ self.eos_token_id = eos_token_id
76
+ self.pad_token_id = pad_token_id
77
+ self.decoder_start_token_id = decoder_start_token_id
78
+ self.scope = None
79
+ self.decoder_layers = decoder_layers
80
+
81
+ def prepare_config_and_inputs(self):
82
+ input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
83
+ decoder_input_ids = shift_tokens_right(input_ids, self.pad_token_id, self.pad_token_id)
84
+
85
+ attention_mask = None
86
+ decoder_attention_mask = None
87
+ if self.use_attention_mask:
88
+ attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
89
+ decoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
90
+
91
+ config = T5VaeConfig(
92
+ latent_token_size=self.latent_token_size,
93
+ n_latent_tokens=self.n_latent_tokens,
94
+ vocab_size=self.vocab_size,
95
+ d_model=self.hidden_size,
96
+ block_size=self.seq_length,
97
+ d_ff=self.d_ff,
98
+ d_kv=self.hidden_size // self.num_attention_heads,
99
+ num_layers=self.num_hidden_layers,
100
+ num_decoder_layers=self.decoder_layers,
101
+ num_heads=self.num_attention_heads,
102
+ relative_attention_num_buckets=self.relative_attention_num_buckets,
103
+ dropout_rate=self.dropout_rate,
104
+ initializer_factor=self.initializer_factor,
105
+ eos_token_id=self.eos_token_id,
106
+ bos_token_id=self.pad_token_id,
107
+ pad_token_id=self.pad_token_id,
108
+ decoder_start_token_id=self.decoder_start_token_id,
109
+ )
110
+
111
+ return (
112
+ config,
113
+ input_ids,
114
+ decoder_input_ids,
115
+ attention_mask,
116
+ decoder_attention_mask,
117
+ )
118
+
119
+ def create_and_check_model(
120
+ self,
121
+ config,
122
+ input_ids,
123
+ decoder_input_ids,
124
+ attention_mask,
125
+ decoder_attention_mask,
126
+ ):
127
+ model = FlaxT5VaeForAutoencoding(config=config)
128
+ result = model(
129
+ input_ids=input_ids,
130
+ decoder_input_ids=decoder_input_ids,
131
+ attention_mask=attention_mask,
132
+ decoder_attention_mask=decoder_attention_mask,
133
+ )
134
+ result = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
135
+ decoder_output = result.last_hidden_state
136
+ encoder_output = result.encoder_last_hidden_state
137
+
138
+ self.parent.assertEqual(encoder_output.shape, (self.batch_size, self.seq_length, self.hidden_size))
139
+ self.parent.assertEqual(decoder_output.shape, (self.batch_size, self.decoder_seq_length+1, self.hidden_size))
140
+
141
+ def check_use_cache_forward_with_attn_mask(
142
+ self,
143
+ model_class_name,
144
+ config,
145
+ input_ids,
146
+ decoder_input_ids,
147
+ attention_mask,
148
+ decoder_attention_mask,
149
+ ):
150
+ max_decoder_length = 20
151
+ model = model_class_name(config)
152
+
153
+ latent_codes = model.encode(input_ids)
154
+
155
+ # prevent fully zero'd out attention mask
156
+ decoder_attention_mask = jnp.ones_like(decoder_attention_mask)
157
+
158
+ decoder_attention_mask_cache = jnp.concatenate(
159
+ [
160
+ decoder_attention_mask,
161
+ jnp.zeros((decoder_attention_mask.shape[0], max_decoder_length - decoder_attention_mask.shape[1])),
162
+ ],
163
+ axis=-1,
164
+ )
165
+
166
+ past_key_values = model.init_cache(decoder_input_ids.shape[0], max_decoder_length, encoder_outputs)
167
+
168
+ outputs_cache = model.decode(
169
+ decoder_input_ids[:, :-1],
170
+ latent_codes,
171
+ decoder_attention_mask=decoder_attention_mask_cache,
172
+ past_key_values=past_key_values,
173
+ )
174
+ outputs_cache_next = model.decode(
175
+ decoder_input_ids[:, -1:],
176
+ latent_codes,
177
+ past_key_values=outputs_cache.past_key_values,
178
+ decoder_attention_mask=decoder_attention_mask_cache,
179
+ )
180
+
181
+ outputs = model.decode(decoder_input_ids, latent_codes, decoder_attention_mask=decoder_attention_mask)
182
+
183
+ diff = np.max(np.abs((outputs_cache_next[0][:, -1, :5] - outputs[0][:, -1, :5])))
184
+ self.parent.assertTrue(diff < 1e-3, msg=f"Max diff is {diff}")
185
+
186
+ def prepare_config_and_inputs_for_common(self):
187
+ config_and_inputs = self.prepare_config_and_inputs()
188
+ (
189
+ config,
190
+ input_ids,
191
+ decoder_input_ids,
192
+ attention_mask,
193
+ decoder_attention_mask,
194
+ ) = config_and_inputs
195
+
196
+ inputs_dict = {
197
+ "input_ids": input_ids,
198
+ "attention_mask": attention_mask,
199
+ "decoder_input_ids": decoder_input_ids,
200
+ "decoder_attention_mask": decoder_attention_mask,
201
+ }
202
+ return config, inputs_dict
203
+
204
+
205
+ @require_flax
206
+ class FlaxT5VaeModelTest(FlaxModelTesterMixin, FlaxGenerationTesterMixin, unittest.TestCase):
207
+
208
+ all_model_classes = (FlaxT5VaeForAutoencoding,) if is_flax_available() else ()
209
+ is_encoder_decoder = True
210
+
211
+ def setUp(self):
212
+ self.model_tester = FlaxVaeModelTester(self)
213
+ self.config_tester = ConfigTester(self, config_class=T5VaeConfig, d_model=37)
214
+
215
+ def test_config(self):
216
+ self.config_tester.run_common_tests()
217
+
218
+ def test_model(self):
219
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
220
+ self.model_tester.create_and_check_model(*config_and_inputs)
221
+
222
+ def test_model_v1_1(self):
223
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
224
+ # check that gated gelu feed forward and different word embeddings work
225
+ config = config_and_inputs[0]
226
+ config.tie_word_embeddings = False
227
+ config.feed_forward_proj = "gated-gelu"
228
+ self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
229
+
230
+ def test_use_cache_forward_with_attn_mask(self):
231
+ config_and_inputs = self.model_tester.prepare_config_and_inputs()
232
+ for model_class in self.all_model_classes:
233
+ self.model_tester.check_use_cache_forward_with_attn_mask(model_class, *config_and_inputs)
234
+
235
+ def test_encode(self):
236
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
237
+
238
+ for model_class in self.all_model_classes:
239
+ with self.subTest(model_class.__name__):
240
+ prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
241
+ model = model_class(config)
242
+
243
+ @jax.jit
244
+ def encode_jitted(input_ids, attention_mask=None, **kwargs):
245
+ return model.encode(input_ids=input_ids, attention_mask=attention_mask)
246
+
247
+ with self.subTest("JIT Enabled"):
248
+ jitted_outputs = encode_jitted(**prepared_inputs_dict)
249
+
250
+ with self.subTest("JIT Disabled"):
251
+ with jax.disable_jit():
252
+ outputs = encode_jitted(**prepared_inputs_dict)
253
+
254
+ self.assertEqual(outputs.shape, (inputs_dict['input_ids'].shape[0], config.n_latent_tokens, config.latent_token_size))
255
+ self.assertEqual(jitted_outputs.shape, (inputs_dict['input_ids'].shape[0], config.n_latent_tokens, config.latent_token_size))
256
+
257
+ self.assertEqual(len(outputs), len(jitted_outputs))
258
+ for jitted_output, output in zip(jitted_outputs, outputs):
259
+ self.assertEqual(jitted_output.shape, output.shape)
260
+
261
+ def test_decode(self):
262
+ config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
263
+
264
+ for model_class in self.all_model_classes:
265
+ with self.subTest(model_class.__name__):
266
+ model = model_class(config)
267
+ latent_codes = model.encode(inputs_dict["input_ids"], inputs_dict["attention_mask"])
268
+
269
+ prepared_inputs_dict = {
270
+ "decoder_input_ids": inputs_dict["decoder_input_ids"],
271
+ "decoder_attention_mask": inputs_dict["decoder_attention_mask"],
272
+ "latent_codes": latent_codes,
273
+ }
274
+
275
+ @jax.jit
276
+ def decode_jitted(decoder_input_ids, decoder_attention_mask, latent_codes):
277
+ return model.decode(
278
+ decoder_input_ids=decoder_input_ids,
279
+ latent_codes=latent_codes,
280
+ decoder_attention_mask=decoder_attention_mask,
281
+ )
282
+
283
+ with self.subTest("JIT Enabled"):
284
+ jitted_outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
285
+
286
+ with self.subTest("JIT Disabled"):
287
+ with jax.disable_jit():
288
+ outputs = decode_jitted(**prepared_inputs_dict).to_tuple()
289
+
290
+ self.assertEqual(len(outputs), len(jitted_outputs))
291
+ for jitted_output, output in zip(jitted_outputs, outputs):
292
+ self.assertEqual(jitted_output.shape, output.shape)
293
+
294
+ # overwrite since special base model prefix is used
295
+ def test_save_load_from_base(self):
296
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
297
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
298
+
299
+ for model_class in self.all_model_classes:
300
+ if model_class == base_class:
301
+ continue
302
+
303
+ model = base_class(config)
304
+ base_params = flatten_dict(unfreeze(model.params))
305
+
306
+ # check that all base model weights are loaded correctly
307
+ with tempfile.TemporaryDirectory() as tmpdirname:
308
+ model.save_pretrained(tmpdirname)
309
+ head_model = model_class.from_pretrained(tmpdirname)
310
+
311
+ base_param_from_head = flatten_dict(unfreeze(head_model.params))
312
+
313
+ for key in base_param_from_head.keys():
314
+ max_diff = (base_params[key] - base_param_from_head[key]).sum().item()
315
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
316
+
317
+ # overwrite since special base model prefix is used
318
+ def test_save_load_to_base(self):
319
+ config, _ = self.model_tester.prepare_config_and_inputs_for_common()
320
+ base_class = FLAX_MODEL_MAPPING[config.__class__]
321
+
322
+ for model_class in self.all_model_classes:
323
+ if model_class == base_class:
324
+ continue
325
+
326
+ model = model_class(config)
327
+ base_params_from_head = flatten_dict(unfreeze(model.params))
328
+
329
+ # check that all base model weights are loaded correctly
330
+ with tempfile.TemporaryDirectory() as tmpdirname:
331
+ model.save_pretrained(tmpdirname)
332
+ base_model = base_class.from_pretrained(tmpdirname)
333
+
334
+ base_params = flatten_dict(unfreeze(base_model.params))
335
+
336
+ for key in base_params_from_head.keys():
337
+ max_diff = (base_params[key] - base_params_from_head[key]).sum().item()
338
+ self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
339
+
340
+
341
+ '''
342
+ # Not using for now.
343
+
344
+ @require_sentencepiece
345
+ @require_tokenizers
346
+ @require_flax
347
+ class FlaxT5ModelIntegrationTests(unittest.TestCase):
348
+ @slow
349
+ def test_small_integration_test(self):
350
+ """
351
+ For comparision run:
352
+ >>> import t5 # pip install t5==0.7.1
353
+ >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
354
+
355
+ >>> path_to_mtf_small_t5_checkpoint = '<fill_in>'
356
+ >>> path_to_mtf_small_spm_model_path = '<fill_in>'
357
+ >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None)
358
+ >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
359
+ >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
360
+ """
361
+
362
+ model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small")
363
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
364
+
365
+ input_ids = tokenizer("Hello there", return_tensors="np").input_ids
366
+ labels = tokenizer("Hi I am", return_tensors="np").input_ids
367
+
368
+ decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id)
369
+
370
+ logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits
371
+
372
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
373
+ mtf_score = -(labels.shape[-1] * loss.item())
374
+
375
+ EXPECTED_SCORE = -19.0845
376
+ self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
377
+
378
+ @slow
379
+ def test_small_v1_1_integration_test(self):
380
+ """
381
+ For comparision run:
382
+ >>> import t5 # pip install t5==0.7.1
383
+ >>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
384
+
385
+ >>> path_to_mtf_small_t5_v1_1_checkpoint = '<fill_in>'
386
+ >>> path_to_mtf_small_spm_model_path = '<fill_in>'
387
+ >>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_v1_1_checkpoint, batch_size=1, tpu=None)
388
+ >>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
389
+ >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
390
+ """
391
+
392
+ model = FlaxT5ForConditionalGeneration.from_pretrained("google/t5-v1_1-small")
393
+ tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-small")
394
+
395
+ input_ids = tokenizer("Hello there", return_tensors="np").input_ids
396
+ labels = tokenizer("Hi I am", return_tensors="np").input_ids
397
+
398
+ decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id)
399
+
400
+ logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits
401
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
402
+
403
+ mtf_score = -(labels.shape[-1] * loss.item())
404
+
405
+ EXPECTED_SCORE = -59.0293
406
+ self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
407
+
408
+ @slow
409
+ def test_small_byt5_integration_test(self):
410
+ """
411
+ For comparision run:
412
+ >>> import t5 # pip install t5==0.9.1
413
+
414
+ >>> path_to_byt5_small_checkpoint = '<fill_in>'
415
+ >>> t5_model = t5.models.MtfModel(model_dir=path_to_tf_checkpoint, batch_size=1, tpu=None)
416
+ >>> vocab = t5.data.ByteVocabulary()
417
+ >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
418
+ """
419
+
420
+ model = FlaxT5ForConditionalGeneration.from_pretrained("google/byt5-small")
421
+ tokenizer = ByT5Tokenizer.from_pretrained("google/byt5-small")
422
+
423
+ input_ids = tokenizer("Hello there", return_tensors="np").input_ids
424
+ labels = tokenizer("Hi I am", return_tensors="np").input_ids
425
+
426
+ decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id)
427
+
428
+ logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits
429
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
430
+
431
+ mtf_score = -(labels.shape[-1] * loss.item())
432
+
433
+ EXPECTED_SCORE = -60.7397
434
+ self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
435
+
436
+ @slow
437
+ def test_small_generation(self):
438
+ model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small")
439
+ model.config.max_length = 8
440
+ model.config.num_beams = 1
441
+ model.config.do_sample = False
442
+ tokenizer = T5Tokenizer.from_pretrained("t5-small")
443
+
444
+ input_ids = tokenizer("summarize: Hello there", return_tensors="np").input_ids
445
+
446
+ sequences = model.generate(input_ids).sequences
447
+
448
+ output_str = tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
449
+ self.assertTrue(output_str == "Hello there!")
450
+
451
+ @slow
452
+ def test_summarization(self):
453
+ model = FlaxT5ForConditionalGeneration.from_pretrained("t5-base")
454
+ tok = T5Tokenizer.from_pretrained("t5-base")
455
+
456
+ FRANCE_ARTICLE = 'Marseille, France (CNN)The French prosecutor leading an investigation into the crash of Germanwings Flight 9525 insisted Wednesday that he was not aware of any video footage from on board the plane. Marseille prosecutor Brice Robin told CNN that "so far no videos were used in the crash investigation." He added, "A person who has such a video needs to immediately give it to the investigators." Robin\'s comments follow claims by two magazines, German daily Bild and French Paris Match, of a cell phone video showing the harrowing final seconds from on board Germanwings Flight 9525 as it crashed into the French Alps. All 150 on board were killed. Paris Match and Bild reported that the video was recovered from a phone at the wreckage site. The two publications described the supposed video, but did not post it on their websites. The publications said that they watched the video, which was found by a source close to the investigation. "One can hear cries of \'My God\' in several languages," Paris Match reported. "Metallic banging can also be heard more than three times, perhaps of the pilot trying to open the cockpit door with a heavy object. Towards the end, after a heavy shake, stronger than the others, the screaming intensifies. Then nothing." "It is a very disturbing scene," said Julian Reichelt, editor-in-chief of Bild online. An official with France\'s accident investigation agency, the BEA, said the agency is not aware of any such video. Lt. Col. Jean-Marc Menichini, a French Gendarmerie spokesman in charge of communications on rescue efforts around the Germanwings crash site, told CNN that the reports were "completely wrong" and "unwarranted." Cell phones have been collected at the site, he said, but that they "hadn\'t been exploited yet." Menichini said he believed the cell phones would need to be sent to the Criminal Research Institute in Rosny sous-Bois, near Paris, in order to be analyzed by specialized technicians working hand-in-hand with investigators. But none of the cell phones found so far have been sent to the institute, Menichini said. Asked whether staff involved in the search could have leaked a memory card to the media, Menichini answered with a categorical "no." Reichelt told "Erin Burnett: Outfront" that he had watched the video and stood by the report, saying Bild and Paris Match are "very confident" that the clip is real. He noted that investigators only revealed they\'d recovered cell phones from the crash site after Bild and Paris Match published their reports. "That is something we did not know before. ... Overall we can say many things of the investigation weren\'t revealed by the investigation at the beginning," he said. What was mental state of Germanwings co-pilot? German airline Lufthansa confirmed Tuesday that co-pilot Andreas Lubitz had battled depression years before he took the controls of Germanwings Flight 9525, which he\'s accused of deliberately crashing last week in the French Alps. Lubitz told his Lufthansa flight training school in 2009 that he had a "previous episode of severe depression," the airline said Tuesday. Email correspondence between Lubitz and the school discovered in an internal investigation, Lufthansa said, included medical documents he submitted in connection with resuming his flight training. The announcement indicates that Lufthansa, the parent company of Germanwings, knew of Lubitz\'s battle with depression, allowed him to continue training and ultimately put him in the cockpit. Lufthansa, whose CEO Carsten Spohr previously said Lubitz was 100% fit to fly, described its statement Tuesday as a "swift and seamless clarification" and said it was sharing the information and documents -- including training and medical records -- with public prosecutors. Spohr traveled to the crash site Wednesday, where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside. He saw the crisis center set up in Seyne-les-Alpes, laid a wreath in the village of Le Vernet, closer to the crash site, where grieving families have left flowers at a simple stone memorial. Menichini told CNN late Tuesday that no visible human remains were left at the site but recovery teams would keep searching. French President Francois Hollande, speaking Tuesday, said that it should be possible to identify all the victims using DNA analysis by the end of the week, sooner than authorities had previously suggested. In the meantime, the recovery of the victims\' personal belongings will start Wednesday, Menichini said. Among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board. Check out the latest from our correspondents . The details about Lubitz\'s correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and Lubitz\'s possible motive for downing the jet. A Lufthansa spokesperson told CNN on Tuesday that Lubitz had a valid medical certificate, had passed all his examinations and "held all the licenses required." Earlier, a spokesman for the prosecutor\'s office in Dusseldorf, Christoph Kumpa, said medical records reveal Lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot\'s license. Kumpa emphasized there\'s no evidence suggesting Lubitz was suicidal or acting aggressively before the crash. Investigators are looking into whether Lubitz feared his medical condition would cause him to lose his pilot\'s license, a European government official briefed on the investigation told CNN on Tuesday. While flying was "a big part of his life," the source said, it\'s only one theory being considered. Another source, a law enforcement official briefed on the investigation, also told CNN that authorities believe the primary motive for Lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems. Lubitz\'s girlfriend told investigators he had seen an eye doctor and a neuropsychologist, both of whom deemed him unfit to work recently and concluded he had psychological issues, the European government official said. But no matter what details emerge about his previous mental health struggles, there\'s more to the story, said Brian Russell, a forensic psychologist. "Psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they weren\'t going to keep doing their job and they\'re upset about that and so they\'re suicidal," he said. "But there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person\'s problems." Germanwings crash compensation: What we know . Who was the captain of Germanwings Flight 9525? CNN\'s Margot Haddad reported from Marseille and Pamela Brown from Dusseldorf, while Laura Smith-Spark wrote from London. CNN\'s Frederik Pleitgen, Pamela Boykoff, Antonia Mortensen, Sandrine Amiel and Anna-Maja Rappard contributed to this report.' # @noqa
457
+ SHORTER_ARTICLE = '(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC\'s founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians\' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki, speaking at Wednesday\'s ceremony, said it was a move toward greater justice. "As Palestine formally becomes a State Party to the Rome Statute today, the world is also a step closer to ending a long era of impunity and injustice," he said, according to an ICC news release. "Indeed, today brings us closer to our shared goals of justice and peace." Judge Kuniko Ozaki, a vice president of the ICC, said acceding to the treaty was just the first step for the Palestinians. "As the Rome Statute today enters into force for the State of Palestine, Palestine acquires all the rights as well as responsibilities that come with being a State Party to the Statute. These are substantive commitments, which cannot be taken lightly," she said. Rights group Human Rights Watch welcomed the development. "Governments seeking to penalize Palestine for joining the ICC should immediately end their pressure, and countries that support universal acceptance of the court\'s treaty should speak out to welcome its membership," said Balkees Jarrah, international justice counsel for the group. "What\'s objectionable is the attempts to undermine international justice, not Palestine\'s decision to join a treaty to which over 100 countries around the world are members." In January, when the preliminary ICC examination was opened, Israeli Prime Minister Benjamin Netanyahu described it as an outrage, saying the court was overstepping its boundaries. The United States also said it "strongly" disagreed with the court\'s decision. "As we have said repeatedly, we do not believe that Palestine is a state and therefore we do not believe that it is eligible to join the ICC," the State Department said in a statement. It urged the warring sides to resolve their differences through direct negotiations. "We will continue to oppose actions against Israel at the ICC as counterproductive to the cause of peace," it said. But the ICC begs to differ with the definition of a state for its purposes and refers to the territories as "Palestine." While a preliminary examination is not a formal investigation, it allows the court to review evidence and determine whether to investigate suspects on both sides. Prosecutor Fatou Bensouda said her office would "conduct its analysis in full independence and impartiality." The war between Israel and Hamas militants in Gaza last summer left more than 2,000 people dead. The inquiry will include alleged war crimes committed since June. The International Criminal Court was set up in 2002 to prosecute genocide, crimes against humanity and war crimes. CNN\'s Vasco Cotovio, Kareem Khadder and Faith Karimi contributed to this report.'
458
+ IRAN_ARTICLE = "(CNN)The United States and its negotiating partners reached a very strong framework agreement with Iran in Lausanne, Switzerland, on Thursday that limits Iran's nuclear program in such a way as to effectively block it from building a nuclear weapon. Expect pushback anyway, if the recent past is any harbinger. Just last month, in an attempt to head off such an agreement, House Speaker John Boehner invited Israeli Prime Minister Benjamin Netanyahu to preemptively blast it before Congress, and 47 senators sent a letter to the Iranian leadership warning them away from a deal. The debate that has already begun since the announcement of the new framework will likely result in more heat than light. It will not be helped by the gathering swirl of dubious assumptions and doubtful assertions. Let us address some of these: . The most misleading assertion, despite universal rejection by experts, is that the negotiations' objective at the outset was the total elimination of any nuclear program in Iran. That is the position of Netanyahu and his acolytes in the U.S. Congress. But that is not and never was the objective. If it had been, there would have been no Iranian team at the negotiating table. Rather, the objective has always been to structure an agreement or series of agreements so that Iran could not covertly develop a nuclear arsenal before the United States and its allies could respond. The new framework has exceeded expectations in achieving that goal. It would reduce Iran's low-enriched uranium stockpile, cut by two-thirds its number of installed centrifuges and implement a rigorous inspection regime. Another dubious assumption of opponents is that the Iranian nuclear program is a covert weapons program. Despite sharp accusations by some in the United States and its allies, Iran denies having such a program, and U.S. intelligence contends that Iran has not yet made the decision to build a nuclear weapon. Iran's continued cooperation with International Atomic Energy Agency inspections is further evidence on this point, and we'll know even more about Iran's program in the coming months and years because of the deal. In fact, the inspections provisions that are part of this agreement are designed to protect against any covert action by the Iranians. What's more, the rhetoric of some members of Congress has implied that the negotiations have been between only the United States and Iran (i.e., the 47 senators' letter warning that a deal might be killed by Congress or a future president). This of course is not the case. The talks were between Iran and the five permanent members of the U.N. Security Council (United States, United Kingdom, France, China and Russia) plus Germany, dubbed the P5+1. While the United States has played a leading role in the effort, it negotiated the terms alongside its partners. If the agreement reached by the P5+1 is rejected by Congress, it could result in an unraveling of the sanctions on Iran and threaten NATO cohesion in other areas. Another questionable assertion is that this agreement contains a sunset clause, after which Iran will be free to do as it pleases. Again, this is not the case. Some of the restrictions on Iran's nuclear activities, such as uranium enrichment, will be eased or eliminated over time, as long as 15 years. But most importantly, the framework agreement includes Iran's ratification of the Additional Protocol, which allows IAEA inspectors expanded access to nuclear sites both declared and nondeclared. This provision will be permanent. It does not sunset. Thus, going forward, if Iran decides to enrich uranium to weapons-grade levels, monitors will be able to detect such a move in a matter of days and alert the U.N. Security Council. Many in Congress have said that the agreement should be a formal treaty requiring the Senate to \"advise and consent.\" But the issue is not suited for a treaty. Treaties impose equivalent obligations on all signatories. For example, the New START treaty limits Russia and the United States to 1,550 deployed strategic warheads. But any agreement with Iran will not be so balanced. The restrictions and obligations in the final framework agreement will be imposed almost exclusively on Iran. The P5+1 are obligated only to ease and eventually remove most but not all economic sanctions, which were imposed as leverage to gain this final deal. Finally some insist that any agreement must address Iranian missile programs, human rights violations or support for Hamas or Hezbollah. As important as these issues are, and they must indeed be addressed, they are unrelated to the most important aim of a nuclear deal: preventing a nuclear Iran. To include them in the negotiations would be a poison pill. This agreement should be judged on its merits and on how it affects the security of our negotiating partners and allies, including Israel. Those judgments should be fact-based, not based on questionable assertions or dubious assumptions."
459
+ ARTICLE_SUBWAY = 'New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County, New York. A year later, she got married again in Westchester County, but to a different man and without divorcing her first husband. Only 18 days after that marriage, she got hitched yet again. Then, Barrientos declared "I do" five more times, sometimes only within two weeks of each other. In 2010, she married once more, this time in the Bronx. In an application for a marriage license, she stated it was her "first and only" marriage. Barrientos, now 39, is facing two criminal counts of "offering a false instrument for filing in the first degree," referring to her false statements on the 2010 marriage license application, according to court documents. Prosecutors said the marriages were part of an immigration scam. On Friday, she pleaded not guilty at State Supreme Court in the Bronx, according to her attorney, Christopher Wright, who declined to comment further. After leaving court, Barrientos was arrested and charged with theft of service and criminal trespass for allegedly sneaking into the New York subway through an emergency exit, said Detective Annette Markowski, a police spokeswoman. In total, Barrientos has been married 10 times, with nine of her marriages occurring between 1999 and 2002. All occurred either in Westchester County, Long Island, New Jersey or the Bronx. She is believed to still be married to four men, and at one time, she was married to eight men at once, prosecutors say. Prosecutors said the immigration scam involved some of her husbands, who filed for permanent residence status shortly after the marriages. Any divorces happened only after such filings were approved. It was unclear whether any of the men will be prosecuted. The case was referred to the Bronx District Attorney\'s Office by Immigration and Customs Enforcement and the Department of Homeland Security\'s Investigation Division. Seven of the men are from so-called "red-flagged" countries, including Egypt, Turkey, Georgia, Pakistan and Mali. Her eighth husband, Rashid Rajput, was deported in 2006 to his native Pakistan after an investigation by the Joint Terrorism Task Force. If convicted, Barrientos faces up to four years in prison. Her next court appearance is scheduled for May 18.'
460
+
461
+ expected_summaries = [
462
+ 'prosecutor: "so far no videos were used in the crash investigation" two magazines claim to have found a cell phone video of the final seconds . "one can hear cries of \'My God\' in several languages," the magazine says . all 150 on board the germanwings flight were killed .',
463
+ "the Palestinians become the 123rd member of the international criminal court . the accession was marked by a ceremony at the Hague, where the court is based . as members of the court, Palestinians may be subject to counter-charges as well .",
464
+ "the u.s. and its negotiating partners reached a very strong framework agreement with Iran . aaron miller: the debate that has already begun since the announcement of the new framework will likely result in more heat than light . he says the new framework would reduce Iran's low-enriched uranium stockpile and cut centrifuges . miller: if it had been, there would have been no Iranian team at the table .",
465
+ 'prosecutors say the marriages were part of an immigration scam . if convicted, barrientos faces two criminal counts of "offering a false instrument for filing in the first degree" she has been married 10 times, with nine of her marriages occurring between 1999 and 2002 .',
466
+ ]
467
+
468
+ dct = tok(
469
+ ["summarize: " + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
470
+ padding="max_length",
471
+ truncation=True,
472
+ return_tensors="np",
473
+ )
474
+ self.assertEqual(512, dct["input_ids"].shape[1])
475
+
476
+ hypotheses_batch = model.generate(
477
+ **dct,
478
+ num_beams=4,
479
+ length_penalty=2.0,
480
+ max_length=142,
481
+ min_length=56,
482
+ do_sample=False,
483
+ early_stopping=True,
484
+ ).sequences
485
+
486
+ decoded = tok.batch_decode(hypotheses_batch, skip_special_tokens=True, clean_up_tokenization_spaces=False)
487
+ self.assertListEqual(
488
+ expected_summaries,
489
+ decoded,
490
+ )
491
+ '''
train.py CHANGED
@@ -47,8 +47,8 @@ from transformers import (
47
  )
48
  from transformers.testing_utils import CaptureLogger
49
 
50
- from model.t5_vae import FlaxT5_VAE_ForAutoencoding
51
- from model.config import T5_VAE_Config
52
 
53
 
54
  logger = logging.getLogger(__name__)
@@ -316,15 +316,15 @@ def main():
316
  # download model & vocab.
317
 
318
  if model_args.config_path:
319
- config = T5_VAE_Config.from_pretrained(
320
  model_args.config_path, cache_dir=model_args.cache_dir
321
  )
322
  elif model_args.model_name_or_path:
323
- config = T5_VAE_Config.from_pretrained(
324
  model_args.model_name_or_path, cache_dir=model_args.cache_dir
325
  )
326
  else:
327
- config = T5_VAE_Config(**model_args.__dict__)
328
  logger.warning("You are instantiating a new config instance from scratch.")
329
 
330
  if model_args.tokenizer_name:
@@ -346,7 +346,7 @@ def main():
346
  )
347
 
348
  if model_args.model_name_or_path:
349
- model = FlaxT5_VAE_ForAutoencoding.from_pretrained(
350
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
351
  )
352
  # TODO assert token embedding size == len(tokenizer)
@@ -355,7 +355,7 @@ def main():
355
  config.t5.vocab_size = vocab_size
356
  config.vocab_size = vocab_size
357
  logger.info("Training new model from scratch.")
358
- model = FlaxT5_VAE_ForAutoencoding(
359
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
360
  )
361
 
@@ -402,7 +402,7 @@ def main():
402
  # Limits each input sequence to size block_size.
403
  pad_token_id = tokenizer.pad_token_id
404
 
405
- def limit_length(examples):
406
  examples["labels"] = examples["input_ids"].copy()
407
 
408
  for i, input_ids in enumerate(examples["input_ids"]):
@@ -425,7 +425,7 @@ def main():
425
  # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
426
 
427
  lm_datasets = tokenized_datasets.map(
428
- limit_length,
429
  batched=True,
430
  num_proc=data_args.preprocessing_num_workers,
431
  load_from_cache_file=not data_args.overwrite_cache,
@@ -536,22 +536,23 @@ def main():
536
  true_samples = jax.random.normal(rng, latent_codes.shape())
537
  return compute_mmd(true_samples, latent_codes)
538
 
539
- def loss_fn(logits, labels, latent_codes, rng: jax.random.PRNGKey):
540
  shift_logits = logits[..., :-1, :]
541
  shift_labels = labels[..., 1:]
542
  loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
543
 
544
- reg_loss = regulariser_loss(latent_codes, rng)
545
  return loss.mean() + reg_loss.mean()
546
 
547
  # Define gradient update step fn
548
  def train_step(state, batch):
549
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
 
550
 
551
  def compute_loss(params):
552
  labels = batch.pop("labels")
553
  outputs = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)
554
- loss = loss_fn(outputs.logits, labels, outputs.latent_codes, state.dropout_rng)
555
  return loss
556
 
557
  grad_fn = jax.value_and_grad(compute_loss)
 
47
  )
48
  from transformers.testing_utils import CaptureLogger
49
 
50
+ from model.t5_vae import FlaxT5VaeForAutoencoding
51
+ from model.config import T5VaeConfig
52
 
53
 
54
  logger = logging.getLogger(__name__)
 
316
  # download model & vocab.
317
 
318
  if model_args.config_path:
319
+ config = T5VaeConfig.from_pretrained(
320
  model_args.config_path, cache_dir=model_args.cache_dir
321
  )
322
  elif model_args.model_name_or_path:
323
+ config = T5VaeConfig.from_pretrained(
324
  model_args.model_name_or_path, cache_dir=model_args.cache_dir
325
  )
326
  else:
327
+ config = T5VaeConfig(**model_args.__dict__)
328
  logger.warning("You are instantiating a new config instance from scratch.")
329
 
330
  if model_args.tokenizer_name:
 
346
  )
347
 
348
  if model_args.model_name_or_path:
349
+ model = FlaxT5VaeForAutoencoding.from_pretrained(
350
  model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
351
  )
352
  # TODO assert token embedding size == len(tokenizer)
 
355
  config.t5.vocab_size = vocab_size
356
  config.vocab_size = vocab_size
357
  logger.info("Training new model from scratch.")
358
+ model = FlaxT5VaeForAutoencoding(
359
  config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
360
  )
361
 
 
402
  # Limits each input sequence to size block_size.
403
  pad_token_id = tokenizer.pad_token_id
404
 
405
+ def clip_texts(examples):
406
  examples["labels"] = examples["input_ids"].copy()
407
 
408
  for i, input_ids in enumerate(examples["input_ids"]):
 
425
  # https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
426
 
427
  lm_datasets = tokenized_datasets.map(
428
+ clip_texts,
429
  batched=True,
430
  num_proc=data_args.preprocessing_num_workers,
431
  load_from_cache_file=not data_args.overwrite_cache,
 
536
  true_samples = jax.random.normal(rng, latent_codes.shape())
537
  return compute_mmd(true_samples, latent_codes)
538
 
539
+ def loss_fn(logits, labels, latent_codes, regulariser_rng: jax.random.PRNGKey):
540
  shift_logits = logits[..., :-1, :]
541
  shift_labels = labels[..., 1:]
542
  loss = optax.softmax_cross_entropy(shift_logits, onehot(shift_labels, shift_logits.shape[-1]))
543
 
544
+ reg_loss = regulariser_loss(latent_codes, regulariser_rng)
545
  return loss.mean() + reg_loss.mean()
546
 
547
  # Define gradient update step fn
548
  def train_step(state, batch):
549
  dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
550
+ new_dropout_rng, regulariser_rng = jax.random.split(new_dropout_rng)
551
 
552
  def compute_loss(params):
553
  labels = batch.pop("labels")
554
  outputs = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)
555
+ loss = loss_fn(outputs.logits, labels, outputs.latent_codes, regulariser_rng)
556
  return loss
557
 
558
  grad_fn = jax.value_and_grad(compute_loss)