| | from dataclasses import dataclass |
| | from transformers import ( |
| | Wav2Vec2BertModel, |
| | Wav2Vec2BertPreTrainedModel, |
| | Wav2Vec2BertProcessor, |
| | Wav2Vec2CTCTokenizer, |
| | Wav2Vec2Processor, |
| | Wav2Vec2ForCTC, |
| | Wav2Vec2PreTrainedModel, |
| | Wav2Vec2Model, |
| | ) |
| | from pycantonese.jyutping.parse_jyutping import ONSETS |
| | import re |
| | from transformers.models.wav2vec2_bert.modeling_wav2vec2_bert import ( |
| | _HIDDEN_STATES_START_POSITION, |
| | ) |
| | from transformers.models.wav2vec2.modeling_wav2vec2 import Wav2Vec2AttnAdapterLayer |
| | from transformers.models.wav2vec2_conformer.modeling_wav2vec2_conformer import ( |
| | Wav2Vec2ConformerPreTrainedModel, |
| | Wav2Vec2ConformerModel, |
| | Wav2Vec2ConformerForCTC, |
| | ) |
| | from transformers.modeling_outputs import ModelOutput |
| | from typing import Optional, Tuple, Union |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import logging |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | @dataclass |
| | class JuytpingOutput(ModelOutput): |
| | """ |
| | Output type of Wav2Vec2BertForCantonese |
| | """ |
| |
|
| | loss: Optional[torch.FloatTensor] = None |
| | jyutping_logits: torch.FloatTensor = None |
| | tone_logits: torch.FloatTensor = None |
| | jyutping_loss: Optional[torch.FloatTensor] = None |
| | tone_loss: Optional[torch.FloatTensor] = None |
| | hidden_states: Optional[Tuple[torch.FloatTensor]] = None |
| | attentions: Optional[Tuple[torch.FloatTensor]] = None |
| |
|
| |
|
| | class Wav2Vec2BertForCantonese(Wav2Vec2BertPreTrainedModel): |
| | """ |
| | Wav2Vec2BertForCantonese is a Wav2Vec2BertModel with a language model head on top (a linear layer on top of the hidden-states output) that outputs Jyutping and tone logits. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | config, |
| | tone_vocab_size: int = 9, |
| | ): |
| | super().__init__(config) |
| |
|
| | self.wav2vec2_bert = Wav2Vec2BertModel(config) |
| | self.dropout = nn.Dropout(config.final_dropout) |
| | self.tone_vocab_size = tone_vocab_size |
| |
|
| | if config.vocab_size is None: |
| | raise ValueError( |
| | f"You are trying to instantiate {self.__class__} with a configuration that " |
| | "does not define the vocabulary size of the language model head. Please " |
| | "instantiate the model as follows: `Wav2Vec2BertForCTC.from_pretrained(..., vocab_size=vocab_size)`. " |
| | "or define `vocab_size` of your model's configuration." |
| | ) |
| | output_hidden_size = ( |
| | config.output_hidden_size |
| | if hasattr(config, "add_adapter") and config.add_adapter |
| | else config.hidden_size |
| | ) |
| | self.jyutping_head = nn.Linear(output_hidden_size, config.vocab_size) |
| | self.tone_head = nn.Linear(output_hidden_size, tone_vocab_size) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def forward( |
| | self, |
| | input_features: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | jyutping_labels: Optional[torch.Tensor] = None, |
| | tone_labels: Optional[torch.Tensor] = None, |
| | ) -> Union[Tuple, JuytpingOutput]: |
| | if ( |
| | jyutping_labels is not None |
| | and jyutping_labels.max() >= self.config.vocab_size |
| | ): |
| | raise ValueError( |
| | f"Label values must be <= vocab_size: {self.config.vocab_size}" |
| | ) |
| |
|
| | if tone_labels is not None and tone_labels.max() >= self.tone_vocab_size: |
| | raise ValueError( |
| | f"Label values must be <= tone_vocab_size: {self.tone_vocab_size}" |
| | ) |
| |
|
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | outputs = self.wav2vec2_bert( |
| | input_features, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | hidden_states = outputs[0] |
| | hidden_states = self.dropout(hidden_states) |
| |
|
| | jyutping_logits = self.jyutping_head(hidden_states) |
| | tone_logits = self.tone_head(hidden_states) |
| |
|
| | loss = None |
| | jyutping_loss = None |
| | tone_loss = None |
| |
|
| | if jyutping_labels is not None and tone_labels is not None: |
| | |
| | attention_mask = ( |
| | attention_mask |
| | if attention_mask is not None |
| | else torch.ones( |
| | input_features.shape[:2], |
| | device=input_features.device, |
| | dtype=torch.long, |
| | ) |
| | ) |
| | input_lengths = self._get_feat_extract_output_lengths( |
| | attention_mask.sum([-1]) |
| | ).to(torch.long) |
| |
|
| | |
| | |
| | jyutping_labels_mask = jyutping_labels >= 0 |
| | jyutping_target_lengths = jyutping_labels_mask.sum(-1) |
| | jyutping_flattened_targets = jyutping_labels.masked_select( |
| | jyutping_labels_mask |
| | ) |
| |
|
| | |
| | jyutping_log_probs = nn.functional.log_softmax( |
| | jyutping_logits, dim=-1, dtype=torch.float32 |
| | ).transpose(0, 1) |
| |
|
| | with torch.backends.cudnn.flags(enabled=False): |
| | jyutping_loss = nn.functional.ctc_loss( |
| | jyutping_log_probs, |
| | jyutping_flattened_targets, |
| | input_lengths, |
| | jyutping_target_lengths, |
| | blank=self.config.pad_token_id, |
| | reduction=self.config.ctc_loss_reduction, |
| | zero_infinity=self.config.ctc_zero_infinity, |
| | ) |
| |
|
| | tone_labels_mask = tone_labels >= 0 |
| | tone_target_lengths = tone_labels_mask.sum(-1) |
| | tone_flattened_targets = tone_labels.masked_select(tone_labels_mask) |
| |
|
| | tone_log_probs = nn.functional.log_softmax( |
| | tone_logits, dim=-1, dtype=torch.float32 |
| | ).transpose(0, 1) |
| |
|
| | with torch.backends.cudnn.flags(enabled=False): |
| | tone_loss = nn.functional.ctc_loss( |
| | tone_log_probs, |
| | tone_flattened_targets, |
| | input_lengths, |
| | tone_target_lengths, |
| | blank=self.config.pad_token_id, |
| | reduction=self.config.ctc_loss_reduction, |
| | zero_infinity=self.config.ctc_zero_infinity, |
| | ) |
| |
|
| | loss = jyutping_loss + tone_loss |
| |
|
| | if not return_dict: |
| | output = (jyutping_logits, tone_logits) + outputs[ |
| | _HIDDEN_STATES_START_POSITION: |
| | ] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return JuytpingOutput( |
| | loss=loss, |
| | jyutping_logits=jyutping_logits, |
| | tone_logits=tone_logits, |
| | jyutping_loss=jyutping_loss, |
| | tone_loss=tone_loss, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| | def inference( |
| | self, |
| | processor: Wav2Vec2BertProcessor, |
| | tone_tokenizer: Wav2Vec2CTCTokenizer, |
| | input_features: torch.Tensor, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | ): |
| | outputs = self.forward( |
| | input_features=input_features, |
| | attention_mask=attention_mask, |
| | output_attentions=False, |
| | output_hidden_states=False, |
| | return_dict=True, |
| | ) |
| | jyutping_logits = outputs.jyutping_logits |
| | tone_logits = outputs.tone_logits |
| | jyutping_pred_ids = torch.argmax(jyutping_logits, dim=-1) |
| | tone_pred_ids = torch.argmax(tone_logits, dim=-1) |
| | jyutping_pred = processor.batch_decode(jyutping_pred_ids)[0] |
| | tone_pred = tone_tokenizer.batch_decode(tone_pred_ids)[0] |
| | jyutping_list = jyutping_pred.split(" ") |
| | tone_list = tone_pred.split(" ") |
| | jyutping_output = [] |
| |
|
| | for jypt in jyutping_list: |
| | is_initial = jypt in ONSETS |
| |
|
| | if is_initial: |
| | jypt = "_" + jypt |
| | else: |
| | jypt = jypt + "_" |
| |
|
| | jyutping_output.append(jypt) |
| |
|
| | jyutping_output = re.sub( |
| | r"\s+", " ", "".join(jyutping_output).replace("_", " ").strip() |
| | ).split(" ") |
| |
|
| | if len(tone_list) > len(jyutping_output): |
| | tone_list = tone_list[: len(jyutping_output)] |
| | elif len(tone_list) < len(jyutping_output): |
| | |
| | tone_list = tone_list + [tone_list[-1]] * ( |
| | len(jyutping_output) - len(tone_list) |
| | ) |
| |
|
| | return ( |
| | " ".join( |
| | [f"{jypt}{tone}" for jypt, tone in zip(jyutping_output, tone_list)] |
| | ), |
| | jyutping_logits, |
| | tone_logits, |
| | ) |
| |
|
| |
|
| | class Wav2Vec2ForCantonese(Wav2Vec2PreTrainedModel): |
| | def __init__( |
| | self, config, tone_vocab_size: int = 9, target_lang: Optional[str] = None |
| | ): |
| | super().__init__(config) |
| |
|
| | self.wav2vec2 = Wav2Vec2Model(config) |
| | self.dropout = nn.Dropout(config.final_dropout) |
| |
|
| | self.tone_vocab_size = tone_vocab_size |
| | self.target_lang = target_lang |
| |
|
| | if config.vocab_size is None: |
| | raise ValueError( |
| | f"You are trying to instantiate {self.__class__} with a configuration that " |
| | "does not define the vocabulary size of the language model head. Please " |
| | "instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`. " |
| | "or define `vocab_size` of your model's configuration." |
| | ) |
| |
|
| | output_hidden_size = ( |
| | config.output_hidden_size |
| | if hasattr(config, "add_adapter") and config.add_adapter |
| | else config.hidden_size |
| | ) |
| | self.jyutping_head = nn.Linear(output_hidden_size, config.vocab_size) |
| | self.tone_head = nn.Linear(output_hidden_size, tone_vocab_size) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def tie_weights(self): |
| | Wav2Vec2ForCTC.tie_weights(self) |
| |
|
| | def freeze_feature_extractor(self): |
| | Wav2Vec2ForCTC.freeze_feature_extractor(self) |
| |
|
| | def freeze_feature_encoder(self): |
| | Wav2Vec2ForCTC.freeze_feature_encoder(self) |
| |
|
| | def freeze_base_model(self): |
| | Wav2Vec2ForCTC.freeze_base_model(self) |
| |
|
| | def _get_adapters(self): |
| | if self.config.adapter_attn_dim is None: |
| | raise ValueError( |
| | f"{self.__class__} has no adapter layers. Make sure to define `config.adapter_attn_dim`." |
| | ) |
| |
|
| | adapter_weights = {} |
| | for name, module in self.named_modules(): |
| | if isinstance(module, Wav2Vec2AttnAdapterLayer): |
| | for param_name, param in module.named_parameters(): |
| | adapter_weights[".".join([name, param_name])] = param |
| |
|
| | if isinstance(self, Wav2Vec2ForCTC): |
| | for name, param in self.jyutping_head.named_parameters(): |
| | adapter_weights[".".join(["jyutping_head", name])] = param |
| | for name, param in self.tone_head.named_parameters(): |
| | adapter_weights[".".join(["tone_head", name])] = param |
| |
|
| | return adapter_weights |
| |
|
| | def forward( |
| | self, |
| | input_values: Optional[torch.Tensor], |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | jyutping_labels: Optional[torch.Tensor] = None, |
| | tone_labels: Optional[torch.Tensor] = None, |
| | ) -> Union[Tuple, JuytpingOutput]: |
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | if ( |
| | jyutping_labels is not None |
| | and jyutping_labels.max() >= self.config.vocab_size |
| | ): |
| | raise ValueError( |
| | f"Label values must be <= vocab_size: {self.config.vocab_size}" |
| | ) |
| |
|
| | if tone_labels is not None and tone_labels.max() >= self.tone_vocab_size: |
| | raise ValueError( |
| | f"Label values must be <= tone_vocab_size: {self.tone_vocab_size}" |
| | ) |
| |
|
| | outputs = self.wav2vec2( |
| | input_values, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | hidden_states = outputs[0] |
| | hidden_states = self.dropout(hidden_states) |
| |
|
| | jyutping_logits = self.jyutping_head(hidden_states) |
| | tone_logits = self.tone_head(hidden_states) |
| |
|
| | loss = None |
| | jyutping_loss = None |
| | tone_loss = None |
| |
|
| | if jyutping_labels is not None and tone_labels is not None: |
| | |
| | attention_mask = ( |
| | attention_mask |
| | if attention_mask is not None |
| | else torch.ones( |
| | input_values.shape[:2], |
| | device=input_values.device, |
| | dtype=torch.long, |
| | ) |
| | ) |
| | input_lengths = self._get_feat_extract_output_lengths( |
| | attention_mask.sum([-1]) |
| | ).to(torch.long) |
| |
|
| | |
| | |
| | jyutping_labels_mask = jyutping_labels >= 0 |
| | jyutping_target_lengths = jyutping_labels_mask.sum(-1) |
| | jyutping_flattened_targets = jyutping_labels.masked_select( |
| | jyutping_labels_mask |
| | ) |
| |
|
| | |
| | jyutping_log_probs = nn.functional.log_softmax( |
| | jyutping_logits, dim=-1, dtype=torch.float32 |
| | ).transpose(0, 1) |
| |
|
| | with torch.backends.cudnn.flags(enabled=False): |
| | jyutping_loss = nn.functional.ctc_loss( |
| | jyutping_log_probs, |
| | jyutping_flattened_targets, |
| | input_lengths, |
| | jyutping_target_lengths, |
| | blank=self.config.pad_token_id, |
| | reduction=self.config.ctc_loss_reduction, |
| | zero_infinity=self.config.ctc_zero_infinity, |
| | ) |
| |
|
| | tone_labels_mask = tone_labels >= 0 |
| | tone_target_lengths = tone_labels_mask.sum(-1) |
| | tone_flattened_targets = tone_labels.masked_select(tone_labels_mask) |
| |
|
| | tone_log_probs = nn.functional.log_softmax( |
| | tone_logits, dim=-1, dtype=torch.float32 |
| | ).transpose(0, 1) |
| |
|
| | with torch.backends.cudnn.flags(enabled=False): |
| | tone_loss = nn.functional.ctc_loss( |
| | tone_log_probs, |
| | tone_flattened_targets, |
| | input_lengths, |
| | tone_target_lengths, |
| | blank=self.config.pad_token_id, |
| | reduction=self.config.ctc_loss_reduction, |
| | zero_infinity=self.config.ctc_zero_infinity, |
| | ) |
| |
|
| | loss = jyutping_loss + tone_loss |
| |
|
| | if not return_dict: |
| | output = (jyutping_logits, tone_logits) + outputs[ |
| | _HIDDEN_STATES_START_POSITION: |
| | ] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return JuytpingOutput( |
| | loss=loss, |
| | jyutping_logits=jyutping_logits, |
| | tone_logits=tone_logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| |
|
| | class Wav2Vec2ConformerForCantonese(Wav2Vec2ConformerPreTrainedModel): |
| | def __init__( |
| | self, config, tone_vocab_size: int = 9, target_lang: Optional[str] = None |
| | ): |
| | super().__init__(config) |
| |
|
| | self.wav2vec2_conformer = Wav2Vec2ConformerModel(config) |
| | self.dropout = nn.Dropout(config.final_dropout) |
| |
|
| | self.tone_vocab_size = tone_vocab_size |
| | self.target_lang = target_lang |
| |
|
| | if config.vocab_size is None: |
| | raise ValueError( |
| | f"You are trying to instantiate {self.__class__} with a configuration that " |
| | "does not define the vocabulary size of the language model head. Please " |
| | "instantiate the model as follows: `Wav2Vec2ConformerForCTC.from_pretrained(..., vocab_size=vocab_size)`. " |
| | "or define `vocab_size` of your model's configuration." |
| | ) |
| |
|
| | output_hidden_size = ( |
| | config.output_hidden_size |
| | if hasattr(config, "add_adapter") and config.add_adapter |
| | else config.hidden_size |
| | ) |
| |
|
| | self.jyutping_head = nn.Linear(output_hidden_size, config.vocab_size) |
| | self.tone_head = nn.Linear(output_hidden_size, tone_vocab_size) |
| |
|
| | |
| | self.post_init() |
| |
|
| | def freeze_feature_encoder(self): |
| | Wav2Vec2ConformerForCTC.freeze_feature_encoder(self) |
| |
|
| | def forward( |
| | self, |
| | input_values: Optional[torch.Tensor], |
| | attention_mask: Optional[torch.Tensor] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | jyutping_labels: Optional[torch.Tensor] = None, |
| | tone_labels: Optional[torch.Tensor] = None, |
| | ) -> Union[Tuple, JuytpingOutput]: |
| | return_dict = ( |
| | return_dict if return_dict is not None else self.config.use_return_dict |
| | ) |
| |
|
| | if ( |
| | jyutping_labels is not None |
| | and jyutping_labels.max() >= self.config.vocab_size |
| | ): |
| | raise ValueError( |
| | f"Label values must be <= vocab_size: {self.config.vocab_size}" |
| | ) |
| |
|
| | if tone_labels is not None and tone_labels.max() >= self.tone_vocab_size: |
| | raise ValueError( |
| | f"Label values must be <= tone_vocab_size: {self.tone_vocab_size}" |
| | ) |
| |
|
| | outputs = self.wav2vec2_conformer( |
| | input_values, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| |
|
| | hidden_states = outputs[0] |
| | hidden_states = self.dropout(hidden_states) |
| |
|
| | jyutping_logits = self.jyutping_head(hidden_states) |
| | tone_logits = self.tone_head(hidden_states) |
| |
|
| | loss = None |
| | jyutping_loss = None |
| | tone_loss = None |
| |
|
| | if jyutping_labels is not None and tone_labels is not None: |
| | |
| | attention_mask = ( |
| | attention_mask |
| | if attention_mask is not None |
| | else torch.ones( |
| | input_values.shape[:2], |
| | device=input_values.device, |
| | dtype=torch.long, |
| | ) |
| | ) |
| | input_lengths = self._get_feat_extract_output_lengths( |
| | attention_mask.sum([-1]) |
| | ).to(torch.long) |
| |
|
| | |
| | |
| | jyutping_labels_mask = jyutping_labels >= 0 |
| | jyutping_target_lengths = jyutping_labels_mask.sum(-1) |
| | jyutping_flattened_targets = jyutping_labels.masked_select( |
| | jyutping_labels_mask |
| | ) |
| |
|
| | |
| | jyutping_log_probs = nn.functional.log_softmax( |
| | jyutping_logits, dim=-1, dtype=torch.float32 |
| | ).transpose(0, 1) |
| |
|
| | with torch.backends.cudnn.flags(enabled=False): |
| | jyutping_loss = nn.functional.ctc_loss( |
| | jyutping_log_probs, |
| | jyutping_flattened_targets, |
| | input_lengths, |
| | jyutping_target_lengths, |
| | blank=self.config.pad_token_id, |
| | reduction=self.config.ctc_loss_reduction, |
| | zero_infinity=self.config.ctc_zero_infinity, |
| | ) |
| |
|
| | tone_labels_mask = tone_labels >= 0 |
| | tone_target_lengths = tone_labels_mask.sum(-1) |
| | tone_flattened_targets = tone_labels.masked_select(tone_labels_mask) |
| |
|
| | tone_log_probs = nn.functional.log_softmax( |
| | tone_logits, dim=-1, dtype=torch.float32 |
| | ).transpose(0, 1) |
| |
|
| | with torch.backends.cudnn.flags(enabled=False): |
| | tone_loss = nn.functional.ctc_loss( |
| | tone_log_probs, |
| | tone_flattened_targets, |
| | input_lengths, |
| | tone_target_lengths, |
| | blank=self.config.pad_token_id, |
| | reduction=self.config.ctc_loss_reduction, |
| | zero_infinity=self.config.ctc_zero_infinity, |
| | ) |
| |
|
| | loss = jyutping_loss + tone_loss |
| |
|
| | if not return_dict: |
| | output = (jyutping_logits, tone_logits) + outputs[ |
| | _HIDDEN_STATES_START_POSITION: |
| | ] |
| | return ((loss,) + output) if loss is not None else output |
| |
|
| | return JuytpingOutput( |
| | loss=loss, |
| | jyutping_logits=jyutping_logits, |
| | tone_logits=tone_logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import torch |
| | import librosa |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | from transformers import Wav2Vec2Processor, Wav2Vec2CTCTokenizer |
| |
|
| | processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
| | tokenizer = Wav2Vec2CTCTokenizer( |
| | "vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|" |
| | ) |
| | processor = Wav2Vec2Processor( |
| | feature_extractor=processor.feature_extractor, tokenizer=tokenizer |
| | ) |
| | model = Wav2Vec2ForCantonese.from_pretrained( |
| | "TencentGameMate/chinese-hubert-base", |
| | tone_vocab_size=6, |
| | vocab_size=32, |
| | ctc_loss_reduction="mean", |
| | |
| | |
| | |
| | |
| | |
| | ) |
| | |
| |
|
| | wav, sr = librosa.load( |
| | "/home/pj24001684/ku40000295/jc/projects/wav2vec2bert-jyutping/test2.wav", |
| | sr=16000, |
| | ) |
| |
|
| | input_values = processor(wav, sampling_rate=sr).input_values[0] |
| | input_values = torch.from_numpy(input_values).unsqueeze(0) |
| | |
| | jyutping_labels = torch.randint(0, 32, (1, 10)) |
| | tone_labels = torch.randint(0, 6, (1, 10)) |
| |
|
| | output = model( |
| | input_values, |
| | jyutping_labels=jyutping_labels, |
| | tone_labels=tone_labels, |
| | ) |
| |
|
| | print(output.loss, output.jyutping_logits.shape, output.tone_logits.shape) |
| |
|