|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Progressive distillation for MobileBERT student model."""
|
| import dataclasses
|
| from typing import List, Optional
|
|
|
| from absl import logging
|
| import orbit
|
| import tensorflow as tf, tf_keras
|
| from official.core import base_task
|
| from official.core import config_definitions as cfg
|
| from official.modeling import optimization
|
| from official.modeling import tf_utils
|
| from official.modeling.fast_training.progressive import policies
|
| from official.modeling.hyperparams import base_config
|
| from official.nlp import modeling
|
| from official.nlp.configs import bert
|
| from official.nlp.configs import encoders
|
| from official.nlp.data import data_loader_factory
|
| from official.nlp.modeling import layers
|
| from official.nlp.modeling import models
|
|
|
|
|
| @dataclasses.dataclass
|
| class LayerWiseDistillConfig(base_config.Config):
|
| """Defines the behavior of layerwise distillation."""
|
| num_steps: int = 10000
|
| warmup_steps: int = 0
|
| initial_learning_rate: float = 1.5e-3
|
| end_learning_rate: float = 1.5e-3
|
| decay_steps: int = 10000
|
| hidden_distill_factor: float = 100.0
|
| beta_distill_factor: float = 5000.0
|
| gamma_distill_factor: float = 5.0
|
| if_transfer_attention: bool = True
|
| attention_distill_factor: float = 1.0
|
| if_freeze_previous_layers: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| transfer_teacher_layers: Optional[List[int]] = None
|
|
|
|
|
| @dataclasses.dataclass
|
| class PretrainDistillConfig(base_config.Config):
|
| """Defines the behavior of pretrain distillation."""
|
| num_steps: int = 500000
|
| warmup_steps: int = 10000
|
| initial_learning_rate: float = 1.5e-3
|
| end_learning_rate: float = 1.5e-7
|
| decay_steps: int = 500000
|
| if_use_nsp_loss: bool = True
|
| distill_ground_truth_ratio: float = 0.5
|
|
|
|
|
| @dataclasses.dataclass
|
| class BertDistillationProgressiveConfig(policies.ProgressiveConfig):
|
| """Defines the specific distillation behavior."""
|
| if_copy_embeddings: bool = True
|
| layer_wise_distill_config: LayerWiseDistillConfig = dataclasses.field(
|
| default_factory=LayerWiseDistillConfig
|
| )
|
| pretrain_distill_config: PretrainDistillConfig = dataclasses.field(
|
| default_factory=PretrainDistillConfig
|
| )
|
|
|
|
|
| @dataclasses.dataclass
|
| class BertDistillationTaskConfig(cfg.TaskConfig):
|
| """Defines the teacher/student model architecture and training data."""
|
| teacher_model: bert.PretrainerConfig = dataclasses.field(
|
| default_factory=lambda: bert.PretrainerConfig(
|
| encoder=encoders.EncoderConfig(type='mobilebert')
|
| )
|
| )
|
|
|
| student_model: bert.PretrainerConfig = dataclasses.field(
|
| default_factory=lambda: bert.PretrainerConfig(
|
| encoder=encoders.EncoderConfig(type='mobilebert')
|
| )
|
| )
|
|
|
| teacher_model_init_checkpoint: str = ''
|
| train_data: cfg.DataConfig = dataclasses.field(default_factory=cfg.DataConfig)
|
| validation_data: cfg.DataConfig = dataclasses.field(
|
| default_factory=cfg.DataConfig
|
| )
|
|
|
|
|
| def build_sub_encoder(encoder, target_layer_id):
|
| """Builds an encoder that only computes first few transformer layers."""
|
| input_ids = encoder.inputs[0]
|
| input_mask = encoder.inputs[1]
|
| type_ids = encoder.inputs[2]
|
| attention_mask = modeling.layers.SelfAttentionMask()(
|
| inputs=input_ids, to_mask=input_mask)
|
| embedding_output = encoder.embedding_layer(input_ids, type_ids)
|
|
|
| layer_output = embedding_output
|
| attention_score = None
|
| for layer_idx in range(target_layer_id + 1):
|
| layer_output, attention_score = encoder.transformer_layers[layer_idx](
|
| layer_output, attention_mask, return_attention_scores=True)
|
|
|
| return tf_keras.Model(
|
| inputs=[input_ids, input_mask, type_ids],
|
| outputs=[layer_output, attention_score])
|
|
|
|
|
| class BertDistillationTask(policies.ProgressivePolicy, base_task.Task):
|
| """Distillation language modeling task progressively."""
|
|
|
| def __init__(self,
|
| strategy,
|
| progressive: BertDistillationProgressiveConfig,
|
| optimizer_config: optimization.OptimizationConfig,
|
| task_config: BertDistillationTaskConfig,
|
| logging_dir=None):
|
|
|
| self._strategy = strategy
|
| self._task_config = task_config
|
| self._progressive_config = progressive
|
| self._optimizer_config = optimizer_config
|
| self._train_data_config = task_config.train_data
|
| self._eval_data_config = task_config.validation_data
|
| self._the_only_train_dataset = None
|
| self._the_only_eval_dataset = None
|
|
|
| layer_wise_config = self._progressive_config.layer_wise_distill_config
|
| transfer_teacher_layers = layer_wise_config.transfer_teacher_layers
|
| num_teacher_layers = (
|
| self._task_config.teacher_model.encoder.mobilebert.num_blocks)
|
| num_student_layers = (
|
| self._task_config.student_model.encoder.mobilebert.num_blocks)
|
| if transfer_teacher_layers and len(
|
| transfer_teacher_layers) != num_student_layers:
|
| raise ValueError('The number of `transfer_teacher_layers` %s does not '
|
| 'match the number of student layers. %d' %
|
| (transfer_teacher_layers, num_student_layers))
|
| if not transfer_teacher_layers and (num_teacher_layers !=
|
| num_student_layers):
|
| raise ValueError('`transfer_teacher_layers` is not specified, and the '
|
| 'number of teacher layers does not match '
|
| 'the number of student layers.')
|
|
|
| ratio = progressive.pretrain_distill_config.distill_ground_truth_ratio
|
| if ratio < 0 or ratio > 1:
|
| raise ValueError('distill_ground_truth_ratio has to be within [0, 1].')
|
|
|
|
|
| self._layer_norm = tf_keras.layers.LayerNormalization(
|
| axis=-1,
|
| beta_initializer='zeros',
|
| gamma_initializer='ones',
|
| trainable=False)
|
|
|
|
|
| self._teacher_pretrainer = self._build_pretrainer(
|
| self._task_config.teacher_model, name='teacher')
|
| self._student_pretrainer = self._build_pretrainer(
|
| self._task_config.student_model, name='student')
|
|
|
| base_task.Task.__init__(
|
| self, params=task_config, logging_dir=logging_dir)
|
| policies.ProgressivePolicy.__init__(self)
|
|
|
| def _build_pretrainer(self, pretrainer_cfg: bert.PretrainerConfig, name: str):
|
| """Builds pretrainer from config and encoder."""
|
| encoder = encoders.build_encoder(pretrainer_cfg.encoder)
|
| if pretrainer_cfg.cls_heads:
|
| cls_heads = [
|
| layers.ClassificationHead(**cfg.as_dict())
|
| for cfg in pretrainer_cfg.cls_heads
|
| ]
|
| else:
|
| cls_heads = []
|
|
|
| masked_lm = layers.MobileBertMaskedLM(
|
| embedding_table=encoder.get_embedding_table(),
|
| activation=tf_utils.get_activation(pretrainer_cfg.mlm_activation),
|
| initializer=tf_keras.initializers.TruncatedNormal(
|
| stddev=pretrainer_cfg.mlm_initializer_range),
|
| name='cls/predictions')
|
|
|
| pretrainer = models.BertPretrainerV2(
|
| encoder_network=encoder,
|
| classification_heads=cls_heads,
|
| customized_masked_lm=masked_lm,
|
| name=name)
|
| return pretrainer
|
|
|
|
|
| def num_stages(self):
|
|
|
| return self._task_config.student_model.encoder.mobilebert.num_blocks + 1
|
|
|
|
|
| def num_steps(self, stage_id) -> int:
|
| """Return the total number of steps in this stage."""
|
| if stage_id + 1 < self.num_stages():
|
| return self._progressive_config.layer_wise_distill_config.num_steps
|
| else:
|
| return self._progressive_config.pretrain_distill_config.num_steps
|
|
|
|
|
| def get_model(self, stage_id, old_model=None) -> tf_keras.Model:
|
| del old_model
|
| return self.build_model(stage_id)
|
|
|
|
|
| def get_optimizer(self, stage_id):
|
| """Build optimizer for each stage."""
|
| if stage_id + 1 < self.num_stages():
|
| distill_config = self._progressive_config.layer_wise_distill_config
|
| else:
|
| distill_config = self._progressive_config.pretrain_distill_config
|
|
|
| params = self._optimizer_config.replace(
|
| learning_rate={
|
| 'polynomial': {
|
| 'decay_steps':
|
| distill_config.decay_steps,
|
| 'initial_learning_rate':
|
| distill_config.initial_learning_rate,
|
| 'end_learning_rate':
|
| distill_config.end_learning_rate,
|
| }
|
| },
|
| warmup={
|
| 'linear':
|
| {'warmup_steps':
|
| distill_config.warmup_steps,
|
| }
|
| })
|
| opt_factory = optimization.OptimizerFactory(params)
|
| optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
|
| if isinstance(optimizer, tf_keras.optimizers.experimental.Optimizer):
|
| optimizer = tf_keras.__internal__.optimizers.convert_to_legacy_optimizer(
|
| optimizer)
|
|
|
| return optimizer
|
|
|
|
|
| def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
|
| """Return Dataset for this stage."""
|
| del stage_id
|
| if self._the_only_train_dataset is None:
|
| self._the_only_train_dataset = orbit.utils.make_distributed_dataset(
|
| self._strategy, self.build_inputs, self._train_data_config)
|
| return self._the_only_train_dataset
|
|
|
|
|
| def get_eval_dataset(self, stage_id):
|
| del stage_id
|
| if self._the_only_eval_dataset is None:
|
| self._the_only_eval_dataset = orbit.utils.make_distributed_dataset(
|
| self._strategy, self.build_inputs, self._eval_data_config)
|
| return self._the_only_eval_dataset
|
|
|
|
|
| def build_model(self, stage_id) -> tf_keras.Model:
|
| """Build teacher/student keras models with outputs for current stage."""
|
|
|
| self._teacher_pretrainer.trainable = False
|
| layer_wise_config = self._progressive_config.layer_wise_distill_config
|
| freeze_previous_layers = layer_wise_config.if_freeze_previous_layers
|
| student_encoder = self._student_pretrainer.encoder_network
|
|
|
| if stage_id != self.num_stages() - 1:
|
|
|
| inputs = student_encoder.inputs
|
| student_sub_encoder = build_sub_encoder(
|
| encoder=student_encoder, target_layer_id=stage_id)
|
| student_output_feature, student_attention_score = student_sub_encoder(
|
| inputs)
|
|
|
| if layer_wise_config.transfer_teacher_layers:
|
| teacher_layer_id = layer_wise_config.transfer_teacher_layers[stage_id]
|
| else:
|
| teacher_layer_id = stage_id
|
|
|
| teacher_sub_encoder = build_sub_encoder(
|
| encoder=self._teacher_pretrainer.encoder_network,
|
| target_layer_id=teacher_layer_id)
|
|
|
| teacher_output_feature, teacher_attention_score = teacher_sub_encoder(
|
| inputs)
|
|
|
| if freeze_previous_layers:
|
| student_encoder.embedding_layer.trainable = False
|
| for i in range(stage_id):
|
| student_encoder.transformer_layers[i].trainable = False
|
|
|
| return tf_keras.Model(
|
| inputs=inputs,
|
| outputs=dict(
|
| student_output_feature=student_output_feature,
|
| student_attention_score=student_attention_score,
|
| teacher_output_feature=teacher_output_feature,
|
| teacher_attention_score=teacher_attention_score))
|
| else:
|
|
|
| inputs = self._student_pretrainer.inputs
|
| student_pretrainer_output = self._student_pretrainer(inputs)
|
| teacher_pretrainer_output = self._teacher_pretrainer(inputs)
|
|
|
|
|
| if freeze_previous_layers:
|
| student_encoder.embedding_layer.trainable = True
|
| for layer in student_encoder.transformer_layers:
|
| layer.trainable = True
|
|
|
| model = tf_keras.Model(
|
| inputs=inputs,
|
| outputs=dict(
|
| student_pretrainer_output=student_pretrainer_output,
|
| teacher_pretrainer_output=teacher_pretrainer_output,
|
| ))
|
|
|
| model.checkpoint_items = self._student_pretrainer.checkpoint_items
|
| return model
|
|
|
|
|
| def build_inputs(self, params, input_context=None):
|
| """Returns tf.data.Dataset for pretraining."""
|
|
|
| if params.input_path == 'dummy':
|
|
|
| def dummy_data(_):
|
| dummy_ids = tf.zeros((1, params.seq_length), dtype=tf.int32)
|
| dummy_lm = tf.zeros((1, params.max_predictions_per_seq), dtype=tf.int32)
|
| return dict(
|
| input_word_ids=dummy_ids,
|
| input_mask=dummy_ids,
|
| input_type_ids=dummy_ids,
|
| masked_lm_positions=dummy_lm,
|
| masked_lm_ids=dummy_lm,
|
| masked_lm_weights=tf.cast(dummy_lm, dtype=tf.float32),
|
| next_sentence_labels=tf.zeros((1, 1), dtype=tf.int32))
|
|
|
| dataset = tf.data.Dataset.range(1)
|
| dataset = dataset.repeat()
|
| dataset = dataset.map(
|
| dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
|
| return dataset
|
|
|
| return data_loader_factory.get_data_loader(params).load(input_context)
|
|
|
| def _get_distribution_losses(self, teacher, student):
|
| """Return the beta and gamma distall losses for feature distribution."""
|
| teacher_mean = tf.math.reduce_mean(teacher, axis=-1, keepdims=True)
|
| student_mean = tf.math.reduce_mean(student, axis=-1, keepdims=True)
|
| teacher_var = tf.math.reduce_variance(teacher, axis=-1, keepdims=True)
|
| student_var = tf.math.reduce_variance(student, axis=-1, keepdims=True)
|
|
|
| beta_loss = tf.math.squared_difference(student_mean, teacher_mean)
|
| beta_loss = tf.math.reduce_mean(beta_loss, axis=None, keepdims=False)
|
| gamma_loss = tf.math.abs(student_var - teacher_var)
|
| gamma_loss = tf.math.reduce_mean(gamma_loss, axis=None, keepdims=False)
|
|
|
| return beta_loss, gamma_loss
|
|
|
| def _get_attention_loss(self, teacher_score, student_score):
|
|
|
|
|
|
|
| teacher_weight = tf_keras.activations.softmax(teacher_score, axis=-1)
|
| student_log_weight = tf.nn.log_softmax(student_score, axis=-1)
|
| kl_divergence = -(teacher_weight * student_log_weight)
|
| kl_divergence = tf.math.reduce_sum(kl_divergence, axis=-1, keepdims=True)
|
| kl_divergence = tf.math.reduce_mean(kl_divergence, axis=None,
|
| keepdims=False)
|
| return kl_divergence
|
|
|
| def build_losses(self, labels, outputs, metrics) -> tf.Tensor:
|
| """Builds losses and update loss-related metrics for the current stage."""
|
| last_stage = 'student_pretrainer_output' in outputs
|
|
|
|
|
| if not last_stage:
|
| distill_config = self._progressive_config.layer_wise_distill_config
|
| teacher_feature = outputs['teacher_output_feature']
|
| student_feature = outputs['student_output_feature']
|
|
|
| feature_transfer_loss = tf_keras.losses.mean_squared_error(
|
| self._layer_norm(teacher_feature), self._layer_norm(student_feature))
|
| feature_transfer_loss *= distill_config.hidden_distill_factor
|
| beta_loss, gamma_loss = self._get_distribution_losses(teacher_feature,
|
| student_feature)
|
| beta_loss *= distill_config.beta_distill_factor
|
| gamma_loss *= distill_config.gamma_distill_factor
|
| total_loss = feature_transfer_loss + beta_loss + gamma_loss
|
|
|
| if distill_config.if_transfer_attention:
|
| teacher_attention = outputs['teacher_attention_score']
|
| student_attention = outputs['student_attention_score']
|
| attention_loss = self._get_attention_loss(teacher_attention,
|
| student_attention)
|
| attention_loss *= distill_config.attention_distill_factor
|
| total_loss += attention_loss
|
|
|
| total_loss /= tf.cast((self._stage_id + 1), tf.float32)
|
|
|
|
|
| else:
|
| distill_config = self._progressive_config.pretrain_distill_config
|
| lm_label = labels['masked_lm_ids']
|
| vocab_size = (
|
| self._task_config.student_model.encoder.mobilebert.word_vocab_size)
|
|
|
|
|
| lm_label = tf.one_hot(indices=lm_label, depth=vocab_size, on_value=1.0,
|
| off_value=0.0, axis=-1, dtype=tf.float32)
|
| gt_ratio = distill_config.distill_ground_truth_ratio
|
| if gt_ratio != 1.0:
|
| teacher_mlm_logits = outputs['teacher_pretrainer_output']['mlm_logits']
|
| teacher_labels = tf.nn.softmax(teacher_mlm_logits, axis=-1)
|
| lm_label = gt_ratio * lm_label + (1-gt_ratio) * teacher_labels
|
|
|
| student_pretrainer_output = outputs['student_pretrainer_output']
|
|
|
| student_lm_log_probs = tf.nn.log_softmax(
|
| student_pretrainer_output['mlm_logits'], axis=-1)
|
|
|
|
|
| per_example_loss = tf.reshape(
|
| -tf.reduce_sum(student_lm_log_probs * lm_label, axis=[-1]), [-1])
|
|
|
| lm_label_weights = tf.reshape(labels['masked_lm_weights'], [-1])
|
| lm_numerator_loss = tf.reduce_sum(per_example_loss * lm_label_weights)
|
| lm_denominator_loss = tf.reduce_sum(lm_label_weights)
|
| mlm_loss = tf.math.divide_no_nan(lm_numerator_loss, lm_denominator_loss)
|
| total_loss = mlm_loss
|
|
|
| if 'next_sentence_labels' in labels:
|
| sentence_labels = labels['next_sentence_labels']
|
| sentence_outputs = tf.cast(
|
| student_pretrainer_output['next_sentence'], dtype=tf.float32)
|
| sentence_loss = tf.reduce_mean(
|
| tf_keras.losses.sparse_categorical_crossentropy(
|
| sentence_labels, sentence_outputs, from_logits=True))
|
| total_loss += sentence_loss
|
|
|
|
|
| metrics = dict([(metric.name, metric) for metric in metrics])
|
|
|
| if not last_stage:
|
| metrics['feature_transfer_mse'].update_state(feature_transfer_loss)
|
| metrics['beta_transfer_loss'].update_state(beta_loss)
|
| metrics['gamma_transfer_loss'].update_state(gamma_loss)
|
| layer_wise_config = self._progressive_config.layer_wise_distill_config
|
| if layer_wise_config.if_transfer_attention:
|
| metrics['attention_transfer_loss'].update_state(attention_loss)
|
| else:
|
| metrics['lm_example_loss'].update_state(mlm_loss)
|
| if 'next_sentence_labels' in labels:
|
| metrics['next_sentence_loss'].update_state(sentence_loss)
|
| metrics['total_loss'].update_state(total_loss)
|
|
|
| return total_loss
|
|
|
|
|
| def build_metrics(self, training=None):
|
| del training
|
| metrics = [
|
| tf_keras.metrics.Mean(name='feature_transfer_mse'),
|
| tf_keras.metrics.Mean(name='beta_transfer_loss'),
|
| tf_keras.metrics.Mean(name='gamma_transfer_loss'),
|
| tf_keras.metrics.SparseCategoricalAccuracy(name='masked_lm_accuracy'),
|
| tf_keras.metrics.Mean(name='lm_example_loss'),
|
| tf_keras.metrics.Mean(name='total_loss')]
|
| if self._progressive_config.layer_wise_distill_config.if_transfer_attention:
|
| metrics.append(tf_keras.metrics.Mean(name='attention_transfer_loss'))
|
| if self._task_config.train_data.use_next_sentence_label:
|
| metrics.append(tf_keras.metrics.SparseCategoricalAccuracy(
|
| name='next_sentence_accuracy'))
|
| metrics.append(tf_keras.metrics.Mean(name='next_sentence_loss'))
|
|
|
| return metrics
|
|
|
|
|
|
|
| def process_metrics(self, metrics, labels, student_pretrainer_output):
|
| metrics = dict([(metric.name, metric) for metric in metrics])
|
|
|
| if student_pretrainer_output is not None:
|
| if 'masked_lm_accuracy' in metrics:
|
| metrics['masked_lm_accuracy'].update_state(
|
| labels['masked_lm_ids'], student_pretrainer_output['mlm_logits'],
|
| labels['masked_lm_weights'])
|
| if 'next_sentence_accuracy' in metrics:
|
| metrics['next_sentence_accuracy'].update_state(
|
| labels['next_sentence_labels'],
|
| student_pretrainer_output['next_sentence'])
|
|
|
|
|
| def train_step(self, inputs, model: tf_keras.Model,
|
| optimizer: tf_keras.optimizers.Optimizer, metrics):
|
| """Does forward and backward.
|
|
|
| Args:
|
| inputs: a dictionary of input tensors.
|
| model: the model, forward pass definition.
|
| optimizer: the optimizer for this training step.
|
| metrics: a nested structure of metrics objects.
|
|
|
| Returns:
|
| A dictionary of logs.
|
| """
|
| with tf.GradientTape() as tape:
|
| outputs = model(inputs, training=True)
|
|
|
|
|
| loss = self.build_losses(
|
| labels=inputs,
|
| outputs=outputs,
|
| metrics=metrics)
|
|
|
|
|
|
|
|
|
|
|
|
|
| tvars = model.trainable_variables
|
| last_stage = 'student_pretrainer_output' in outputs
|
|
|
| grads = tape.gradient(loss, tvars)
|
| optimizer.apply_gradients(list(zip(grads, tvars)))
|
| self.process_metrics(
|
| metrics, inputs,
|
| outputs['student_pretrainer_output'] if last_stage else None)
|
| return {self.loss: loss}
|
|
|
|
|
| def validation_step(self, inputs, model: tf_keras.Model, metrics):
|
| """Validatation step.
|
|
|
| Args:
|
| inputs: a dictionary of input tensors.
|
| model: the keras.Model.
|
| metrics: a nested structure of metrics objects.
|
|
|
| Returns:
|
| A dictionary of logs.
|
| """
|
| outputs = model(inputs, training=False)
|
|
|
| loss = self.build_losses(labels=inputs, outputs=outputs, metrics=metrics)
|
| last_stage = 'student_pretrainer_output' in outputs
|
| self.process_metrics(
|
| metrics, inputs,
|
| outputs['student_pretrainer_output'] if last_stage else None)
|
| return {self.loss: loss}
|
|
|
| @property
|
| def cur_checkpoint_items(self):
|
| """Checkpoints for model, stage_id, optimizer for preemption handling."""
|
| return dict(
|
| stage_id=self._stage_id,
|
| volatiles=self._volatiles,
|
| student_pretrainer=self._student_pretrainer,
|
| teacher_pretrainer=self._teacher_pretrainer,
|
| encoder=self._student_pretrainer.encoder_network)
|
|
|
| def initialize(self, model):
|
| """Loads teacher's pretrained checkpoint and copy student's embedding."""
|
|
|
|
|
|
|
|
|
| del model
|
| logging.info('Begin to load checkpoint for teacher pretrainer model.')
|
| ckpt_dir_or_file = self._task_config.teacher_model_init_checkpoint
|
| if not ckpt_dir_or_file:
|
| raise ValueError('`teacher_model_init_checkpoint` is not specified.')
|
|
|
| if tf.io.gfile.isdir(ckpt_dir_or_file):
|
| ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
|
|
|
| _ = self._teacher_pretrainer(self._teacher_pretrainer.inputs)
|
| teacher_checkpoint = tf.train.Checkpoint(
|
| **self._teacher_pretrainer.checkpoint_items)
|
| teacher_checkpoint.read(ckpt_dir_or_file).assert_existing_objects_matched()
|
|
|
| logging.info('Begin to copy word embedding from teacher model to student.')
|
| teacher_encoder = self._teacher_pretrainer.encoder_network
|
| student_encoder = self._student_pretrainer.encoder_network
|
| embedding_weights = teacher_encoder.embedding_layer.get_weights()
|
| student_encoder.embedding_layer.set_weights(embedding_weights)
|
|
|