voiceSHIELD-small / modeling_voiceshield.py
sumitranjan's picture
Update modeling_voiceshield.py
2af4e33 verified
"""
VoiceShield: Audio Classification Model for Voice Security
Combines Whisper encoder with custom classifier head for malicious audio detection
"""
import torch
import torch.nn as nn
from transformers import WhisperModel, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
from transformers.configuration_utils import PretrainedConfig
class VoiceShieldConfig(PretrainedConfig):
"""
Configuration class for VoiceShield model.
Args:
num_labels (int): Number of classification labels (default: 2)
base_model (str): Base Whisper model to use (default: "openai/whisper-small")
"""
model_type = "voiceshield"
def __init__(
self,
num_labels=2,
base_model="openai/whisper-small",
id2label=None,
label2id=None,
**kwargs
):
super().__init__(**kwargs)
self.num_labels = num_labels
self.base_model = base_model
# Set default labels if not provided
if id2label is None:
self.id2label = {0: "safe", 1: "malicious"}
else:
self.id2label = id2label
if label2id is None:
self.label2id = {"safe": 0, "malicious": 1}
else:
self.label2id = label2id
class VoiceShieldForAudioClassification(PreTrainedModel):
"""
VoiceShield model for audio classification.
Uses a pre-trained Whisper encoder with a custom classification head.
The encoder weights are loaded from the base Whisper model, while
the classifier head is trained for voice security tasks.
"""
config_class = VoiceShieldConfig
def __init__(self, config):
super().__init__(config)
# Tell HuggingFace to ignore missing encoder keys during load
# Encoder weights come from base Whisper model, not model.safetensors
self._keys_to_ignore_on_load_missing = [r"encoder\."]
self._keys_to_ignore_on_load_unexpected = []
# Load Whisper encoder
whisper = WhisperModel.from_pretrained(config.base_model)
self.encoder = whisper.encoder
# Get model dimension
d_model = self.encoder.config.d_model
# Classification head
self.classifier = nn.Sequential(
nn.Linear(d_model, 512),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.GELU(),
nn.Dropout(0.3),
nn.Linear(128, config.num_labels),
)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_features=None,
labels=None,
output_hidden_states=False,
return_dict=True,
**kwargs
):
"""
Forward pass for VoiceShield model.
Args:
input_features: Mel spectrogram features from audio
labels: Ground truth labels for training
output_hidden_states: Whether to return hidden states
return_dict: Whether to return ModelOutput object
Returns:
SequenceClassifierOutput with loss and logits
"""
# Encode audio features
encoder_outputs = self.encoder(
input_features,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
# Get last hidden state
hidden = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0]
# Mean pooling over sequence dimension
pooled = hidden.mean(dim=1)
# Classification
logits = self.classifier(pooled)
# Calculate loss if labels provided
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,) + encoder_outputs[1:]
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
)