| """ |
| VoiceShield: Audio Classification Model for Voice Security |
| Combines Whisper encoder with custom classifier head for malicious audio detection |
| """ |
| import torch |
| import torch.nn as nn |
| from transformers import WhisperModel, PreTrainedModel |
| from transformers.modeling_outputs import SequenceClassifierOutput |
| from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
| class VoiceShieldConfig(PretrainedConfig): |
| """ |
| Configuration class for VoiceShield model. |
| |
| Args: |
| num_labels (int): Number of classification labels (default: 2) |
| base_model (str): Base Whisper model to use (default: "openai/whisper-small") |
| """ |
| model_type = "voiceshield" |
| |
| def __init__( |
| self, |
| num_labels=2, |
| base_model="openai/whisper-small", |
| id2label=None, |
| label2id=None, |
| **kwargs |
| ): |
| super().__init__(**kwargs) |
| self.num_labels = num_labels |
| self.base_model = base_model |
| |
| |
| if id2label is None: |
| self.id2label = {0: "safe", 1: "malicious"} |
| else: |
| self.id2label = id2label |
| |
| if label2id is None: |
| self.label2id = {"safe": 0, "malicious": 1} |
| else: |
| self.label2id = label2id |
|
|
|
|
| class VoiceShieldForAudioClassification(PreTrainedModel): |
| """ |
| VoiceShield model for audio classification. |
| |
| Uses a pre-trained Whisper encoder with a custom classification head. |
| The encoder weights are loaded from the base Whisper model, while |
| the classifier head is trained for voice security tasks. |
| """ |
| config_class = VoiceShieldConfig |
| |
| def __init__(self, config): |
| super().__init__(config) |
| |
| |
| |
| self._keys_to_ignore_on_load_missing = [r"encoder\."] |
| self._keys_to_ignore_on_load_unexpected = [] |
| |
| |
| whisper = WhisperModel.from_pretrained(config.base_model) |
| self.encoder = whisper.encoder |
| |
| |
| d_model = self.encoder.config.d_model |
| |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(d_model, 512), |
| nn.GELU(), |
| nn.Dropout(0.3), |
| nn.Linear(512, 128), |
| nn.GELU(), |
| nn.Dropout(0.3), |
| nn.Linear(128, config.num_labels), |
| ) |
| |
| |
| self.post_init() |
| |
| def forward( |
| self, |
| input_features=None, |
| labels=None, |
| output_hidden_states=False, |
| return_dict=True, |
| **kwargs |
| ): |
| """ |
| Forward pass for VoiceShield model. |
| |
| Args: |
| input_features: Mel spectrogram features from audio |
| labels: Ground truth labels for training |
| output_hidden_states: Whether to return hidden states |
| return_dict: Whether to return ModelOutput object |
| |
| Returns: |
| SequenceClassifierOutput with loss and logits |
| """ |
| |
| encoder_outputs = self.encoder( |
| input_features, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict |
| ) |
| |
| |
| hidden = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0] |
| |
| |
| pooled = hidden.mean(dim=1) |
| |
| |
| logits = self.classifier(pooled) |
| |
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct(logits, labels) |
| |
| if not return_dict: |
| output = (logits,) + encoder_outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
| |
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=encoder_outputs.hidden_states if output_hidden_states else None, |
| ) |