sumitranjan commited on
Commit
3051af2
·
verified ·
1 Parent(s): 2af4e33

Update pipeline_voiceshield.py

Browse files
Files changed (1) hide show
  1. pipeline_voiceshield.py +108 -49
pipeline_voiceshield.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import torch
2
  import torch.nn.functional as F
3
  import numpy as np
@@ -7,91 +10,147 @@ from transformers import Pipeline, WhisperProcessor, WhisperForConditionalGenera
7
 
8
 
9
  class VoiceShieldPipeline(Pipeline):
10
-
 
 
 
 
 
 
 
 
 
11
  def __init__(self, model, threshold=0.2, **kwargs):
12
  self.threshold = threshold
13
-
14
- # FIX 1: tokenizer= must be passed to super().__init__()
15
- # Pipeline requires it even for audio tasks, pass None explicitly
16
- # to prevent "tokenizer is required" crash
17
  kwargs.setdefault("tokenizer", None)
18
-
19
  super().__init__(model=model, **kwargs)
20
-
 
21
  base_model = model.config.base_model
22
-
23
- # FIX 2: load processor AFTER super().__init__() so self.device is set
24
  self.processor = WhisperProcessor.from_pretrained(base_model)
25
  self.stt_model = WhisperForConditionalGeneration.from_pretrained(base_model)
26
  self.stt_model.to(self.device)
27
  self.stt_model.eval()
28
-
29
  def _sanitize_parameters(self, threshold=None, **kwargs):
 
 
 
 
 
30
  forward_kwargs = {}
 
 
31
  if threshold is not None:
32
  forward_kwargs["threshold"] = threshold
33
- # FIX 3: must return exactly 3 dicts (preprocess, forward, postprocess)
34
- return {}, forward_kwargs, {}
35
-
36
  def preprocess(self, inputs, **kwargs):
37
- # FIX 4: soundfile.read returns (data, samplerate) — unpack correctly
38
- audio_np, sr = sf.read(inputs)
39
-
40
- # Stereo → mono
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  if len(audio_np.shape) > 1:
42
  audio_np = np.mean(audio_np, axis=1)
43
-
44
  # Resample to 16kHz if needed
45
  if sr != 16000:
46
  num_samples = int(len(audio_np) * 16000 / sr)
47
  audio_np = resample(audio_np, num_samples).astype(np.float32)
48
-
 
49
  features = self.processor(
50
  audio_np,
51
  sampling_rate=16000,
52
  return_tensors="pt"
53
  ).input_features.to(self.device)
54
-
55
- return {"features": features}
56
-
57
  def _forward(self, model_inputs, threshold=None, **kwargs):
58
- # FIX 5: use instance threshold as fallback, not hardcoded 0.2
 
 
 
 
 
 
 
 
 
 
59
  threshold = threshold if threshold is not None else self.threshold
60
- features = model_inputs["features"]
61
-
62
- # Transcription
63
- attn_mask = torch.ones(
64
- features.shape[:2], dtype=torch.long, device=self.device
65
- )
66
  with torch.no_grad():
67
- ids = self.stt_model.generate(
 
 
 
 
 
 
 
 
68
  features,
69
- attention_mask=attn_mask, # FIX 6: prevents pad==eos warning
70
  language="en",
71
  task="transcribe",
72
- suppress_tokens=[], # FIX 7: prevents duplicate processor warning
73
  )
74
- transcript = self.processor.batch_decode(
75
- ids, skip_special_tokens=True
76
- )[0].strip()
77
-
78
- # Classification
 
 
 
79
  with torch.no_grad():
80
- probs = F.softmax(self.model(features).logits, dim=-1)[0]
81
- mal_prob = probs[1].item()
 
82
  safe_prob = probs[0].item()
83
-
84
- label = "malicious" if mal_prob >= threshold else "safe"
 
 
85
  confidence = mal_prob if label == "malicious" else safe_prob
86
-
87
  return {
88
- "transcript": transcript,
89
- "label": label,
90
- "confidence": round(confidence, 6),
 
91
  "p_malicious": round(mal_prob, 6),
92
- "p_safe": round(safe_prob, 6),
93
- "threshold": threshold,
94
  }
95
-
96
  def postprocess(self, model_outputs, **kwargs):
