Momenta / app.py
aneeshm44's picture
Update app.py
6237077 verified
import streamlit as st
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import tempfile
import os
import uuid
if "uploaded_file" not in st.session_state:
st.session_state.uploaded_file = None
if "uploader_key" not in st.session_state:
st.session_state.uploader_key = str(uuid.uuid4())
def reset_state():
st.session_state.uploaded_file = None
st.session_state.uploader_key = str(uuid.uuid4())
st.rerun()
st.markdown(
"""
<style>
body { background-color: #f5f5f5; }
.main { background-color: white; padding: 2rem; border-radius: 10px; margin: 2rem auto; max-width: 800px; }
.stButton>button { background-color: #4CAF50; color: white; border: none; padding: 0.5rem 1rem; border-radius: 5px; }
</style>
""",
unsafe_allow_html=True,
)
def load_and_process_audio(file_path, target_length=3.0, sample_rate=16000):
waveform, sr = torchaudio.load(file_path)
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0, keepdim=True)
if sr != sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
waveform = resampler(waveform)
target_samples = int(target_length * sample_rate)
current_samples = waveform.shape[1]
if current_samples > target_samples:
start = (current_samples - target_samples) // 2
waveform = waveform[:, start : start + target_samples]
elif current_samples < target_samples:
padding = target_samples - current_samples
pad_left = padding // 2
pad_right = padding - pad_left
waveform = F.pad(waveform, (pad_left, pad_right))
return waveform
def extract_melspectrogram(waveform, sample_rate=16000, n_mels=80, n_fft=1024, hop_length=512):
mel_spec = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels
)(waveform)
mel_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec)
return mel_spec
class AudioDeepfakeMODEL(nn.Module):
def __init__(self, num_classes=2):
super(AudioDeepfakeMODEL, self).__init__()
self.resnet = models.resnet18(pretrained=True)
self.resnet.conv1 = nn.Conv2d(
in_channels=1,
out_channels=64,
kernel_size=(7, 7),
stride=(2, 2),
padding=(3, 3),
bias=False,
)
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
for param in self.resnet.parameters():
param.requires_grad = False
for param in self.resnet.fc.parameters():
param.requires_grad = True
def forward(self, x):
return self.resnet(x)
def predict_audio_deepfake(wav_path, checkpoint_path, device):
model = AudioDeepfakeMODEL()
checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval()
waveform = load_and_process_audio(wav_path)
mel_spec = extract_melspectrogram(waveform)
mel_spec = mel_spec.unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(mel_spec)
_, predicted = torch.max(outputs, 1)
return predicted.item()
def main():
st.markdown('<div class="main">', unsafe_allow_html=True)
st.title("Audio Deepfake Detector")
st.write("Upload a **.wav** file to check if it's **Real** or **Fake**.")
uploaded_file = st.file_uploader(
"Choose a .wav file",
type=["wav"],
key=st.session_state.uploader_key
)
if uploaded_file is not None:
st.session_state.uploaded_file = uploaded_file
if st.session_state.uploaded_file is not None:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
tmp_file.write(st.session_state.uploaded_file.read())
tmp_path = tmp_file.name
st.audio(st.session_state.uploaded_file, format="audio/wav")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_path = "best_model.pth"
try:
result = predict_audio_deepfake(tmp_path, checkpoint_path, device)
label = "Real" if result == 1 else "Fake"
st.success(f"Prediction: **{label}**")
except Exception as e:
st.error(f"Error during prediction: {e}")
finally:
os.remove(tmp_path)
if st.button("Reset"):
reset_state()
st.markdown("</div>", unsafe_allow_html=True)
if __name__ == "__main__":
main()