LingoJr commited on
Commit
2dad960
Β·
verified Β·
1 Parent(s): a834bb6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -16
app.py CHANGED
@@ -7,28 +7,52 @@ speech_classifier = pipeline("audio-classification", model="superb/wav2vec2-base
7
  text_tokenizer = AutoTokenizer.from_pretrained("tae898/emoberta-base")
8
  text_model = AutoModelForSequenceClassification.from_pretrained("tae898/emoberta-base")
9
 
10
- def predict_emotion(audio, text):
11
- results = {}
12
-
13
- if audio is not None:
14
- waveform, sr = torchaudio.load(audio)
15
  preds = speech_classifier(waveform.squeeze().numpy(), sampling_rate=sr, top_k=3)
16
- results["audio_emotion"] = preds[0]["label"]
17
 
18
- if text is not None and text.strip() != "":
19
- inputs = text_tokenizer(text, return_tensors="pt")
 
 
 
 
 
 
 
20
  with torch.no_grad():
21
  outputs = text_model(**inputs)
22
- emotion = text_model.config.id2label[torch.argmax(outputs.logits)]
23
- results["text_emotion"] = emotion
24
 
25
- return results
 
 
 
 
 
 
 
 
 
 
26
 
27
- ui = gr.Interface(
28
- fn=predict_emotion,
29
- inputs=[gr.Audio(type="filepath"), gr.Textbox()],
 
 
 
 
 
 
30
  outputs="json",
31
- title="Multimodal Emotion Recognition"
 
32
  )
33
 
34
- ui.launch()
 
 
 
 
 
7
  text_tokenizer = AutoTokenizer.from_pretrained("tae898/emoberta-base")
8
  text_model = AutoModelForSequenceClassification.from_pretrained("tae898/emoberta-base")
9
 
10
+ def gradio_combined(audio_file, text):
11
+ # Case 1 β€” Audio provided
12
+ if audio_file is not None:
13
+ waveform, sr = torchaudio.load(audio_file)
 
14
  preds = speech_classifier(waveform.squeeze().numpy(), sampling_rate=sr, top_k=3)
 
15
 
16
+ return {
17
+ "Detected Emotion": preds[0]["label"],
18
+ "Top Predictions": {p["label"]: round(p["score"], 3) for p in preds},
19
+ "Source": "Audio"
20
+ }
21
+
22
+ # Case 2 β€” Text provided
23
+ if text.strip() != "":
24
+ inputs = text_tokenizer(text, return_tensors="pt", truncation=True)
25
  with torch.no_grad():
26
  outputs = text_model(**inputs)
 
 
27
 
28
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
29
+ label_id = torch.argmax(probs).item()
30
+
31
+ return {
32
+ "Detected Emotion": text_model.config.id2label[label_id],
33
+ "Top Predictions": {
34
+ text_model.config.id2label[i]: round(p, 3)
35
+ for i, p in enumerate(probs[0].tolist())
36
+ },
37
+ "Source": "Text"
38
+ }
39
 
40
+ return {"Error": "Please provide audio or text input."}
41
+
42
+ # Building the UI
43
+ gradio_ui = gr.Interface(
44
+ fn=gradio_combined,
45
+ inputs=[
46
+ gr.Audio(label="🎀 Upload or Record Speech", sources=["microphone", "upload"], type="filepath"),
47
+ gr.Textbox(label="πŸ’¬ Enter Text Emotion", placeholder="Type something...")
48
+ ],
49
  outputs="json",
50
+ title="🎭 Multimodal Emotion Recognizer",
51
+ description="Use either speech or text β€” the model detects the emotion automatically!"
52
  )
53
 
54
+ # Mount Gradio at /gradio
55
+ app = gr.mount_gradio_app(app, gradio_ui, path="/gradio")
56
+
57
+
58
+ gradio_ui.launch(share=True)