| import streamlit as st | |
| import torch | |
| import torchaudio | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| import tempfile | |
| import os | |
| import uuid | |
| if "uploaded_file" not in st.session_state: | |
| st.session_state.uploaded_file = None | |
| if "uploader_key" not in st.session_state: | |
| st.session_state.uploader_key = str(uuid.uuid4()) | |
| def reset_state(): | |
| st.session_state.uploaded_file = None | |
| st.session_state.uploader_key = str(uuid.uuid4()) | |
| st.rerun() | |
| st.markdown( | |
| """ | |
| <style> | |
| body { background-color: #f5f5f5; } | |
| .main { background-color: white; padding: 2rem; border-radius: 10px; margin: 2rem auto; max-width: 800px; } | |
| .stButton>button { background-color: #4CAF50; color: white; border: none; padding: 0.5rem 1rem; border-radius: 5px; } | |
| </style> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| def load_and_process_audio(file_path, target_length=3.0, sample_rate=16000): | |
| waveform, sr = torchaudio.load(file_path) | |
| if waveform.shape[0] > 1: | |
| waveform = torch.mean(waveform, dim=0, keepdim=True) | |
| if sr != sample_rate: | |
| resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) | |
| waveform = resampler(waveform) | |
| target_samples = int(target_length * sample_rate) | |
| current_samples = waveform.shape[1] | |
| if current_samples > target_samples: | |
| start = (current_samples - target_samples) // 2 | |
| waveform = waveform[:, start : start + target_samples] | |
| elif current_samples < target_samples: | |
| padding = target_samples - current_samples | |
| pad_left = padding // 2 | |
| pad_right = padding - pad_left | |
| waveform = F.pad(waveform, (pad_left, pad_right)) | |
| return waveform | |
| def extract_melspectrogram(waveform, sample_rate=16000, n_mels=80, n_fft=1024, hop_length=512): | |
| mel_spec = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels | |
| )(waveform) | |
| mel_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec) | |
| return mel_spec | |
| class AudioDeepfakeMODEL(nn.Module): | |
| def __init__(self, num_classes=2): | |
| super(AudioDeepfakeMODEL, self).__init__() | |
| self.resnet = models.resnet18(pretrained=True) | |
| self.resnet.conv1 = nn.Conv2d( | |
| in_channels=1, | |
| out_channels=64, | |
| kernel_size=(7, 7), | |
| stride=(2, 2), | |
| padding=(3, 3), | |
| bias=False, | |
| ) | |
| self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes) | |
| for param in self.resnet.parameters(): | |
| param.requires_grad = False | |
| for param in self.resnet.fc.parameters(): | |
| param.requires_grad = True | |
| def forward(self, x): | |
| return self.resnet(x) | |
| def predict_audio_deepfake(wav_path, checkpoint_path, device): | |
| model = AudioDeepfakeMODEL() | |
| checkpoint = torch.load(checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.to(device) | |
| model.eval() | |
| waveform = load_and_process_audio(wav_path) | |
| mel_spec = extract_melspectrogram(waveform) | |
| mel_spec = mel_spec.unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| outputs = model(mel_spec) | |
| _, predicted = torch.max(outputs, 1) | |
| return predicted.item() | |
| def main(): | |
| st.markdown('<div class="main">', unsafe_allow_html=True) | |
| st.title("Audio Deepfake Detector") | |
| st.write("Upload a **.wav** file to check if it's **Real** or **Fake**.") | |
| uploaded_file = st.file_uploader( | |
| "Choose a .wav file", | |
| type=["wav"], | |
| key=st.session_state.uploader_key | |
| ) | |
| if uploaded_file is not None: | |
| st.session_state.uploaded_file = uploaded_file | |
| if st.session_state.uploaded_file is not None: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: | |
| tmp_file.write(st.session_state.uploaded_file.read()) | |
| tmp_path = tmp_file.name | |
| st.audio(st.session_state.uploaded_file, format="audio/wav") | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| checkpoint_path = "best_model.pth" | |
| try: | |
| result = predict_audio_deepfake(tmp_path, checkpoint_path, device) | |
| label = "Real" if result == 1 else "Fake" | |
| st.success(f"Prediction: **{label}**") | |
| except Exception as e: | |
| st.error(f"Error during prediction: {e}") | |
| finally: | |
| os.remove(tmp_path) | |
| if st.button("Reset"): | |
| reset_state() | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| if __name__ == "__main__": | |
| main() | |