AlexTolstenko commited on
Commit
75aea43
·
1 Parent(s): bc1ecc3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np # linear algebra
2
+ import pandas as pd
3
+ import os
4
+ import librosa as lr
5
+ import torch
6
+ import torch.nn as nn
7
+ import pytorch_lightning as pl
8
+ import gradio
9
+ from model import MFCC_CNN
10
+
11
+ EMOTIONS = {
12
+ 'neutral_Male': 9,
13
+ 'happy_Male': 7,
14
+ 'sad_Male': 11,
15
+ 'angry_Male': 1,
16
+ 'fear_Male': 5,
17
+ 'disgust_Male': 3,
18
+ 'surprise_Male': 13,
19
+ 'neutral_Female': 8,
20
+ 'happy_Female': 6,
21
+ 'sad_Female': 10,
22
+ 'angry_Female': 0,
23
+ 'fear_Female': 4,
24
+ 'disgust_Female': 2,
25
+ 'surprise_Female': 12
26
+ }
27
+
28
+ # LOAD AUDIO
29
+ SAMPLE_RATE = 16000
30
+ DURATION = 3
31
+
32
+ # GET MFCC
33
+ N_MFCC = 50
34
+ WIN_LENGTH = 2048
35
+ WINDOW = 'hann'
36
+ HOP_LENGTH = 512
37
+
38
+ PATH = './chekpoint/models-epoch=97-val_loss=2.09.ckpt'
39
+ ckpt = torch.load(PATH)
40
+
41
+ pretrained_model = MFCC_CNN(14)
42
+ pretrained_model.load_state_dict(ckpt['state_dict'])
43
+ pretrained_model.eval()
44
+ pretrained_model.freeze()
45
+
46
+ def processAudio(audio_file):
47
+ audio, sr = lr.load(audio_file,
48
+ duration=DURATION,
49
+ sr=SAMPLE_RATE)
50
+
51
+ signal = np.zeros((int(SAMPLE_RATE*3,)))
52
+ signal[:len(audio)] = audio
53
+
54
+ feature_set = []
55
+
56
+ mfcc = lr.feature.mfcc(y=signal,
57
+ r=sr,
58
+ n_mfcc=N_MFCC,
59
+ win_length=WIN_LENGTH,
60
+ window=WINDOW,
61
+ hop_length=HOP_LENGTH,
62
+ )
63
+
64
+ feature_set = torch.tensor(mfcc, dtype=torch.float)
65
+
66
+ feature_set = feature_set.view(-1, 1, 50, 94)
67
+ prediction = pretrained_model(feature_set,)
68
+ prediction = torch.argmax(prediction)
69
+
70
+ return EMOTIONS[prediction.item()]
71
+
72
+ demo = grad.Interface(
73
+ fn=processAudio,
74
+ inputs=gr.Audio(),
75
+ outputs=gr.Lable(),
76
+ examples=[
77
+ [os.path.join(os.path.dirname(__file__), "files/03-01-01-01-02-02-01.wav")],
78
+ [os.path.join(os.path.dirname(__file__), "files/03-01-07-01-02-02-01.wav")],
79
+ [os.path.join(os.path.dirname(__file__), "files/03-01-08-02-02-02-01.wav")],
80
+ ],
81
+ )
82
+
83
+ if __name__ == '__main__':
84
+ demo.launch()