import streamlit as st import torch import torchaudio import torch.nn as nn import torch.nn.functional as F import torchvision.models as models import tempfile import os import uuid if "uploaded_file" not in st.session_state: st.session_state.uploaded_file = None if "uploader_key" not in st.session_state: st.session_state.uploader_key = str(uuid.uuid4()) def reset_state(): st.session_state.uploaded_file = None st.session_state.uploader_key = str(uuid.uuid4()) st.rerun() st.markdown( """ """, unsafe_allow_html=True, ) def load_and_process_audio(file_path, target_length=3.0, sample_rate=16000): waveform, sr = torchaudio.load(file_path) if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0, keepdim=True) if sr != sample_rate: resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate) waveform = resampler(waveform) target_samples = int(target_length * sample_rate) current_samples = waveform.shape[1] if current_samples > target_samples: start = (current_samples - target_samples) // 2 waveform = waveform[:, start : start + target_samples] elif current_samples < target_samples: padding = target_samples - current_samples pad_left = padding // 2 pad_right = padding - pad_left waveform = F.pad(waveform, (pad_left, pad_right)) return waveform def extract_melspectrogram(waveform, sample_rate=16000, n_mels=80, n_fft=1024, hop_length=512): mel_spec = torchaudio.transforms.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels )(waveform) mel_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec) return mel_spec class AudioDeepfakeMODEL(nn.Module): def __init__(self, num_classes=2): super(AudioDeepfakeMODEL, self).__init__() self.resnet = models.resnet18(pretrained=True) self.resnet.conv1 = nn.Conv2d( in_channels=1, out_channels=64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False, ) self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes) for param in self.resnet.parameters(): param.requires_grad = False for param in self.resnet.fc.parameters(): param.requires_grad = True def forward(self, x): return self.resnet(x) def predict_audio_deepfake(wav_path, checkpoint_path, device): model = AudioDeepfakeMODEL() checkpoint = torch.load(checkpoint_path, map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) model.to(device) model.eval() waveform = load_and_process_audio(wav_path) mel_spec = extract_melspectrogram(waveform) mel_spec = mel_spec.unsqueeze(0).to(device) with torch.no_grad(): outputs = model(mel_spec) _, predicted = torch.max(outputs, 1) return predicted.item() def main(): st.markdown('