| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Convert a Flax training state to HF Transformers Whisper weights. |
| """ |
|
|
| import logging |
| import os |
| import sys |
| from dataclasses import field |
| from pathlib import Path |
| from typing import Callable, Optional |
|
|
| import flax |
| import jax |
| import jax.numpy as jnp |
| import optax |
| from flax import jax_utils, traverse_util |
| from flax.serialization import from_bytes |
| from flax.training import train_state |
| from flax.training.common_utils import shard_prng_key |
| from huggingface_hub import Repository, create_repo |
| from optax._src import linear_algebra |
| from transformers import ( |
| AutoConfig, |
| HfArgumentParser, |
| Seq2SeqTrainingArguments, |
| ) |
| from transformers.file_utils import get_full_repo_name |
| from transformers.utils import check_min_version |
| from transformers.utils.versions import require_version |
|
|
| from distil_whisper import FlaxWhisperForConditionalGeneration |
|
|
|
|
| |
| jax.distributed.initialize() |
|
|
| |
| check_min_version("4.27.0.dev0") |
|
|
| require_version( |
| "datasets>=1.18.0", |
| "To fix: pip install -r examples/flax/speech-recogintion/requirements.txt", |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| @flax.struct.dataclass |
| class ModelArguments: |
| """ |
| Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. |
| """ |
|
|
| model_name_or_path: str = field( |
| metadata={"help": ("Path to pretrained student model or model identifier from huggingface.co/models")} |
| ) |
| config_name: Optional[str] = field( |
| default=None, |
| metadata={"help": "Pretrained config name or path if not the same as model_name"}, |
| ) |
| cache_dir: Optional[str] = field( |
| default=None, |
| metadata={"help": ("Where to store the pretrained models downloaded from huggingface.co")}, |
| ) |
| use_fast_tokenizer: bool = field( |
| default=True, |
| metadata={"help": ("Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.")}, |
| ) |
| model_revision: str = field( |
| default="main", |
| metadata={"help": ("The specific model version to use (can be a branch name, tag name or commit id).")}, |
| ) |
| use_auth_token: bool = field( |
| default=False, |
| metadata={ |
| "help": ( |
| "Will use the token generated when running `transformers-cli login`" |
| " (necessary to use this script with private models)." |
| ) |
| }, |
| ) |
| dtype: Optional[str] = field( |
| default="float32", |
| metadata={ |
| "help": ( |
| "Floating-point format in which the model weights should be initialized" |
| " and trained. Choose one of `[float32, float16, bfloat16]`." |
| ) |
| }, |
| ) |
| load_with_scan_weights: bool = field( |
| default=False, |
| metadata={ |
| "help": "Whether the pre-trained checkpoint has its weights stored in scan format. Set to True for scanned " |
| "weights, defaults to False for non-scan (unrolled) weights." |
| }, |
| ) |
| use_scan: bool = field( |
| default=True, |
| metadata={"help": ("Whether or not to use `scan_with_axes` over the encoder and decoder blocks.")}, |
| ) |
|
|
|
|
| def create_learning_rate_fn( |
| num_train_steps: int, lr_scheduler_type: str, num_warmup_steps: int, learning_rate: float |
| ) -> Callable[[int], jnp.array]: |
| """Returns a linear warmup, linear_decay learning rate function.""" |
| lr_scheduler_types = ("linear", "constant_with_warmup") |
|
|
| if lr_scheduler_type not in lr_scheduler_types: |
| raise ValueError( |
| f"lr_scheduler_type of type {lr_scheduler_type} not supported, choose from {lr_scheduler_types}." |
| ) |
|
|
| warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps) |
| decay_fn = optax.linear_schedule( |
| init_value=learning_rate, |
| end_value=0 if lr_scheduler_type == "linear" else learning_rate, |
| transition_steps=num_train_steps - num_warmup_steps, |
| ) |
| schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]) |
| return schedule_fn |
|
|
|
|
| class TrainState(train_state.TrainState): |
| dropout_rng: jnp.ndarray |
| max_grad_norm: float |
|
|
| def apply_gradients(self, *, grads, **kwargs): |
| """Updates `step`, `params`, `opt_state` and `**kwargs` in return value, clipping the |
| gradients by the maximum grad norm. |
| |
| Note that internally this function calls `.tx.update()` followed by a call |
| to `optax.apply_updates()` to update `params` and `opt_state`. |
| |
| Args: |
| grads: Gradients that have the same pytree structure as `.params`. |
| **kwargs: Additional dataclass attributes that should be `.replace()`-ed. |
| |
| Returns: |
| An updated instance of `self` with `step` incremented by one, `params` |
| and `opt_state` updated by applying `grads`, and additional attributes |
| replaced as specified by `kwargs`. |
| """ |
| |
| g_norm = linear_algebra.global_norm(grads) |
| g_norm = jnp.maximum(self.max_grad_norm, g_norm) |
| grads = jax.tree_map(lambda t: (t / g_norm) * self.max_grad_norm, grads) |
|
|
| updates, new_opt_state = self.tx.update(grads, self.opt_state, self.params) |
| new_params = optax.apply_updates(self.params, updates) |
|
|
| return self.replace( |
| step=self.step + 1, |
| params=new_params, |
| opt_state=new_opt_state, |
| **kwargs, |
| ) |
|
|
| def replicate(self): |
| return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng)) |
|
|
| def unreplicate(self): |
| return jax_utils.unreplicate(self) |
|
|
|
|
| def main(): |
| |
| |
| |
| |
| parser = HfArgumentParser( |
| ( |
| ModelArguments, |
| Seq2SeqTrainingArguments, |
| ) |
| ) |
|
|
| if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): |
| |
| |
| model_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) |
| else: |
| model_args, training_args = parser.parse_args_into_dataclasses() |
|
|
| |
| if training_args.push_to_hub: |
| if training_args.hub_model_id is None: |
| repo_name = get_full_repo_name( |
| Path(training_args.output_dir).absolute().name, |
| token=training_args.hub_token, |
| ) |
| else: |
| repo_name = training_args.hub_model_id |
| create_repo(repo_name, exist_ok=True, token=training_args.hub_token) |
| repo = Repository( |
| training_args.output_dir, |
| clone_from=repo_name, |
| token=training_args.hub_token, |
| ) |
|
|
| |
| config = AutoConfig.from_pretrained( |
| (model_args.config_name if model_args.config_name else model_args.model_name_or_path), |
| cache_dir=model_args.cache_dir, |
| revision=model_args.model_revision, |
| use_auth_token=True if model_args.use_auth_token else None, |
| ) |
| student_model, student_params = FlaxWhisperForConditionalGeneration.from_pretrained( |
| model_args.model_name_or_path, |
| config=config, |
| dtype=getattr(jnp, model_args.dtype), |
| cache_dir=model_args.cache_dir, |
| revision=model_args.model_revision, |
| use_auth_token=True if model_args.use_auth_token else None, |
| _do_init=False, |
| use_scan=model_args.load_with_scan_weights, |
| ) |
|
|
| |
| if model_args.use_scan: |
| student_model.enable_scan() |
| student_params = student_model.convert_unroll_to_scan(student_params) |
|
|
| |
| rng = jax.random.PRNGKey(training_args.seed) |
| rng, dropout_rng = jax.random.split(rng) |
|
|
| total_train_steps = int(training_args.max_steps) |
|
|
| |
| linear_decay_lr_schedule_fn = create_learning_rate_fn( |
| total_train_steps, |
| training_args.lr_scheduler_type, |
| training_args.warmup_steps, |
| training_args.learning_rate, |
| ) |
|
|
| |
| |
| |
| |
| def decay_mask_fn(params): |
| flat_params = traverse_util.flatten_dict(params) |
| |
| layer_norm_candidates = [ |
| "layer_norm", |
| "self_attn_layer_norm", |
| "final_layer_norm", |
| "encoder_attn_layer_norm", |
| ] |
| layer_norm_named_params = { |
| layer[-2:] |
| for layer_norm_name in layer_norm_candidates |
| for layer in flat_params.keys() |
| if layer_norm_name in "".join(layer).lower() |
| } |
| flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params} |
| return traverse_util.unflatten_dict(flat_mask) |
|
|
| |
| adamw = optax.adamw( |
| learning_rate=linear_decay_lr_schedule_fn, |
| b1=training_args.adam_beta1, |
| b2=training_args.adam_beta2, |
| eps=training_args.adam_epsilon, |
| weight_decay=training_args.weight_decay, |
| mask=decay_mask_fn, |
| ) |
|
|
| |
| student_state = TrainState.create( |
| apply_fn=student_model.__call__, |
| params=student_params, |
| tx=adamw, |
| dropout_rng=dropout_rng, |
| max_grad_norm=training_args.max_grad_norm, |
| ) |
|
|
| if training_args.resume_from_checkpoint is not None: |
| if os.path.isfile(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")): |
| logger.info( |
| f"Checkpoint detected, resuming training at {training_args.resume_from_checkpoint}. To avoid " |
| "this behavior, omit the resume_from_checkpoint argument." |
| ) |
| with Path(os.path.join(training_args.resume_from_checkpoint, "train_state.msgpack")).open("rb") as f: |
| student_state = from_bytes(student_state, f.read()) |
| else: |
| logger.warning( |
| f"Checkpoint {training_args.resume_from_checkpoint} not detected, training from scratch. Ensure " |
| f"you pass the path to a folder with a valid checkpoint for your model." |
| ) |
|
|
| cur_step = int(jax.device_get(student_state.step)) |
|
|
| |
| if jax.process_index() == 0: |
| student_model.disable_scan() |
| student_state_params = student_model.convert_scan_to_unroll(student_state.params) |
| student_params = jax.device_get(student_state_params) |
| student_model.save_pretrained( |
| os.path.join(training_args.output_dir, f"checkpoint-{cur_step}"), params=student_params |
| ) |
| if training_args.push_to_hub: |
| repo.push_to_hub( |
| commit_message=f"Saving weights of step {cur_step}", |
| blocking=False, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|