sudo-paras-shah commited on
Commit
9825f94
·
1 Parent(s): 49bdc4b

Add all file

Browse files

Maybe keras will mess up

.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ **/__pycache__/
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
@@ -18,4 +18,4 @@ EXPOSE 8501
18
 
19
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
 
21
- ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
 
1
+ FROM python:3.12-slim
2
 
3
  WORKDIR /app
4
 
 
18
 
19
  HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
20
 
21
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
 
1
+ streamlit
2
+ streamlit-webrtc
3
+ numpy
4
+ tensorflow-gpu
5
+ keras
6
+ Pillow
7
+ opencv-python
src/classification.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+ from nets import get_model_from_name
7
+ from utils.utils import (cvtColor, get_classes, letterbox_image,
8
+ preprocess_input)
9
+
10
+ from huggingface_hub import hf_hub_download
11
+
12
+ #--------------------------------------------#
13
+ # 使用自己训练好的模型预测需要修改4个参数
14
+ # model_path和classes_path、backbone
15
+ # 和alpha都需要修改!
16
+ #--------------------------------------------#
17
+ class Classification(object):
18
+ _defaults = {
19
+ #--------------------------------------------------------------------------#
20
+ # 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
21
+ # model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
22
+ # 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
23
+ #--------------------------------------------------------------------------#
24
+ # "model_path" : 'model_data/mobilenet_2_5_224_tf_no_top.h5',
25
+ "model_path" : hf_hub_download(repo_id="username/model-name", filename="model.h5"),
26
+ "classes_path" : 'model_data/cls_classes.txt',
27
+ #--------------------------------------------------------------------#
28
+ # 输入的图片大小
29
+ #--------------------------------------------------------------------#
30
+ "input_shape" : [224, 224],
31
+ #--------------------------------------------------------------------#
32
+ # 所用模型种类:
33
+ # mobilenet、resnet50、vgg16是常用的分类网络
34
+ #--------------------------------------------------------------------#
35
+ "backbone" : 'vgg16',
36
+ #--------------------------------------------------------------------#
37
+ # 当使用mobilenet的alpha值
38
+ # 仅在backbone='mobilenet'的时候有效
39
+ #--------------------------------------------------------------------#
40
+ "alpha" : 0.25
41
+ }
42
+
43
+ @classmethod
44
+ def get_defaults(cls, n):
45
+ if n in cls._defaults:
46
+ return cls._defaults[n]
47
+ else:
48
+ return "Unrecognized attribute name '" + n + "'"
49
+
50
+ #---------------------------------------------------#
51
+ # 初始化classification
52
+ #---------------------------------------------------#
53
+ def __init__(self, **kwargs):
54
+ self.__dict__.update(self._defaults)
55
+ for name, value in kwargs.items():
56
+ setattr(self, name, value)
57
+
58
+ #---------------------------------------------------#
59
+ # 获得种类
60
+ #---------------------------------------------------#
61
+ self.class_names, self.num_classes = get_classes(self.classes_path)
62
+ self.generate()
63
+
64
+ #---------------------------------------------------#
65
+ # 载入模型
66
+ #---------------------------------------------------#
67
+ def generate(self):
68
+ model_path = os.path.expanduser(self.model_path)
69
+ assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
70
+
71
+ #---------------------------------------------------#
72
+ # 载入模型与权值
73
+ #---------------------------------------------------#
74
+ if self.backbone == "mobilenet":
75
+ self.model = get_model_from_name[self.backbone](input_shape = [self.input_shape[0], self.input_shape[1], 3], classes = self.num_classes, alpha = self.alpha)
76
+ else:
77
+ self.model = get_model_from_name[self.backbone](input_shape = [self.input_shape[0], self.input_shape[1], 3], classes = self.num_classes)
78
+ self.model.load_weights(self.model_path)
79
+ print('{} model, and classes {} loaded.'.format(model_path, self.class_names))
80
+
81
+ #---------------------------------------------------#
82
+ # 检测图片
83
+ #---------------------------------------------------#
84
+ def detect_image(self, image):
85
+ #---------------------------------------------------------#
86
+ # 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
87
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
88
+ #---------------------------------------------------------#
89
+ image = cvtColor(image)
90
+ # 查看数据类型
91
+ # print(type(image))
92
+ #---------------------------------------------------#
93
+ # 对图片进行不失真的resize
94
+ #---------------------------------------------------#
95
+ image_data = letterbox_image(image, [self.input_shape[1], self.input_shape[0]])
96
+ #---------------------------------------------------------#
97
+ # 归一化+添加上batch_size维度
98
+ #---------------------------------------------------------#
99
+ image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0)
100
+
101
+ #---------------------------------------------------#
102
+ # 图片传入网络进行预测
103
+ #---------------------------------------------------#
104
+ preds = self.model.predict(image_data)[0]
105
+ #---------------------------------------------------#
106
+ # 获得所属种类
107
+ #---------------------------------------------------#
108
+ class_name = self.class_names[np.argmax(preds)]
109
+ probability = np.max(preds)
110
+
111
+ #---------------------------------------------------#
112
+ # 绘图并写字
113
+ #---------------------------------------------------#
114
+
115
+ # plt.subplot(1, 1, 1)
116
+ # plt.imshow(np.array(image))
117
+ # plt.title('Class:%s Probability:%.3f' %(class_name, probability))
118
+ # plt.show()
119
+
120
+ return class_name, probability
src/model_data/cls_classes.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ disgust
2
+ fear
3
+ happiness
4
+ others
5
+ repression
6
+ sadness
7
+ surprise
src/model_data/haarcascade_frontalface_alt.xml ADDED
The diff for this file is too large to render. See raw diff
 
