Raemih commited on
Commit
e23e048
·
verified ·
1 Parent(s): aa9449c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -21
app.py CHANGED
@@ -1,36 +1,129 @@
1
- # app.py — Space 5
2
- # requirements.txt: transformers, torch, gradio, TTS, numpy, soundfile
3
-
4
  import gradio as gr
5
- from transformers import pipeline
6
- from TTS.api import TTS
7
  import numpy as np
 
 
 
 
 
 
8
 
9
- emotion_model = pipeline(
10
- "audio-classification",
11
- model="E-motionAssistant/mms-300m-multilingual-ser"
 
 
 
 
 
12
  )
13
 
14
- tts_english = TTS(model_name="E-motionAssistant/text-to-speech-VITS-english", progress_bar=False)
15
- tts_sinhala = TTS(model_name="E-motionAssistant/Text-to-speech-VITS-sinhala", progress_bar=False)
16
- tts_tamil = TTS(model_name="E-motionAssistant/text-to-speech-VITS-tamil", progress_bar=False)
 
 
 
 
 
 
 
 
 
 
17
 
18
- def transcribe(audio):
19
- return asr(audio)["text"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def speak_english(text): return (22050, np.array(tts_english.tts(text)))
22
- def speak_sinhala(text): return (22050, np.array(tts_sinhala.tts(text)))
23
- def speak_tamil(text): return (22050, np.array(tts_tamil.tts(text)))
24
 
25
  with gr.Blocks() as demo:
 
 
 
26
  gr.TabbedInterface(
27
  [
28
- gr.Interface(fn=transcribe, inputs=gr.Audio(type="filepath"), outputs=gr.Textbox(), title="ASR"),
29
- gr.Interface(fn=speak_english, inputs=gr.Textbox(), outputs=gr.Audio(), title="TTS English"),
30
- gr.Interface(fn=speak_sinhala, inputs=gr.Textbox(), outputs=gr.Audio(), title="TTS Sinhala"),
31
- gr.Interface(fn=speak_tamil, inputs=gr.Textbox(), outputs=gr.Audio(), title="TTS Tamil"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ],
33
- ["ASR (MMS)", "TTS English", "TTS Sinhala", "TTS Tamil"]
 
 
 
 
 
 
34
  )
35
 
 
36
  demo.launch()
 
 
 
 
1
  import gradio as gr
2
+ import torch
 
3
  import numpy as np
4
+ import librosa
5
+
6
+ from transformers import AutoFeatureExtractor
7
+ from TTS.api import TTS
8
+
9
+ from model import MMSForMultilingualSER
10
 
11
+ MODEL_ID = "E-motionAssistant/mms-300m-multilingual-ser"
12
+
13
+ # Load feature extractor + model
14
+ feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_ID)
15
+
16
+ emotion_model = MMSForMultilingualSER.from_pretrained(
17
+ MODEL_ID,
18
+ ignore_mismatched_sizes=True
19
  )
20
 
21
+ emotion_model.eval()
22
+
23
+
24
+ # Emotion labels (adjust if different)
25
+ emotion_labels = [
26
+ "anger",
27
+ "disgust",
28
+ "fear",
29
+ "happy",
30
+ "neutral",
31
+ "sad"
32
+ ]
33
+
34
 
35
+ def detect_emotion(audio_file):
36
+
37
+ speech, sr = librosa.load(audio_file, sr=16000)
38
+
39
+ inputs = feature_extractor(
40
+ speech,
41
+ sampling_rate=16000,
42
+ return_tensors="pt"
43
+ )
44
+
45
+ with torch.no_grad():
46
+ logits = emotion_model(**inputs)
47
+
48
+ pred = torch.argmax(logits, dim=-1).item()
49
+
50
+ return emotion_labels[pred]
51
+
52
+
53
+ # Load TTS models
54
+ tts_english = TTS(
55
+ model_name="E-motionAssistant/text-to-speech-VITS-english",
56
+ progress_bar=False
57
+ )
58
+
59
+ tts_sinhala = TTS(
60
+ model_name="E-motionAssistant/Text-to-speech-VITS-sinhala",
61
+ progress_bar=False
62
+ )
63
+
64
+ tts_tamil = TTS(
65
+ model_name="E-motionAssistant/text-to-speech-VITS-tamil",
66
+ progress_bar=False
67
+ )
68
+
69
+
70
+ def speak_english(text):
71
+ audio = tts_english.tts(text)
72
+ return (22050, np.array(audio))
73
+
74
+
75
+ def speak_sinhala(text):
76
+ audio = tts_sinhala.tts(text)
77
+ return (22050, np.array(audio))
78
+
79
+
80
+ def speak_tamil(text):
81
+ audio = tts_tamil.tts(text)
82
+ return (22050, np.array(audio))
83
 
 
 
 
84
 
85
  with gr.Blocks() as demo:
86
+
87
+ gr.Markdown("# Emotion Regulation Assistant")
88
+
89
  gr.TabbedInterface(
90
  [
91
+ gr.Interface(
92
+ fn=detect_emotion,
93
+ inputs=gr.Audio(type="filepath"),
94
+ outputs=gr.Textbox(),
95
+ title="Emotion Detection"
96
+ ),
97
+
98
+ gr.Interface(
99
+ fn=speak_english,
100
+ inputs=gr.Textbox(),
101
+ outputs=gr.Audio(),
102
+ title="TTS English"
103
+ ),
104
+
105
+ gr.Interface(
106
+ fn=speak_sinhala,
107
+ inputs=gr.Textbox(),
108
+ outputs=gr.Audio(),
109
+ title="TTS Sinhala"
110
+ ),
111
+
112
+ gr.Interface(
113
+ fn=speak_tamil,
114
+ inputs=gr.Textbox(),
115
+ outputs=gr.Audio(),
116
+ title="TTS Tamil"
117
+ )
118
  ],
119
+
120
+ [
121
+ "Emotion Detection",
122
+ "English TTS",
123
+ "Sinhala TTS",
124
+ "Tamil TTS"
125
+ ]
126
  )
127
 
128
+
129
  demo.launch()