ashutoshroy02 commited on
Commit
8fed6be
Β·
verified Β·
1 Parent(s): 0f74a15

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +75 -15
src/streamlit_app.py CHANGED
@@ -2,8 +2,13 @@ import streamlit as st
2
  import torch
3
  import librosa
4
  import numpy as np
 
5
  from transformers import Wav2Vec2Processor
6
  from huggingface_hub import hf_hub_download
 
 
 
 
7
 
8
  # -------------------------
9
  # CONFIG
@@ -13,10 +18,10 @@ MODEL_FILE = "model.pt"
13
 
14
  st.set_page_config(page_title="Emotion & Stress Detection", layout="centered")
15
  st.title("🎀 Emotion & Stress Detection")
16
- st.write("Upload or record audio to detect emotion and stress")
17
 
18
  # -------------------------
19
- # LOAD MODEL
20
  # -------------------------
21
  @st.cache_resource
22
  def load_model():
@@ -31,9 +36,10 @@ def load_model():
31
  id2emotion = {v: k for k, v in emotion2id.items()}
32
  num_emotions = checkpoint["num_emotions"]
33
 
34
- from model import Wav2Vec2_LSTM_MultiTask
 
 
35
 
36
- processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
37
  model = Wav2Vec2_LSTM_MultiTask(num_emotions)
38
  model.load_state_dict(checkpoint["model_state"])
39
  model.eval()
@@ -41,23 +47,26 @@ def load_model():
41
  return model, processor, id2emotion
42
 
43
 
44
- # -------------------------
45
- # LOAD MODEL ON START
46
- # -------------------------
47
  with st.spinner("Loading model..."):
48
  model, processor, id2emotion = load_model()
49
 
50
  st.success("Model loaded successfully")
51
 
52
  # -------------------------
53
- # AUDIO INPUT
54
  # -------------------------
55
- uploaded_file = st.file_uploader("Upload a WAV file", type=["wav"])
 
 
 
 
 
 
 
56
 
57
- if uploaded_file is not None:
58
- st.audio(uploaded_file)
59
 
60
- audio, _ = librosa.load(uploaded_file, sr=16000)
 
61
 
62
  inputs = processor(
63
  audio,
@@ -71,6 +80,57 @@ if uploaded_file is not None:
71
  emotion = id2emotion[emotion_logits.argmax(dim=1).item()]
72
  stress = round(stress_pred.item(), 3)
73
 
74
- st.subheader("🧠 Prediction")
75
- st.write(f"**Emotion:** {emotion}")
76
- st.write(f"**Stress Level:** {stress}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import librosa
4
  import numpy as np
5
+ import tempfile
6
  from transformers import Wav2Vec2Processor
7
  from huggingface_hub import hf_hub_download
8
+ from pydub import AudioSegment
9
+ from streamlit_mic_recorder import mic_recorder
10
+
11
+ from model import Wav2Vec2_LSTM_MultiTask
12
 
13
  # -------------------------
14
  # CONFIG
 
18
 
19
  st.set_page_config(page_title="Emotion & Stress Detection", layout="centered")
20
  st.title("🎀 Emotion & Stress Detection")
21
+ st.write("Record live audio or upload any audio file")
22
 
23
  # -------------------------
24
+ # LOAD MODEL (CACHED)
25
  # -------------------------
26
  @st.cache_resource
27
  def load_model():
 
36
  id2emotion = {v: k for k, v in emotion2id.items()}
37
  num_emotions = checkpoint["num_emotions"]
38
 
39
+ processor = Wav2Vec2Processor.from_pretrained(
40
+ "facebook/wav2vec2-base"
41
+ )
42
 
 
43
  model = Wav2Vec2_LSTM_MultiTask(num_emotions)
44
  model.load_state_dict(checkpoint["model_state"])
45
  model.eval()
 
47
  return model, processor, id2emotion
48
 
49
 
 
 
 
50
  with st.spinner("Loading model..."):
51
  model, processor, id2emotion = load_model()
52
 
53
  st.success("Model loaded successfully")
54
 
55
  # -------------------------
56
+ # AUDIO UTILITIES
57
  # -------------------------
58
+ def convert_to_wav(audio_bytes):
59
+ """Convert any audio format to WAV (16kHz, mono)"""
60
+ audio = AudioSegment.from_file(audio_bytes)
61
+ audio = audio.set_channels(1).set_frame_rate(16000)
62
+
63
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
64
+ audio.export(tmp.name, format="wav")
65
+ return tmp.name
66
 
 
 
67
 
68
+ def predict_from_audio(audio_path):
69
+ audio, _ = librosa.load(audio_path, sr=16000)
70
 
71
  inputs = processor(
72
  audio,
 
80
  emotion = id2emotion[emotion_logits.argmax(dim=1).item()]
81
  stress = round(stress_pred.item(), 3)
82
 
83
+ return emotion, stress
84
+
85
+ # -------------------------
86
+ # UI TABS
87
+ # -------------------------
88
+ tab1, tab2 = st.tabs(["πŸŽ™οΈ Live Record", "πŸ“ Upload Audio"])
89
+
90
+ # =========================
91
+ # πŸŽ™οΈ LIVE RECORD TAB
92
+ # =========================
93
+ with tab1:
94
+ st.subheader("Record Live Audio")
95
+
96
+ audio_data = mic_recorder(
97
+ start_prompt="πŸŽ™οΈ Start Recording",
98
+ stop_prompt="⏹️ Stop Recording",
99
+ just_once=True,
100
+ use_container_width=True
101
+ )
102
+
103
+ if audio_data:
104
+ st.audio(audio_data["bytes"])
105
+
106
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
107
+ f.write(audio_data["bytes"])
108
+ wav_path = f.name
109
+
110
+ emotion, stress = predict_from_audio(wav_path)
111
+
112
+ st.subheader("🧠 Prediction")
113
+ st.write(f"**Emotion:** {emotion}")
114
+ st.write(f"**Stress Level:** {stress}")
115
+
116
+ # =========================
117
+ # πŸ“ UPLOAD FILE TAB
118
+ # =========================
119
+ with tab2:
120
+ st.subheader("Upload Audio File")
121
+
122
+ uploaded_file = st.file_uploader(
123
+ "Upload audio (.wav, .mp3, .m4a, .flac)",
124
+ type=["wav", "mp3", "m4a", "flac"]
125
+ )
126
+
127
+ if uploaded_file:
128
+ st.audio(uploaded_file)
129
+
130
+ wav_path = convert_to_wav(uploaded_file)
131
+
132
+ emotion, stress = predict_from_audio(wav_path)
133
+
134
+ st.subheader("🧠 Prediction")
135
+ st.write(f"**Emotion:** {emotion}")
136
+ st.write(f"**Stress Level:** {stress}")