src/model_data/mobilenet_2_5_224_tf_no_top.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dbdb03ee2a22fd895301636cd328b234bb3a9952358f436d82f46b81e0d5b0bf
3
+ size 2108140
src/model_data/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfe5187d0a272bed55ba430631598124cff8e880b98d38c9e56c8d66032abdc1
3
+ size 58889256
src/nets/Loss.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from keras import backend as K
3
+ import tensorflow as tf
4
+ import numpy as np
5
+
6
+ def multi_category_focal_loss2(gamma=2., alpha=1):
7
+ """
8
+ focal loss for multi category of multi label problem
9
+ 适用于多分类或多标签问题的focal loss
10
+ alpha控制真值y_true为1/0时的权重
11
+ 1的权重为alpha, 0的权重为1-alpha
12
+ 当你的模型欠拟合,学习存在困难时,可以尝试适用本函数作为loss
13
+ 当模型过于激进(无论何时总是倾向于预测出1),尝试将alpha调小
14
+ 当模型过于惰性(无论何时总是倾向于预测出0,或是某一个固定的常数,说明没有学到有效特征)
15
+ 尝试将alpha调大,鼓励模型进行预测出1。
16
+ Usage:
17
+ model.compile(loss=[multi_category_focal_loss2(alpha=0.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
18
+ """
19
+ epsilon = 1.e-7
20
+ gamma = float(gamma)
21
+ alpha = tf.constant(alpha, dtype=tf.float32)
22
+
23
+ def multi_category_focal_loss2_fixed(y_true, y_pred):
24
+ y_true = tf.cast(y_true, tf.float32)
25
+ y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
26
+
27
+ alpha_t = y_true * alpha + (tf.ones_like(y_true) - y_true) * (1 - alpha)
28
+ y_t = tf.multiply(y_true, y_pred) + tf.multiply(1 - y_true, 1 - y_pred)
29
+ ce = -tf.log(y_t)
30
+ weight = tf.pow(tf.subtract(1., y_t), gamma)
31
+ fl = tf.multiply(tf.multiply(weight, ce), alpha_t)
32
+ loss = tf.reduce_mean(fl)
33
+ return loss
34
+
35
+ return multi_category_focal_loss2_fixed
36
+
37
+ def multi_category_focal_loss1(alpha, gamma=2.0):
38
+ """
39
+ focal loss for multi category of multi label problem
40
+ 适用于多分类或多标签问题的focal loss
41
+ alpha用于指定不同类别/标签的权重,数组大小需要与类别个数一致
42
+ 当你的数据集不同类别/标签之间存在偏斜,可以尝试适用本函数作为loss
43
+ Usage:
44
+ model.compile(loss=[multi_category_focal_loss1(alpha=[1,2,3,2], gamma=2)], metrics=["accuracy"], optimizer=adam)
45
+ """
46
+ epsilon = 1.e-7
47
+ alpha = tf.constant(alpha, dtype=tf.float32)
48
+ #alpha = tf.constant([[1],[1],[1],[1],[1]], dtype=tf.float32)
49
+ #alpha = tf.constant_initializer(alpha)
50
+ gamma = float(gamma)
51
+ def multi_category_focal_loss1_fixed(y_true, y_pred):
52
+ y_true = tf.cast(y_true, tf.float32)
53
+ y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
54
+ y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
55
+ ce = -tf.log(y_t)
56
+ weight = tf.pow(tf.subtract(1., y_t), gamma)
57
+ fl = tf.matmul(tf.multiply(weight, ce), alpha)
58
+ loss = tf.reduce_mean(fl)
59
+ return loss
60
+ return multi_category_focal_loss1_fixed
61
+
62
+
63
+ def Cross_entropy_loss(y_true, y_pred):
64
+ '''
65
+ :param y_true: ont-hot encoding ,shape is [batch_size,nums_classes]
66
+ :param y_pred: shape is [batch_size,nums_classes],each example defined as probability for per class
67
+ :return:shape is [batch_size,], a list include cross_entropy for per example
68
+ '''
69
+ y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
70
+ crossEntropyLoss = -y_true * tf.log(y_pred)
71
+
72
+ return tf.reduce_sum(crossEntropyLoss, -1)
73
+
74
+ # focal loss with multi label
75
+ def focal_loss(classes_num, gamma=2., alpha=.25, e=0.1):
76
+ # classes_num contains sample number of each classes
77
+ def focal_loss_fixed(target_tensor, prediction_tensor):
78
+ '''
79
+ prediction_tensor is the output tensor with shape [None, 100], where 100 is the number of classes
80
+ target_tensor is the label tensor, same shape as predcition_tensor
81
+ '''
82
+ import tensorflow as tf
83
+ from tensorflow.python.ops import array_ops
84
+ from keras import backend as K
85
+
86
+ #1# get focal loss with no balanced weight which presented in paper function (4)
87
+ zeros = array_ops.zeros_like(prediction_tensor, dtype=prediction_tensor.dtype)
88
+ one_minus_p = array_ops.where(tf.greater(target_tensor,zeros), target_tensor - prediction_tensor, zeros)
89
+ FT = -1 * (one_minus_p ** gamma) * tf.log(tf.clip_by_value(prediction_tensor, 1e-8, 1.0))
90
+
91
+ #2# get balanced weight alpha
92
+ classes_weight = array_ops.zeros_like(prediction_tensor, dtype=prediction_tensor.dtype)
93
+
94
+ total_num = float(sum(classes_num))
95
+ classes_w_t1 = [ total_num / ff for ff in classes_num ]
96
+ sum_ = sum(classes_w_t1)
97
+ classes_w_t2 = [ ff/sum_ for ff in classes_w_t1 ] #scale
98
+ classes_w_tensor = tf.convert_to_tensor(classes_w_t2, dtype=prediction_tensor.dtype)
99
+ classes_weight += classes_w_tensor
100
+
101
+ alpha = array_ops.where(tf.greater(target_tensor, zeros), classes_weight, zeros)
102
+
103
+ #3# get balanced focal loss
104
+ balanced_fl = alpha * FT
105
+ balanced_fl = tf.reduce_mean(balanced_fl)
106
+
107
+ #4# add other op to prevent overfit
108
+ # reference : https://spaces.ac.cn/archives/4493
109
+ nb_classes = len(classes_num)
110
+ fianal_loss = (1-e) * balanced_fl + e * K.categorical_crossentropy(K.ones_like(prediction_tensor)/nb_classes, prediction_tensor)
111
+
112
+ return fianal_loss
113
+ return focal_loss_fixed
src/nets/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mobilenet import MobileNet
2
+ from .resnet50 import ResNet50
3
+ from .vgg16 import VGG16
4
+
5
+ get_model_from_name = {
6
+ "mobilenet" : MobileNet,
7
+ "resnet50" : ResNet50,
8
+ "vgg16" : VGG16,
9
+ }
10
+
11
+ freeze_layers = {
12
+ "mobilenet" : 81,
13
+ "resnet50" : 173,
14
+ "vgg16" : 19,
15
+ "cspdarknet53" : 60,
16
+ }
src/nets/mobilenet.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras import backend as K
2
+ from keras.layers import (Activation, BatchNormalization, Conv2D,
3
+ DepthwiseConv2D, Dropout, GlobalAveragePooling2D,
4
+ Input, Reshape)
5
+ from keras.models import Model
6
+
7
+
8
+ def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
9
+ filters = int(filters * alpha)
10
+ x = Conv2D(filters, kernel,
11
+ padding='same',
12
+ use_bias=False,
13
+ strides=strides,
14
+ name='conv1')(inputs)
15
+ x = BatchNormalization(name='conv1_bn')(x)
16
+ return Activation(relu6, name='conv1_relu')(x)
17
+
18
+
19
+ def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha,
20
+ depth_multiplier=1, strides=(1, 1), block_id=1):
21
+
22
+ pointwise_conv_filters = int(pointwise_conv_filters * alpha)
23
+
24
+ x = DepthwiseConv2D((3, 3),
25
+ padding='same',
26
+ depth_multiplier=depth_multiplier,
27
+ strides=strides,
28
+ use_bias=False,
29
+ name='conv_dw_%d' % block_id)(inputs)
30
+
31
+ x = BatchNormalization(name='conv_dw_%d_bn' % block_id)(x)
32
+ x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
33
+
34
+ x = Conv2D(pointwise_conv_filters, (1, 1),
35
+ padding='same',
36
+ use_bias=False,
37
+ strides=(1, 1),
38
+ name='conv_pw_%d' % block_id)(x)
39
+ x = BatchNormalization(name='conv_pw_%d_bn' % block_id)(x)
40
+ return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x)
41
+
42
+ def MobileNet(input_shape=None,
43
+ alpha=1.0,
44
+ depth_multiplier=1,
45
+ dropout=1e-3,
46
+ classes=1000):
47
+
48
+ img_input = Input(shape=input_shape)
49
+
50
+ # 224,224,3 -> 112,112,32
51
+ x = _conv_block(img_input, 32, alpha, strides=(2, 2))
52
+
53
+ # 112,112,32 -> 112,112,64
54
+ x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1)
55
+
56
+
57
+ # 112,112,64 -> 56,56,128
58
+ x = _depthwise_conv_block(x, 128, alpha, depth_multiplier,
59
+ strides=(2, 2), block_id=2)
60
+ x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3)
61
+
62
+
63
+ # 56,56,128 -> 28,28,256
64
+ x = _depthwise_conv_block(x, 256, alpha, depth_multiplier,
65
+ strides=(2, 2), block_id=4)
66
+ x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5)
67
+
68
+
69
+ # 28,28,256 -> 14,14,512
70
+ x = _depthwise_conv_block(x, 512, alpha, depth_multiplier,
71
+ strides=(2, 2), block_id=6)
72
+ x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7)
73
+ x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8)
74
+ x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9)
75
+ x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10)
76
+ x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11)
77
+
78
+ # 14,14,512 -> 7,7,1024
79
+ x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier,
80
+ strides=(2, 2), block_id=12)
81
+ x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13)
82
+
83
+ # 7,7,1024 -> 1,1,1024
84
+ x = GlobalAveragePooling2D()(x)
85
+
86
+ shape = (1, 1, int(1024 * alpha))
87
+
88
+ x = Reshape(shape, name='reshape_1')(x)
89
+ x = Dropout(dropout, name='dropout')(x)
90
+
91
+ x = Conv2D(classes, (1, 1),padding='same', name='conv_preds')(x)
92
+ x = Activation('softmax', name='act_softmax')(x)
93
+ x = Reshape((classes,), name='reshape_2')(x)
94
+
95
+ inputs = img_input
96
+
97
+ model = Model(inputs, x, name='mobilenet_%0.2f' % (alpha))
98
+ return model
99
+
100
+ def relu6(x):
101
+ return K.relu(x, max_value=6)
102
+
103
+ if __name__ == '__main__':
104
+ model = MobileNet(input_shape=(224, 224, 3))
105
+ model.summary()
src/nets/resnet50.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras import layers
2
+ from keras.layers import (Activation, AveragePooling2D, BatchNormalization,
3
+ Conv2D, Dense, Flatten, Input, MaxPooling2D,
4
+ ZeroPadding2D)
5
+ from keras.models import Model
6
+
7
+
8
+ def identity_block(input_tensor, kernel_size, filters, stage, block):
9
+
10
+ filters1, filters2, filters3 = filters
11
+
12
+ conv_name_base = 'res' + str(stage) + block + '_branch'
13
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
14
+
15
+ # 减少通道数
16
+ x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor)
17
+ x = BatchNormalization(name=bn_name_base + '2a')(x)
18
+ x = Activation('relu')(x)
19
+
20
+ # 3x3卷积
21
+ x = Conv2D(filters2, kernel_size,padding='same', name=conv_name_base + '2b')(x)
22
+ x = BatchNormalization(name=bn_name_base + '2b')(x)
23
+ x = Activation('relu')(x)
24
+
25
+ # 上升通道数
26
+ x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
27
+ x = BatchNormalization(name=bn_name_base + '2c')(x)
28
+
29
+ x = layers.add([x, input_tensor])
30
+ x = Activation('relu')(x)
31
+ return x
32
+
33
+
34
+ def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
35
+ filters1, filters2, filters3 = filters
36
+
37
+ conv_name_base = 'res' + str(stage) + block + '_branch'
38
+ bn_name_base = 'bn' + str(stage) + block + '_branch'
39
+
40
+ # 减少通道数
41
+ x = Conv2D(filters1, (1, 1), strides=strides, name=conv_name_base + '2a')(input_tensor)
42
+ x = BatchNormalization(name=bn_name_base + '2a')(x)
43
+ x = Activation('relu')(x)
44
+
45
+ # 3x3卷积
46
+ x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x)
47
+ x = BatchNormalization(name=bn_name_base + '2b')(x)
48
+ x = Activation('relu')(x)
49
+
50
+ # 上升通道数
51
+ x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
52
+ x = BatchNormalization(name=bn_name_base + '2c')(x)
53
+
54
+ # 残差边
55
+ shortcut = Conv2D(filters3, (1, 1), strides=strides,
56
+ name=conv_name_base + '1')(input_tensor)
57
+ shortcut = BatchNormalization(name=bn_name_base + '1')(shortcut)
58
+
59
+ x = layers.add([x, shortcut])
60
+ x = Activation('relu')(x)
61
+ return x
62
+
63
+
64
+ def ResNet50(input_shape=[224,224,3], classes=1000):
65
+ img_input = Input(shape=input_shape)
66
+
67
+ x = ZeroPadding2D((3, 3))(img_input)
68
+ # 224,224,3 -> 112,112,64
69
+ x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
70
+ x = BatchNormalization(name='bn_conv1')(x)
71
+ x = Activation('relu')(x)
72
+
73
+ x = ZeroPadding2D((1, 1))(x)
74
+ # 112,112,64 -> 56,56,64
75
+ x = MaxPooling2D((3, 3), strides=(2, 2))(x)
76
+
77
+ # 56,56,64 -> 56,56,256
78
+ x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
79
+ x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
80
+ x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
81
+
82
+ # 56,56,256 -> 28,28,512
83
+ x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
84
+ x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
85
+ x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
86
+ x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
87
+
88
+ # 28,28,512 -> 14,14,1024
89
+ x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
90
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
91
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
92
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
93
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
94
+ x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
95
+
96
+ # 14,14,1024 -> 7,7,2048
97
+ x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
98
+ x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
99
+ x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
100
+
101
+ # 1,1,2048
102
+ x = AveragePooling2D((7, 7), name='avg_pool')(x)
103
+
104
+ # 进行预测
105
+ # 2048
106
+ x = Flatten()(x)
107
+
108
+ # num_classes
109
+ x = Dense(classes, activation='softmax', name='fc1000')(x)
110
+
111
+ model = Model(img_input, x, name='resnet50')
112
+
113
+ return model
114
+
115
+
116
+ if __name__ == '__main__':
117
+ model = ResNet50()
118
+ model.summary()
src/nets/vgg16.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.layers import Conv2D, Dense, Flatten, Input, MaxPooling2D
2
+ from keras.models import Model #导入包Conv2D是卷积核 Flatten是展开 Input输入 MaxPooling2D最大卷积核
3
+
4
+
5
+ def VGG16(input_shape=None, classes=1000): #def 就是开始定义VGG16的网络
6
+ img_input = Input(shape=input_shape) # 224, 224, 3
7
+
8
+ # Block 1
9
+ # 224, 224, 3 -> 224, 224, 64
10
+ x = Conv2D(64, (3, 3), #开始第一个卷积核进行特征提取,64为卷积核的个数;(3, 3)是卷积核的大小
11
+ activation='relu', #relu激活函数
12
+ padding='same', #padding='same' 尺寸不变
13
+ name='block1_conv1')(img_input) #dui juanjihe jinxing mingming # x= 224*224*64
14
+ x = Conv2D(64, (3, 3),
15
+ activation='relu',
16
+ padding='same',
17
+ name='block1_conv2')(x) #224*224*64
18
+
19
+ x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) #112*122*64
20
+
21
+ # Block 2
22
+
23
+ x = Conv2D(128, (3, 3),
24
+ activation='relu',
25
+ padding='same',
26
+ name='block2_conv1')(x) #112*112*128
27
+ x = Conv2D(128, (3, 3),
28
+ activation='relu',
29
+ padding='same',
30
+ name='block2_conv2')(x)
31
+
32
+ x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
33
+
34
+ # Block 3
35
+
36
+ x = Conv2D(256, (3, 3),
37
+ activation='relu',
38
+ padding='same',
39
+ name='block3_conv1')(x)
40
+ x = Conv2D(256, (3, 3),
41
+ activation='relu',
42
+ padding='same',
43
+ name='block3_conv2')(x)
44
+ x = Conv2D(256, (3, 3),
45
+ activation='relu',
46
+ padding='same',
47
+ name='block3_conv3')(x)
48
+
49
+ x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
50
+
51
+ # Block 4
52
+
53
+ x = Conv2D(512, (3, 3),
54
+ activation='relu',
55
+ padding='same',
56
+ name='block4_conv1')(x)
57
+ x = Conv2D(512, (3, 3),
58
+ activation='relu',
59
+ padding='same',
60
+ name='block4_conv2')(x)
61
+ x = Conv2D(512, (3, 3),
62
+ activation='relu',
63
+ padding='same',
64
+ name='block4_conv3')(x)
65
+
66
+
67
+ x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
68
+
69
+ # Block 5
70
+
71
+ x = Conv2D(512, (3, 3),
72
+ activation='relu',
73
+ padding='same',
74
+ name='block5_conv1')(x)
75
+ x = Conv2D(512, (3, 3),
76
+ activation='relu',
77
+ padding='same',
78
+ name='block5_conv2')(x)
79
+ x = Conv2D(512, (3, 3),
80
+ activation='relu',
81
+ padding='same',
82
+ name='block5_conv3')(x)
83
+ # 14, 14, 512 -> 7, 7, 512
84
+ x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
85
+
86
+ x = Flatten(name='flatten')(x)
87
+ x = Dense(4096, activation='relu', name='fc1')(x)
88
+ x = Dense(4096, activation='relu', name='fc2')(x)
89
+ x = Dense(classes, activation='softmax', name='predictions')(x) #激活函数
90
+
91
+ inputs = img_input
92
+
93
+ model = Model(inputs, x, name='vgg16')
94
+ return model
95
+
96
+ if __name__ == '__main__':
97
+ model = VGG16(input_shape=(224, 224, 3))
98
+ model.summary()
src/streamlit_app.py CHANGED
@@ -1,40 +1,173 @@
1
- import altair as alt
 
 
 
