Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| import librosa | |
| import numpy as np | |
| import onnxruntime as ort | |
| import os | |
| import requests | |
| # Audio padding function | |
| def pad(x, max_len=64600): | |
| """ | |
| Pad or trim an audio segment to a fixed length by repeating or slicing. | |
| """ | |
| x_len = x.shape[0] | |
| if x_len >= max_len: | |
| return x[:max_len] # Trim if longer | |
| # Repeat to fill max_len | |
| num_repeats = (max_len // x_len) + 1 | |
| padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0] | |
| return padded_x | |
| # Preprocess audio for a single segment | |
| def preprocess_audio_segment(segment, cut=64600): | |
| """ | |
| Preprocess a single audio segment: pad or trim as required. | |
| """ | |
| segment = pad(segment, max_len=cut) | |
| return np.expand_dims(np.array(segment, dtype=np.float32), axis=0) # Add batch dimension | |
| # Download ONNX model from Hugging Face | |
| def download_model(url, local_path="RawNet_model.onnx"): | |
| """ | |
| Download the ONNX model from a URL if it doesn't already exist locally. | |
| """ | |
| if not os.path.exists(local_path): | |
| with st.spinner("Downloading ONNX model..."): | |
| response = requests.get(url) | |
| if response.status_code == 200: | |
| with open(local_path, "wb") as f: | |
| f.write(response.content) | |
| st.success("Model downloaded successfully!") | |
| else: | |
| raise Exception("Failed to download ONNX model") | |
| return local_path | |
| # Sliding window prediction function | |
| def predict_with_sliding_window(audio_path, onnx_model_path, window_size=64600, step_size=64600, sample_rate=16000): | |
| """ | |
| Use a sliding window to predict if the audio is real or fake over the entire audio. | |
| """ | |
| # Load ONNX runtime session | |
| ort_session = ort.InferenceSession(onnx_model_path) | |
| # Load audio file | |
| waveform, _ = librosa.load(audio_path, sr=sample_rate) | |
| total_segments = [] | |
| total_probabilities = [] | |
| # Sliding window processing | |
| for start in range(0, len(waveform), step_size): | |
| end = start + window_size | |
| segment = waveform[start:end] | |
| # Preprocess the segment | |
| audio_tensor = preprocess_audio_segment(segment) | |
| # Perform inference | |
| inputs = {ort_session.get_inputs()[0].name: audio_tensor} | |
| outputs = ort_session.run(None, inputs) | |
| probabilities = np.exp(outputs[0]) # Softmax probabilities | |
| prediction = np.argmax(probabilities) | |
| # Store the results | |
| predicted_class = "Real" if prediction == 1 else "Fake" | |
| total_segments.append(predicted_class) | |
| total_probabilities.append(probabilities[0][prediction]) | |
| # Final aggregation | |
| majority_class = max(set(total_segments), key=total_segments.count) # Majority voting | |
| avg_probability = np.mean(total_probabilities) * 100 # Average probability in percentage | |
| return majority_class, avg_probability | |
| # Streamlit app | |
| st.set_page_config(page_title="Audio Spoof Detection", page_icon="🎵", layout="centered") | |
| # Header Section | |
| st.markdown("<h1 style='text-align: center; color: blue;'>Audio Spoof Detection</h1>", unsafe_allow_html=True) | |
| st.markdown( | |
| """ | |
| <p style='text-align: center;'> | |
| Detect whether an uploaded audio file is <strong>Real</strong> or <strong>Fake</strong> using an ONNX model. | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| # Sidebar | |
| st.sidebar.header("Instructions") | |
| st.sidebar.write( | |
| """ | |
| - Upload an audio file in WAV or MP3 format. | |
| - Wait for the model to process the file. | |
| - View the prediction result and confidence score. | |
| """ | |
| ) | |
| st.sidebar.markdown("### About the Model") | |
| st.sidebar.info( | |
| """ | |
| The model is trained to classify audio as Real or Fake using a RawNet-based architecture. | |
| """ | |
| ) | |
| # File uploader | |
| uploaded_file = st.file_uploader("Upload your audio file (WAV or MP3)", type=["wav", "mp3"]) | |
| # ONNX model URL (replace with your actual Hugging Face model URL) | |
| onnx_model_url = "https://huggingface.co/Mrkomiljon/DeepVoiceGuard/resolve/main/RawNet_model.onnx" | |
| # Ensure ONNX model is downloaded locally | |
| onnx_model_path = download_model(onnx_model_url) | |
| if uploaded_file is not None: | |
| st.markdown("<h3 style='text-align: center;'>Processing Your File...</h3>", unsafe_allow_html=True) | |
| # Save uploaded file temporarily | |
| with open("temp_audio_file.wav", "wb") as f: | |
| f.write(uploaded_file.read()) | |
| # Perform prediction | |
| with st.spinner("Running the model..."): | |
| result, avg_probability = predict_with_sliding_window("temp_audio_file.wav", onnx_model_path) | |
| # Display results | |
| st.success(f"Prediction: {result}") | |
| st.metric(label="Confidence", value=f"{avg_probability:.2f}%", delta=None) | |
| # Clean up temporary file | |
| os.remove("temp_audio_file.wav") | |
| # Footer | |
| st.markdown( | |
| """ | |
| <hr> | |
| <p style='text-align: center; font-size: small;'> | |
| Created with ❤️ using Streamlit. | |
| </p> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |