abedir commited on
Commit
3461076
·
verified ·
1 Parent(s): caf81b0

Upload 7 files

Browse files
Files changed (7) hide show
  1. app.py +35 -0
  2. audio_utils.py +24 -0
  3. best_clstm.pt +3 -0
  4. config.py +19 -0
  5. inference.py +36 -0
  6. model.py +81 -0
  7. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ import shutil
3
+ import uuid
4
+ import os
5
+ from inference import predict
6
+
7
+ app = FastAPI(title="Audio Emotion Recognition API")
8
+
9
+ UPLOAD_DIR = "/tmp"
10
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
11
+
12
+
13
+ @app.get("/")
14
+ def root():
15
+ return {
16
+ "message": "Audio Emotion Recognition API is running"
17
+ }
18
+
19
+
20
+ @app.get("/health")
21
+ def health():
22
+ return {"status": "ok"}
23
+
24
+
25
+ @app.post("/predict")
26
+ async def predict_emotion(file: UploadFile = File(...)):
27
+ file_path = f"{UPLOAD_DIR}/{uuid.uuid4()}.wav"
28
+
29
+ with open(file_path, "wb") as buffer:
30
+ shutil.copyfileobj(file.file, buffer)
31
+
32
+ result = predict(file_path)
33
+ os.remove(file_path)
34
+
35
+ return result
audio_utils.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ from config import CONFIG
5
+
6
+ def preprocess_audio(path, device):
7
+ y, _ = librosa.load(path, sr=CONFIG["sample_rate"])
8
+
9
+ max_len = int(CONFIG["sample_rate"] * CONFIG["duration"])
10
+ y = y[:max_len] if len(y) > max_len else np.pad(y, (0, max_len - len(y)))
11
+
12
+ mel = librosa.feature.melspectrogram(
13
+ y=y,
14
+ sr=CONFIG["sample_rate"],
15
+ n_fft=CONFIG["n_fft"],
16
+ hop_length=CONFIG["hop_length"],
17
+ n_mels=CONFIG["n_mels"]
18
+ )
19
+
20
+ mel_db = librosa.power_to_db(mel, ref=np.max)
21
+ mel_db = (mel_db - mel_db.mean()) / (mel_db.std() + 1e-9)
22
+
23
+ tensor = torch.from_numpy(mel_db).unsqueeze(0).unsqueeze(0)
24
+ return tensor.to(device)
best_clstm.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10dfec9f10188b7bfc8663e36f2baa9d52b6d7a2819ba9b05ac5172d49775b1f
3
+ size 16568874
config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ CONFIG = {
2
+ "model_path": "best_clstm.pt",
3
+ "sample_rate": 16000,
4
+ "duration": 3.0,
5
+ "n_mels": 40,
6
+ "n_fft": 512,
7
+ "hop_length": 256
8
+ }
9
+
10
+ EMOTION_CONFIG = {
11
+ "angry": "😠",
12
+ "calm": "😌",
13
+ "disgust": "🤢",
14
+ "fearful": "😨",
15
+ "happy": "😊",
16
+ "neutral": "😐",
17
+ "sad": "😢",
18
+ "surprised": "😲"
19
+ }
inference.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model import CLSTMModel
3
+ from config import CONFIG, EMOTION_CONFIG
4
+ from audio_utils import preprocess_audio
5
+
6
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
+
8
+ checkpoint = torch.load(CONFIG["model_path"], map_location=device)
9
+
10
+ if "label_map" in checkpoint:
11
+ inv = {v: k for k, v in checkpoint["label_map"].items()}
12
+ emotions = [inv[i] for i in range(len(inv))]
13
+ else:
14
+ emotions = list(EMOTION_CONFIG.keys())
15
+
16
+ model = CLSTMModel(
17
+ n_mels=CONFIG["n_mels"],
18
+ n_classes=len(emotions)
19
+ ).to(device)
20
+
21
+ model.load_state_dict(checkpoint["model_state_dict"])
22
+ model.eval()
23
+
24
+
25
+ def predict(path):
26
+ x = preprocess_audio(path, device)
27
+
28
+ with torch.no_grad():
29
+ logits = model(x)
30
+ probs = torch.softmax(logits, dim=1)
31
+ idx = torch.argmax(probs, dim=1).item()
32
+
33
+ return {
34
+ "emotion": emotions[idx],
35
+ "confidence": float(probs[0][idx])
36
+ }
model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class ConvBlock(nn.Module):
6
+ def __init__(self, in_ch, out_ch, kernel_size=(3,3), pool=(2,2)):
7
+ super().__init__()
8
+ self.net = nn.Sequential(
9
+ nn.Conv2d(in_ch, out_ch, kernel_size,
10
+ padding=(kernel_size[0]//2, kernel_size[1]//2)),
11
+ nn.BatchNorm2d(out_ch),
12
+ nn.ReLU(),
13
+ nn.MaxPool2d(pool),
14
+ nn.Dropout2d(0.2)
15
+ )
16
+
17
+ def forward(self, x):
18
+ return self.net(x)
19
+
20
+
21
+ class AttentionLayer(nn.Module):
22
+ def __init__(self, hidden_dim):
23
+ super().__init__()
24
+ self.attention = nn.Linear(hidden_dim, 1)
25
+
26
+ def forward(self, lstm_out):
27
+ weights = torch.softmax(self.attention(lstm_out), dim=1)
28
+ return torch.sum(weights * lstm_out, dim=1)
29
+
30
+
31
+ class CLSTMModel(nn.Module):
32
+ def __init__(
33
+ self,
34
+ n_mels=40,
35
+ n_classes=8,
36
+ conv_channels=[32, 64, 128],
37
+ lstm_hidden=128,
38
+ lstm_layers=2,
39
+ dropout=0.4
40
+ ):
41
+ super().__init__()
42
+
43
+ self.conv1 = ConvBlock(1, conv_channels[0])
44
+ self.conv2 = ConvBlock(conv_channels[0], conv_channels[1])
45
+ self.conv3 = ConvBlock(conv_channels[1], conv_channels[2])
46
+
47
+ freq_after = math.ceil(n_mels / (2 ** 3))
48
+ self.lstm_input = conv_channels[2] * freq_after
49
+
50
+ self.lstm = nn.LSTM(
51
+ self.lstm_input,
52
+ lstm_hidden,
53
+ num_layers=lstm_layers,
54
+ batch_first=True,
55
+ bidirectional=True,
56
+ dropout=dropout if lstm_layers > 1 else 0
57
+ )
58
+
59
+ self.attention = AttentionLayer(lstm_hidden * 2)
60
+
61
+ self.classifier = nn.Sequential(
62
+ nn.Linear(lstm_hidden * 2, 256),
63
+ nn.ReLU(),
64
+ nn.Dropout(dropout),
65
+ nn.Linear(256, 128),
66
+ nn.ReLU(),
67
+ nn.Dropout(dropout),
68
+ nn.Linear(128, n_classes)
69
+ )
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = self.conv2(x)
74
+ x = self.conv3(x)
75
+
76
+ b, c, f, t = x.size()
77
+ x = x.permute(0, 3, 1, 2).contiguous().view(b, t, c * f)
78
+
79
+ out, _ = self.lstm(x)
80
+ out = self.attention(out)
81
+ return self.classifier(out)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi
2
+ torch
3
+ librosa
4
+ numpy
5
+ python-multipart
6
+ soundfile