ROSHANNN123 commited on
Commit
7146854
·
verified ·
1 Parent(s): 690ba94

Update model_service.py

Browse files
Files changed (1) hide show
  1. model_service.py +92 -92
model_service.py CHANGED
@@ -1,92 +1,92 @@
1
- import torch
2
- import librosa
3
- import numpy as np
4
- import io
5
- import soundfile as sf
6
- from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
7
- import torch.nn.functional as F
8
-
9
- # Configuration
10
- MODEL_NAME = "Hemgg/Deepfake-audio-detection" # Using a known fine-tuned model
11
- # Alternative: "mo-thecreator/Deepfake-audio-detection" if the above fails or is private
12
- # But usually public models are fine.
13
-
14
- class ModelService:
15
- def __init__(self):
16
- print(f"Loading model: {MODEL_NAME}...")
17
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
- try:
19
- self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
20
- self.model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(self.device)
21
- print(f"Model loaded on {self.device}")
22
- except Exception as e:
23
- print(f"Error loading model: {e}")
24
- raise e
25
-
26
- def preprocess_audio(self, audio_bytes):
27
- """
28
- Load audio bytes, resample to 16000 Hz (required by Wav2Vec2).
29
- """
30
- try:
31
- # Load audio from bytes
32
- # librosa.load supports file-like objects
33
- audio_file = io.BytesIO(audio_bytes)
34
-
35
- # Load and resample to 16k
36
- speech, sr = librosa.load(audio_file, sr=16000)
37
-
38
- # Ensure it's mono (if multi-channel, average them) - librosa.load handles this by default (mono=True)
39
-
40
- return speech
41
- except Exception as e:
42
- print(f"Error processing audio: {e}")
43
- raise ValueError("Invalid audio format or corrupted file.")
44
-
45
- def predict(self, audio_bytes):
46
- speech = self.preprocess_audio(audio_bytes)
47
-
48
- # Tokenize (extract features)
49
- inputs = self.feature_extractor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
50
- inputs = {key: val.to(self.device) for key, val in inputs.items()}
51
-
52
- with torch.no_grad():
53
- logits = self.model(**inputs).logits
54
-
55
- # Get probabilities
56
- probs = F.softmax(logits, dim=-1)
57
-
58
- # The model usually outputs [real, fake] or [fake, real].
59
- # We need to verify the label mapping.
60
- # Typically, id2label is stored in the config.
61
- id2label = self.model.config.id2label
62
- # Example id2label: {0: 'real', 1: 'fake'} or similar.
63
-
64
- predicted_id = torch.argmax(probs, dim=-1).item()
65
- predicted_label = id2label[predicted_id]
66
- confidence = probs[0][predicted_id].item()
67
-
68
- # Map to required output format "AI_GENERATED" or "HUMAN"
69
- # Adjust based on specific model labels.
70
- # Assuming common labels like "real"/"spoof" or "human"/"ai"
71
- normalized_label = "UNKNOWN"
72
-
73
- lower_label = predicted_label.lower()
74
- if "real" in lower_label or "human" in lower_label or "bonafide" in lower_label:
75
- normalized_label = "HUMAN"
76
- elif "fake" in lower_label or "spoof" in lower_label or "ai" in lower_label:
77
- normalized_label = "AI_GENERATED"
78
- else:
79
- # Fallback if labels are obscure, typically 0 is real, 1 is fake for many datasets but not all.
80
- # We trust the string matching first.
81
- normalized_label = predicted_label
82
-
83
- return normalized_label, confidence
84
-
85
- # Singleton instance
86
- model_service = None
87
-
88
- def get_model_service():
89
- global model_service
90
- if model_service is None:
91
- model_service = ModelService()
92
- return model_service
 
1
+ import torch
2
+ import librosa
3
+ import numpy as np
4
+ import io
5
+ import soundfile as sf
6
+ from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor
7
+ import torch.nn.functional as F
8
+
9
+ # Configuration
10
+ MODEL_NAME = "Hemgg/Deepfake-audio-detection" # Using a known fine-tuned model
11
+ # Alternative: "mo-thecreator/Deepfake-audio-detection" if the above fails or is private
12
+ # But usually public models are fine.
13
+
14
+ class ModelService:
15
+ def __init__(self):
16
+ print(f"Loading model: {MODEL_NAME}...")
17
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
18
+ try:
19
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
20
+ self.model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(self.device)
21
+ print(f"Model loaded on {self.device}")
22
+ except Exception as e:
23
+ print(f"Error loading model: {e}")
24
+ raise e
25
+
26
+ def preprocess_audio(self, audio_bytes):
27
+ """
28
+ Load audio bytes, resample to 16000 Hz (required by Wav2Vec2).
29
+ """
30
+ try:
31
+ # Load audio from bytes
32
+ # librosa.load supports file-like objects
33
+ audio_file = io.BytesIO(audio_bytes)
34
+
35
+ # Load and resample to 16k
36
+ speech, sr = librosa.load(audio_file, sr=16000)
37
+
38
+ # Ensure it's mono (if multi-channel, average them) - librosa.load handles this by default (mono=True)
39
+
40
+ return speech
41
+ except Exception as e:
42
+ print(f"Error processing audio: {e}")
43
+ raise ValueError("Invalid audio format or corrupted file: {str(e)}")
44
+
45
+ def predict(self, audio_bytes):
46
+ speech = self.preprocess_audio(audio_bytes)
47
+
48
+ # Tokenize (extract features)
49
+ inputs = self.feature_extractor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
50
+ inputs = {key: val.to(self.device) for key, val in inputs.items()}
51
+
52
+ with torch.no_grad():
53
+ logits = self.model(**inputs).logits
54
+
55
+ # Get probabilities
56
+ probs = F.softmax(logits, dim=-1)
57
+
58
+ # The model usually outputs [real, fake] or [fake, real].
59
+ # We need to verify the label mapping.
60
+ # Typically, id2label is stored in the config.
61
+ id2label = self.model.config.id2label
62
+ # Example id2label: {0: 'real', 1: 'fake'} or similar.
63
+
64
+ predicted_id = torch.argmax(probs, dim=-1).item()
65
+ predicted_label = id2label[predicted_id]
66
+ confidence = probs[0][predicted_id].item()
67
+
68
+ # Map to required output format "AI_GENERATED" or "HUMAN"
69
+ # Adjust based on specific model labels.
70
+ # Assuming common labels like "real"/"spoof" or "human"/"ai"
71
+ normalized_label = "UNKNOWN"
72
+
73
+ lower_label = predicted_label.lower()
74
+ if "real" in lower_label or "human" in lower_label or "bonafide" in lower_label:
75
+ normalized_label = "HUMAN"
76
+ elif "fake" in lower_label or "spoof" in lower_label or "ai" in lower_label:
77
+ normalized_label = "AI_GENERATED"
78
+ else:
79
+ # Fallback if labels are obscure, typically 0 is real, 1 is fake for many datasets but not all.
80
+ # We trust the string matching first.
81
+ normalized_label = predicted_label
82
+
83
+ return normalized_label, confidence
84
+
85
+ # Singleton instance
86
+ model_service = None
87
+
88
+ def get_model_service():
89
+ global model_service
90
+ if model_service is None:
91
+ model_service = ModelService()
92
+ return model_service