2
  import numpy as np
3
- import pandas as pd
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+ import subprocess
3
+
4
+ import cv2
5
  import numpy as np
6
+ from PIL import Image
7
+
8
  import streamlit as st
9
+ from streamlit_webrtc import VideoProcessorBase, webrtc_streamer
10
+
11
+ from classification import Classification
12
+
13
+ classificator = Classification()
14
+ face_cascade = cv2.CascadeClassifier(
15
+ os.path.join('model_data', 'haarcascade_frontalface_alt.xml')
16
+ )
17
+
18
+ # Streamlit Title
19
+ st.title("Real-Time Micro-Emotion Recognition")
20
+
21
+ # Only Live Emotion Detection Mode
22
+ st.write("Turn on your camera and detect emotions in real-time.")
23
+
24
+ # Camera selection UI
25
+ st.sidebar.header("Camera Settings")
26
+ def get_connected_cameras():
27
+ try:
28
+ result = subprocess.run(
29
+ ['v4l2-ctl', '--list-devices'],
30
+ capture_output=True,
31
+ text=True,
32
+ check=True)
33
+ devices = result.stdout.split('\n\n')
34
+ camera_indices = []
35
+ for device in devices:
36
+ if "Camera" in device or "camera" in device:
37
+ lines = device.split('\n')
38
+ if len(lines) > 1:
39
+ index_line = lines[1]
40
+ index_str = index_line.strip().split(':')[0].strip()
41
+ try:
42
+ index = int(index_str[4:])
43
+ camera_indices.append(index)
44
+ except (ValueError, IndexError):
45
+ pass
46
+ return camera_indices
47
+ except FileNotFoundError:
48
+ return [0] # Fallback to default camera if v4l2-ctl is not available
49
+ except subprocess.CalledProcessError:
50
+ return [0]
51
+
52
+ available_cameras = get_connected_cameras()
53
+
54
+ if len(available_cameras) > 1:
55
+ camera_index = st.sidebar.selectbox(
56
+ "Select Camera Index",
57
+ options=available_cameras,
58
+ index=0,
59
+ format_func=lambda x: f"Camera {x}"
60
+ )
61
+ else:
62
+ camera_index = 0
63
+ st.sidebar.write("Only one camera detected. Using default camera.")
64
+
65
+ # --- Face detection and augmentation functions ---
66
+ def face_detect(img):
67
+ img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
68
+ faces = face_cascade.detectMultiScale(
69
+ img_gray,
70
+ scaleFactor=1.1,
71
+ minNeighbors=1,
72
+ minSize=(30, 30)
73
+ )
74
+ return img, img_gray, faces
75
+
76
+ # --- Emotion class mapping ---
77
+ def map_emotion_to_class(emotion):
78
+ positive = ['happiness', 'happy']
79
+ negative = ['disgust', 'sadness', 'fear', 'sad', 'angry', 'disgusted']
80
+ surprise = ['surprise']
81
+ others = ['repression', 'tense', 'neutral', 'others']
82
+ e = emotion.lower()
83
+ if any(p in e for p in positive):
84
+ return 'Positive'
85
+ elif any(n in e for n in negative):
86
+ return 'Negative'
87
+ elif any(s in e for s in surprise):
88
+ return 'Surprise'
89
+ else:
90
+ return 'Others'
91
+
92
+ # --- Streamlit session state for emotion tracking ---
93
+ if 'emotion_history' not in st.session_state:
94
+ st.session_state['emotion_history'] = []
95
+
96
+ # Video Processing Class
97
+ class EmotionRecognitionProcessor(VideoProcessorBase):
98
+ def __init__(self):
99
+ self.last_class = None
100
+ self.rapid_change_count = 0
101
+
102
+ def recv(self, frame):
103
+ border_color = (255, 0, 0) # Rectangle color (blue in BGR)
104
+ font_color = (0, 0, 255) # Text color (red in BGR)
105
+ img = frame.to_ndarray(format="bgr24")
106
+ img_disp, img_gray, faces = face_detect(img)
107
+ current_class = None
108
+
109
+ if len(faces) == 0:
110
+ cv2.putText(
111
+ img_disp, 'No Face Detect.', (2, 20),
112
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1
113
+ )
114
+
115
+ for (x, y, w, h) in faces:
116
+ x1, y1 = max(x - 10, 0), max(y - 10, 0)
117
+ x2 = min(x + w + 10, img_disp.shape[1])
118
+ y2 = min(y + h + 10, img_disp.shape[0])
119
+
120
+ face_img_gray = img_gray[y1:y2, x1:x2]
121
+ if face_img_gray.size == 0:
122
+ continue
123
+ face_img_pil = Image.fromarray(face_img_gray)
124
+ emotion, probability = classificator.detect_image(face_img_pil)
125
+ emotion_class = map_emotion_to_class(emotion)
126
+
127
+ cv2.rectangle(
128
+ img_disp,
129
+ (x1, y1),
130
+ (x2, y2),
131
+ border_color,
132
+ thickness=2
133
+ )
134
+ cv2.putText(
135
+ img_disp, emotion, (x + 30, y - 30),
136
+ cv2.FONT_HERSHEY_SIMPLEX, 1, font_color, 1
137
+ )
138
+ # Show probability
139
+ cv2.putText(
140
+ img_disp, str(round(probability, 3)), (x + 30, y - 50),
141
+ cv2.FONT_HERSHEY_SIMPLEX, 0.3, font_color, 1
142
+ )
143
+ current_class = emotion_class
144
+
145
+ # Track emotion class changes
146
+ if current_class:
147
+ history = st.session_state['emotion_history']
148
+ history.append(current_class)
149
+ if len(history) > 10:
150
+ history.pop(0)
151
+ # Detect rapid changes
152
+ if len(history) >= 3 and len(set(history[-3:])) > 1:
153
+ self.rapid_change_count += 1
154
+ else:
155
+ self.rapid_change_count = 0
156
+
157
+ return frame.from_ndarray(img_disp, format="bgr24")
158
+
159
+ webrtc_streamer(
160
+ key="emotion-detection",
161
+ video_processor_factory=EmotionRecognitionProcessor,
162
+ )
163
 
