LingoJr commited on
Commit
2de24ed
Β·
verified Β·
1 Parent(s): 9b7ce52

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +172 -0
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -------------------------------------------------------------
2
+ # app.py β€” Multimodal Emotion Recognition System
3
+ # Speech (Wav2Vec2) + Text (EmoBERTa)
4
+ # FastAPI + Gradio integrated into one application
5
+ # -------------------------------------------------------------
6
+
7
+ from fastapi import FastAPI, UploadFile, File, Form
8
+ from fastapi.responses import JSONResponse
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
11
+ import torchaudio
12
+ import io
13
+ import torch
14
+ import gradio as gr
15
+ from typing import Optional
16
+
17
+ # -------------------------------------------------------------
18
+ # 1️⃣ Initialize FastAPI
19
+ # -------------------------------------------------------------
20
+ app = FastAPI(
21
+ title="Multimodal Emotion Recognition API",
22
+ description="Detect emotions from Speech or Text using AI",
23
+ version="1.0.0"
24
+ )
25
+
26
+ # Allow any frontend to access the API
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"],
30
+ allow_credentials=True,
31
+ allow_methods=["*"],
32
+ allow_headers=["*"],
33
+ )
34
+
35
+ # -------------------------------------------------------------
36
+ # 2️⃣ Load Models (Global = Faster)
37
+ # -------------------------------------------------------------
38
+
39
+ # Speech Emotion Model
40
+ speech_classifier = pipeline(
41
+ "audio-classification",
42
+ model="superb/wav2vec2-base-superb-er"
43
+ )
44
+
45
+ # Text Emotion Model
46
+ text_tokenizer = AutoTokenizer.from_pretrained("tae898/emoberta-base")
47
+ text_model = AutoModelForSequenceClassification.from_pretrained("tae898/emoberta-base")
48
+
49
+ # -------------------------------------------------------------
50
+ # 3️⃣ FastAPI Endpoint β†’ Multimodal /predict
51
+ # -------------------------------------------------------------
52
+ @app.post("/predict")
53
+ async def predict(
54
+ file: Optional[UploadFile] = File(None),
55
+ text: Optional[str] = Form(None)
56
+ ):
57
+ """
58
+ Accepts:
59
+ - Audio file (wav/mp3)
60
+ - OR text
61
+ - OR both (audio takes priority)
62
+ """
63
+
64
+ # ----------------------------------------
65
+ # Case 1 β€” If audio is provided
66
+ # ----------------------------------------
67
+ if file is not None:
68
+ try:
69
+ audio_bytes = await file.read()
70
+ waveform, sr = torchaudio.load(io.BytesIO(audio_bytes))
71
+
72
+ preds = speech_classifier(
73
+ waveform.squeeze().numpy(),
74
+ sampling_rate=sr,
75
+ top_k=3
76
+ )
77
+
78
+ return {
79
+ "mode": "audio",
80
+ "filename": file.filename,
81
+ "emotion": preds[0]["label"],
82
+ "top_predictions": preds
83
+ }
84
+
85
+ except Exception as e:
86
+ return JSONResponse({"error": f"Audio error: {e}"}, status_code=500)
87
+
88
+ # ----------------------------------------
89
+ # Case 2 β€” If text is provided
90
+ # ----------------------------------------
91
+ if text is not None and text.strip() != "":
92
+ try:
93
+ inputs = text_tokenizer(text, return_tensors="pt", truncation=True)
94
+
95
+ with torch.no_grad():
96
+ outputs = text_model(**inputs)
97
+
98
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
99
+ label_id = torch.argmax(probs).item()
100
+ emotion = text_model.config.id2label[label_id]
101
+
102
+ return {
103
+ "mode": "text",
104
+ "text": text,
105
+ "emotion": emotion,
106
+ "probabilities": {
107
+ text_model.config.id2label[i]: float(round(p, 4))
108
+ for i, p in enumerate(probs[0].tolist())
109
+ }
110
+ }
111
+
112
+ except Exception as e:
113
+ return JSONResponse({"error": f"Text error: {e}"}, status_code=500)
114
+
115
+ # ----------------------------------------
116
+ # Case 3 β€” Nothing provided
117
+ # ----------------------------------------
118
+ return JSONResponse(
119
+ {"error": "Provide an audio file or text."},
120
+ status_code=400
121
+ )
122
+
123
+
124
+ # -------------------------------------------------------------
125
+ # 4️⃣ Gradio Interface (Single Tab: Audio + Text)
126
+ # -------------------------------------------------------------
127
+ def gradio_combined(audio_file, text):
128
+ # Case 1 β€” Audio provided
129
+ if audio_file is not None:
130
+ waveform, sr = torchaudio.load(audio_file)
131
+ preds = speech_classifier(waveform.squeeze().numpy(), sampling_rate=sr, top_k=3)
132
+
133
+ return {
134
+ "Detected Emotion": preds[0]["label"],
135
+ "Top Predictions": {p["label"]: round(p["score"], 3) for p in preds},
136
+ "Source": "Audio"
137
+ }
138
+
139
+ # Case 2 β€” Text provided
140
+ if text.strip() != "":
141
+ inputs = text_tokenizer(text, return_tensors="pt", truncation=True)
142
+ with torch.no_grad():
143
+ outputs = text_model(**inputs)
144
+
145
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
146
+ label_id = torch.argmax(probs).item()
147
+
148
+ return {
149
+ "Detected Emotion": text_model.config.id2label[label_id],
150
+ "Top Predictions": {
151
+ text_model.config.id2label[i]: round(p, 3)
152
+ for i, p in enumerate(probs[0].tolist())
153
+ },
154
+ "Source": "Text"
155
+ }
156
+
157
+ return {"Error": "Please provide audio or text input."}
158
+
159
+ # Building the UI
160
+ gradio_ui = gr.Interface(
161
+ fn=gradio_combined,
162
+ inputs=[
163
+ gr.Audio(label="🎀 Upload or Record Speech", sources=["microphone", "upload"], type="filepath"),
164
+ gr.Textbox(label="πŸ’¬ Enter Text Emotion", placeholder="Type something...")
165
+ ],
166
+ outputs="json",
167
+ title="🎭 Multimodal Emotion Recognizer",
168
+ description="Use either speech or text β€” the model detects the emotion automatically!"
169
+ )
170
+
171
+ # Mount Gradio at /gradio
172
+ app = gr.mount_gradio_app(app, gradio_ui, path="/gradio")