| import argparse |
| import os |
| import json |
| import logging |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import fairseq |
| from transformers import ( |
| Wav2Vec2Config, |
| Wav2Vec2CTCTokenizer, |
| Wav2Vec2FeatureExtractor, |
| Wav2Vec2ForCTC, |
| Wav2Vec2ForPreTraining, |
| Wav2Vec2Processor, |
| PreTrainedModel, |
| logging as transformers_logging, |
| ) |
| from transformers.models.wav2vec2.modeling_wav2vec2 import ( |
| Wav2Vec2ForSequenceClassification, |
| Wav2Vec2PreTrainedModel, |
| Wav2Vec2Model |
| ) |
|
|
| |
| transformers_logging.set_verbosity_info() |
| logger = transformers_logging.get_logger(__name__) |
|
|
| |
| class DinosrAudioConfig(Wav2Vec2Config): |
| def __init__( |
| self, |
| discrete=False, |
| codebook_size=256, |
| normal_init_codebook=False, |
| codebook_init_decay=0.9, |
| codebook_end_decay=0.9, |
| codebook_end_decay_step=0, |
| freeze_teacher_step=200001, |
| freeze_pre_enc_modules=True, |
| loss_beta=0, |
| loss_scale=None, |
| average_top_k_layers=8, |
| layer_norm_target_layer=False, |
| instance_norm_target_layer=False, |
| instance_norm_targets=False, |
| layer_norm_targets=False, |
| batch_norm_target_layer=False, |
| group_norm_target_layer=False, |
| ema_decay=0.999, |
| ema_end_decay=0.9999, |
| ema_anneal_end_step=None, |
| ema_transformer_only=True, |
| ema_layers_only=True, |
| max_update=None, |
| min_target_var=0.1, |
| min_pred_var=0.01, |
| pos_conv_depth=5, |
| **kwargs, |
| ): |
| super().__init__(**kwargs) |
| |
| |
| self.discrete = discrete |
| self.codebook_size = codebook_size |
| self.normal_init_codebook = normal_init_codebook |
| self.codebook_init_decay = codebook_init_decay |
| self.codebook_end_decay = codebook_end_decay |
| self.codebook_end_decay_step = codebook_end_decay_step |
| self.freeze_teacher_step = freeze_teacher_step |
| self.freeze_pre_enc_modules = freeze_pre_enc_modules |
| self.loss_beta = loss_beta |
| self.loss_scale = loss_scale |
| self.average_top_k_layers = average_top_k_layers |
| self.layer_norm_target_layer = layer_norm_target_layer |
| self.instance_norm_target_layer = instance_norm_target_layer |
| self.instance_norm_targets = instance_norm_targets |
| self.layer_norm_targets = layer_norm_targets |
| self.batch_norm_target_layer = batch_norm_target_layer |
| self.group_norm_target_layer = group_norm_target_layer |
| self.ema_decay = ema_decay |
| self.ema_end_decay = ema_end_decay |
| self.ema_anneal_end_step = ema_anneal_end_step |
| self.ema_transformer_only = ema_transformer_only |
| self.ema_layers_only = ema_layers_only |
| self.max_update = max_update |
| self.min_target_var = min_target_var |
| self.min_pred_var = min_pred_var |
| self.pos_conv_depth = pos_conv_depth |
|
|
| |
|
|
| class DinosrPositionalConvEmbedding(nn.Module): |
| """ |
| Extended positional convolutional embeddings with multiple layers |
| """ |
| def __init__(self, config): |
| super().__init__() |
| self.conv_layers = nn.ModuleList() |
| |
| |
| self.conv = nn.Conv1d( |
| config.hidden_size, |
| config.hidden_size, |
| kernel_size=config.conv_pos_kernel_size, |
| padding=config.conv_pos_kernel_size // 2, |
| groups=config.num_conv_pos_embedding_groups, |
| ) |
| |
| |
| num_additional_layers = config.pos_conv_depth - 1 |
| for i in range(num_additional_layers): |
| additional_conv = nn.Conv1d( |
| config.hidden_size, |
| config.hidden_size, |
| kernel_size=config.conv_pos_kernel_size // 2, |
| padding=(config.conv_pos_kernel_size // 2) // 2, |
| groups=config.num_conv_pos_embedding_groups, |
| ) |
| self.conv_layers.append(additional_conv) |
| |
| self.padding = Wav2Vec2SamePadLayer(config.conv_pos_kernel_size) |
| self.activation = nn.GELU() |
|
|
| def forward(self, hidden_states): |
| hidden_states = hidden_states.transpose(1, 2) |
| |
| |
| hidden_states = self.conv(hidden_states) |
| hidden_states = self.activation(hidden_states) |
| |
| |
| for conv_layer in self.conv_layers: |
| hidden_states = conv_layer(hidden_states) |
| hidden_states = self.activation(hidden_states) |
| |
| hidden_states = hidden_states.transpose(1, 2) |
| return hidden_states |
|
|
| |
| class DinosrCodebook(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.n_codebooks = config.average_top_k_layers |
| self.codebook_size = config.codebook_size |
| self.embed_dim = config.hidden_size |
| |
| |
| self.register_buffer("codebooks", torch.zeros(self.n_codebooks, self.codebook_size, self.embed_dim)) |
| self.register_buffer("codebook_cnts", torch.ones(self.n_codebooks, self.codebook_size)) |
| |
| |
| self.heads = nn.ModuleList([ |
| nn.Linear(self.embed_dim, self.codebook_size) for _ in range(self.n_codebooks) |
| ]) |
| |
| def forward(self, hidden_states, mask_indices=None): |
| if mask_indices is not None: |
| hidden_states = hidden_states[mask_indices] |
| |
| logits = [head(hidden_states) for head in self.heads] |
| return logits |
| |
| def compute_ppl(self, y, input_onehot=False, tokenwise=False): |
| """Compute perplexity for codebook stats""" |
| |
| if not input_onehot: |
| y = y.softmax(dim=-1) |
| if tokenwise: |
| y = 2**(- y * (y+1e-8).log2()).sum(-1) |
| y = y.mean(0) |
| y = 2**(- y * (y+1e-8).log2()).sum() |
| return y |
|
|
| |
| class Wav2Vec2SamePadLayer(nn.Module): |
| def __init__(self, kernel_size): |
| super().__init__() |
| self.kernel_size = kernel_size |
|
|
| def forward(self, hidden_states): |
| if self.kernel_size % 2 == 1: |
| padding = (self.kernel_size - 1) // 2 |
| hidden_states = nn.functional.pad(hidden_states, (padding, padding), mode="reflect") |
| else: |
| padding_right = self.kernel_size // 2 |
| padding_left = self.kernel_size // 2 - 1 |
| hidden_states = nn.functional.pad( |
| hidden_states, (padding_left, padding_right), mode="reflect" |
| ) |
| return hidden_states |
|
|
| |
| class DinosrForPreTraining(Wav2Vec2PreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| self.wav2vec2 = Wav2Vec2Model(config) |
| |
| |
| if hasattr(config, "pos_conv_depth") and config.pos_conv_depth > 1: |
| self.wav2vec2.encoder.pos_conv_embed = DinosrPositionalConvEmbedding(config) |
| |
| |
| if config.discrete: |
| self.codebook = DinosrCodebook(config) |
| else: |
| self.project_hid = nn.Linear(config.hidden_size, config.hidden_size) |
| self.project_q = nn.Linear(config.hidden_size, config.hidden_size) |
| |
| |
| self.init_weights() |
| |
| def forward( |
| self, |
| input_values, |
| attention_mask=None, |
| mask_time_indices=None, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=None, |
| ): |
| outputs = self.wav2vec2( |
| input_values, |
| attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| output_hidden_states=True, |
| return_dict=True, |
| ) |
| |
| |
| hidden_states = outputs.last_hidden_state |
| |
| |
| if hasattr(self, "codebook"): |
| logits = self.codebook(hidden_states, mask_time_indices) |
| return {"logits": logits, "hidden_states": outputs.hidden_states} |
| else: |
| |
| projected_states = self.project_hid(hidden_states) |
| projected_quantized_states = self.project_q(hidden_states) |
| |
| |
| if mask_time_indices is not None: |
| projected_states = projected_states[mask_time_indices] |
| projected_quantized_states = projected_quantized_states[mask_time_indices] |
| |
| return { |
| "projected_states": projected_states, |
| "projected_quantized_states": projected_quantized_states, |
| "hidden_states": outputs.hidden_states |
| } |
|
|
| |
| MAPPING = { |
| "post_extract_proj": "feature_projection.projection", |
| "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", |
| "encoder.pos_conv.1": "encoder.pos_conv_embed.conv_layers.0", |
| "encoder.pos_conv.2": "encoder.pos_conv_embed.conv_layers.1", |
| "encoder.pos_conv.3": "encoder.pos_conv_embed.conv_layers.2", |
| "encoder.pos_conv.4": "encoder.pos_conv_embed.conv_layers.3", |
| "self_attn.k_proj": "encoder.layers.*.attention.k_proj", |
| "self_attn.v_proj": "encoder.layers.*.attention.v_proj", |
| "self_attn.q_proj": "encoder.layers.*.attention.q_proj", |
| "self_attn.out_proj": "encoder.layers.*.attention.out_proj", |
| "self_attn_layer_norm": "encoder.layers.*.layer_norm", |
| "fc1": "encoder.layers.*.feed_forward.intermediate_dense", |
| "fc2": "encoder.layers.*.feed_forward.output_dense", |
| "final_layer_norm": "encoder.layers.*.final_layer_norm", |
| "encoder.layer_norm": "encoder.layer_norm", |
| "adapter_layer": "encoder.layers.*.adapter_layer", |
| "w2v_model.layer_norm": "feature_projection.layer_norm", |
| "quantizer.weight_proj": "quantizer.weight_proj", |
| "quantizer.vars": "quantizer.codevectors", |
| "project_q": "project_q", |
| "final_proj": "project_hid", |
| "w2v_encoder.proj": "lm_head", |
| "mask_emb": "masked_spec_embed", |
| "pooling_layer.linear": "projector", |
| "pooling_layer.projection": "classifier", |
| "heads": "codebook.heads", |
| "_codebook": "codebook.codebooks", |
| "_codebook_cnts": "codebook.codebook_cnts", |
| } |
|
|
| TOP_LEVEL_KEYS = [ |
| "lm_head", |
| "quantizer.weight_proj", |
| "quantizer.codevectors", |
| "project_q", |
| "project_hid", |
| "projector", |
| "classifier", |
| "codebook.heads", |
| "codebook.codebooks", |
| "codebook.codebook_cnts", |
| ] |
|
|
| PARAM_MAPPING = { |
| "W_a": "linear_1.weight", |
| "W_b": "linear_2.weight", |
| "b_a": "linear_1.bias", |
| "b_b": "linear_2.bias", |
| "ln_W": "norm.weight", |
| "ln_b": "norm.bias", |
| } |
|
|
| def read_txt_into_dict(filename): |
| result = {} |
| with open(filename, "r") as file: |
| for line_number, line in enumerate(file): |
| line = line.strip() |
| if line: |
| words = line.split() |
| key = line_number |
| value = words[0] |
| result[key] = value |
| return result |
|
|
| def set_recursively(key, value, full_name, weight_type, hf_pointer): |
| |
| if "encoder.pos_conv_embed.conv" in key and weight_type == "weight": |
| |
| logger.info(f"Handling special case for {key} (pos_conv_embed)") |
| |
| for attribute in key.split("."): |
| hf_pointer = getattr(hf_pointer, attribute) |
| |
| |
| hf_shape = hf_pointer.weight.shape |
| value_shape = value.shape |
| |
| if hf_shape != value_shape: |
| logger.info(f"Reshaping position embedding from {value_shape} to {hf_shape}") |
| |
| |
| new_weight = torch.zeros(hf_shape, device=value.device, dtype=value.dtype) |
| |
| |
| |
| min_kernel_size = min(hf_shape[2], value_shape[2]) |
| |
| |
| new_weight[:, :, :min_kernel_size] = value[:, :, :min_kernel_size] |
| |
| hf_pointer.weight.data = new_weight |
| logger.info(f"Position embedding resized and initialized") |
| return |
| |
| |
| import re |
| if "_codebook_cnts" in full_name and "codebook.codebooks" in key: |
| logger.info(f"Handling special case for {full_name}, {key} (DinoSR codebook counts)") |
| |
| print(full_name) |
| codebook_part = full_name.split("_codebook")[1] |
| print(codebook_part) |
| if codebook_part.startswith('_cnts'): |
| |
| codebook_idx = int(codebook_part.split('_cnts')[1]) |
| else: |
| |
| codebook_idx = int(re.search(r'\d+', codebook_part).group()) |
| |
| |
| |
| codebook_cnts = None |
| for attribute in key.split("."): |
| if hasattr(hf_pointer, attribute): |
| hf_pointer = getattr(hf_pointer, attribute) |
| if attribute == "codebook_cnts": |
| codebook_cnts = hf_pointer |
| |
| if codebook_cnts is not None: |
| |
| codebook_cnts[codebook_idx] = value |
| logger.info(f"Codebook counts {codebook_idx} initialized from {full_name}") |
| return |
| if "_codebook" in full_name and "codebook.codebooks" in key: |
|
|
| logger.info(f"Handling special case for {full_name} , {key} (DinoSR codebook)") |
| |
| |
| codebook_idx = int(full_name.split("_codebook")[1][0]) |
| |
| |
| |
| codebooks = None |
| for attribute in key.split("."): |
| if hasattr(hf_pointer, attribute): |
| hf_pointer = getattr(hf_pointer, attribute) |
| if attribute == "codebooks": |
| codebooks = hf_pointer |
| |
| if codebooks is not None: |
| |
| codebooks[codebook_idx] = value |
| logger.info(f"Codebook {codebook_idx} initialized from {full_name}") |
| return |
| |
|
|
| |
| |
| if "heads." in full_name and "codebook.heads" in key: |
| logger.info(f"Handling special case for {full_name} (DinoSR prediction head)") |
| |
| |
| parts = full_name.split(".") |
| head_idx = int(parts[-2]) |
| param_type = parts[-1] |
| |
| |
| heads = None |
| for attribute in key.split("."): |
| if hasattr(hf_pointer, attribute): |
| hf_pointer = getattr(hf_pointer, attribute) |
| if attribute == "heads": |
| heads = hf_pointer |
| |
| if heads is not None and head_idx < len(heads): |
| |
| if param_type == "weight": |
| heads[head_idx].weight.data = value |
| elif param_type == "bias": |
| heads[head_idx].bias.data = value |
| logger.info(f"Head {head_idx} {param_type} initialized from {full_name}") |
| return |
| |
| |
| if "encoder.pos_conv." in full_name and any(f".{i}." in full_name for i in range(1, 5)) and "encoder.pos_conv_embed.conv_layers" in key: |
| logger.info(f"Handling special case for {full_name} (additional pos conv layer)") |
| |
| |
| layer_idx = int(full_name.split("encoder.pos_conv.")[1].split(".")[0]) - 1 |
| param_type = "weight" if "weight" in full_name else "bias" |
| |
| |
| conv_layers = None |
| for attribute in key.split("."): |
| if hasattr(hf_pointer, attribute): |
| hf_pointer = getattr(hf_pointer, attribute) |
| if attribute == "conv_layers": |
| conv_layers = hf_pointer |
| |
| if conv_layers is not None and layer_idx < len(conv_layers): |
| |
| if param_type == "weight": |
| conv_layers[layer_idx].weight.data = value |
| elif param_type == "bias": |
| conv_layers[layer_idx].bias.data = value |
| logger.info(f"Pos conv layer {layer_idx} {param_type} initialized from {full_name}") |
| return |
| |
| |
| for attribute in key.split("."): |
| if hasattr(hf_pointer, attribute): |
| hf_pointer = getattr(hf_pointer, attribute) |
| else: |
| logger.warning(f"Attribute {attribute} not found in model. Skipping parameter: {full_name}") |
| return |
|
|
| hf_param_name = None |
| for param_key in PARAM_MAPPING.keys(): |
| if full_name.endswith(param_key): |
| hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]] |
| weight_type = "param" |
|
|
| |
| if weight_type is not None and weight_type != "param": |
| if weight_type == "weight_g" and not hasattr(hf_pointer, "weight_g"): |
| if hasattr(hf_pointer, "parametrizations") and hasattr(hf_pointer.parametrizations, "weight") and hasattr(hf_pointer.parametrizations.weight, "original0"): |
| hf_shape = hf_pointer.parametrizations.weight.original0.shape |
| else: |
| logger.warning(f"Could not find weight_g parametrization for {full_name}. Skipping.") |
| return |
| elif weight_type == "weight_v" and not hasattr(hf_pointer, "weight_v"): |
| if hasattr(hf_pointer, "parametrizations") and hasattr(hf_pointer.parametrizations, "weight") and hasattr(hf_pointer.parametrizations.weight, "original1"): |
| hf_shape = hf_pointer.parametrizations.weight.original1.shape |
| else: |
| logger.warning(f"Could not find weight_v parametrization for {full_name}. Skipping.") |
| return |
| else: |
| if hasattr(hf_pointer, weight_type): |
| hf_shape = getattr(hf_pointer, weight_type).shape |
| else: |
| logger.warning(f"Attribute {weight_type} not found in model. Skipping parameter: {full_name}") |
| return |
| elif weight_type is not None and weight_type == "param": |
| shape_pointer = hf_pointer |
| for attribute in hf_param_name.split("."): |
| if hasattr(shape_pointer, attribute): |
| shape_pointer = getattr(shape_pointer, attribute) |
| else: |
| logger.warning(f"Attribute {attribute} not found in parameter. Skipping: {full_name}") |
| return |
| hf_shape = shape_pointer.shape |
|
|
| |
| value = value[0] |
| else: |
| if hasattr(hf_pointer, "shape"): |
| hf_shape = hf_pointer.shape |
| else: |
| logger.warning(f"Shape attribute not found. Skipping parameter: {full_name}") |
| return |
|
|
| if hf_shape != value.shape: |
| logger.warning(f"Shape mismatch for {full_name}: HF shape {hf_shape}, Fairseq shape {value.shape}") |
| if "encoder.pos_conv" in full_name: |
| logger.info(f"Skipping pos_conv layer due to shape mismatch") |
| return |
| else: |
| logger.warning(f"Shape mismatch for {full_name}. Attempting to adapt...") |
| |
| try: |
| if len(hf_shape) == len(value.shape): |
| |
| new_value = torch.zeros(hf_shape, device=value.device, dtype=value.dtype) |
| |
| |
| min_dims = [min(hf_shape[i], value.shape[i]) for i in range(len(hf_shape))] |
| |
| |
| slices = tuple(slice(0, d) for d in min_dims) |
| |
| |
| if len(slices) == 1: |
| new_value[slices[0]] = value[slices[0]] |
| elif len(slices) == 2: |
| new_value[slices[0], slices[1]] = value[slices[0], slices[1]] |
| elif len(slices) == 3: |
| new_value[slices[0], slices[1], slices[2]] = value[slices[0], slices[1], slices[2]] |
| else: |
| raise ValueError("Unsupported number of dimensions") |
| |
| value = new_value |
| logger.info(f"Successfully adapted shape for {full_name}") |
| else: |
| logger.warning(f"Cannot adapt shapes with different dimensionality. Skipping {full_name}") |
| return |
| except Exception as e: |
| logger.warning(f"Failed to adapt shape: {e}. Skipping {full_name}") |
| return |
|
|
| if weight_type == "weight": |
| hf_pointer.weight.data = value |
| elif weight_type == "weight_g": |
| if hasattr(hf_pointer, "weight_g"): |
| hf_pointer.weight_g.data = value |
| else: |
| hf_pointer.parametrizations.weight.original0.data = value |
| elif weight_type == "weight_v": |
| if hasattr(hf_pointer, "weight_v"): |
| hf_pointer.weight_v.data = value |
| else: |
| hf_pointer.parametrizations.weight.original1.data = value |
| elif weight_type == "bias": |
| hf_pointer.bias.data = value |
| elif weight_type == "param": |
| for attribute in hf_param_name.split("."): |
| hf_pointer = getattr(hf_pointer, attribute) |
| hf_pointer.data = value |
| else: |
| hf_pointer.data = value |
|
|
| logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") |
|
|
| def rename_dict(key, value, full_name, weight_type, hf_dict): |
| hf_param_name = None |
| for param_key in PARAM_MAPPING.keys(): |
| if full_name.endswith(param_key): |
| hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]] |
| weight_type = "param" |
|
|
| if weight_type is not None and weight_type != "param": |
| full_key = ".".join([key, weight_type]) |
| elif weight_type is not None and weight_type == "param": |
| full_key = ".".join([key, hf_param_name]) |
| else: |
| full_key = key |
|
|
| hf_dict[full_key] = value if "lm_head" in full_key else value[0] |
|
|
| def load_wav2vec2_layer(name, value, hf_model=None, hf_dict=None): |
| is_used = False |
| |
| |
| if name.startswith("_codebook") and hf_model is not None and hasattr(hf_model, "codebook"): |
| codebook_idx = int(name.split("_codebook")[1][0]) |
| |
| set_recursively("codebook.codebooks", value, name, None, hf_model) |
| return True |
| |
| |
| if name.startswith("_codebook_cnts") and hf_model is not None and hasattr(hf_model, "codebook"): |
| codebook_idx = int(name.split("_codebook_cnts")[1][0]) |
| |
| set_recursively("codebook.codebook_cnts", value, name, None, hf_model) |
| return True |
| |
| |
| if name.startswith("heads.") and hf_model is not None and hasattr(hf_model, "codebook"): |
| head_parts = name.split(".") |
| head_idx = int(head_parts[1]) |
| param_type = head_parts[2] |
| |
| set_recursively("codebook.heads", value, name, param_type, hf_model) |
| return True |
| |
| |
| if name.startswith("encoder.pos_conv.") and any(f".{i}." in name for i in range(1, 5)): |
| layer_idx = int(name.split("encoder.pos_conv.")[1].split(".")[0]) |
| if layer_idx > 0 and hf_model is not None and hasattr(hf_model.wav2vec2.encoder, "pos_conv_embed"): |
| |
| param_type = "weight" if "weight" in name else "bias" |
| set_recursively(f"wav2vec2.encoder.pos_conv_embed.conv_layers", value, name, None, hf_model) |
| return True |
| |
| |
| for key, mapped_key in MAPPING.items(): |
| mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key |
| if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: |
| is_used = True |
| if "*" in mapped_key: |
| layer_index = name.split(key)[0].split(".")[-2] |
| mapped_key = mapped_key.replace("*", layer_index) |
| if "weight_g" in name: |
| weight_type = "weight_g" |
| elif "weight_v" in name: |
| weight_type = "weight_v" |
| elif "bias" in name: |
| weight_type = "bias" |
| elif "weight" in name: |
| weight_type = "weight" |
| else: |
| weight_type = None |
| if hf_dict is not None: |
| rename_dict(mapped_key, value, name, weight_type, hf_dict) |
| else: |
| set_recursively(mapped_key, value, name, weight_type, hf_model) |
| return is_used |
| return is_used |
|
|
| def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): |
| name = full_name.split("conv_layers.")[-1] |
| items = name.split(".") |
| layer_id = int(items[0]) |
| type_id = int(items[1]) |
|
|
| if type_id == 0: |
| if "bias" in name: |
| if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape: |
| raise ValueError( |
| f"{full_name} has size {value.shape}, but" |
| f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." |
| ) |
| feature_extractor.conv_layers[layer_id].conv.bias.data = value |
| logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") |
| elif "weight" in name: |
| if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape: |
| raise ValueError( |
| f"{full_name} has size {value.shape}, but" |
| f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." |
| ) |
| feature_extractor.conv_layers[layer_id].conv.weight.data = value |
| logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") |
| elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): |
| if "bias" in name: |
| if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape: |
| raise ValueError( |
| f"{full_name} has size {value.shape}, but" |
| f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found." |
| ) |
| feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value |
| logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") |
| elif "weight" in name: |
| if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape: |
| raise ValueError( |
| f"{full_name} has size {value.shape}, but" |
| f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found." |
| ) |
| feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value |
| logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") |
| else: |
| unused_weights.append(full_name) |
|
|
| def recursively_load_weights(fairseq_model, hf_model, is_headless): |
| unused_weights = [] |
| fairseq_dict = fairseq_model.state_dict() |
|
|
| feature_extractor = hf_model.wav2vec2.feature_extractor |
|
|
| for name, value in fairseq_dict.items(): |
| is_used = False |
| if "conv_layers" in name: |
| load_conv_layer( |
| name, |
| value, |
| feature_extractor, |
| unused_weights, |
| hf_model.config.feat_extract_norm == "group", |
| ) |
| is_used = True |
| else: |
| is_used = load_wav2vec2_layer(name, value, hf_model) |
| if not is_used: |
| unused_weights.append(name) |
|
|
| logger.warning(f"Unused weights: {unused_weights}") |
|
|
| @torch.no_grad() |
| def convert_wav2vec2_checkpoint(code_path="dinosr", checkpoint_path="dinosr.ckpt", |
| config_path=None, dict_path=None, |
| is_finetuned=False, is_seq_class=False, |
| pytorch_dump_folder_path="converted_model"): |
| """ |
| Convert DinoSR model to HuggingFace Wav2Vec2 format |
| """ |
| |
| fairseq.utils.import_user_module(argparse.Namespace(user_dir=code_path)) |
| |
| |
| model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path]) |
|
|
| |
| if config_path is not None: |
| config = Wav2Vec2Config.from_pretrained(config_path) |
| else: |
| config = DinosrAudioConfig() |
| |
| |
| config.mask_prob = 0.8 |
| config.mask_length = 10 |
| config.conv_pos = 95 |
| config.encoder_embed_dim = 768 |
| config.hidden_size = 768 |
| config.encoder_layers = 12 |
| config.num_hidden_layers = 12 |
| config.feat_extract_norm = "layer" |
| |
| |
| |
| config.conv_pos_kernel_size = 95 |
| config.num_conv_pos_embeddings = 128 |
| config.hidden_dropout = 0.0 |
| config.activation_dropout = 0.0 |
| config.feat_proj_dropout = 0.0 |
| config.layerdrop = 0.05 |
| |
| |
| config.num_codevectors_per_group = 256 |
| config.num_codevector_groups = 8 |
|
|
| if is_seq_class: |
| id2label = read_txt_into_dict(dict_path) |
| config.id2label = id2label |
| hf_wav2vec = Wav2Vec2ForSequenceClassification(config) |
| feature_extractor = Wav2Vec2FeatureExtractor( |
| feature_size=1, |
| sampling_rate=16000, |
| padding_value=0, |
| do_normalize=True, |
| return_attention_mask=True, |
| ) |
| feature_extractor.save_pretrained(pytorch_dump_folder_path) |
|
|
| elif is_finetuned: |
| if dict_path: |
| target_dict = Dictionary.load(dict_path) |
|
|
| |
| |
| config.bos_token_id = target_dict.pad_index |
| config.pad_token_id = target_dict.bos_index |
| config.eos_token_id = target_dict.eos_index |
| config.vocab_size = len(target_dict.symbols) |
| vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json") |
| if not os.path.isdir(pytorch_dump_folder_path): |
| logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path)) |
| return |
| os.makedirs(pytorch_dump_folder_path, exist_ok=True) |
| vocab_dict = target_dict.indices |
|
|
| |
| vocab_dict["<pad>"] = 0 |
| vocab_dict["<s>"] = 1 |
| with open(vocab_path, "w", encoding="utf-8") as vocab_handle: |
| json.dump(vocab_dict, vocab_handle) |
| tokenizer = Wav2Vec2CTCTokenizer( |
| vocab_path, |
| unk_token=target_dict.unk_word, |
| pad_token=target_dict.pad_word, |
| bos_token=target_dict.bos_word, |
| eos_token=target_dict.eos_word, |
| word_delimiter_token="|", |
| do_lower_case=False, |
| ) |
| return_attention_mask = True if config.feat_extract_norm == "layer" else False |
| feature_extractor = Wav2Vec2FeatureExtractor( |
| feature_size=1, |
| sampling_rate=16000, |
| padding_value=0, |
| do_normalize=True, |
| return_attention_mask=return_attention_mask, |
| ) |
| processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer) |
| processor.save_pretrained(pytorch_dump_folder_path) |
|
|
| hf_wav2vec = Wav2Vec2ForCTC(config) |
| else: |
| hf_wav2vec = Wav2Vec2ForPreTraining(config) |
|
|
| |
| os.makedirs(pytorch_dump_folder_path, exist_ok=True) |
|
|
| |
| model = model[0].eval() |
|
|
| |
| recursively_load_weights(model, hf_wav2vec, not is_finetuned) |
|
|
| |
| hf_wav2vec.save_pretrained(pytorch_dump_folder_path) |
| |
| |
| if not is_finetuned and not is_seq_class: |
| feature_extractor = Wav2Vec2FeatureExtractor( |
| feature_size=1, |
| sampling_rate=16000, |
| padding_value=0, |
| do_normalize=True, |
| return_attention_mask=True if config.feat_extract_norm == "layer" else False, |
| ) |
| feature_extractor.save_pretrained(pytorch_dump_folder_path) |
| |
| logger.info(f"Model saved to {pytorch_dump_folder_path}") |
| return hf_wav2vec |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--code_path", type=str, default="dinosr", help="Path to the DinoSR code") |
| parser.add_argument("--checkpoint_path", type=str, default="dinosr.ckpt", help="Path to the checkpoint") |
| parser.add_argument("--config_path", type=str, default=None, help="Path to the config file") |
| parser.add_argument("--dict_path", type=str, default=None, help="Path to the dictionary") |
| parser.add_argument("--is_finetuned", action="store_true", help="Whether the model is finetuned") |
| parser.add_argument("--is_seq_class", action="store_true", help="Whether the model is for sequence classification") |
| parser.add_argument("--pytorch_dump_folder_path", type=str, default="converted_model", help="Path to save the converted model") |
| args = parser.parse_args() |
| |
| convert_wav2vec2_checkpoint( |
| code_path=args.code_path, |
| checkpoint_path=args.checkpoint_path, |
| config_path=args.config_path, |
| dict_path=args.dict_path, |
| is_finetuned=args.is_finetuned, |
| is_seq_class=args.is_seq_class, |
| pytorch_dump_folder_path=args.pytorch_dump_folder_path, |
| ) |