Spaces:
Sleeping
Sleeping
Update model_service.py
Browse files- model_service.py +10 -15
model_service.py
CHANGED
|
@@ -11,25 +11,23 @@ MODEL_NAME = "Hemgg/Deepfake-audio-detection"
|
|
| 11 |
|
| 12 |
class ModelService:
|
| 13 |
def __init__(self):
|
| 14 |
-
print(
|
| 15 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
|
| 17 |
self.model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(self.device)
|
| 18 |
-
print(f"Model loaded on {self.device}")
|
| 19 |
|
| 20 |
def preprocess_audio(self, audio_bytes):
|
| 21 |
-
#
|
| 22 |
fd, tmp_path = tempfile.mkstemp(suffix=".audio")
|
| 23 |
try:
|
| 24 |
with os.fdopen(fd, 'wb') as tmp:
|
| 25 |
tmp.write(audio_bytes)
|
| 26 |
|
| 27 |
-
# Load and resample to 16kHz
|
| 28 |
speech, _ = librosa.load(tmp_path, sr=16000)
|
| 29 |
return speech
|
| 30 |
except Exception as e:
|
| 31 |
-
|
| 32 |
-
raise ValueError(f"Invalid audio format: {str(e)}")
|
| 33 |
finally:
|
| 34 |
if os.path.exists(tmp_path):
|
| 35 |
os.remove(tmp_path)
|
|
@@ -45,16 +43,13 @@ class ModelService:
|
|
| 45 |
probs = F.softmax(logits, dim=-1)
|
| 46 |
id2label = self.model.config.id2label
|
| 47 |
predicted_id = torch.argmax(probs, dim=-1).item()
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
lower_label = predicted_label.lower()
|
| 52 |
-
if "real" in lower_label or "human" in lower_label or "bonafide" in lower_label:
|
| 53 |
-
return "HUMAN", confidence
|
| 54 |
-
else:
|
| 55 |
-
return "AI_GENERATED", confidence
|
| 56 |
-
|
| 57 |
-
# Singleton
|
| 58 |
model_service = None
|
| 59 |
def get_model_service():
|
| 60 |
global model_service
|
|
|
|
| 11 |
|
| 12 |
class ModelService:
|
| 13 |
def __init__(self):
|
| 14 |
+
print("Loading AI Model...")
|
| 15 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
|
| 17 |
self.model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(self.device)
|
|
|
|
| 18 |
|
| 19 |
def preprocess_audio(self, audio_bytes):
|
| 20 |
+
# Temp file is the safest way to read MP3/WAV/OGG on cloud servers
|
| 21 |
fd, tmp_path = tempfile.mkstemp(suffix=".audio")
|
| 22 |
try:
|
| 23 |
with os.fdopen(fd, 'wb') as tmp:
|
| 24 |
tmp.write(audio_bytes)
|
| 25 |
|
| 26 |
+
# Load and resample to 16kHz (Standard for Wav2Vec2)
|
| 27 |
speech, _ = librosa.load(tmp_path, sr=16000)
|
| 28 |
return speech
|
| 29 |
except Exception as e:
|
| 30 |
+
raise ValueError(f"Audio processing failed: {str(e)}")
|
|
|
|
| 31 |
finally:
|
| 32 |
if os.path.exists(tmp_path):
|
| 33 |
os.remove(tmp_path)
|
|
|
|
| 43 |
probs = F.softmax(logits, dim=-1)
|
| 44 |
id2label = self.model.config.id2label
|
| 45 |
predicted_id = torch.argmax(probs, dim=-1).item()
|
| 46 |
+
|
| 47 |
+
# Mapping to Portal Labels
|
| 48 |
+
lbl = id2label[predicted_id].lower()
|
| 49 |
+
if "real" in lbl or "human" in lbl or "bonafide" in lbl:
|
| 50 |
+
return "HUMAN", probs[0][predicted_id].item()
|
| 51 |
+
return "AI_GENERATED", probs[0][predicted_id].item()
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
model_service = None
|
| 54 |
def get_model_service():
|
| 55 |
global model_service
|