ROSHANNN123 commited on
Commit
922c67e
·
verified ·
1 Parent(s): 045c92b

Update model_service.py

Browse files
Files changed (1) hide show
  1. model_service.py +10 -15
model_service.py CHANGED
@@ -11,25 +11,23 @@ MODEL_NAME = "Hemgg/Deepfake-audio-detection"
11
 
12
  class ModelService:
13
  def __init__(self):
14
- print(f"Loading model: {MODEL_NAME}...")
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
  self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
17
  self.model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(self.device)
18
- print(f"Model loaded on {self.device}")
19
 
20
  def preprocess_audio(self, audio_bytes):
21
- # Using a temporary file is the most robust way to handle MP3/WAV with FFmpeg
22
  fd, tmp_path = tempfile.mkstemp(suffix=".audio")
23
  try:
24
  with os.fdopen(fd, 'wb') as tmp:
25
  tmp.write(audio_bytes)
26
 
27
- # Load and resample to 16kHz
28
  speech, _ = librosa.load(tmp_path, sr=16000)
29
  return speech
30
  except Exception as e:
31
- print(f"Error processing audio: {e}")
32
- raise ValueError(f"Invalid audio format: {str(e)}")
33
  finally:
34
  if os.path.exists(tmp_path):
35
  os.remove(tmp_path)
@@ -45,16 +43,13 @@ class ModelService:
45
  probs = F.softmax(logits, dim=-1)
46
  id2label = self.model.config.id2label
47
  predicted_id = torch.argmax(probs, dim=-1).item()
48
- predicted_label = id2label[predicted_id]
49
- confidence = probs[0][predicted_id].item()
 
 
 
 
50
 
51
- lower_label = predicted_label.lower()
52
- if "real" in lower_label or "human" in lower_label or "bonafide" in lower_label:
53
- return "HUMAN", confidence
54
- else:
55
- return "AI_GENERATED", confidence
56
-
57
- # Singleton
58
  model_service = None
59
  def get_model_service():
60
  global model_service
 
11
 
12
  class ModelService:
13
  def __init__(self):
14
+ print("Loading AI Model...")
15
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
16
  self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_NAME)
17
  self.model = AutoModelForAudioClassification.from_pretrained(MODEL_NAME).to(self.device)
 
18
 
19
  def preprocess_audio(self, audio_bytes):
20
+ # Temp file is the safest way to read MP3/WAV/OGG on cloud servers
21
  fd, tmp_path = tempfile.mkstemp(suffix=".audio")
22
  try:
23
  with os.fdopen(fd, 'wb') as tmp:
24
  tmp.write(audio_bytes)
25
 
26
+ # Load and resample to 16kHz (Standard for Wav2Vec2)
27
  speech, _ = librosa.load(tmp_path, sr=16000)
28
  return speech
29
  except Exception as e:
30
+ raise ValueError(f"Audio processing failed: {str(e)}")
 
31
  finally:
32
  if os.path.exists(tmp_path):
33
  os.remove(tmp_path)
 
43
  probs = F.softmax(logits, dim=-1)
44
  id2label = self.model.config.id2label
45
  predicted_id = torch.argmax(probs, dim=-1).item()
46
+
47
+ # Mapping to Portal Labels
48
+ lbl = id2label[predicted_id].lower()
49
+ if "real" in lbl or "human" in lbl or "bonafide" in lbl:
50
+ return "HUMAN", probs[0][predicted_id].item()
51
+ return "AI_GENERATED", probs[0][predicted_id].item()
52
 
 
 
 
 
 
 
 
53
  model_service = None
54
  def get_model_service():
55
  global model_service