File size: 3,658 Bytes
162f472 3c0a049 8fed6be 3c0a049 8fed6be 3c0a049 8fed6be 3c0a049 8fed6be 3c0a049 8fed6be 3c0a049 8fed6be 3c0a049 8fed6be 3c0a049 8fed6be 3c0a049 162f472 8fed6be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import streamlit as st
import torch
import librosa
import numpy as np
import tempfile
from transformers import Wav2Vec2Processor
from huggingface_hub import hf_hub_download
from pydub import AudioSegment
from streamlit_mic_recorder import mic_recorder
from model import Wav2Vec2_LSTM_MultiTask
# -------------------------
# CONFIG
# -------------------------
MODEL_REPO = "ashutoshroy02/hybrid-wave2vec-LSTM-emotion-stress-RAVDESS"
MODEL_FILE = "model.pt"
st.set_page_config(page_title="Emotion & Stress Detection", layout="centered")
st.title("π€ Emotion & Stress Detection")
st.write("Record live audio or upload any audio file")
# -------------------------
# LOAD MODEL (CACHED)
# -------------------------
@st.cache_resource
def load_model():
model_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE
)
checkpoint = torch.load(model_path, map_location="cpu")
emotion2id = checkpoint["emotion2id"]
id2emotion = {v: k for k, v in emotion2id.items()}
num_emotions = checkpoint["num_emotions"]
processor = Wav2Vec2Processor.from_pretrained(
"facebook/wav2vec2-base"
)
model = Wav2Vec2_LSTM_MultiTask(num_emotions)
model.load_state_dict(checkpoint["model_state"])
model.eval()
return model, processor, id2emotion
with st.spinner("Loading model..."):
model, processor, id2emotion = load_model()
st.success("Model loaded successfully")
# -------------------------
# AUDIO UTILITIES
# -------------------------
def convert_to_wav(audio_bytes):
"""Convert any audio format to WAV (16kHz, mono)"""
audio = AudioSegment.from_file(audio_bytes)
audio = audio.set_channels(1).set_frame_rate(16000)
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
audio.export(tmp.name, format="wav")
return tmp.name
def predict_from_audio(audio_path):
audio, _ = librosa.load(audio_path, sr=16000)
inputs = processor(
audio,
sampling_rate=16000,
return_tensors="pt"
).input_values
with torch.no_grad():
emotion_logits, stress_pred = model(inputs)
emotion = id2emotion[emotion_logits.argmax(dim=1).item()]
stress = round(stress_pred.item(), 3)
return emotion, stress
# -------------------------
# UI TABS
# -------------------------
tab1, tab2 = st.tabs(["ποΈ Live Record", "π Upload Audio"])
# =========================
# ποΈ LIVE RECORD TAB
# =========================
with tab1:
st.subheader("Record Live Audio")
audio_data = mic_recorder(
start_prompt="ποΈ Start Recording",
stop_prompt="βΉοΈ Stop Recording",
just_once=True,
use_container_width=True
)
if audio_data:
st.audio(audio_data["bytes"])
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
f.write(audio_data["bytes"])
wav_path = f.name
emotion, stress = predict_from_audio(wav_path)
st.subheader("π§ Prediction")
st.write(f"**Emotion:** {emotion}")
st.write(f"**Stress Level:** {stress}")
# =========================
# π UPLOAD FILE TAB
# =========================
with tab2:
st.subheader("Upload Audio File")
uploaded_file = st.file_uploader(
"Upload audio (.wav, .mp3, .m4a, .flac)",
type=["wav", "mp3", "m4a", "flac"]
)
if uploaded_file:
st.audio(uploaded_file)
wav_path = convert_to_wav(uploaded_file)
emotion, stress = predict_from_audio(wav_path)
st.subheader("π§ Prediction")
st.write(f"**Emotion:** {emotion}")
st.write(f"**Stress Level:** {stress}")
|