164
+ # --- Streamlit alert for rapid emotion changes ---
165
+ history = st.session_state['emotion_history']
166
+ if len(history) >= 3 and len(set(history[-3:])) > 1:
167
+ st.warning(
168
+ "⚠️ Rapid changes in your detected emotional state were observed. "
169
+ "Micro-expressions may not always reflect your true feelings. "
170
+ "If you feel emotionally unstable or distressed, " \
171
+ "consider reaching out to a mental health professional, "
172
+ "talking it over with a close person or taking a break."
173
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
src/utils/backend/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .tensorflow_backend import * # noqa: F401,F403
src/utils/backend/tensorflow_backend.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow
2
+
3
+
4
+ def disable_tensorflow_v2_behavior():
5
+ """ See https://www.tensorflow.org/api_docs/python/tf/compat/v1/disable_tensorflow_v2_behavior .
6
+ """
7
+ tensorflow.compat.v1.disable_v2_behavior()
8
+
9
+
10
+ def ones(*args, **kwargs):
11
+ """ See https://www.tensorflow.org/api_docs/python/tf/ones .
12
+ """
13
+ return tensorflow.ones(*args, **kwargs)
14
+
15
+
16
+ def transpose(*args, **kwargs):
17
+ """ See https://www.tensorflow.org/api_docs/python/tf/transpose .
18
+ """
19
+ return tensorflow.transpose(*args, **kwargs)
20
+
21
+
22
+ def map_fn(*args, **kwargs):
23
+ """ See https://www.tensorflow.org/api_docs/python/tf/map_fn .
24
+ """
25
+ return tensorflow.map_fn(*args, **kwargs)
26
+
27
+
28
+ def pad(*args, **kwargs):
29
+ """ See https://www.tensorflow.org/api_docs/python/tf/pad .
30
+ """
31
+ return tensorflow.pad(*args, **kwargs)
32
+
33
+
34
+ def top_k(*args, **kwargs):
35
+ """ See https://www.tensorflow.org/api_docs/python/tf/nn/top_k .
36
+ """
37
+ return tensorflow.nn.top_k(*args, **kwargs)
38
+
39
+
40
+ def clip_by_value(*args, **kwargs):
41
+ """ See https://www.tensorflow.org/api_docs/python/tf/clip_by_value .
42
+ """
43
+ return tensorflow.clip_by_value(*args, **kwargs)
44
+
45
+
46
+ def resize_images(images, size, method='bilinear', align_corners=False):
47
+ """ See https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/image/resize_images .
48
+
49
+ Args
50
+ method: The method used for interpolation. One of ('bilinear', 'nearest', 'bicubic', 'area').
51
+ """
52
+ methods = {
53
+ 'bilinear': tensorflow.image.ResizeMethod.BILINEAR,
54
+ 'nearest' : tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR,
55
+ 'bicubic' : tensorflow.image.ResizeMethod.BICUBIC,
56
+ 'area' : tensorflow.image.ResizeMethod.AREA,
57
+ }
58
+ return tensorflow.compat.v1.image.resize_images(images, size, tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR, False)
59
+
60
+
61
+ def non_max_suppression(*args, **kwargs):
62
+ """ See https://www.tensorflow.org/api_docs/python/tf/image/non_max_suppression .
63
+ """
64
+ return tensorflow.image.non_max_suppression(*args, **kwargs)
65
+
66
+
67
+ def range(*args, **kwargs):
68
+ """ See https://www.tensorflow.org/api_docs/python/tf/range .
69
+ """
70
+ return tensorflow.range(*args, **kwargs)
71
+
72
+
73
+ def scatter_nd(*args, **kwargs):
74
+ """ See https://www.tensorflow.org/api_docs/python/tf/scatter_nd .
75
+ """
76
+ return tensorflow.scatter_nd(*args, **kwargs)
77
+
78
+
79
+ def gather_nd(*args, **kwargs):
80
+ """ See https://www.tensorflow.org/api_docs/python/tf/gather_nd .
81
+ """
82
+ return tensorflow.gather_nd(*args, **kwargs)
83
+
84
+
85
+ def meshgrid(*args, **kwargs):
86
+ """ See https://www.tensorflow.org/api_docs/python/tf/meshgrid .
87
+ """
88
+ return tensorflow.meshgrid(*args, **kwargs)
89
+
90
+
91
+ def where(*args, **kwargs):
92
+ """ See https://www.tensorflow.org/api_docs/python/tf/where .
93
+ """
94
+ return tensorflow.where(*args, **kwargs)
95
+
96
+
97
+ def unstack(*args, **kwargs):
98
+ """ See https://www.tensorflow.org/api_docs/python/tf/unstack .
99
+ """
100
+ return tensorflow.unstack(*args, **kwargs)
src/utils/callbacks.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import keras
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ from matplotlib import pyplot as plt
7
+ import scipy.signal
8
+ import tensorflow as tf
9
+
10
+
11
+ class LossHistory(keras.callbacks.Callback):
12
+ def __init__(self, log_dir):
13
+ import datetime
14
+ curr_time = datetime.datetime.now()
15
+ time_str = datetime.datetime.strftime(curr_time,'%Y_%m_%d_%H_%M_%S')
16
+ self.log_dir = log_dir
17
+ self.time_str = time_str
18
+ self.save_path = os.path.join(self.log_dir, "loss_" + str(self.time_str))
19
+ self.losses = []
20
+ self.val_loss = []
21
+
22
+ os.makedirs(self.save_path)
23
+
24
+ def on_epoch_end(self, batch, logs={}):
25
+ self.losses.append(logs.get('loss'))
26
+ self.val_loss.append(logs.get('val_loss'))
27
+ with open(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".txt"), 'a') as f:
28
+ f.write(str(logs.get('loss')))
29
+ f.write("\n")
30
+ with open(os.path.join(self.save_path, "epoch_val_loss_" + str(self.time_str) + ".txt"), 'a') as f:
31
+ f.write(str(logs.get('val_loss')))
32
+ f.write("\n")
33
+ self.loss_plot()
34
+
35
+ def loss_plot(self):
36
+ iters = range(len(self.losses))
37
+
38
+ plt.figure()
39
+ plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
40
+ plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
41
+ try:
42
+ if len(self.losses) < 25:
43
+ num = 5
44
+ else:
45
+ num = 15
46
+
47
+ plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
48
+ plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
49
+ except:
50
+ pass
51
+
52
+ plt.grid(True)
53
+ plt.xlabel('Epoch')
54
+ plt.ylabel('Loss')
55
+ plt.title('A Loss Curve')
56
+ plt.legend(loc="upper right")
57
+
58
+ plt.savefig(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".png"))
59
+
60
+ plt.cla()
61
+ plt.close("all")
62
+
63
+ class ExponentDecayScheduler(keras.callbacks.Callback):
64
+ def __init__(self,
65
+ decay_rate,
66
+ verbose=0):
67
+ super(ExponentDecayScheduler, self).__init__()
68
+ self.decay_rate = decay_rate
69
+ self.verbose = verbose
70
+ self.learning_rates = []
71
+
72
+ def on_epoch_end(self, batch, logs=None):
73
+ lr = self.model.optimizer.learning_rate
74
+ try:
75
+ current_lr = tf.keras.backend.get_value(lr)
76
+ except Exception:
77
+ current_lr = lr
78
+
79
+ new_lr = current_lr * self.decay_rate
80
+ try:
81
+ tf.keras.backend.set_value(lr, new_lr)
82
+ except Exception:
83
+ print("Warning: Could not set learning rate dynamically.")
84
+
85
+ if self.verbose > 0:
86
+ print('Setting learning rate to %s.' % (new_lr))
src/utils/dataloader.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from random import shuffle
3
+
4
+ import cv2
5
+ import keras
6
+ import numpy as np
7
+ from keras.utils import to_categorical
8
+ from PIL import Image
9
+
10
+ from .utils import cvtColor, preprocess_input
11
+
12
+
13
+ class ClsDatasets(keras.utils.Sequence):
14
+ def __init__(self, annotation_lines, input_shape, batch_size, num_classes, train, **kwargs):
15
+ super().__init__()
16
+ self.annotation_lines = annotation_lines
17
+ self.length = len(self.annotation_lines)
18
+
19
+ self.input_shape = input_shape
20
+ self.batch_size = batch_size
21
+ self.num_classes = num_classes
22
+ self.train = train
23
+
24
+ def __len__(self):
25
+ return int(math.ceil(self.length / float(self.batch_size)))
26
+
27
+ def __getitem__(self, index):
28
+ X_train = []
29
+ Y_train = []
30
+ start = index * self.batch_size
31
+ end = min((index + 1) * self.batch_size, self.length)
32
+ for i in range(start, end):
33
+ annotation_path = self.annotation_lines[i].split(';')[1].split()[0]
34
+ image = Image.open(annotation_path)
35
+ image = self.get_random_data(image, self.input_shape, random=self.train)
36
+ image = preprocess_input(np.array(image).astype(np.float32))
37
+
38
+ X_train.append(image)
39
+ Y_train.append(int(self.annotation_lines[i].split(';')[0]))
40
+
41
+ X_train = np.array(X_train)
42
+ Y_train = to_categorical(np.array(Y_train), num_classes = self.num_classes)
43
+ return X_train, Y_train
44
+
45
+ def on_epoch_end(self):
46
+ if self.train:
47
+ np.random.shuffle(self.annotation_lines)
48
+
49
+ def rand(self, a=0, b=1):
50
+ return np.random.rand()*(b-a) + a
51
+
52
+ def get_random_data(self, image, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
53
+ #------------------------------#
54
+ # 读取图像并转换成RGB图像
55
+ #------------------------------#
56
+ image = cvtColor(image)
57
+ #------------------------------#
58
+ # 获得图像的高宽与目标高宽
59
+ #------------------------------#
60
+ iw, ih = image.size
61
+ h, w = input_shape
62
+
63
+ if not random:
64
+ scale = min(w/iw, h/ih)
65
+ nw = int(iw*scale)
66
+ nh = int(ih*scale)
67
+ dx = (w-nw)//2
68
+ dy = (h-nh)//2
69
+
70
+ #---------------------------------#
71
+ # 将图像多余的部分加上灰条
72
+ #---------------------------------#
73
+ image = image.resize((nw,nh), Image.BICUBIC)
74
+ new_image = Image.new('RGB', (w,h), (128,128,128))
75
+ new_image.paste(image, (dx, dy))
76
+ image_data = np.array(new_image, np.float32)
77
+
78
+ return image_data
79
+
80
+ #------------------------------------------#
81
+ # 对图像进行缩放并且进行长和宽的扭曲
82
+ #------------------------------------------#
83
+ new_ar = w/h * self.rand(1-jitter,1+jitter)/self.rand(1-jitter,1+jitter)
84
+ scale = self.rand(.75, 1.25)
85
+ if new_ar < 1:
86
+ nh = int(scale*h)
87
+ nw = int(nh*new_ar)
88
+ else:
89
+ nw = int(scale*w)
90
+ nh = int(nw/new_ar)
91
+ image = image.resize((nw,nh), Image.BICUBIC)
92
+
93
+ #------------------------------------------#
94
+ # 将图像多余的部分加上灰条
95
+ #------------------------------------------#
96
+ dx = int(self.rand(0, w-nw))
97
+ dy = int(self.rand(0, h-nh))
98
+ new_image = Image.new('RGB', (w,h), (128,128,128))
99
+ new_image.paste(image, (dx, dy))
100
+ image = new_image
101
+
102
+ #------------------------------------------#
103
+ # 翻转图像
104
+ #------------------------------------------#
105
+ flip = self.rand()<.5
106
+ if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
107
+
108
+ rotate = self.rand()<.5
109
+ if rotate:
110
+ angle = np.random.randint(-15,15)
111
+ a,b = w/2,h/2
112
+ M = cv2.getRotationMatrix2D((a,b),angle,1)
113
+ image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128,128,128])
114
+
115
+ #------------------------------------------#
116
+ # 色域扭曲
117
+ #------------------------------------------#
118
+ hue = self.rand(-hue, hue)
119
+ sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
120
+ val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
121
+ x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV)
122
+ x[..., 1] *= sat
123
+ x[..., 2] *= val
124
+ x[x[:,:, 0]>360, 0] = 360
125
+ x[:, :, 1:][x[:, :, 1:]>1] = 1
126
+ x[x<0] = 0
127
+ image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
128
+ return image_data
src/utils/utils.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+
4
+ #---------------------------------------------------#
5
+ # 不失真的resize
6
+ #---------------------------------------------------#
7
+ def letterbox_image(image, size):
8
+ iw, ih = image.size
9
+ w, h = size
10
+
11
+ scale = min(w/iw, h/ih)
12
+ nw = int(iw*scale)
13
+ nh = int(ih*scale)
14
+
15
+ image = image.resize((nw,nh), Image.BICUBIC)
16
+ new_image = Image.new('RGB', size, (128,128,128))
17
+ new_image.paste(image, ((w-nw)//2, (h-nh)//2))
18
+
19
+ return new_image
20
+
21
+ #---------------------------------------------------#
22
+ # 获得类
23
+ #---------------------------------------------------#
24
+ def get_classes(classes_path):
25
+ with open(classes_path, encoding='utf-8') as f:
26
+ class_names = f.readlines()
27
+ class_names = [c.strip() for c in class_names]
28
+ return class_names, len(class_names)
29
+
30
+ #---------------------------------------------------------#
31
+ # 将图像转换成RGB图像,防止灰度图在预测时报错。
32
+ # 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
33
+ #---------------------------------------------------------#
34
+ def cvtColor(image):
35
+ if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
36
+ return image
37
+ else:
38
+ image = image.convert('RGB')
39
+ return image
40
+
41
+ #----------------------------------------#
42
+ # 预处理训练图片
43
+ #----------------------------------------#
44
+ def preprocess_input(x):
45
+ x /= 127.5
46
+ x -= 1.
47
+ return x