| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ |
| | Initialise a student Whisper model from a pre-trained teacher model for |
| | teacher-student distillation. |
| | """ |
| |
|
| | import argparse |
| | import copy |
| | import logging |
| |
|
| | import numpy as np |
| | import torch |
| | from transformers import GenerationConfig, WhisperForConditionalGeneration, WhisperProcessor |
| |
|
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser( |
| | description="Initialise a student Whisper model from a teacher model, copying the relevant layer weights and adjusting the processor as necessary." |
| | ) |
| | parser.add_argument( |
| | "--teacher_checkpoint", |
| | type=str, |
| | required=True, |
| | help="The HF Hub ID of the teacher checkpoint.", |
| | ) |
| | parser.add_argument( |
| | "--subfolder", |
| | type=str, |
| | default="", |
| | help="In case the relevant teacher weights are located inside a subfolder of the model repo on huggingface.co, you " |
| | "can specify the folder name here.", |
| | ) |
| | parser.add_argument( |
| | "--encoder_layers", |
| | type=int, |
| | default=None, |
| | help="Number of encoder layers to use in the student model. Defaults to all layers from the teacher.", |
| | ) |
| | parser.add_argument( |
| | "--decoder_layers", |
| | type=int, |
| | default=2, |
| | help="Number of decoder layers to use in the student model. Defaults to 2 layers.", |
| | ) |
| | parser.add_argument( |
| | "--decoder_layers_numbers", |
| | type=int, |
| | nargs="*", |
| | help="Layers numbers of the decoder teacher to use in the student model. Defaults to None, equivalent to taking first and last layer (and equivalent to `--decoder_layers_numbers 0 -1`).", |
| | ) |
| | parser.add_argument( |
| | "--save_dir", |
| | type=str, |
| | required=True, |
| | help="Where to save the student weights and processor.", |
| | ) |
| | parser.add_argument( |
| | "--push_to_hub", |
| | type=bool, |
| | required=False, |
| | default=False, |
| | help="Whether to push the student weights and processor to the Hub.", |
| | ) |
| | parser.add_argument( |
| | "--cache_dir", |
| | type=str, |
| | default=None, |
| | help="Where to store the pretrained models downloaded from huggingface.co", |
| | ) |
| |
|
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def init_student_model_from_teacher( |
| | teacher_checkpoint, |
| | encoder_layers=None, |
| | decoder_layers=2, |
| | decoder_layers_numbers=None, |
| | save_dir=None, |
| | push_to_hub=None, |
| | cache_dir=None, |
| | subfolder="", |
| | ): |
| | if decoder_layers_numbers is not None and len(decoder_layers_numbers) != decoder_layers: |
| | raise ValueError( |
| | f"Got {len(decoder_layers_numbers)} layers number for {decoder_layers} decoder layers." |
| | ) |
| |
|
| | teacher_model = WhisperForConditionalGeneration.from_pretrained( |
| | teacher_checkpoint, |
| | cache_dir=cache_dir, |
| | subfolder=subfolder, |
| | low_cpu_mem_usage=True, |
| | ) |
| | processor = WhisperProcessor.from_pretrained(teacher_checkpoint) |
| | generation_config = GenerationConfig.from_pretrained(teacher_checkpoint) |
| | generation_config.forced_decoder_ids = None |
| |
|
| | teacher_config = teacher_model.config |
| | teacher_encoder_layers = teacher_config.encoder_layers |
| | teacher_decoder_layers = teacher_config.decoder_layers |
| |
|
| | student_config = copy.deepcopy(teacher_config) |
| | student_config.update( |
| | { |
| | "encoder_layers": encoder_layers if encoder_layers is not None else teacher_encoder_layers, |
| | "decoder_layers": decoder_layers, |
| | } |
| | ) |
| |
|
| | encoder_mapping = np.linspace(0, teacher_encoder_layers - 1, student_config.encoder_layers, dtype=int) |
| | encoder_mapping[-1] = teacher_encoder_layers - 1 |
| | |
| | encoder_map = {} |
| | for student_layer, teacher_layer in enumerate(encoder_mapping): |
| | encoder_map[teacher_layer] = student_layer |
| |
|
| | if decoder_layers_numbers is None: |
| | decoder_mapping = np.linspace(0, teacher_decoder_layers - 1, student_config.decoder_layers, dtype=int) |
| | decoder_mapping[-1] = teacher_decoder_layers - 1 |
| | else: |
| | decoder_mapping = decoder_layers_numbers |
| |
|
| | decoder_map = {} |
| | for student_layer, teacher_layer in enumerate(decoder_mapping): |
| | decoder_map[teacher_layer] = student_layer |
| |
|
| | |
| | student_model = WhisperForConditionalGeneration(student_config) |
| | missing_keys, unexpected_keys = student_model.load_state_dict(teacher_model.state_dict(), strict=False) |
| | if len(missing_keys) > 0: |
| | raise RuntimeError( |
| | "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n" |
| | f"Missing key(s) in state_dict: {missing_keys}" |
| | ) |
| | if decoder_layers == teacher_decoder_layers: |
| | decoder_keys = [key for key in unexpected_keys if "model.decoder.layers" in key] |
| | if len(decoder_keys) > 0: |
| | raise RuntimeError( |
| | "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n" |
| | f"Unexpected key(s) in state_dict: {decoder_keys}" |
| | ) |
| | if encoder_layers == teacher_encoder_layers: |
| | encoder_keys = [key for key in unexpected_keys if "model.encoder.layers" in key] |
| | if len(encoder_keys) > 0: |
| | raise RuntimeError( |
| | "Error(s) in loading state_dict for WhisperForConditionalGeneration. \n" |
| | f"Unexpected key(s) in state_dict: {encoder_keys}" |
| | ) |
| | |
| | for layer in range(teacher_decoder_layers): |
| | if layer in decoder_map: |
| | |
| | student_model.model.decoder.layers[decoder_map[layer]].load_state_dict( |
| | teacher_model.model.decoder.layers[layer].state_dict() |
| | ) |
| |
|
| | if encoder_layers is not None: |
| | for layer in range(teacher_encoder_layers): |
| | if layer in encoder_map: |
| | |
| | student_model.model.encoder.layers[encoder_map[layer]].load_state_dict( |
| | teacher_model.model.encoder.layers[layer].state_dict() |
| | ) |
| |
|
| | |
| | del teacher_model |
| |
|
| | |
| | if save_dir is not None: |
| | student_model.save_pretrained(save_dir) |
| | |
| | processor.save_pretrained(save_dir) |
| | generation_config.save_pretrained(save_dir) |
| |
|
| | |
| | logger.info("Checking we can load the saved model...") |
| | student_model = WhisperForConditionalGeneration.from_pretrained( |
| | save_dir, |
| | low_cpu_mem_usage=True, |
| | ) |
| | processor = WhisperProcessor.from_pretrained(save_dir) |
| |
|
| | |
| | input_features = processor(np.ones(16000), sampling_rate=16000, return_tensors="pt").input_features |
| | decoder_start_token_id = student_model.config.decoder_start_token_id |
| | decoder_input_ids = torch.ones((input_features.shape[0], 1), dtype=torch.long) * decoder_start_token_id |
| |
|
| | |
| | |
| | logger.info("Checking we can run the converted model forward...") |
| | _ = student_model(input_features, decoder_input_ids=decoder_input_ids).logits |
| | logger.info("Conversion successful!") |
| |
|
| | if push_to_hub: |
| | student_model.push_to_hub(save_dir) |
| | processor.push_to_hub(save_dir) |
| | generation_config.push_to_hub(save_dir) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = parse_args() |
| |
|
| | init_student_model_from_teacher( |
| | teacher_checkpoint=args.teacher_checkpoint, |
| | encoder_layers=args.encoder_layers, |
| | decoder_layers=args.decoder_layers, |
| | decoder_layers_numbers=args.decoder_layers_numbers, |
| | save_dir=args.save_dir, |
| | push_to_hub=args.push_to_hub, |
| | cache_dir=args.cache_dir, |
| | subfolder=args.subfolder, |
| | ) |
| |
|