Safetensors
wav2vec2-bert
wav2vec2bert-jyutping / modeling.py
indiejoseph's picture
Rename model.py to modeling.py
4c25a54 verified
from dataclasses import dataclass
from transformers import (
Wav2Vec2BertModel,
Wav2Vec2BertPreTrainedModel,
Wav2Vec2BertProcessor,
Wav2Vec2CTCTokenizer,
Wav2Vec2Processor,
Wav2Vec2ForCTC,
Wav2Vec2PreTrainedModel,
Wav2Vec2Model,
)
from pycantonese.jyutping.parse_jyutping import ONSETS
import re
from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import (
_HIDDEN_STATES_START_POSITION,
)
from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2AttnAdapterLayer
from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import (
Wav2Vec2ConformerPreTrainedModel,
Wav2Vec2ConformerModel,
Wav2Vec2ConformerForCTC,
)
from transformers.modeling_outputs import ModelOutput
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import logging
logger = logging.getLogger(__name__)
@dataclass
class JuytpingOutput(ModelOutput):
"""
Output type of Wav2Vec2BertForCantonese
"""
loss: Optional[torch.FloatTensor] = None
jyutping_logits: torch.FloatTensor = None
tone_logits: torch.FloatTensor = None
jyutping_loss: Optional[torch.FloatTensor] = None
tone_loss: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class Wav2Vec2BertForCantonese(Wav2Vec2BertPreTrainedModel):
"""
Wav2Vec2BertForCantonese is a Wav2Vec2BertModel with a language model head on top (a linear layer on top of the hidden-states output) that outputs Jyutping and tone logits.
"""
def __init__(
self,
config,
tone_vocab_size: int = 9,
):
super().__init__(config)
self.wav2vec2_bert = Wav2Vec2BertModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.tone_vocab_size = tone_vocab_size
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
"does not define the vocabulary size of the language model head. Please "
"instantiate the model as follows: `Wav2Vec2BertForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
"or define `vocab_size` of your model's configuration."
)
output_hidden_size = (
config.output_hidden_size
if hasattr(config, "add_adapter") and config.add_adapter
else config.hidden_size
)
self.jyutping_head = nn.Linear(output_hidden_size, config.vocab_size)
self.tone_head = nn.Linear(output_hidden_size, tone_vocab_size)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
jyutping_labels: Optional[torch.Tensor] = None,
tone_labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, JuytpingOutput]:
if (
jyutping_labels is not None
and jyutping_labels.max() >= self.config.vocab_size
):
raise ValueError(
f"Label values must be <= vocab_size: {self.config.vocab_size}"
)
if tone_labels is not None and tone_labels.max() >= self.tone_vocab_size:
raise ValueError(
f"Label values must be <= tone_vocab_size: {self.tone_vocab_size}"
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
outputs = self.wav2vec2_bert(
input_features,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
jyutping_logits = self.jyutping_head(hidden_states)
tone_logits = self.tone_head(hidden_states)
loss = None
jyutping_loss = None
tone_loss = None
if jyutping_labels is not None and tone_labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask
if attention_mask is not None
else torch.ones(
input_features.shape[:2],
device=input_features.device,
dtype=torch.long,
)
)
input_lengths = self._get_feat_extract_output_lengths(
attention_mask.sum([-1])
).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
jyutping_labels_mask = jyutping_labels >= 0
jyutping_target_lengths = jyutping_labels_mask.sum(-1)
jyutping_flattened_targets = jyutping_labels.masked_select(
jyutping_labels_mask
)
# ctc_loss doesn't support fp16
jyutping_log_probs = nn.functional.log_softmax(
jyutping_logits, dim=-1, dtype=torch.float32
).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
jyutping_loss = nn.functional.ctc_loss(
jyutping_log_probs,
jyutping_flattened_targets,
input_lengths,
jyutping_target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
tone_labels_mask = tone_labels >= 0
tone_target_lengths = tone_labels_mask.sum(-1)
tone_flattened_targets = tone_labels.masked_select(tone_labels_mask)
tone_log_probs = nn.functional.log_softmax(
tone_logits, dim=-1, dtype=torch.float32
).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
tone_loss = nn.functional.ctc_loss(
tone_log_probs,
tone_flattened_targets,
input_lengths,
tone_target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
loss = jyutping_loss + tone_loss
if not return_dict:
output = (jyutping_logits, tone_logits) + outputs[
_HIDDEN_STATES_START_POSITION:
]
return ((loss,) + output) if loss is not None else output
return JuytpingOutput(
loss=loss,
jyutping_logits=jyutping_logits,
tone_logits=tone_logits,
jyutping_loss=jyutping_loss,
tone_loss=tone_loss,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def inference(
self,
processor: Wav2Vec2BertProcessor,
tone_tokenizer: Wav2Vec2CTCTokenizer,
input_features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
):
outputs = self.forward(
input_features=input_features,
attention_mask=attention_mask,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
)
jyutping_logits = outputs.jyutping_logits
tone_logits = outputs.tone_logits
jyutping_pred_ids = torch.argmax(jyutping_logits, dim=-1)
tone_pred_ids = torch.argmax(tone_logits, dim=-1)
jyutping_pred = processor.batch_decode(jyutping_pred_ids)[0]
tone_pred = tone_tokenizer.batch_decode(tone_pred_ids)[0]
jyutping_list = jyutping_pred.split(" ")
tone_list = tone_pred.split(" ")
jyutping_output = []
for jypt in jyutping_list:
is_initial = jypt in ONSETS
if is_initial:
jypt = "_" + jypt
else:
jypt = jypt + "_"
jyutping_output.append(jypt)
jyutping_output = re.sub(
r"\s+", " ", "".join(jyutping_output).replace("_", " ").strip()
).split(" ")
if len(tone_list) > len(jyutping_output):
tone_list = tone_list[: len(jyutping_output)]
elif len(tone_list) < len(jyutping_output):
# repeat the last tone if the length of tone list is shorter than the length of jyutping list
tone_list = tone_list + [tone_list[-1]] * (
len(jyutping_output) - len(tone_list)
)
return (
" ".join(
[f"{jypt}{tone}" for jypt, tone in zip(jyutping_output, tone_list)]
),
jyutping_logits,
tone_logits,
)
class Wav2Vec2ForCantonese(Wav2Vec2PreTrainedModel):
def __init__(
self, config, tone_vocab_size: int = 9, target_lang: Optional[str] = None
):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.final_dropout)
self.tone_vocab_size = tone_vocab_size
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
"does not define the vocabulary size of the language model head. Please "
"instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
"or define `vocab_size` of your model's configuration."
)
output_hidden_size = (
config.output_hidden_size
if hasattr(config, "add_adapter") and config.add_adapter
else config.hidden_size
)
self.jyutping_head = nn.Linear(output_hidden_size, config.vocab_size)
self.tone_head = nn.Linear(output_hidden_size, tone_vocab_size)
# Initialize weights and apply final processing
self.post_init()
def tie_weights(self):
Wav2Vec2ForCTC.tie_weights(self)
def freeze_feature_extractor(self):
Wav2Vec2ForCTC.freeze_feature_extractor(self)
def freeze_feature_encoder(self):
Wav2Vec2ForCTC.freeze_feature_encoder(self)
def freeze_base_model(self):
Wav2Vec2ForCTC.freeze_base_model(self)
def _get_adapters(self):
if self.config.adapter_attn_dim is None:
raise ValueError(
f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`."
)
adapter_weights = {}
for name, module in self.named_modules():
if isinstance(module, Wav2Vec2AttnAdapterLayer):
for param_name, param in module.named_parameters():
adapter_weights[".".join([name, param_name])] = param
if isinstance(self, Wav2Vec2ForCTC):
for name, param in self.jyutping_head.named_parameters():
adapter_weights[".".join(["jyutping_head", name])] = param
for name, param in self.tone_head.named_parameters():
adapter_weights[".".join(["tone_head", name])] = param
return adapter_weights
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
jyutping_labels: Optional[torch.Tensor] = None,
tone_labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, JuytpingOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (
jyutping_labels is not None
and jyutping_labels.max() >= self.config.vocab_size
):
raise ValueError(
f"Label values must be <= vocab_size: {self.config.vocab_size}"
)
if tone_labels is not None and tone_labels.max() >= self.tone_vocab_size:
raise ValueError(
f"Label values must be <= tone_vocab_size: {self.tone_vocab_size}"
)
outputs = self.wav2vec2(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
jyutping_logits = self.jyutping_head(hidden_states)
tone_logits = self.tone_head(hidden_states)
loss = None
jyutping_loss = None
tone_loss = None
if jyutping_labels is not None and tone_labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask
if attention_mask is not None
else torch.ones(
input_values.shape[:2],
device=input_values.device,
dtype=torch.long,
)
)
input_lengths = self._get_feat_extract_output_lengths(
attention_mask.sum([-1])
).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
jyutping_labels_mask = jyutping_labels >= 0
jyutping_target_lengths = jyutping_labels_mask.sum(-1)
jyutping_flattened_targets = jyutping_labels.masked_select(
jyutping_labels_mask
)
# ctc_loss doesn't support fp16
jyutping_log_probs = nn.functional.log_softmax(
jyutping_logits, dim=-1, dtype=torch.float32
).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
jyutping_loss = nn.functional.ctc_loss(
jyutping_log_probs,
jyutping_flattened_targets,
input_lengths,
jyutping_target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
tone_labels_mask = tone_labels >= 0
tone_target_lengths = tone_labels_mask.sum(-1)
tone_flattened_targets = tone_labels.masked_select(tone_labels_mask)
tone_log_probs = nn.functional.log_softmax(
tone_logits, dim=-1, dtype=torch.float32
).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
tone_loss = nn.functional.ctc_loss(
tone_log_probs,
tone_flattened_targets,
input_lengths,
tone_target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
loss = jyutping_loss + tone_loss
if not return_dict:
output = (jyutping_logits, tone_logits) + outputs[
_HIDDEN_STATES_START_POSITION:
]
return ((loss,) + output) if loss is not None else output
return JuytpingOutput(
loss=loss,
jyutping_logits=jyutping_logits,
tone_logits=tone_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class Wav2Vec2ConformerForCantonese(Wav2Vec2ConformerPreTrainedModel):
def __init__(
self, config, tone_vocab_size: int = 9, target_lang: Optional[str] = None
):
super().__init__(config)
self.wav2vec2_conformer = Wav2Vec2ConformerModel(config)
self.dropout = nn.Dropout(config.final_dropout)
self.tone_vocab_size = tone_vocab_size
self.target_lang = target_lang
if config.vocab_size is None:
raise ValueError(
f"You are trying to instantiate {self.__class__} with a configuration that "
"does not define the vocabulary size of the language model head. Please "
"instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
"or define `vocab_size` of your model's configuration."
)
output_hidden_size = (
config.output_hidden_size
if hasattr(config, "add_adapter") and config.add_adapter
else config.hidden_size
)
self.jyutping_head = nn.Linear(output_hidden_size, config.vocab_size)
self.tone_head = nn.Linear(output_hidden_size, tone_vocab_size)
# Initialize weights and apply final processing
self.post_init()
def freeze_feature_encoder(self):
Wav2Vec2ConformerForCTC.freeze_feature_encoder(self)
def forward(
self,
input_values: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
jyutping_labels: Optional[torch.Tensor] = None,
tone_labels: Optional[torch.Tensor] = None,
) -> Union[Tuple, JuytpingOutput]:
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
if (
jyutping_labels is not None
and jyutping_labels.max() >= self.config.vocab_size
):
raise ValueError(
f"Label values must be <= vocab_size: {self.config.vocab_size}"
)
if tone_labels is not None and tone_labels.max() >= self.tone_vocab_size:
raise ValueError(
f"Label values must be <= tone_vocab_size: {self.tone_vocab_size}"
)
outputs = self.wav2vec2_conformer(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
hidden_states = self.dropout(hidden_states)
jyutping_logits = self.jyutping_head(hidden_states)
tone_logits = self.tone_head(hidden_states)
loss = None
jyutping_loss = None
tone_loss = None
if jyutping_labels is not None and tone_labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask
if attention_mask is not None
else torch.ones(
input_values.shape[:2],
device=input_values.device,
dtype=torch.long,
)
)
input_lengths = self._get_feat_extract_output_lengths(
attention_mask.sum([-1])
).to(torch.long)
# assuming that padded tokens are filled with -100
# when not being attended to
jyutping_labels_mask = jyutping_labels >= 0
jyutping_target_lengths = jyutping_labels_mask.sum(-1)
jyutping_flattened_targets = jyutping_labels.masked_select(
jyutping_labels_mask
)
# ctc_loss doesn't support fp16
jyutping_log_probs = nn.functional.log_softmax(
jyutping_logits, dim=-1, dtype=torch.float32
).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
jyutping_loss = nn.functional.ctc_loss(
jyutping_log_probs,
jyutping_flattened_targets,
input_lengths,
jyutping_target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
tone_labels_mask = tone_labels >= 0
tone_target_lengths = tone_labels_mask.sum(-1)
tone_flattened_targets = tone_labels.masked_select(tone_labels_mask)
tone_log_probs = nn.functional.log_softmax(
tone_logits, dim=-1, dtype=torch.float32
).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
tone_loss = nn.functional.ctc_loss(
tone_log_probs,
tone_flattened_targets,
input_lengths,
tone_target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
loss = jyutping_loss + tone_loss
if not return_dict:
output = (jyutping_logits, tone_logits) + outputs[
_HIDDEN_STATES_START_POSITION:
]
return ((loss,) + output) if loss is not None else output
return JuytpingOutput(
loss=loss,
jyutping_logits=jyutping_logits,
tone_logits=tone_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
if __name__ == "__main__":
import torch
import librosa
# from transformers import (
# SeamlessM4TFeatureExtractor,
# Wav2Vec2BertProcessor,
# Wav2Vec2CTCTokenizer,
# )
# tokenizer = Wav2Vec2CTCTokenizer(
# "vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
# )
# feature_extractor = SeamlessM4TFeatureExtractor.from_pretrained(
# "facebook/w2v-bert-2.0"
# )
# processor = Wav2Vec2BertProcessor(
# feature_extractor=feature_extractor, tokenizer=tokenizer
# )
# wav, sr = librosa.load("/notebooks/projects/wav2vec2-yue/test_nei1.wav", sr=16000)
# input_features = processor(wav, sampling_rate=sr).input_features[0]
# model = Wav2Vec2BertForCantonese.from_pretrained(
# "facebook/w2v-bert-2.0",
# tone_vocab_size=6,
# vocab_size=32,
# attention_dropout=0.2,
# hidden_dropout=0.2,
# feat_proj_dropout=0.0,
# mask_time_prob=0.0,
# layerdrop=0.0,
# ctc_loss_reduction="mean",
# add_adapter=True,
# pad_token_id=processor.tokenizer.pad_token_id,
# )
# print("input_features", input_features.shape)
# print(wav.shape)
# # Test forward pass
# input_features = torch.randn(1, 123, 160)
# jyutping_labels = torch.randint(0, 32, (1, 10))
# tone_labels = torch.randint(0, 6, (1, 10))
from transformers import Wav2Vec2Processor, Wav2Vec2CTCTokenizer
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
tokenizer = Wav2Vec2CTCTokenizer(
"vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|"
)
processor = Wav2Vec2Processor(
feature_extractor=processor.feature_extractor, tokenizer=tokenizer
)
model = Wav2Vec2ForCantonese.from_pretrained(
"TencentGameMate/chinese-hubert-base",
tone_vocab_size=6,
vocab_size=32,
ctc_loss_reduction="mean",
# pad_token_id=processor.tokenizer.pad_token_id,
# mask_time_prob=0.0, # 0.05
# mask_time_length=10, # 10
# mask_feature_prob=0.3, # 0
# mask_feature_length=10, # 10
)
# model.freeze_feature_extractor()
wav, sr = librosa.load(
"/home/pj24001684/ku40000295/jc/projects/wav2vec2bert-jyutping/test2.wav",
sr=16000,
)
input_values = processor(wav, sampling_rate=sr).input_values[0]
input_values = torch.from_numpy(input_values).unsqueeze(0)
# input_values = torch.randn(16000 * 10)
jyutping_labels = torch.randint(0, 32, (1, 10))
tone_labels = torch.randint(0, 6, (1, 10))
output = model(
input_values,
jyutping_labels=jyutping_labels,
tone_labels=tone_labels,
)
print(output.loss, output.jyutping_logits.shape, output.tone_logits.shape)