Aloukik21 commited on
Commit
cef9fed
·
verified ·
1 Parent(s): b710a14

Upload audio/DF_Arena_1B_V_1/pipeline_antispoofing.py with huggingface_hub

Browse files
audio/DF_Arena_1B_V_1/pipeline_antispoofing.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+ import torch
3
+ from .feature_extraction_antispoofing import AntispoofingFeatureExtractor
4
+ class AntispoofingPipeline(Pipeline):
5
+ def __init__(self, model, **kwargs):
6
+ super().__init__(model=model, **kwargs)
7
+ self.feature_extractor = AntispoofingFeatureExtractor()
8
+
9
+ def _sanitize_parameters(self, **kwargs):
10
+ preprocess_kwargs = {}
11
+ postprocess_kwargs = {}
12
+
13
+ if "sampling_rate" in kwargs:
14
+ preprocess_kwargs["sampling_rate"] = kwargs["sampling_rate"]
15
+
16
+ return preprocess_kwargs, {}, postprocess_kwargs
17
+
18
+ def preprocess(self, audio, sampling_rate=16000):
19
+ audio = self.feature_extractor(audio)['input_values']
20
+ inputs = {"input_values": audio}
21
+
22
+ return inputs
23
+
24
+ def _forward(self, model_inputs):
25
+ outputs = self.model(**model_inputs)
26
+ return outputs
27
+
28
+ def postprocess(self, model_outputs):
29
+ logits = model_outputs['logits']
30
+ probs = torch.nn.functional.softmax(logits, dim=-1)
31
+ predicted_class = torch.argmax(probs, dim=-1).item()
32
+ confidence = probs[0][predicted_class].item()
33
+
34
+ return {
35
+ "label": self.model.config.id2label[predicted_class],
36
+ "logits": logits.tolist(),
37
+ "score": confidence,
38
+ "all_scores": {
39
+ self.model.config.id2label[i]: probs[0][i].item()
40
+ for i in range(len(probs[0]))
41
+ }
42
+ }