aneeshm44 commited on
Commit
cbaab68
·
verified ·
1 Parent(s): a9452c0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchaudio
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.models as models
7
+ import tempfile
8
+ import os
9
+
10
+ st.markdown("""
11
+ <style>
12
+ body {
13
+ background-color: #f5f5f5;
14
+ }
15
+ .main {
16
+ background-color: white;
17
+ padding: 2rem;
18
+ border-radius: 10px;
19
+ margin: 2rem auto;
20
+ max-width: 800px;
21
+ }
22
+ .stButton>button {
23
+ background-color: #4CAF50;
24
+ color: white;
25
+ border: none;
26
+ padding: 0.5rem 1rem;
27
+ border-radius: 5px;
28
+ }
29
+ </style>
30
+ """, unsafe_allow_html=True)
31
+
32
+ def load_and_process_audio(file_path, target_length=3.0, sample_rate=16000):
33
+ waveform, sr = torchaudio.load(file_path)
34
+ if waveform.shape[0] > 1:
35
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
36
+ if sr != sample_rate:
37
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
38
+ waveform = resampler(waveform)
39
+ target_samples = int(target_length * sample_rate)
40
+ current_samples = waveform.shape[1]
41
+ if current_samples > target_samples:
42
+ start = (current_samples - target_samples) // 2
43
+ waveform = waveform[:, start:start+target_samples]
44
+ elif current_samples < target_samples:
45
+ padding = target_samples - current_samples
46
+ pad_left = padding // 2
47
+ pad_right = padding - pad_left
48
+ waveform = F.pad(waveform, (pad_left, pad_right))
49
+ return waveform
50
+
51
+ def extract_melspectrogram(waveform, sample_rate=16000, n_mels=80, n_fft=1024, hop_length=512):
52
+ mel_spec = torchaudio.transforms.MelSpectrogram(
53
+ sample_rate=sample_rate,
54
+ n_fft=n_fft,
55
+ hop_length=hop_length,
56
+ n_mels=n_mels
57
+ )(waveform)
58
+ mel_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec)
59
+ return mel_spec
60
+
61
+ class AudioDeepfakeMODEL(nn.Module):
62
+ def __init__(self, num_classes=2):
63
+ super(AudioDeepfakeMODEL, self).__init__()
64
+ self.resnet = models.resnet18(pretrained=True)
65
+ self.resnet.conv1 = nn.Conv2d(
66
+ in_channels=1,
67
+ out_channels=64,
68
+ kernel_size=(7, 7),
69
+ stride=(2, 2),
70
+ padding=(3, 3),
71
+ bias=False
72
+ )
73
+ self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
74
+ for param in self.resnet.parameters():
75
+ param.requires_grad = False
76
+ for param in self.resnet.fc.parameters():
77
+ param.requires_grad = True
78
+
79
+ def forward(self, x):
80
+ return self.resnet(x)
81
+
82
+ def predict_audio_deepfake(wav_path, checkpoint_path, device):
83
+ model = AudioDeepfakeMODEL()
84
+ checkpoint = torch.load(checkpoint_path, map_location=device)
85
+ model.load_state_dict(checkpoint['model_state_dict'])
86
+ model.to(device)
87
+ model.eval()
88
+
89
+ waveform = load_and_process_audio(wav_path)
90
+ mel_spec = extract_melspectrogram(waveform)
91
+ mel_spec = mel_spec.unsqueeze(0).to(device)
92
+ with torch.no_grad():
93
+ outputs = model(mel_spec)
94
+ _, predicted = torch.max(outputs, 1)
95
+ return predicted.item()
96
+
97
+ def main():
98
+ st.markdown('<div class="main">', unsafe_allow_html=True)
99
+ st.title("Audio Deepfake Detector")
100
+ st.write("Upload a **.wav** file to check if it's **Real** or **Fake**.")
101
+
102
+ # File uploader widget
103
+ uploaded_file = st.file_uploader("Choose a .wav file", type=["wav"])
104
+
105
+ if st.button("Reset"):
106
+ st.experimental_rerun()
107
+
108
+ if uploaded_file is not None:
109
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
110
+ tmp_file.write(uploaded_file.read())
111
+ tmp_path = tmp_file.name
112
+
113
+ st.audio(uploaded_file, format="audio/wav")
114
+
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ checkpoint_path = "best_model.pth"
117
+ try:
118
+ result = predict_audio_deepfake(tmp_path, checkpoint_path, device)
119
+ label = "Real" if result == 1 else "Fake"
120
+ st.success(f"Prediction: **{label}**")
121
+ except Exception as e:
122
+ st.error(f"Error during prediction: {e}")
123
+ finally:
124
+ os.remove(tmp_path)
125
+ st.markdown("</div>", unsafe_allow_html=True)
126
+
127
+ if __name__ == "__main__":
128
+ main()