File size: 2,152 Bytes
7146854
 
 
 
045c92b
 
7146854
045c92b
7146854
045c92b
7146854
 
 
922c67e
7146854
045c92b
 
7146854
045c92b
922c67e
14fb181
7146854
14fb181
 
7146854
922c67e
14fb181
7146854
 
922c67e
14fb181
 
 
7146854
 
 
 
 
 
 
 
 
 
 
 
922c67e
 
 
 
 
 
7146854
 
 
 
 
 
045c92b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import librosa
import numpy as np
import io
import os
import tempfile
import torch.nn.functional as F
from transformers import AutoModelForAudioClassification, Wav2Vec2FeatureExtractor

MODEL_NAME = "Hemgg/Deepfake-audio-detection"

class ModelService:
    def __init__(self):
        print("Loading AI Model...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
        self.model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(self.device)

    def preprocess_audio(self, audio_bytes):
        # Temp file is the safest way to read MP3/WAV/OGG on cloud servers
        fd, tmp_path = tempfile.mkstemp(suffix=".audio")
        try:
            with os.fdopen(fd, 'wb') as tmp:
                tmp.write(audio_bytes)
            
            # Load and resample to 16kHz (Standard for Wav2Vec2)
            speech, _ = librosa.load(tmp_path, sr=16000)
            return speech
        except Exception as e:
            raise ValueError(f"Audio processing failed: {str(e)}")
        finally:
            if os.path.exists(tmp_path):
                os.remove(tmp_path)

    def predict(self, audio_bytes):
        speech = self.preprocess_audio(audio_bytes)
        inputs = self.feature_extractor(speech, sampling_rate=16000, return_tensors="pt", padding=True)
        inputs = {key: val.to(self.device) for key, val in inputs.items()}

        with torch.no_grad():
            logits = self.model(**inputs).logits
            
        probs = F.softmax(logits, dim=-1)
        id2label = self.model.config.id2label
        predicted_id = torch.argmax(probs, dim=-1).item()
        
        # Mapping to Portal Labels
        lbl = id2label[predicted_id].lower()
        if "real" in lbl or "human" in lbl or "bonafide" in lbl:
            return "HUMAN", probs[0][predicted_id].item()
        return "AI_GENERATED", probs[0][predicted_id].item()

model_service = None
def get_model_service():
    global model_service
    if model_service is None:
        model_service = ModelService()
    return model_service