import streamlit as st import torch import torchaudio import torch.nn as nn import torch.nn.functional as F import torchvision.models as models import tempfile import os import uuid if "uploaded_file" not in st.session_state: st.session_state.uploaded_file = None if "uploader_key" not in st.session_state: st.session_state.uploader_key = str(uuid.uuid4()) def reset_state(): st.session_state.uploaded_file = None st.session_state.uploader_key = str(uuid.uuid4()) st.rerun() st.markdown( """ """, unsafe_allow_html=True, ) def load_and_process_audio(file_path, target_length=3.0, sample_rate=16000): waveform, sr = torchaudio.load(file_path) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if sr != sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) waveform = resampler(waveform) target_samples = int(target_length * sample_rate) current_samples = waveform.shape[1] if current_samples > target_samples: start = (current_samples - target_samples) // 2 waveform = waveform[:, start : start + target_samples] elif current_samples < target_samples: padding = target_samples - current_samples pad_left = padding // 2 pad_right = padding - pad_left waveform = F.pad(waveform, (pad_left, pad_right)) return waveform def extract_melspectrogram(waveform, sample_rate=16000, n_mels=80, n_fft=1024, hop_length=512): mel_spec = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels )(waveform) mel_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec) return mel_spec class AudioDeepfakeMODEL(nn.Module): def __init__(self, num_classes=2): super(AudioDeepfakeMODEL, self).__init__() self.resnet = models.resnet18(pretrained=True) self.resnet.conv1 = nn.Conv2d( in_channels=1, out_channels=64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False, ) self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes) for param in self.resnet.parameters(): param.requires_grad = False for param in self.resnet.fc.parameters(): param.requires_grad = True def forward(self, x): return self.resnet(x) def predict_audio_deepfake(wav_path, checkpoint_path, device): model = AudioDeepfakeMODEL() checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) model.to(device) model.eval() waveform = load_and_process_audio(wav_path) mel_spec = extract_melspectrogram(waveform) mel_spec = mel_spec.unsqueeze(0).to(device) with torch.no_grad(): outputs = model(mel_spec) _, predicted = torch.max(outputs, 1) return predicted.item() def main(): st.markdown('
', unsafe_allow_html=True) st.title("Audio Deepfake Detector") st.write("Upload a **.wav** file to check if it's **Real** or **Fake**.") uploaded_file = st.file_uploader( "Choose a .wav file", type=["wav"], key=st.session_state.uploader_key ) if uploaded_file is not None: st.session_state.uploaded_file = uploaded_file if st.session_state.uploaded_file is not None: with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file: tmp_file.write(st.session_state.uploaded_file.read()) tmp_path = tmp_file.name st.audio(st.session_state.uploaded_file, format="audio/wav") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint_path = "best_model.pth" try: result = predict_audio_deepfake(tmp_path, checkpoint_path, device) label = "Real" if result == 1 else "Fake" st.success(f"Prediction: **{label}**") except Exception as e: st.error(f"Error during prediction: {e}") finally: os.remove(tmp_path) if st.button("Reset"): reset_state() st.markdown("
", unsafe_allow_html=True) if __name__ == "__main__": main()