prashant-garg's picture
removing pyaudio
71edc1f
raw
history blame
4.76 kB
"""
Streamlit application for real-time gender detection from audio input.
Uses wav2vec2 model to analyze voice and predict speaker gender.
"""
import streamlit as st
import pyaudio
import numpy as np
import torch
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Define audio stream parameters
FORMAT = pyaudio.paInt16 # 16-bit resolution
CHANNELS = 1 # Mono audio
RATE = 16000 # 16kHz sampling rate
CHUNK = 1024 # Number of frames per buffer
@st.cache_resource
def load_model():
"""
Load the wav2vec2 model and feature extractor for gender recognition.
Returns:
tuple: A tuple containing the feature extractor and the model.
"""
model_path = "alefiury/wav2vec2-large-xlsr-53-gender-recognition-librispeech"
# model_path = "./local-model"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_path)
model = AutoModelForAudioClassification.from_pretrained(model_path)
model.eval()
logging.info("Model loaded successfully.")
return feature_extractor, model
st.title("Gender Detection")
# Initialize session state
if 'listening' not in st.session_state:
st.session_state['listening'] = False
if 'prediction' not in st.session_state:
st.session_state['prediction'] = ""
# Function to stop listening
def stop_listening():
"""Stop the audio stream and update session state to stop listening."""
if 'stream' in st.session_state:
logging.info("Stopping stream")
st.session_state['stream'].stop_stream()
st.session_state['stream'].close()
if 'audio' in st.session_state:
logging.info("Stopping audio")
st.session_state['audio'].terminate()
st.session_state['listening'] = False
st.session_state['prediction'] = "Stopped listening, click 'Start Listening' to start again."
st.rerun()
def start_listening():
"""Start the audio stream and continuously process audio for gender detection."""
placeholder = st.empty()
try:
placeholder.write("Loading model...")
feature_extractor, model = load_model()
audio = pyaudio.PyAudio()
stream = audio.open(format=FORMAT,
channels=CHANNELS,
rate=RATE,
input=True,
frames_per_buffer=CHUNK)
st.session_state['stream'] = stream
st.session_state['audio'] = audio
st.session_state['listening'] = True
st.session_state['prediction'] = "Listening........................"
placeholder.write("Listening for audio...")
while st.session_state['listening']:
audio_data = np.array([], dtype=np.float32)
for _ in range(int(RATE / CHUNK * 1.5)):
# Read audio chunk from the stream
data = stream.read(CHUNK, exception_on_overflow=False)
# Convert byte data to numpy array and normalize
chunk_data = np.frombuffer(data, dtype=np.int16).astype(np.float32) / 32768.0
audio_data = np.concatenate((audio_data, chunk_data))
# Check if there is significant sound
if np.max(np.abs(audio_data)) > 0.05: # Threshold for detecting sound
# Process the audio data
inputs = feature_extractor(audio_data, sampling_rate=RATE, return_tensors="pt", padding=True)
# Perform inference
with torch.no_grad():
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
# Map predicted IDs to labels
predicted_label = model.config.id2label[predicted_ids.item()]
if predicted_label != st.session_state['prediction']:
st.session_state['prediction'] = predicted_label
# st.write(f"Detected Gender: {predicted_label}")
placeholder.write(f"Detected Gender: {predicted_label}")
else:
st.session_state['prediction'] = "---- No significant sound detected, skipping prediction. ----"
placeholder.empty()
placeholder.empty()
except Exception as e:
logging.error(f"An error occurred: {e}")
st.error(f"An error occurred: {e}")
stop_listening()
# Buttons to start and stop listening
col1, col2 = st.columns(2)
with col1:
if st.button("Start Listening"):
start_listening()
with col2:
if st.button("Stop Listening"):
stop_listening()