sumitranjan commited on
Commit
2af4e33
·
verified ·
1 Parent(s): 9d4e582

Update modeling_voiceshield.py

Browse files
Files changed (1) hide show
  1. modeling_voiceshield.py +102 -21
modeling_voiceshield.py CHANGED
@@ -1,33 +1,74 @@
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  from transformers import WhisperModel, PreTrainedModel
4
  from transformers.modeling_outputs import SequenceClassifierOutput
5
  from transformers.configuration_utils import PretrainedConfig
6
 
 
7
  class VoiceShieldConfig(PretrainedConfig):
 
 
 
 
 
 
 
8
  model_type = "voiceshield"
9
-
10
- def __init__(self, num_labels=2, base_model="openai/whisper-small", **kwargs):
 
 
 
 
 
 
 
11
  super().__init__(**kwargs)
12
  self.num_labels = num_labels
13
  self.base_model = base_model
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- from transformers import WhisperConfig
16
- from transformers.models.whisper.modeling_whisper import WhisperEncoder
17
 
18
  class VoiceShieldForAudioClassification(PreTrainedModel):
 
 
 
 
 
 
 
19
  config_class = VoiceShieldConfig
20
-
21
  def __init__(self, config):
22
  super().__init__(config)
23
-
 
 
24
  self._keys_to_ignore_on_load_missing = [r"encoder\."]
25
-
26
- # Load the Whisper configuration only (no weights)
27
- whisper_config = WhisperConfig.from_pretrained(config.base_model)
28
- self.encoder = WhisperEncoder(whisper_config) # creates the encoder architecture
 
 
 
29
  d_model = self.encoder.config.d_model
30
-
 
31
  self.classifier = nn.Sequential(
32
  nn.Linear(d_model, 512),
33
  nn.GELU(),
@@ -37,18 +78,58 @@ class VoiceShieldForAudioClassification(PreTrainedModel):
37
  nn.Dropout(0.3),
38
  nn.Linear(128, config.num_labels),
39
  )
40
-
 
41
  self.post_init()
42
-
43
-
44
-
45
- def forward(self, input_features=None, labels=None, **kwargs):
46
- hidden = self.encoder(input_features).last_hidden_state
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  pooled = hidden.mean(dim=1)
 
 
48
  logits = self.classifier(pooled)
49
-
 
50
  loss = None
51
  if labels is not None:
52
- loss = nn.CrossEntropyLoss()(logits, labels)
53
-
54
- return SequenceClassifierOutput(loss=loss, logits=logits)
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoiceShield: Audio Classification Model for Voice Security
3
+ Combines Whisper encoder with custom classifier head for malicious audio detection
4
+ """
5
  import torch
6
  import torch.nn as nn
7
  from transformers import WhisperModel, PreTrainedModel
8
  from transformers.modeling_outputs import SequenceClassifierOutput
9
  from transformers.configuration_utils import PretrainedConfig
10
 
11
+
12
  class VoiceShieldConfig(PretrainedConfig):
13
+ """
14
+ Configuration class for VoiceShield model.
15
+
16
+ Args:
17
+ num_labels (int): Number of classification labels (default: 2)
18
+ base_model (str): Base Whisper model to use (default: "openai/whisper-small")
19
+ """
20
  model_type = "voiceshield"
21
+
22
+ def __init__(
23
+ self,
24
+ num_labels=2,
25
+ base_model="openai/whisper-small",
26
+ id2label=None,
27
+ label2id=None,
28
+ **kwargs
29
+ ):
30
  super().__init__(**kwargs)
31
  self.num_labels = num_labels
32
  self.base_model = base_model
33
+
34
+ # Set default labels if not provided
35
+ if id2label is None:
36
+ self.id2label = {0: "safe", 1: "malicious"}
37
+ else:
38
+ self.id2label = id2label
39
+
40
+ if label2id is None:
41
+ self.label2id = {"safe": 0, "malicious": 1}
42
+ else:
43
+ self.label2id = label2id
44
 
 
 
45
 
46
  class VoiceShieldForAudioClassification(PreTrainedModel):
47
+ """
48
+ VoiceShield model for audio classification.
49
+
50
+ Uses a pre-trained Whisper encoder with a custom classification head.
51
+ The encoder weights are loaded from the base Whisper model, while
52
+ the classifier head is trained for voice security tasks.
53
+ """
54
  config_class = VoiceShieldConfig
55
+
56
  def __init__(self, config):
57
  super().__init__(config)
58
+
59
+ # Tell HuggingFace to ignore missing encoder keys during load
60
+ # Encoder weights come from base Whisper model, not model.safetensors
61
  self._keys_to_ignore_on_load_missing = [r"encoder\."]
62
+ self._keys_to_ignore_on_load_unexpected = []
63
+
64
+ # Load Whisper encoder
65
+ whisper = WhisperModel.from_pretrained(config.base_model)
66
+ self.encoder = whisper.encoder
67
+
68
+ # Get model dimension
69
  d_model = self.encoder.config.d_model
70
+
71
+ # Classification head
72
  self.classifier = nn.Sequential(
73
  nn.Linear(d_model, 512),
74
  nn.GELU(),
 
78
  nn.Dropout(0.3),
79
  nn.Linear(128, config.num_labels),
80
  )
81
+
82
+ # Initialize weights and apply final processing
83
  self.post_init()
84
+
85
+ def forward(
86
+ self,
87
+ input_features=None,
88
+ labels=None,
89
+ output_hidden_states=False,
90
+ return_dict=True,
91
+ **kwargs
92
+ ):
93
+ """
94
+ Forward pass for VoiceShield model.
95
+
96
+ Args:
97
+ input_features: Mel spectrogram features from audio
98
+ labels: Ground truth labels for training
99
+ output_hidden_states: Whether to return hidden states
100
+ return_dict: Whether to return ModelOutput object
101
+
102
+ Returns:
103
+ SequenceClassifierOutput with loss and logits
104
+ """
105
+ # Encode audio features
106
+ encoder_outputs = self.encoder(
107
+ input_features,
108
+ output_hidden_states=output_hidden_states,
109
+ return_dict=return_dict
110
+ )
111
+
112
+ # Get last hidden state
113
+ hidden = encoder_outputs.last_hidden_state if return_dict else encoder_outputs[0]
114
+
115
+ # Mean pooling over sequence dimension
116
  pooled = hidden.mean(dim=1)
117
+
118
+ # Classification
119
  logits = self.classifier(pooled)
120
+
121
+ # Calculate loss if labels provided
122
  loss = None
123
  if labels is not None:
124
+ loss_fct = nn.CrossEntropyLoss()
125
+ loss = loss_fct(logits, labels)
126
+
127
+ if not return_dict:
128
+ output = (logits,) + encoder_outputs[1:]
129
+ return ((loss,) + output) if loss is not None else output
130
+
131
+ return SequenceClassifierOutput(
132
+ loss=loss,
133
+ logits=logits,
134
+ hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
135
+ )