| | --- |
| | language: |
| | - en |
| | license: mit |
| | base_model: |
| | - facebook/wav2vec2-base |
| | --- |
| | |
| | SCD(Speaker Change Detection,讲者变化检测):是指在音频或视频内容中识别出讲话者发生变化的技术。它通常被应用于多讲者的对话或演讲场景中,以此来检测何时从一个讲者切换到另一个讲者。 |
| |
|
| | 如何使用 |
| | # Note: at the time this code was originally written, transformers.Wav2Vec2ForAudioFrameClassification was incomplete |
| | # -> this adds the then-missing parts |
| | class Wav2Vec2ForAudioFrameClassification_custom(transformers.Wav2Vec2ForAudioFrameClassification, |
| | PyTorchModelHubMixin, |
| | repo_url="your-repo-url", |
| | pipeline_tag="text-to-image", |
| | license="mit",): |
| | def __init__(self, config): |
| | super().__init__(config) |
| | self.num_labels = config.num_labels |
| | |
| | if hasattr(config, "add_adapter") and config.add_adapter: |
| | raise ValueError( |
| | "Audio frame classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)" |
| | ) |
| | self.wav2vec2 = Wav2Vec2Model(config) |
| | num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings |
| | if config.use_weighted_layer_sum: |
| | self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
| | self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
| | |
| | self.init_weights() |
| | |
| | def forward( |
| | self, |
| | input_values, |
| | attention_mask=None, |
| | output_attentions=None, |
| | output_hidden_states=None, |
| | return_dict=None, |
| | labels=None, # ADDED |
| | ): |
| | return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| | output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states |
| | |
| | outputs = self.wav2vec2( |
| | input_values, |
| | attention_mask=attention_mask, |
| | output_attentions=output_attentions, |
| | output_hidden_states=output_hidden_states, |
| | return_dict=return_dict, |
| | ) |
| | |
| | if self.config.use_weighted_layer_sum: |
| | hidden_states = outputs[_HIDDEN_STATES_START_POSITION] |
| | hidden_states = torch.stack(hidden_states, dim=1) |
| | norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) |
| | hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
| | else: |
| | hidden_states = outputs[0] |
| | |
| | logits = self.classifier(hidden_states) |
| | labels = labels.reshape(-1,1) # 1xN -> Nx1 |
| | |
| | # ADDED |
| | loss = None |
| | if labels is not None: |
| | if self.num_labels == 1: |
| | loss_fct = MSELoss() |
| | #loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels) |
| | else: |
| | loss_fct = CrossEntropyLoss() |
| | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| | |
| | |
| | if not return_dict: |
| | output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] |
| | return ((loss,) + output) if loss is not None else output |
| | |
| | return TokenClassifierOutput( |
| | loss=loss, |
| | logits=logits, |
| | hidden_states=outputs.hidden_states, |
| | attentions=outputs.attentions, |
| | ) |
| | |
| |
|