97
- return model_outputs
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ VoiceShield Pipeline for audio classification and transcription
3
+ """
4
  import torch
5
  import torch.nn.functional as F
6
  import numpy as np
 
10
 
11
 
12
  class VoiceShieldPipeline(Pipeline):
13
+ """
14
+ Pipeline for VoiceShield audio classification.
15
+
16
+ Combines transcription (via Whisper) with malicious audio detection.
17
+
18
+ Args:
19
+ model: VoiceShield classification model
20
+ threshold: Confidence threshold for malicious classification (default: 0.2)
21
+ """
22
+
23
  def __init__(self, model, threshold=0.2, **kwargs):
24
  self.threshold = threshold
25
+
26
+ # Pipeline requires tokenizer parameter, pass None for audio tasks
 
 
27
  kwargs.setdefault("tokenizer", None)
 
28
  super().__init__(model=model, **kwargs)
29
+
30
+ # Load processor and STT model after super().__init__() so self.device is set
31
  base_model = model.config.base_model
 
 
32
  self.processor = WhisperProcessor.from_pretrained(base_model)
33
  self.stt_model = WhisperForConditionalGeneration.from_pretrained(base_model)
34
  self.stt_model.to(self.device)
35
  self.stt_model.eval()
36
+
37
  def _sanitize_parameters(self, threshold=None, **kwargs):
38
+ """
39
+ Sanitize parameters for preprocess, forward, and postprocess.
40
+ Must return exactly 3 dictionaries.
41
+ """
42
+ preprocess_kwargs = {}
43
  forward_kwargs = {}
44
+ postprocess_kwargs = {}
45
+
46
  if threshold is not None:
47
  forward_kwargs["threshold"] = threshold
48
+
49
+ return preprocess_kwargs, forward_kwargs, postprocess_kwargs
50
+
51
  def preprocess(self, inputs, **kwargs):
52
+ """
53
+ Preprocess audio input.
54
+
55
+ Args:
56
+ inputs: Path to audio file or numpy array
57
+
58
+ Returns:
59
+ Dictionary with processed features
60
+ """
61
+ # Load audio file
62
+ if isinstance(inputs, str):
63
+ audio_np, sr = sf.read(inputs)
64
+ else:
65
+ audio_np = inputs
66
+ sr = kwargs.get("sampling_rate", 16000)
67
+
68
+ # Convert stereo to mono
69
  if len(audio_np.shape) > 1:
70
  audio_np = np.mean(audio_np, axis=1)
71
+
72
  # Resample to 16kHz if needed
73
  if sr != 16000:
74
  num_samples = int(len(audio_np) * 16000 / sr)
75
  audio_np = resample(audio_np, num_samples).astype(np.float32)
76
+
77
+ # Process with Whisper processor
78
  features = self.processor(
79
  audio_np,
80
  sampling_rate=16000,
81
  return_tensors="pt"
82
  ).input_features.to(self.device)
83
+
84
+ return {"features": features, "audio": audio_np}
85
+
86
  def _forward(self, model_inputs, threshold=None, **kwargs):
87
+ """
88
+ Forward pass: transcribe and classify audio.
89
+
90
+ Args:
91
+ model_inputs: Preprocessed features
92
+ threshold: Classification threshold
93
+
94
+ Returns:
95
+ Dictionary with transcript, label, and confidence scores
96
+ """
97
+ # Use instance threshold as fallback
98
  threshold = threshold if threshold is not None else self.threshold
99
+ features = model_inputs["features"]
100
+
101
+ # Generate transcription
 
 
 
102
  with torch.no_grad():
103
+ # Create attention mask
104
+ attn_mask = torch.ones(
105
+ features.shape[:2],
106
+ dtype=torch.long,
107
+ device=self.device
108
+ )
109
+
110
+ # Generate transcript
111
+ generated_ids = self.stt_model.generate(
112
  features,
113
+ attention_mask=attn_mask,
114
  language="en",
115
  task="transcribe",
116
+ suppress_tokens=[], # Prevents duplicate processor warning
117
  )
118
+
119
+ # Decode transcript
120
+ transcript = self.processor.batch_decode(
121
+ generated_ids,
122
+ skip_special_tokens=True
123
+ )[0].strip()
124
+
125
+ # Classify audio
126
  with torch.no_grad():
127
+ outputs = self.model(features)
128
+ probs = F.softmax(outputs.logits, dim=-1)[0]
129
+
130
  safe_prob = probs[0].item()
131
+ mal_prob = probs[1].item()
132
+
133
+ # Determine label and confidence
134
+ label = "malicious" if mal_prob >= threshold else "safe"
135
  confidence = mal_prob if label == "malicious" else safe_prob
136
+
137
  return {
138
+ "transcript": transcript,
139
+ "label": label,
140
+ "confidence": round(confidence, 6),
141
+ "p_safe": round(safe_prob, 6),
142
  "p_malicious": round(mal_prob, 6),
143
+ "threshold": threshold,
 
144
  }
145
+
146
  def postprocess(self, model_outputs, **kwargs):
147
+ """
148
+ Postprocess model outputs.
149
+
150
+ Args:
151
+ model_outputs: Outputs from forward pass
152
+
153
+ Returns:
154
+ Final formatted outputs
155
+ """
156
+ return model_outputs