DinoSR / convert.py
MohammadJRanjbar's picture
Rename convert_dinosr.py to convert.py
c2803e1 verified
import argparse
import os
import json
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import fairseq
from transformers import (
Wav2Vec2Config,
Wav2Vec2CTCTokenizer,
Wav2Vec2FeatureExtractor,
Wav2Vec2ForCTC,
Wav2Vec2ForPreTraining,
Wav2Vec2Processor,
PreTrainedModel,
logging as transformers_logging,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2ForSequenceClassification,
Wav2Vec2PreTrainedModel,
Wav2Vec2Model
)
# Setup logging
transformers_logging.set_verbosity_info()
logger = transformers_logging.get_logger(__name__)
# Define a custom DinosrAudioConfig based on Wav2Vec2Config
class DinosrAudioConfig(Wav2Vec2Config):
def __init__(
self,
discrete=False,
codebook_size=256,
normal_init_codebook=False,
codebook_init_decay=0.9,
codebook_end_decay=0.9,
codebook_end_decay_step=0,
freeze_teacher_step=200001,
freeze_pre_enc_modules=True,
loss_beta=0,
loss_scale=None,
average_top_k_layers=8,
layer_norm_target_layer=False,
instance_norm_target_layer=False,
instance_norm_targets=False,
layer_norm_targets=False,
batch_norm_target_layer=False,
group_norm_target_layer=False,
ema_decay=0.999,
ema_end_decay=0.9999,
ema_anneal_end_step=None,
ema_transformer_only=True,
ema_layers_only=True,
max_update=None,
min_target_var=0.1,
min_pred_var=0.01,
pos_conv_depth=5, # Number of positional convolutional layers
**kwargs,
):
super().__init__(**kwargs)
# Add DinoSR custom parameters
self.discrete = discrete
self.codebook_size = codebook_size
self.normal_init_codebook = normal_init_codebook
self.codebook_init_decay = codebook_init_decay
self.codebook_end_decay = codebook_end_decay
self.codebook_end_decay_step = codebook_end_decay_step
self.freeze_teacher_step = freeze_teacher_step
self.freeze_pre_enc_modules = freeze_pre_enc_modules
self.loss_beta = loss_beta
self.loss_scale = loss_scale
self.average_top_k_layers = average_top_k_layers
self.layer_norm_target_layer = layer_norm_target_layer
self.instance_norm_target_layer = instance_norm_target_layer
self.instance_norm_targets = instance_norm_targets
self.layer_norm_targets = layer_norm_targets
self.batch_norm_target_layer = batch_norm_target_layer
self.group_norm_target_layer = group_norm_target_layer
self.ema_decay = ema_decay
self.ema_end_decay = ema_end_decay
self.ema_anneal_end_step = ema_anneal_end_step
self.ema_transformer_only = ema_transformer_only
self.ema_layers_only = ema_layers_only
self.max_update = max_update
self.min_target_var = min_target_var
self.min_pred_var = min_pred_var
self.pos_conv_depth = pos_conv_depth
# Define custom modules for DinoSR
class DinosrPositionalConvEmbedding(nn.Module):
"""
Extended positional convolutional embeddings with multiple layers
"""
def __init__(self, config):
super().__init__()
self.conv_layers = nn.ModuleList()
# First conv layer (equivalent to standard Wav2Vec2)
self.conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=config.conv_pos_kernel_size,
padding=config.conv_pos_kernel_size // 2,
groups=config.num_conv_pos_embedding_groups,
)
# Additional conv layers for DinoSR
num_additional_layers = config.pos_conv_depth - 1
for i in range(num_additional_layers):
additional_conv = nn.Conv1d(
config.hidden_size,
config.hidden_size,
kernel_size=config.conv_pos_kernel_size // 2, # Smaller kernel for deeper layers
padding=(config.conv_pos_kernel_size // 2) // 2,
groups=config.num_conv_pos_embedding_groups,
)
self.conv_layers.append(additional_conv)
self.padding = Wav2Vec2SamePadLayer(config.conv_pos_kernel_size)
self.activation = nn.GELU()
def forward(self, hidden_states):
hidden_states = hidden_states.transpose(1, 2)
# First conv (standard)
hidden_states = self.conv(hidden_states)
hidden_states = self.activation(hidden_states)
# Additional conv layers
for conv_layer in self.conv_layers:
hidden_states = conv_layer(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = hidden_states.transpose(1, 2)
return hidden_states
# DinoSR Codebook
class DinosrCodebook(nn.Module):
def __init__(self, config):
super().__init__()
self.n_codebooks = config.average_top_k_layers
self.codebook_size = config.codebook_size
self.embed_dim = config.hidden_size
# Initialize codebooks and counters
self.register_buffer("codebooks", torch.zeros(self.n_codebooks, self.codebook_size, self.embed_dim))
self.register_buffer("codebook_cnts", torch.ones(self.n_codebooks, self.codebook_size))
# Prediction heads
self.heads = nn.ModuleList([
nn.Linear(self.embed_dim, self.codebook_size) for _ in range(self.n_codebooks)
])
def forward(self, hidden_states, mask_indices=None):
if mask_indices is not None:
hidden_states = hidden_states[mask_indices]
logits = [head(hidden_states) for head in self.heads]
return logits
def compute_ppl(self, y, input_onehot=False, tokenwise=False):
"""Compute perplexity for codebook stats"""
# We track the avg. of 1-hot (argmax)
if not input_onehot:
y = y.softmax(dim=-1)
if tokenwise:
y = 2**(- y * (y+1e-8).log2()).sum(-1)
y = y.mean(0)
y = 2**(- y * (y+1e-8).log2()).sum()
return y
# Same pad layer (from Wav2Vec2)
class Wav2Vec2SamePadLayer(nn.Module):
def __init__(self, kernel_size):
super().__init__()
self.kernel_size = kernel_size
def forward(self, hidden_states):
if self.kernel_size % 2 == 1:
padding = (self.kernel_size - 1) // 2
hidden_states = nn.functional.pad(hidden_states, (padding, padding), mode="reflect")
else:
padding_right = self.kernel_size // 2
padding_left = self.kernel_size // 2 - 1
hidden_states = nn.functional.pad(
hidden_states, (padding_left, padding_right), mode="reflect"
)
return hidden_states
# Define custom DinoSR model for Hugging Face
class DinosrForPreTraining(Wav2Vec2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
# Replace positional embedding with DinoSR version
if hasattr(config, "pos_conv_depth") and config.pos_conv_depth > 1:
self.wav2vec2.encoder.pos_conv_embed = DinosrPositionalConvEmbedding(config)
# DinoSR specific components
if config.discrete:
self.codebook = DinosrCodebook(config)
else:
self.project_hid = nn.Linear(config.hidden_size, config.hidden_size)
self.project_q = nn.Linear(config.hidden_size, config.hidden_size)
# Initialize weights
self.init_weights()
def forward(
self,
input_values,
attention_mask=None,
mask_time_indices=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=True, # We need hidden states for DinoSR
return_dict=True,
)
# Get transformer outputs
hidden_states = outputs.last_hidden_state
# If discrete model (with codebook)
if hasattr(self, "codebook"):
logits = self.codebook(hidden_states, mask_time_indices)
return {"logits": logits, "hidden_states": outputs.hidden_states}
else:
# Regular contrastive learning path
projected_states = self.project_hid(hidden_states)
projected_quantized_states = self.project_q(hidden_states)
# Only return masked indices if specified
if mask_time_indices is not None:
projected_states = projected_states[mask_time_indices]
projected_quantized_states = projected_quantized_states[mask_time_indices]
return {
"projected_states": projected_states,
"projected_quantized_states": projected_quantized_states,
"hidden_states": outputs.hidden_states
}
# Mapping between fairseq and transformers model parameters
MAPPING = {
"post_extract_proj": "feature_projection.projection",
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
"encoder.pos_conv.1": "encoder.pos_conv_embed.conv_layers.0",
"encoder.pos_conv.2": "encoder.pos_conv_embed.conv_layers.1",
"encoder.pos_conv.3": "encoder.pos_conv_embed.conv_layers.2",
"encoder.pos_conv.4": "encoder.pos_conv_embed.conv_layers.3",
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
"fc2": "encoder.layers.*.feed_forward.output_dense",
"final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "encoder.layer_norm",
"adapter_layer": "encoder.layers.*.adapter_layer",
"w2v_model.layer_norm": "feature_projection.layer_norm",
"quantizer.weight_proj": "quantizer.weight_proj",
"quantizer.vars": "quantizer.codevectors",
"project_q": "project_q",
"final_proj": "project_hid",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
"pooling_layer.linear": "projector",
"pooling_layer.projection": "classifier",
"heads": "codebook.heads", # Added for DinoSR
"_codebook": "codebook.codebooks", # Added for DinoSR
"_codebook_cnts": "codebook.codebook_cnts", # Added for DinoSR
}
TOP_LEVEL_KEYS = [
"lm_head",
"quantizer.weight_proj",
"quantizer.codevectors",
"project_q",
"project_hid",
"projector",
"classifier",
"codebook.heads", # Added for DinoSR
"codebook.codebooks", # Added for DinoSR
"codebook.codebook_cnts", # Added for DinoSR
]
PARAM_MAPPING = {
"W_a": "linear_1.weight",
"W_b": "linear_2.weight",
"b_a": "linear_1.bias",
"b_b": "linear_2.bias",
"ln_W": "norm.weight",
"ln_b": "norm.bias",
}
def read_txt_into_dict(filename):
result = {}
with open(filename, "r") as file:
for line_number, line in enumerate(file):
line = line.strip()
if line:
words = line.split()
key = line_number
value = words[0]
result[key] = value
return result
def set_recursively(key, value, full_name, weight_type, hf_pointer):
# Special handling for pos_conv_embed layer because of shape differences
if "encoder.pos_conv_embed.conv" in key and weight_type == "weight":
# Handle the position embedding conv layer differently
logger.info(f"Handling special case for {key} (pos_conv_embed)")
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)
# Get the Hugging Face model shape
hf_shape = hf_pointer.weight.shape
value_shape = value.shape
if hf_shape != value_shape:
logger.info(f"Reshaping position embedding from {value_shape} to {hf_shape}")
# Create a new tensor with the right shape (initialized to zeros)
new_weight = torch.zeros(hf_shape, device=value.device, dtype=value.dtype)
# If we're going from smaller to larger, just copy what we have
# and zero-pad the rest
min_kernel_size = min(hf_shape[2], value_shape[2])
# Copy the values we have
new_weight[:, :, :min_kernel_size] = value[:, :, :min_kernel_size]
hf_pointer.weight.data = new_weight
logger.info(f"Position embedding resized and initialized")
return
# Special handling for DinoSR codebook
import re
if "_codebook_cnts" in full_name and "codebook.codebooks" in key:
logger.info(f"Handling special case for {full_name}, {key} (DinoSR codebook counts)")
# Extract the codebook index
print(full_name)
codebook_part = full_name.split("_codebook")[1]
print(codebook_part)
if codebook_part.startswith('_cnts'):
# Case: "_codebook_cnts0" → extract 0
codebook_idx = int(codebook_part.split('_cnts')[1])
else:
# Case: "_codebook0" → extract 0
codebook_idx = int(re.search(r'\d+', codebook_part).group())
#codebook_idx = int(full_name.split("_codebook_cnts")[1])
# Navigate to the codebook_cnts buffer
codebook_cnts = None
for attribute in key.split("."):
if hasattr(hf_pointer, attribute):
hf_pointer = getattr(hf_pointer, attribute)
if attribute == "codebook_cnts":
codebook_cnts = hf_pointer
if codebook_cnts is not None:
# Update the appropriate codebook count
codebook_cnts[codebook_idx] = value
logger.info(f"Codebook counts {codebook_idx} initialized from {full_name}")
return
if "_codebook" in full_name and "codebook.codebooks" in key:
logger.info(f"Handling special case for {full_name} , {key} (DinoSR codebook)")
# Extract the codebook index
codebook_idx = int(full_name.split("_codebook")[1][0])
# Navigate to the codebooks buffer
codebooks = None
for attribute in key.split("."):
if hasattr(hf_pointer, attribute):
hf_pointer = getattr(hf_pointer, attribute)
if attribute == "codebooks":
codebooks = hf_pointer
if codebooks is not None:
# Update the appropriate codebook
codebooks[codebook_idx] = value
logger.info(f"Codebook {codebook_idx} initialized from {full_name}")
return
# Special handling for DinoSR codebook counts
# Special handling for DinoSR heads
if "heads." in full_name and "codebook.heads" in key:
logger.info(f"Handling special case for {full_name} (DinoSR prediction head)")
# Extract the head index and parameter type
parts = full_name.split(".")
head_idx = int(parts[-2])
param_type = parts[-1] # weight or bias
# Navigate to the heads list
heads = None
for attribute in key.split("."):
if hasattr(hf_pointer, attribute):
hf_pointer = getattr(hf_pointer, attribute)
if attribute == "heads":
heads = hf_pointer
if heads is not None and head_idx < len(heads):
# Update the appropriate head parameter
if param_type == "weight":
heads[head_idx].weight.data = value
elif param_type == "bias":
heads[head_idx].bias.data = value
logger.info(f"Head {head_idx} {param_type} initialized from {full_name}")
return
# Special handling for additional positional conv layers
if "encoder.pos_conv." in full_name and any(f".{i}." in full_name for i in range(1, 5)) and "encoder.pos_conv_embed.conv_layers" in key:
logger.info(f"Handling special case for {full_name} (additional pos conv layer)")
# Extract the layer index and parameter type
layer_idx = int(full_name.split("encoder.pos_conv.")[1].split(".")[0]) - 1
param_type = "weight" if "weight" in full_name else "bias"
# Navigate to the conv_layers list
conv_layers = None
for attribute in key.split("."):
if hasattr(hf_pointer, attribute):
hf_pointer = getattr(hf_pointer, attribute)
if attribute == "conv_layers":
conv_layers = hf_pointer
if conv_layers is not None and layer_idx < len(conv_layers):
# Update the appropriate layer parameter
if param_type == "weight":
conv_layers[layer_idx].weight.data = value
elif param_type == "bias":
conv_layers[layer_idx].bias.data = value
logger.info(f"Pos conv layer {layer_idx} {param_type} initialized from {full_name}")
return
# Normal parameter handling for standard layers
for attribute in key.split("."):
if hasattr(hf_pointer, attribute):
hf_pointer = getattr(hf_pointer, attribute)
else:
logger.warning(f"Attribute {attribute} not found in model. Skipping parameter: {full_name}")
return
hf_param_name = None
for param_key in PARAM_MAPPING.keys():
if full_name.endswith(param_key):
hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
weight_type = "param"
# Handling for parametrizations weight norm
if weight_type is not None and weight_type != "param":
if weight_type == "weight_g" and not hasattr(hf_pointer, "weight_g"):
if hasattr(hf_pointer, "parametrizations") and hasattr(hf_pointer.parametrizations, "weight") and hasattr(hf_pointer.parametrizations.weight, "original0"):
hf_shape = hf_pointer.parametrizations.weight.original0.shape
else:
logger.warning(f"Could not find weight_g parametrization for {full_name}. Skipping.")
return
elif weight_type == "weight_v" and not hasattr(hf_pointer, "weight_v"):
if hasattr(hf_pointer, "parametrizations") and hasattr(hf_pointer.parametrizations, "weight") and hasattr(hf_pointer.parametrizations.weight, "original1"):
hf_shape = hf_pointer.parametrizations.weight.original1.shape
else:
logger.warning(f"Could not find weight_v parametrization for {full_name}. Skipping.")
return
else:
if hasattr(hf_pointer, weight_type):
hf_shape = getattr(hf_pointer, weight_type).shape
else:
logger.warning(f"Attribute {weight_type} not found in model. Skipping parameter: {full_name}")
return
elif weight_type is not None and weight_type == "param":
shape_pointer = hf_pointer
for attribute in hf_param_name.split("."):
if hasattr(shape_pointer, attribute):
shape_pointer = getattr(shape_pointer, attribute)
else:
logger.warning(f"Attribute {attribute} not found in parameter. Skipping: {full_name}")
return
hf_shape = shape_pointer.shape
# reduce dimension
value = value[0]
else:
if hasattr(hf_pointer, "shape"):
hf_shape = hf_pointer.shape
else:
logger.warning(f"Shape attribute not found. Skipping parameter: {full_name}")
return
if hf_shape != value.shape:
logger.warning(f"Shape mismatch for {full_name}: HF shape {hf_shape}, Fairseq shape {value.shape}")
if "encoder.pos_conv" in full_name:
logger.info(f"Skipping pos_conv layer due to shape mismatch")
return
else:
logger.warning(f"Shape mismatch for {full_name}. Attempting to adapt...")
# Try to adapt the shape (for non-critical layers)
try:
if len(hf_shape) == len(value.shape):
# Same dimensionality, try to broadcast/slice
new_value = torch.zeros(hf_shape, device=value.device, dtype=value.dtype)
# Get min dimensions for each axis
min_dims = [min(hf_shape[i], value.shape[i]) for i in range(len(hf_shape))]
# Create slices
slices = tuple(slice(0, d) for d in min_dims)
# Copy data from source to destination
if len(slices) == 1:
new_value[slices[0]] = value[slices[0]]
elif len(slices) == 2:
new_value[slices[0], slices[1]] = value[slices[0], slices[1]]
elif len(slices) == 3:
new_value[slices[0], slices[1], slices[2]] = value[slices[0], slices[1], slices[2]]
else:
raise ValueError("Unsupported number of dimensions")
value = new_value
logger.info(f"Successfully adapted shape for {full_name}")
else:
logger.warning(f"Cannot adapt shapes with different dimensionality. Skipping {full_name}")
return
except Exception as e:
logger.warning(f"Failed to adapt shape: {e}. Skipping {full_name}")
return
if weight_type == "weight":
hf_pointer.weight.data = value
elif weight_type == "weight_g":
if hasattr(hf_pointer, "weight_g"):
hf_pointer.weight_g.data = value
else:
hf_pointer.parametrizations.weight.original0.data = value
elif weight_type == "weight_v":
if hasattr(hf_pointer, "weight_v"):
hf_pointer.weight_v.data = value
else:
hf_pointer.parametrizations.weight.original1.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
elif weight_type == "param":
for attribute in hf_param_name.split("."):
hf_pointer = getattr(hf_pointer, attribute)
hf_pointer.data = value
else:
hf_pointer.data = value
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
def rename_dict(key, value, full_name, weight_type, hf_dict):
hf_param_name = None
for param_key in PARAM_MAPPING.keys():
if full_name.endswith(param_key):
hf_param_name = PARAM_MAPPING[full_name.split(".")[-1]]
weight_type = "param"
if weight_type is not None and weight_type != "param":
full_key = ".".join([key, weight_type])
elif weight_type is not None and weight_type == "param":
full_key = ".".join([key, hf_param_name])
else:
full_key = key
hf_dict[full_key] = value if "lm_head" in full_key else value[0]
def load_wav2vec2_layer(name, value, hf_model=None, hf_dict=None):
is_used = False
# Special case: Handle DinoSR codebooks
if name.startswith("_codebook") and hf_model is not None and hasattr(hf_model, "codebook"):
codebook_idx = int(name.split("_codebook")[1][0])
# Map to our custom model's codebook
set_recursively("codebook.codebooks", value, name, None, hf_model)
return True
# Special case: Handle DinoSR codebook counts
if name.startswith("_codebook_cnts") and hf_model is not None and hasattr(hf_model, "codebook"):
codebook_idx = int(name.split("_codebook_cnts")[1][0])
# Map to our custom model's codebook counters
set_recursively("codebook.codebook_cnts", value, name, None, hf_model)
return True
# Special case: Handle DinoSR heads
if name.startswith("heads.") and hf_model is not None and hasattr(hf_model, "codebook"):
head_parts = name.split(".")
head_idx = int(head_parts[1])
param_type = head_parts[2] # weight or bias
# Map to our custom model's codebook heads
set_recursively("codebook.heads", value, name, param_type, hf_model)
return True
# Special case: Handle additional positional conv layers
if name.startswith("encoder.pos_conv.") and any(f".{i}." in name for i in range(1, 5)):
layer_idx = int(name.split("encoder.pos_conv.")[1].split(".")[0])
if layer_idx > 0 and hf_model is not None and hasattr(hf_model.wav2vec2.encoder, "pos_conv_embed"):
# Map to our custom positional embedding's additional conv layers
param_type = "weight" if "weight" in name else "bias"
set_recursively(f"wav2vec2.encoder.pos_conv_embed.conv_layers", value, name, None, hf_model)
return True
# Standard parameter mapping
for key, mapped_key in MAPPING.items():
mapped_key = "wav2vec2." + mapped_key if mapped_key not in TOP_LEVEL_KEYS else mapped_key
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
is_used = True
if "*" in mapped_key:
layer_index = name.split(key)[0].split(".")[-2]
mapped_key = mapped_key.replace("*", layer_index)
if "weight_g" in name:
weight_type = "weight_g"
elif "weight_v" in name:
weight_type = "weight_v"
elif "bias" in name:
weight_type = "bias"
elif "weight" in name:
weight_type = "weight"
else:
weight_type = None
if hf_dict is not None:
rename_dict(mapped_key, value, name, weight_type, hf_dict)
else:
set_recursively(mapped_key, value, name, weight_type, hf_model)
return is_used
return is_used
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
name = full_name.split("conv_layers.")[-1]
items = name.split(".")
layer_id = int(items[0])
type_id = int(items[1])
if type_id == 0:
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.bias.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.bias.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].conv.weight.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].conv.weight.data = value
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
if "bias" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
elif "weight" in name:
if value.shape != feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape:
raise ValueError(
f"{full_name} has size {value.shape}, but"
f" {feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape} was found."
)
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
else:
unused_weights.append(full_name)
def recursively_load_weights(fairseq_model, hf_model, is_headless):
unused_weights = []
fairseq_dict = fairseq_model.state_dict()
feature_extractor = hf_model.wav2vec2.feature_extractor
for name, value in fairseq_dict.items():
is_used = False
if "conv_layers" in name:
load_conv_layer(
name,
value,
feature_extractor,
unused_weights,
hf_model.config.feat_extract_norm == "group",
)
is_used = True
else:
is_used = load_wav2vec2_layer(name, value, hf_model)
if not is_used:
unused_weights.append(name)
logger.warning(f"Unused weights: {unused_weights}")
@torch.no_grad()
def convert_wav2vec2_checkpoint(code_path="dinosr", checkpoint_path="dinosr.ckpt",
config_path=None, dict_path=None,
is_finetuned=False, is_seq_class=False,
pytorch_dump_folder_path="converted_model"):
"""
Convert DinoSR model to HuggingFace Wav2Vec2 format
"""
# Import user modules from fairseq
fairseq.utils.import_user_module(argparse.Namespace(user_dir=code_path))
# Load the model
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
# Create configuration
if config_path is not None:
config = Wav2Vec2Config.from_pretrained(config_path)
else:
config = DinosrAudioConfig()
# Update config with some parameters from the DinoSR model base YAML
config.mask_prob = 0.8
config.mask_length = 10
config.conv_pos = 95 # This might be equivalent to kernel_size in pos_conv
config.encoder_embed_dim = 768
config.hidden_size = 768
config.encoder_layers = 12 # Default if not set in config
config.num_hidden_layers = 12 # Keep in sync with encoder_layers
config.feat_extract_norm = "layer"
# Set the positional embedding config
# Use the kernel size from your DinoSR model config
config.conv_pos_kernel_size = 95
config.num_conv_pos_embeddings = 128 # Set to default HF size
config.hidden_dropout = 0.0
config.activation_dropout = 0.0
config.feat_proj_dropout = 0.0
config.layerdrop = 0.05
# Set codebook parameters
config.num_codevectors_per_group = 256
config.num_codevector_groups = 8
if is_seq_class:
id2label = read_txt_into_dict(dict_path)
config.id2label = id2label
hf_wav2vec = Wav2Vec2ForSequenceClassification(config)
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=True,
)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
elif is_finetuned:
if dict_path:
target_dict = Dictionary.load(dict_path)
# important change bos & pad token id since CTC symbol is <pad> and
# not <s> as in fairseq
config.bos_token_id = target_dict.pad_index
config.pad_token_id = target_dict.bos_index
config.eos_token_id = target_dict.eos_index
config.vocab_size = len(target_dict.symbols)
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
if not os.path.isdir(pytorch_dump_folder_path):
logger.error("--pytorch_dump_folder_path ({}) should be a directory".format(pytorch_dump_folder_path))
return
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
vocab_dict = target_dict.indices
# fairseq has the <pad> and <s> switched
vocab_dict["<pad>"] = 0
vocab_dict["<s>"] = 1
with open(vocab_path, "w", encoding="utf-8") as vocab_handle:
json.dump(vocab_dict, vocab_handle)
tokenizer = Wav2Vec2CTCTokenizer(
vocab_path,
unk_token=target_dict.unk_word,
pad_token=target_dict.pad_word,
bos_token=target_dict.bos_word,
eos_token=target_dict.eos_word,
word_delimiter_token="|",
do_lower_case=False,
)
return_attention_mask = True if config.feat_extract_norm == "layer" else False
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=return_attention_mask,
)
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
processor.save_pretrained(pytorch_dump_folder_path)
hf_wav2vec = Wav2Vec2ForCTC(config)
else:
hf_wav2vec = Wav2Vec2ForPreTraining(config)
# Create output directory
os.makedirs(pytorch_dump_folder_path, exist_ok=True)
# Extract model from ensemble
model = model[0].eval()
# Transfer weights
recursively_load_weights(model, hf_wav2vec, not is_finetuned)
# Save the model
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
# Create default feature extractor if not already created
if not is_finetuned and not is_seq_class:
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1,
sampling_rate=16000,
padding_value=0,
do_normalize=True,
return_attention_mask=True if config.feat_extract_norm == "layer" else False,
)
feature_extractor.save_pretrained(pytorch_dump_folder_path)
logger.info(f"Model saved to {pytorch_dump_folder_path}")
return hf_wav2vec
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--code_path", type=str, default="dinosr", help="Path to the DinoSR code")
parser.add_argument("--checkpoint_path", type=str, default="dinosr.ckpt", help="Path to the checkpoint")
parser.add_argument("--config_path", type=str, default=None, help="Path to the config file")
parser.add_argument("--dict_path", type=str, default=None, help="Path to the dictionary")
parser.add_argument("--is_finetuned", action="store_true", help="Whether the model is finetuned")
parser.add_argument("--is_seq_class", action="store_true", help="Whether the model is for sequence classification")
parser.add_argument("--pytorch_dump_folder_path", type=str, default="converted_model", help="Path to save the converted model")
args = parser.parse_args()
convert_wav2vec2_checkpoint(
code_path=args.code_path,
checkpoint_path=args.checkpoint_path,
config_path=args.config_path,
dict_path=args.dict_path,
is_finetuned=args.is_finetuned,
is_seq_class=args.is_seq_class,
pytorch_dump_folder_path=args.pytorch_dump_folder_path,
)