ZeyadMostafa22 commited on
Commit ·
4dc47c8
1
Parent(s): c629c7c
finall
Browse files
app.py
CHANGED
|
@@ -3,7 +3,6 @@ import torch
|
|
| 3 |
import torchaudio
|
| 4 |
import numpy as np
|
| 5 |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
| 6 |
-
import torch.nn.functional as F
|
| 7 |
import torchaudio.transforms as T
|
| 8 |
|
| 9 |
MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
|
|
@@ -19,13 +18,15 @@ model.to(device)
|
|
| 19 |
|
| 20 |
label_names = ["fake", "real"] # According to your label2id = {"fake": 0, "real": 1}
|
| 21 |
|
|
|
|
| 22 |
def classify_audio(audio_file):
|
| 23 |
"""
|
| 24 |
audio_file: path to the uploaded file (WAV, MP3, etc.)
|
| 25 |
-
Returns:
|
| 26 |
"""
|
| 27 |
|
| 28 |
# 2) Load the audio file
|
|
|
|
| 29 |
waveform, sr = torchaudio.load(audio_file)
|
| 30 |
|
| 31 |
# If stereo, pick one channel or average
|
|
@@ -39,13 +40,14 @@ def classify_audio(audio_file):
|
|
| 39 |
waveform = resampler(waveform)
|
| 40 |
sr = 16000
|
| 41 |
|
|
|
|
| 42 |
# 3) Preprocess with feature_extractor
|
| 43 |
inputs = feature_extractor(
|
| 44 |
waveform.numpy(),
|
| 45 |
sampling_rate=sr,
|
| 46 |
return_tensors="pt",
|
| 47 |
truncation=True,
|
| 48 |
-
max_length=int(16000
|
| 49 |
)
|
| 50 |
|
| 51 |
# Move everything to device
|
|
@@ -53,24 +55,20 @@ def classify_audio(audio_file):
|
|
| 53 |
|
| 54 |
with torch.no_grad():
|
| 55 |
logits = model(input_values).logits
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
# Get predicted label and confidence
|
| 61 |
-
confidence, pred_id = torch.max(probabilities, dim=-1)
|
| 62 |
-
predicted_label = label_names[pred_id.item()]
|
| 63 |
|
| 64 |
-
# 5) Return label and confidence percentage
|
| 65 |
-
return f"Prediction: {predicted_label}, Confidence: {confidence.item() * 100:.2f}%"
|
| 66 |
|
| 67 |
-
#
|
| 68 |
demo = gr.Interface(
|
| 69 |
fn=classify_audio,
|
| 70 |
-
inputs=gr.Audio(type="filepath"),
|
| 71 |
outputs="text",
|
| 72 |
title="Wav2Vec2 Deepfake Detection",
|
| 73 |
-
description="Upload an audio sample to check if it is fake or real
|
| 74 |
)
|
| 75 |
|
| 76 |
if __name__ == "__main__":
|
|
|
|
| 3 |
import torchaudio
|
| 4 |
import numpy as np
|
| 5 |
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
|
|
|
|
| 6 |
import torchaudio.transforms as T
|
| 7 |
|
| 8 |
MODEL_ID = "Zeyadd-Mostaffa/wav2vec_checkpoints"
|
|
|
|
| 18 |
|
| 19 |
label_names = ["fake", "real"] # According to your label2id = {"fake": 0, "real": 1}
|
| 20 |
|
| 21 |
+
|
| 22 |
def classify_audio(audio_file):
|
| 23 |
"""
|
| 24 |
audio_file: path to the uploaded file (WAV, MP3, etc.)
|
| 25 |
+
Returns: "fake" or "real"
|
| 26 |
"""
|
| 27 |
|
| 28 |
# 2) Load the audio file
|
| 29 |
+
# torchaudio returns (waveform, sample_rate)
|
| 30 |
waveform, sr = torchaudio.load(audio_file)
|
| 31 |
|
| 32 |
# If stereo, pick one channel or average
|
|
|
|
| 40 |
waveform = resampler(waveform)
|
| 41 |
sr = 16000
|
| 42 |
|
| 43 |
+
|
| 44 |
# 3) Preprocess with feature_extractor
|
| 45 |
inputs = feature_extractor(
|
| 46 |
waveform.numpy(),
|
| 47 |
sampling_rate=sr,
|
| 48 |
return_tensors="pt",
|
| 49 |
truncation=True,
|
| 50 |
+
max_length=int(16000* 6.0), # 6 second max
|
| 51 |
)
|
| 52 |
|
| 53 |
# Move everything to device
|
|
|
|
| 55 |
|
| 56 |
with torch.no_grad():
|
| 57 |
logits = model(input_values).logits
|
| 58 |
+
pred_id = torch.argmax(logits, dim=-1).item()
|
| 59 |
|
| 60 |
+
# 4) Return label text
|
| 61 |
+
predicted_label = label_names[pred_id]
|
| 62 |
+
return predicted_label
|
|
|
|
|
|
|
|
|
|
| 63 |
|
|
|
|
|
|
|
| 64 |
|
| 65 |
+
# 5) Build Gradio interface
|
| 66 |
demo = gr.Interface(
|
| 67 |
fn=classify_audio,
|
| 68 |
+
inputs=gr.Audio( type="filepath"),
|
| 69 |
outputs="text",
|
| 70 |
title="Wav2Vec2 Deepfake Detection",
|
| 71 |
+
description="Upload an audio sample to check if it is fake or real."
|
| 72 |
)
|
| 73 |
|
| 74 |
if __name__ == "__main__":
|