File size: 3,658 Bytes
162f472
3c0a049
 
 
8fed6be
3c0a049
 
8fed6be
 
 
 
3c0a049
 
 
 
 
 
 
 
 
8fed6be
3c0a049
 
8fed6be
3c0a049
 
 
 
 
 
 
 
 
 
 
 
 
 
8fed6be
 
 
3c0a049
 
 
 
 
 
 
 
 
 
 
 
 
 
8fed6be
3c0a049
8fed6be
 
 
 
 
 
 
 
3c0a049
 
8fed6be
 
3c0a049
 
 
 
 
 
 
 
 
 
 
 
162f472
8fed6be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import streamlit as st
import torch
import librosa
import numpy as np
import tempfile
from transformers import Wav2Vec2Processor
from huggingface_hub import hf_hub_download
from pydub import AudioSegment
from streamlit_mic_recorder import mic_recorder

from model import Wav2Vec2_LSTM_MultiTask

# -------------------------
# CONFIG
# -------------------------
MODEL_REPO = "ashutoshroy02/hybrid-wave2vec-LSTM-emotion-stress-RAVDESS"
MODEL_FILE = "model.pt"

st.set_page_config(page_title="Emotion & Stress Detection", layout="centered")
st.title("🎀 Emotion & Stress Detection")
st.write("Record live audio or upload any audio file")

# -------------------------
# LOAD MODEL (CACHED)
# -------------------------
@st.cache_resource
def load_model():
    model_path = hf_hub_download(
        repo_id=MODEL_REPO,
        filename=MODEL_FILE
    )

    checkpoint = torch.load(model_path, map_location="cpu")

    emotion2id = checkpoint["emotion2id"]
    id2emotion = {v: k for k, v in emotion2id.items()}
    num_emotions = checkpoint["num_emotions"]

    processor = Wav2Vec2Processor.from_pretrained(
        "facebook/wav2vec2-base"
    )

    model = Wav2Vec2_LSTM_MultiTask(num_emotions)
    model.load_state_dict(checkpoint["model_state"])
    model.eval()

    return model, processor, id2emotion


with st.spinner("Loading model..."):
    model, processor, id2emotion = load_model()

st.success("Model loaded successfully")

# -------------------------
# AUDIO UTILITIES
# -------------------------
def convert_to_wav(audio_bytes):
    """Convert any audio format to WAV (16kHz, mono)"""
    audio = AudioSegment.from_file(audio_bytes)
    audio = audio.set_channels(1).set_frame_rate(16000)

    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
    audio.export(tmp.name, format="wav")
    return tmp.name


def predict_from_audio(audio_path):
    audio, _ = librosa.load(audio_path, sr=16000)

    inputs = processor(
        audio,
        sampling_rate=16000,
        return_tensors="pt"
    ).input_values

    with torch.no_grad():
        emotion_logits, stress_pred = model(inputs)

    emotion = id2emotion[emotion_logits.argmax(dim=1).item()]
    stress = round(stress_pred.item(), 3)

    return emotion, stress

# -------------------------
# UI TABS
# -------------------------
tab1, tab2 = st.tabs(["πŸŽ™οΈ Live Record", "πŸ“ Upload Audio"])

# =========================
# πŸŽ™οΈ LIVE RECORD TAB
# =========================
with tab1:
    st.subheader("Record Live Audio")

    audio_data = mic_recorder(
        start_prompt="πŸŽ™οΈ Start Recording",
        stop_prompt="⏹️ Stop Recording",
        just_once=True,
        use_container_width=True
    )

    if audio_data:
        st.audio(audio_data["bytes"])

        with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
            f.write(audio_data["bytes"])
            wav_path = f.name

        emotion, stress = predict_from_audio(wav_path)

        st.subheader("🧠 Prediction")
        st.write(f"**Emotion:** {emotion}")
        st.write(f"**Stress Level:** {stress}")

# =========================
# πŸ“ UPLOAD FILE TAB
# =========================
with tab2:
    st.subheader("Upload Audio File")

    uploaded_file = st.file_uploader(
        "Upload audio (.wav, .mp3, .m4a, .flac)",
        type=["wav", "mp3", "m4a", "flac"]
    )

    if uploaded_file:
        st.audio(uploaded_file)

        wav_path = convert_to_wav(uploaded_file)

        emotion, stress = predict_from_audio(wav_path)

        st.subheader("🧠 Prediction")
        st.write(f"**Emotion:** {emotion}")
        st.write(f"**Stress Level:** {stress}")