Commit ·
9825f94
1
Parent(s): 49bdc4b
Add all file
Browse filesMaybe keras will mess up
- .gitignore +1 -0
- Dockerfile +2 -2
- requirements.txt +7 -3
- src/classification.py +120 -0
- src/model_data/cls_classes.txt +7 -0
- src/model_data/haarcascade_frontalface_alt.xml +0 -0
- src/model_data/mobilenet_2_5_224_tf_no_top.h5 +3 -0
- src/model_data/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5 +3 -0
- src/nets/Loss.py +113 -0
- src/nets/__init__.py +16 -0
- src/nets/mobilenet.py +105 -0
- src/nets/resnet50.py +118 -0
- src/nets/vgg16.py +98 -0
- src/streamlit_app.py +170 -37
- src/utils/__init__.py +1 -0
- src/utils/backend/__init__.py +1 -0
- src/utils/backend/tensorflow_backend.py +100 -0
- src/utils/callbacks.py +86 -0
- src/utils/dataloader.py +128 -0
- src/utils/utils.py +47 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
**/__pycache__/
|
Dockerfile
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
FROM python:3.
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
|
@@ -18,4 +18,4 @@ EXPOSE 8501
|
|
| 18 |
|
| 19 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 20 |
|
| 21 |
-
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
|
|
|
| 18 |
|
| 19 |
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 20 |
|
| 21 |
+
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
requirements.txt
CHANGED
|
@@ -1,3 +1,7 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
streamlit
|
| 2 |
+
streamlit-webrtc
|
| 3 |
+
numpy
|
| 4 |
+
tensorflow-gpu
|
| 5 |
+
keras
|
| 6 |
+
Pillow
|
| 7 |
+
opencv-python
|
src/classification.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from nets import get_model_from_name
|
| 7 |
+
from utils.utils import (cvtColor, get_classes, letterbox_image,
|
| 8 |
+
preprocess_input)
|
| 9 |
+
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
|
| 12 |
+
#--------------------------------------------#
|
| 13 |
+
# 使用自己训练好的模型预测需要修改4个参数
|
| 14 |
+
# model_path和classes_path、backbone
|
| 15 |
+
# 和alpha都需要修改!
|
| 16 |
+
#--------------------------------------------#
|
| 17 |
+
class Classification(object):
|
| 18 |
+
_defaults = {
|
| 19 |
+
#--------------------------------------------------------------------------#
|
| 20 |
+
# 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
|
| 21 |
+
# model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
|
| 22 |
+
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
|
| 23 |
+
#--------------------------------------------------------------------------#
|
| 24 |
+
# "model_path" : 'model_data/mobilenet_2_5_224_tf_no_top.h5',
|
| 25 |
+
"model_path" : hf_hub_download(repo_id="username/model-name", filename="model.h5"),
|
| 26 |
+
"classes_path" : 'model_data/cls_classes.txt',
|
| 27 |
+
#--------------------------------------------------------------------#
|
| 28 |
+
# 输入的图片大小
|
| 29 |
+
#--------------------------------------------------------------------#
|
| 30 |
+
"input_shape" : [224, 224],
|
| 31 |
+
#--------------------------------------------------------------------#
|
| 32 |
+
# 所用模型种类:
|
| 33 |
+
# mobilenet、resnet50、vgg16是常用的分类网络
|
| 34 |
+
#--------------------------------------------------------------------#
|
| 35 |
+
"backbone" : 'vgg16',
|
| 36 |
+
#--------------------------------------------------------------------#
|
| 37 |
+
# 当使用mobilenet的alpha值
|
| 38 |
+
# 仅在backbone='mobilenet'的时候有效
|
| 39 |
+
#--------------------------------------------------------------------#
|
| 40 |
+
"alpha" : 0.25
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
@classmethod
|
| 44 |
+
def get_defaults(cls, n):
|
| 45 |
+
if n in cls._defaults:
|
| 46 |
+
return cls._defaults[n]
|
| 47 |
+
else:
|
| 48 |
+
return "Unrecognized attribute name '" + n + "'"
|
| 49 |
+
|
| 50 |
+
#---------------------------------------------------#
|
| 51 |
+
# 初始化classification
|
| 52 |
+
#---------------------------------------------------#
|
| 53 |
+
def __init__(self, **kwargs):
|
| 54 |
+
self.__dict__.update(self._defaults)
|
| 55 |
+
for name, value in kwargs.items():
|
| 56 |
+
setattr(self, name, value)
|
| 57 |
+
|
| 58 |
+
#---------------------------------------------------#
|
| 59 |
+
# 获得种类
|
| 60 |
+
#---------------------------------------------------#
|
| 61 |
+
self.class_names, self.num_classes = get_classes(self.classes_path)
|
| 62 |
+
self.generate()
|
| 63 |
+
|
| 64 |
+
#---------------------------------------------------#
|
| 65 |
+
# 载入模型
|
| 66 |
+
#---------------------------------------------------#
|
| 67 |
+
def generate(self):
|
| 68 |
+
model_path = os.path.expanduser(self.model_path)
|
| 69 |
+
assert model_path.endswith('.h5'), 'Keras model or weights must be a .h5 file.'
|
| 70 |
+
|
| 71 |
+
#---------------------------------------------------#
|
| 72 |
+
# 载入模型与权值
|
| 73 |
+
#---------------------------------------------------#
|
| 74 |
+
if self.backbone == "mobilenet":
|
| 75 |
+
self.model = get_model_from_name[self.backbone](input_shape = [self.input_shape[0], self.input_shape[1], 3], classes = self.num_classes, alpha = self.alpha)
|
| 76 |
+
else:
|
| 77 |
+
self.model = get_model_from_name[self.backbone](input_shape = [self.input_shape[0], self.input_shape[1], 3], classes = self.num_classes)
|
| 78 |
+
self.model.load_weights(self.model_path)
|
| 79 |
+
print('{} model, and classes {} loaded.'.format(model_path, self.class_names))
|
| 80 |
+
|
| 81 |
+
#---------------------------------------------------#
|
| 82 |
+
# 检测图片
|
| 83 |
+
#---------------------------------------------------#
|
| 84 |
+
def detect_image(self, image):
|
| 85 |
+
#---------------------------------------------------------#
|
| 86 |
+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
| 87 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
| 88 |
+
#---------------------------------------------------------#
|
| 89 |
+
image = cvtColor(image)
|
| 90 |
+
# 查看数据类型
|
| 91 |
+
# print(type(image))
|
| 92 |
+
#---------------------------------------------------#
|
| 93 |
+
# 对图片进行不失真的resize
|
| 94 |
+
#---------------------------------------------------#
|
| 95 |
+
image_data = letterbox_image(image, [self.input_shape[1], self.input_shape[0]])
|
| 96 |
+
#---------------------------------------------------------#
|
| 97 |
+
# 归一化+添加上batch_size维度
|
| 98 |
+
#---------------------------------------------------------#
|
| 99 |
+
image_data = np.expand_dims(preprocess_input(np.array(image_data, np.float32)), 0)
|
| 100 |
+
|
| 101 |
+
#---------------------------------------------------#
|
| 102 |
+
# 图片传入网络进行预测
|
| 103 |
+
#---------------------------------------------------#
|
| 104 |
+
preds = self.model.predict(image_data)[0]
|
| 105 |
+
#---------------------------------------------------#
|
| 106 |
+
# 获得所属种类
|
| 107 |
+
#---------------------------------------------------#
|
| 108 |
+
class_name = self.class_names[np.argmax(preds)]
|
| 109 |
+
probability = np.max(preds)
|
| 110 |
+
|
| 111 |
+
#---------------------------------------------------#
|
| 112 |
+
# 绘图并写字
|
| 113 |
+
#---------------------------------------------------#
|
| 114 |
+
|
| 115 |
+
# plt.subplot(1, 1, 1)
|
| 116 |
+
# plt.imshow(np.array(image))
|
| 117 |
+
# plt.title('Class:%s Probability:%.3f' %(class_name, probability))
|
| 118 |
+
# plt.show()
|
| 119 |
+
|
| 120 |
+
return class_name, probability
|
src/model_data/cls_classes.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
disgust
|
| 2 |
+
fear
|
| 3 |
+
happiness
|
| 4 |
+
others
|
| 5 |
+
repression
|
| 6 |
+
sadness
|
| 7 |
+
surprise
|
src/model_data/haarcascade_frontalface_alt.xml
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/model_data/mobilenet_2_5_224_tf_no_top.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dbdb03ee2a22fd895301636cd328b234bb3a9952358f436d82f46b81e0d5b0bf
|
| 3 |
+
size 2108140
|
src/model_data/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bfe5187d0a272bed55ba430631598124cff8e880b98d38c9e56c8d66032abdc1
|
| 3 |
+
size 58889256
|
src/nets/Loss.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from keras import backend as K
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
def multi_category_focal_loss2(gamma=2., alpha=1):
|
| 7 |
+
"""
|
| 8 |
+
focal loss for multi category of multi label problem
|
| 9 |
+
适用于多分类或多标签问题的focal loss
|
| 10 |
+
alpha控制真值y_true为1/0时的权重
|
| 11 |
+
1的权重为alpha, 0的权重为1-alpha
|
| 12 |
+
当你的模型欠拟合,学习存在困难时,可以尝试适用本函数作为loss
|
| 13 |
+
当模型过于激进(无论何时总是倾向于预测出1),尝试将alpha调小
|
| 14 |
+
当模型过于惰性(无论何时总是倾向于预测出0,或是某一个固定的常数,说明没有学到有效特征)
|
| 15 |
+
尝试将alpha调大,鼓励模型进行预测出1。
|
| 16 |
+
Usage:
|
| 17 |
+
model.compile(loss=[multi_category_focal_loss2(alpha=0.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
|
| 18 |
+
"""
|
| 19 |
+
epsilon = 1.e-7
|
| 20 |
+
gamma = float(gamma)
|
| 21 |
+
alpha = tf.constant(alpha, dtype=tf.float32)
|
| 22 |
+
|
| 23 |
+
def multi_category_focal_loss2_fixed(y_true, y_pred):
|
| 24 |
+
y_true = tf.cast(y_true, tf.float32)
|
| 25 |
+
y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
|
| 26 |
+
|
| 27 |
+
alpha_t = y_true * alpha + (tf.ones_like(y_true) - y_true) * (1 - alpha)
|
| 28 |
+
y_t = tf.multiply(y_true, y_pred) + tf.multiply(1 - y_true, 1 - y_pred)
|
| 29 |
+
ce = -tf.log(y_t)
|
| 30 |
+
weight = tf.pow(tf.subtract(1., y_t), gamma)
|
| 31 |
+
fl = tf.multiply(tf.multiply(weight, ce), alpha_t)
|
| 32 |
+
loss = tf.reduce_mean(fl)
|
| 33 |
+
return loss
|
| 34 |
+
|
| 35 |
+
return multi_category_focal_loss2_fixed
|
| 36 |
+
|
| 37 |
+
def multi_category_focal_loss1(alpha, gamma=2.0):
|
| 38 |
+
"""
|
| 39 |
+
focal loss for multi category of multi label problem
|
| 40 |
+
适用于多分类或多标签问题的focal loss
|
| 41 |
+
alpha用于指定不同类别/标签的权重,数组大小需要与类别个数一致
|
| 42 |
+
当你的数据集不同类别/标签之间存在偏斜,可以尝试适用本函数作为loss
|
| 43 |
+
Usage:
|
| 44 |
+
model.compile(loss=[multi_category_focal_loss1(alpha=[1,2,3,2], gamma=2)], metrics=["accuracy"], optimizer=adam)
|
| 45 |
+
"""
|
| 46 |
+
epsilon = 1.e-7
|
| 47 |
+
alpha = tf.constant(alpha, dtype=tf.float32)
|
| 48 |
+
#alpha = tf.constant([[1],[1],[1],[1],[1]], dtype=tf.float32)
|
| 49 |
+
#alpha = tf.constant_initializer(alpha)
|
| 50 |
+
gamma = float(gamma)
|
| 51 |
+
def multi_category_focal_loss1_fixed(y_true, y_pred):
|
| 52 |
+
y_true = tf.cast(y_true, tf.float32)
|
| 53 |
+
y_pred = tf.clip_by_value(y_pred, epsilon, 1. - epsilon)
|
| 54 |
+
y_t = tf.multiply(y_true, y_pred) + tf.multiply(1-y_true, 1-y_pred)
|
| 55 |
+
ce = -tf.log(y_t)
|
| 56 |
+
weight = tf.pow(tf.subtract(1., y_t), gamma)
|
| 57 |
+
fl = tf.matmul(tf.multiply(weight, ce), alpha)
|
| 58 |
+
loss = tf.reduce_mean(fl)
|
| 59 |
+
return loss
|
| 60 |
+
return multi_category_focal_loss1_fixed
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def Cross_entropy_loss(y_true, y_pred):
|
| 64 |
+
'''
|
| 65 |
+
:param y_true: ont-hot encoding ,shape is [batch_size,nums_classes]
|
| 66 |
+
:param y_pred: shape is [batch_size,nums_classes],each example defined as probability for per class
|
| 67 |
+
:return:shape is [batch_size,], a list include cross_entropy for per example
|
| 68 |
+
'''
|
| 69 |
+
y_pred = K.clip(y_pred, K.epsilon(), 1.0 - K.epsilon())
|
| 70 |
+
crossEntropyLoss = -y_true * tf.log(y_pred)
|
| 71 |
+
|
| 72 |
+
return tf.reduce_sum(crossEntropyLoss, -1)
|
| 73 |
+
|
| 74 |
+
# focal loss with multi label
|
| 75 |
+
def focal_loss(classes_num, gamma=2., alpha=.25, e=0.1):
|
| 76 |
+
# classes_num contains sample number of each classes
|
| 77 |
+
def focal_loss_fixed(target_tensor, prediction_tensor):
|
| 78 |
+
'''
|
| 79 |
+
prediction_tensor is the output tensor with shape [None, 100], where 100 is the number of classes
|
| 80 |
+
target_tensor is the label tensor, same shape as predcition_tensor
|
| 81 |
+
'''
|
| 82 |
+
import tensorflow as tf
|
| 83 |
+
from tensorflow.python.ops import array_ops
|
| 84 |
+
from keras import backend as K
|
| 85 |
+
|
| 86 |
+
#1# get focal loss with no balanced weight which presented in paper function (4)
|
| 87 |
+
zeros = array_ops.zeros_like(prediction_tensor, dtype=prediction_tensor.dtype)
|
| 88 |
+
one_minus_p = array_ops.where(tf.greater(target_tensor,zeros), target_tensor - prediction_tensor, zeros)
|
| 89 |
+
FT = -1 * (one_minus_p ** gamma) * tf.log(tf.clip_by_value(prediction_tensor, 1e-8, 1.0))
|
| 90 |
+
|
| 91 |
+
#2# get balanced weight alpha
|
| 92 |
+
classes_weight = array_ops.zeros_like(prediction_tensor, dtype=prediction_tensor.dtype)
|
| 93 |
+
|
| 94 |
+
total_num = float(sum(classes_num))
|
| 95 |
+
classes_w_t1 = [ total_num / ff for ff in classes_num ]
|
| 96 |
+
sum_ = sum(classes_w_t1)
|
| 97 |
+
classes_w_t2 = [ ff/sum_ for ff in classes_w_t1 ] #scale
|
| 98 |
+
classes_w_tensor = tf.convert_to_tensor(classes_w_t2, dtype=prediction_tensor.dtype)
|
| 99 |
+
classes_weight += classes_w_tensor
|
| 100 |
+
|
| 101 |
+
alpha = array_ops.where(tf.greater(target_tensor, zeros), classes_weight, zeros)
|
| 102 |
+
|
| 103 |
+
#3# get balanced focal loss
|
| 104 |
+
balanced_fl = alpha * FT
|
| 105 |
+
balanced_fl = tf.reduce_mean(balanced_fl)
|
| 106 |
+
|
| 107 |
+
#4# add other op to prevent overfit
|
| 108 |
+
# reference : https://spaces.ac.cn/archives/4493
|
| 109 |
+
nb_classes = len(classes_num)
|
| 110 |
+
fianal_loss = (1-e) * balanced_fl + e * K.categorical_crossentropy(K.ones_like(prediction_tensor)/nb_classes, prediction_tensor)
|
| 111 |
+
|
| 112 |
+
return fianal_loss
|
| 113 |
+
return focal_loss_fixed
|
src/nets/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mobilenet import MobileNet
|
| 2 |
+
from .resnet50 import ResNet50
|
| 3 |
+
from .vgg16 import VGG16
|
| 4 |
+
|
| 5 |
+
get_model_from_name = {
|
| 6 |
+
"mobilenet" : MobileNet,
|
| 7 |
+
"resnet50" : ResNet50,
|
| 8 |
+
"vgg16" : VGG16,
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
freeze_layers = {
|
| 12 |
+
"mobilenet" : 81,
|
| 13 |
+
"resnet50" : 173,
|
| 14 |
+
"vgg16" : 19,
|
| 15 |
+
"cspdarknet53" : 60,
|
| 16 |
+
}
|
src/nets/mobilenet.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras import backend as K
|
| 2 |
+
from keras.layers import (Activation, BatchNormalization, Conv2D,
|
| 3 |
+
DepthwiseConv2D, Dropout, GlobalAveragePooling2D,
|
| 4 |
+
Input, Reshape)
|
| 5 |
+
from keras.models import Model
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
|
| 9 |
+
filters = int(filters * alpha)
|
| 10 |
+
x = Conv2D(filters, kernel,
|
| 11 |
+
padding='same',
|
| 12 |
+
use_bias=False,
|
| 13 |
+
strides=strides,
|
| 14 |
+
name='conv1')(inputs)
|
| 15 |
+
x = BatchNormalization(name='conv1_bn')(x)
|
| 16 |
+
return Activation(relu6, name='conv1_relu')(x)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _depthwise_conv_block(inputs, pointwise_conv_filters, alpha,
|
| 20 |
+
depth_multiplier=1, strides=(1, 1), block_id=1):
|
| 21 |
+
|
| 22 |
+
pointwise_conv_filters = int(pointwise_conv_filters * alpha)
|
| 23 |
+
|
| 24 |
+
x = DepthwiseConv2D((3, 3),
|
| 25 |
+
padding='same',
|
| 26 |
+
depth_multiplier=depth_multiplier,
|
| 27 |
+
strides=strides,
|
| 28 |
+
use_bias=False,
|
| 29 |
+
name='conv_dw_%d' % block_id)(inputs)
|
| 30 |
+
|
| 31 |
+
x = BatchNormalization(name='conv_dw_%d_bn' % block_id)(x)
|
| 32 |
+
x = Activation(relu6, name='conv_dw_%d_relu' % block_id)(x)
|
| 33 |
+
|
| 34 |
+
x = Conv2D(pointwise_conv_filters, (1, 1),
|
| 35 |
+
padding='same',
|
| 36 |
+
use_bias=False,
|
| 37 |
+
strides=(1, 1),
|
| 38 |
+
name='conv_pw_%d' % block_id)(x)
|
| 39 |
+
x = BatchNormalization(name='conv_pw_%d_bn' % block_id)(x)
|
| 40 |
+
return Activation(relu6, name='conv_pw_%d_relu' % block_id)(x)
|
| 41 |
+
|
| 42 |
+
def MobileNet(input_shape=None,
|
| 43 |
+
alpha=1.0,
|
| 44 |
+
depth_multiplier=1,
|
| 45 |
+
dropout=1e-3,
|
| 46 |
+
classes=1000):
|
| 47 |
+
|
| 48 |
+
img_input = Input(shape=input_shape)
|
| 49 |
+
|
| 50 |
+
# 224,224,3 -> 112,112,32
|
| 51 |
+
x = _conv_block(img_input, 32, alpha, strides=(2, 2))
|
| 52 |
+
|
| 53 |
+
# 112,112,32 -> 112,112,64
|
| 54 |
+
x = _depthwise_conv_block(x, 64, alpha, depth_multiplier, block_id=1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# 112,112,64 -> 56,56,128
|
| 58 |
+
x = _depthwise_conv_block(x, 128, alpha, depth_multiplier,
|
| 59 |
+
strides=(2, 2), block_id=2)
|
| 60 |
+
x = _depthwise_conv_block(x, 128, alpha, depth_multiplier, block_id=3)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# 56,56,128 -> 28,28,256
|
| 64 |
+
x = _depthwise_conv_block(x, 256, alpha, depth_multiplier,
|
| 65 |
+
strides=(2, 2), block_id=4)
|
| 66 |
+
x = _depthwise_conv_block(x, 256, alpha, depth_multiplier, block_id=5)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# 28,28,256 -> 14,14,512
|
| 70 |
+
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier,
|
| 71 |
+
strides=(2, 2), block_id=6)
|
| 72 |
+
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=7)
|
| 73 |
+
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=8)
|
| 74 |
+
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=9)
|
| 75 |
+
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=10)
|
| 76 |
+
x = _depthwise_conv_block(x, 512, alpha, depth_multiplier, block_id=11)
|
| 77 |
+
|
| 78 |
+
# 14,14,512 -> 7,7,1024
|
| 79 |
+
x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier,
|
| 80 |
+
strides=(2, 2), block_id=12)
|
| 81 |
+
x = _depthwise_conv_block(x, 1024, alpha, depth_multiplier, block_id=13)
|
| 82 |
+
|
| 83 |
+
# 7,7,1024 -> 1,1,1024
|
| 84 |
+
x = GlobalAveragePooling2D()(x)
|
| 85 |
+
|
| 86 |
+
shape = (1, 1, int(1024 * alpha))
|
| 87 |
+
|
| 88 |
+
x = Reshape(shape, name='reshape_1')(x)
|
| 89 |
+
x = Dropout(dropout, name='dropout')(x)
|
| 90 |
+
|
| 91 |
+
x = Conv2D(classes, (1, 1),padding='same', name='conv_preds')(x)
|
| 92 |
+
x = Activation('softmax', name='act_softmax')(x)
|
| 93 |
+
x = Reshape((classes,), name='reshape_2')(x)
|
| 94 |
+
|
| 95 |
+
inputs = img_input
|
| 96 |
+
|
| 97 |
+
model = Model(inputs, x, name='mobilenet_%0.2f' % (alpha))
|
| 98 |
+
return model
|
| 99 |
+
|
| 100 |
+
def relu6(x):
|
| 101 |
+
return K.relu(x, max_value=6)
|
| 102 |
+
|
| 103 |
+
if __name__ == '__main__':
|
| 104 |
+
model = MobileNet(input_shape=(224, 224, 3))
|
| 105 |
+
model.summary()
|
src/nets/resnet50.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras import layers
|
| 2 |
+
from keras.layers import (Activation, AveragePooling2D, BatchNormalization,
|
| 3 |
+
Conv2D, Dense, Flatten, Input, MaxPooling2D,
|
| 4 |
+
ZeroPadding2D)
|
| 5 |
+
from keras.models import Model
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def identity_block(input_tensor, kernel_size, filters, stage, block):
|
| 9 |
+
|
| 10 |
+
filters1, filters2, filters3 = filters
|
| 11 |
+
|
| 12 |
+
conv_name_base = 'res' + str(stage) + block + '_branch'
|
| 13 |
+
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
| 14 |
+
|
| 15 |
+
# 减少通道数
|
| 16 |
+
x = Conv2D(filters1, (1, 1), name=conv_name_base + '2a')(input_tensor)
|
| 17 |
+
x = BatchNormalization(name=bn_name_base + '2a')(x)
|
| 18 |
+
x = Activation('relu')(x)
|
| 19 |
+
|
| 20 |
+
# 3x3卷积
|
| 21 |
+
x = Conv2D(filters2, kernel_size,padding='same', name=conv_name_base + '2b')(x)
|
| 22 |
+
x = BatchNormalization(name=bn_name_base + '2b')(x)
|
| 23 |
+
x = Activation('relu')(x)
|
| 24 |
+
|
| 25 |
+
# 上升通道数
|
| 26 |
+
x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
|
| 27 |
+
x = BatchNormalization(name=bn_name_base + '2c')(x)
|
| 28 |
+
|
| 29 |
+
x = layers.add([x, input_tensor])
|
| 30 |
+
x = Activation('relu')(x)
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
|
| 35 |
+
filters1, filters2, filters3 = filters
|
| 36 |
+
|
| 37 |
+
conv_name_base = 'res' + str(stage) + block + '_branch'
|
| 38 |
+
bn_name_base = 'bn' + str(stage) + block + '_branch'
|
| 39 |
+
|
| 40 |
+
# 减少通道数
|
| 41 |
+
x = Conv2D(filters1, (1, 1), strides=strides, name=conv_name_base + '2a')(input_tensor)
|
| 42 |
+
x = BatchNormalization(name=bn_name_base + '2a')(x)
|
| 43 |
+
x = Activation('relu')(x)
|
| 44 |
+
|
| 45 |
+
# 3x3卷积
|
| 46 |
+
x = Conv2D(filters2, kernel_size, padding='same', name=conv_name_base + '2b')(x)
|
| 47 |
+
x = BatchNormalization(name=bn_name_base + '2b')(x)
|
| 48 |
+
x = Activation('relu')(x)
|
| 49 |
+
|
| 50 |
+
# 上升通道数
|
| 51 |
+
x = Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
|
| 52 |
+
x = BatchNormalization(name=bn_name_base + '2c')(x)
|
| 53 |
+
|
| 54 |
+
# 残差边
|
| 55 |
+
shortcut = Conv2D(filters3, (1, 1), strides=strides,
|
| 56 |
+
name=conv_name_base + '1')(input_tensor)
|
| 57 |
+
shortcut = BatchNormalization(name=bn_name_base + '1')(shortcut)
|
| 58 |
+
|
| 59 |
+
x = layers.add([x, shortcut])
|
| 60 |
+
x = Activation('relu')(x)
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def ResNet50(input_shape=[224,224,3], classes=1000):
|
| 65 |
+
img_input = Input(shape=input_shape)
|
| 66 |
+
|
| 67 |
+
x = ZeroPadding2D((3, 3))(img_input)
|
| 68 |
+
# 224,224,3 -> 112,112,64
|
| 69 |
+
x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1')(x)
|
| 70 |
+
x = BatchNormalization(name='bn_conv1')(x)
|
| 71 |
+
x = Activation('relu')(x)
|
| 72 |
+
|
| 73 |
+
x = ZeroPadding2D((1, 1))(x)
|
| 74 |
+
# 112,112,64 -> 56,56,64
|
| 75 |
+
x = MaxPooling2D((3, 3), strides=(2, 2))(x)
|
| 76 |
+
|
| 77 |
+
# 56,56,64 -> 56,56,256
|
| 78 |
+
x = conv_block(x, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
|
| 79 |
+
x = identity_block(x, 3, [64, 64, 256], stage=2, block='b')
|
| 80 |
+
x = identity_block(x, 3, [64, 64, 256], stage=2, block='c')
|
| 81 |
+
|
| 82 |
+
# 56,56,256 -> 28,28,512
|
| 83 |
+
x = conv_block(x, 3, [128, 128, 512], stage=3, block='a')
|
| 84 |
+
x = identity_block(x, 3, [128, 128, 512], stage=3, block='b')
|
| 85 |
+
x = identity_block(x, 3, [128, 128, 512], stage=3, block='c')
|
| 86 |
+
x = identity_block(x, 3, [128, 128, 512], stage=3, block='d')
|
| 87 |
+
|
| 88 |
+
# 28,28,512 -> 14,14,1024
|
| 89 |
+
x = conv_block(x, 3, [256, 256, 1024], stage=4, block='a')
|
| 90 |
+
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='b')
|
| 91 |
+
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='c')
|
| 92 |
+
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='d')
|
| 93 |
+
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='e')
|
| 94 |
+
x = identity_block(x, 3, [256, 256, 1024], stage=4, block='f')
|
| 95 |
+
|
| 96 |
+
# 14,14,1024 -> 7,7,2048
|
| 97 |
+
x = conv_block(x, 3, [512, 512, 2048], stage=5, block='a')
|
| 98 |
+
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='b')
|
| 99 |
+
x = identity_block(x, 3, [512, 512, 2048], stage=5, block='c')
|
| 100 |
+
|
| 101 |
+
# 1,1,2048
|
| 102 |
+
x = AveragePooling2D((7, 7), name='avg_pool')(x)
|
| 103 |
+
|
| 104 |
+
# 进行预测
|
| 105 |
+
# 2048
|
| 106 |
+
x = Flatten()(x)
|
| 107 |
+
|
| 108 |
+
# num_classes
|
| 109 |
+
x = Dense(classes, activation='softmax', name='fc1000')(x)
|
| 110 |
+
|
| 111 |
+
model = Model(img_input, x, name='resnet50')
|
| 112 |
+
|
| 113 |
+
return model
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
if __name__ == '__main__':
|
| 117 |
+
model = ResNet50()
|
| 118 |
+
model.summary()
|
src/nets/vgg16.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.layers import Conv2D, Dense, Flatten, Input, MaxPooling2D
|
| 2 |
+
from keras.models import Model #导入包Conv2D是卷积核 Flatten是展开 Input输入 MaxPooling2D最大卷积核
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def VGG16(input_shape=None, classes=1000): #def 就是开始定义VGG16的网络
|
| 6 |
+
img_input = Input(shape=input_shape) # 224, 224, 3
|
| 7 |
+
|
| 8 |
+
# Block 1
|
| 9 |
+
# 224, 224, 3 -> 224, 224, 64
|
| 10 |
+
x = Conv2D(64, (3, 3), #开始第一个卷积核进行特征提取,64为卷积核的个数;(3, 3)是卷积核的大小
|
| 11 |
+
activation='relu', #relu激活函数
|
| 12 |
+
padding='same', #padding='same' 尺寸不变
|
| 13 |
+
name='block1_conv1')(img_input) #dui juanjihe jinxing mingming # x= 224*224*64
|
| 14 |
+
x = Conv2D(64, (3, 3),
|
| 15 |
+
activation='relu',
|
| 16 |
+
padding='same',
|
| 17 |
+
name='block1_conv2')(x) #224*224*64
|
| 18 |
+
|
| 19 |
+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(x) #112*122*64
|
| 20 |
+
|
| 21 |
+
# Block 2
|
| 22 |
+
|
| 23 |
+
x = Conv2D(128, (3, 3),
|
| 24 |
+
activation='relu',
|
| 25 |
+
padding='same',
|
| 26 |
+
name='block2_conv1')(x) #112*112*128
|
| 27 |
+
x = Conv2D(128, (3, 3),
|
| 28 |
+
activation='relu',
|
| 29 |
+
padding='same',
|
| 30 |
+
name='block2_conv2')(x)
|
| 31 |
+
|
| 32 |
+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(x)
|
| 33 |
+
|
| 34 |
+
# Block 3
|
| 35 |
+
|
| 36 |
+
x = Conv2D(256, (3, 3),
|
| 37 |
+
activation='relu',
|
| 38 |
+
padding='same',
|
| 39 |
+
name='block3_conv1')(x)
|
| 40 |
+
x = Conv2D(256, (3, 3),
|
| 41 |
+
activation='relu',
|
| 42 |
+
padding='same',
|
| 43 |
+
name='block3_conv2')(x)
|
| 44 |
+
x = Conv2D(256, (3, 3),
|
| 45 |
+
activation='relu',
|
| 46 |
+
padding='same',
|
| 47 |
+
name='block3_conv3')(x)
|
| 48 |
+
|
| 49 |
+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(x)
|
| 50 |
+
|
| 51 |
+
# Block 4
|
| 52 |
+
|
| 53 |
+
x = Conv2D(512, (3, 3),
|
| 54 |
+
activation='relu',
|
| 55 |
+
padding='same',
|
| 56 |
+
name='block4_conv1')(x)
|
| 57 |
+
x = Conv2D(512, (3, 3),
|
| 58 |
+
activation='relu',
|
| 59 |
+
padding='same',
|
| 60 |
+
name='block4_conv2')(x)
|
| 61 |
+
x = Conv2D(512, (3, 3),
|
| 62 |
+
activation='relu',
|
| 63 |
+
padding='same',
|
| 64 |
+
name='block4_conv3')(x)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(x)
|
| 68 |
+
|
| 69 |
+
# Block 5
|
| 70 |
+
|
| 71 |
+
x = Conv2D(512, (3, 3),
|
| 72 |
+
activation='relu',
|
| 73 |
+
padding='same',
|
| 74 |
+
name='block5_conv1')(x)
|
| 75 |
+
x = Conv2D(512, (3, 3),
|
| 76 |
+
activation='relu',
|
| 77 |
+
padding='same',
|
| 78 |
+
name='block5_conv2')(x)
|
| 79 |
+
x = Conv2D(512, (3, 3),
|
| 80 |
+
activation='relu',
|
| 81 |
+
padding='same',
|
| 82 |
+
name='block5_conv3')(x)
|
| 83 |
+
# 14, 14, 512 -> 7, 7, 512
|
| 84 |
+
x = MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')(x)
|
| 85 |
+
|
| 86 |
+
x = Flatten(name='flatten')(x)
|
| 87 |
+
x = Dense(4096, activation='relu', name='fc1')(x)
|
| 88 |
+
x = Dense(4096, activation='relu', name='fc2')(x)
|
| 89 |
+
x = Dense(classes, activation='softmax', name='predictions')(x) #激活函数
|
| 90 |
+
|
| 91 |
+
inputs = img_input
|
| 92 |
+
|
| 93 |
+
model = Model(inputs, x, name='vgg16')
|
| 94 |
+
return model
|
| 95 |
+
|
| 96 |
+
if __name__ == '__main__':
|
| 97 |
+
model = VGG16(input_shape=(224, 224, 3))
|
| 98 |
+
model.summary()
|
src/streamlit_app.py
CHANGED
|
@@ -1,40 +1,173 @@
|
|
| 1 |
-
import
|
|
|
|
|
|
|
|
|
|
| 2 |
import numpy as np
|
| 3 |
-
|
|
|
|
| 4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import subprocess
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
import streamlit as st
|
| 9 |
+
from streamlit_webrtc import VideoProcessorBase, webrtc_streamer
|
| 10 |
+
|
| 11 |
+
from classification import Classification
|
| 12 |
+
|
| 13 |
+
classificator = Classification()
|
| 14 |
+
face_cascade = cv2.CascadeClassifier(
|
| 15 |
+
os.path.join('model_data', 'haarcascade_frontalface_alt.xml')
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Streamlit Title
|
| 19 |
+
st.title("Real-Time Micro-Emotion Recognition")
|
| 20 |
+
|
| 21 |
+
# Only Live Emotion Detection Mode
|
| 22 |
+
st.write("Turn on your camera and detect emotions in real-time.")
|
| 23 |
+
|
| 24 |
+
# Camera selection UI
|
| 25 |
+
st.sidebar.header("Camera Settings")
|
| 26 |
+
def get_connected_cameras():
|
| 27 |
+
try:
|
| 28 |
+
result = subprocess.run(
|
| 29 |
+
['v4l2-ctl', '--list-devices'],
|
| 30 |
+
capture_output=True,
|
| 31 |
+
text=True,
|
| 32 |
+
check=True)
|
| 33 |
+
devices = result.stdout.split('\n\n')
|
| 34 |
+
camera_indices = []
|
| 35 |
+
for device in devices:
|
| 36 |
+
if "Camera" in device or "camera" in device:
|
| 37 |
+
lines = device.split('\n')
|
| 38 |
+
if len(lines) > 1:
|
| 39 |
+
index_line = lines[1]
|
| 40 |
+
index_str = index_line.strip().split(':')[0].strip()
|
| 41 |
+
try:
|
| 42 |
+
index = int(index_str[4:])
|
| 43 |
+
camera_indices.append(index)
|
| 44 |
+
except (ValueError, IndexError):
|
| 45 |
+
pass
|
| 46 |
+
return camera_indices
|
| 47 |
+
except FileNotFoundError:
|
| 48 |
+
return [0] # Fallback to default camera if v4l2-ctl is not available
|
| 49 |
+
except subprocess.CalledProcessError:
|
| 50 |
+
return [0]
|
| 51 |
+
|
| 52 |
+
available_cameras = get_connected_cameras()
|
| 53 |
+
|
| 54 |
+
if len(available_cameras) > 1:
|
| 55 |
+
camera_index = st.sidebar.selectbox(
|
| 56 |
+
"Select Camera Index",
|
| 57 |
+
options=available_cameras,
|
| 58 |
+
index=0,
|
| 59 |
+
format_func=lambda x: f"Camera {x}"
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
camera_index = 0
|
| 63 |
+
st.sidebar.write("Only one camera detected. Using default camera.")
|
| 64 |
+
|
| 65 |
+
# --- Face detection and augmentation functions ---
|
| 66 |
+
def face_detect(img):
|
| 67 |
+
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 68 |
+
faces = face_cascade.detectMultiScale(
|
| 69 |
+
img_gray,
|
| 70 |
+
scaleFactor=1.1,
|
| 71 |
+
minNeighbors=1,
|
| 72 |
+
minSize=(30, 30)
|
| 73 |
+
)
|
| 74 |
+
return img, img_gray, faces
|
| 75 |
+
|
| 76 |
+
# --- Emotion class mapping ---
|
| 77 |
+
def map_emotion_to_class(emotion):
|
| 78 |
+
positive = ['happiness', 'happy']
|
| 79 |
+
negative = ['disgust', 'sadness', 'fear', 'sad', 'angry', 'disgusted']
|
| 80 |
+
surprise = ['surprise']
|
| 81 |
+
others = ['repression', 'tense', 'neutral', 'others']
|
| 82 |
+
e = emotion.lower()
|
| 83 |
+
if any(p in e for p in positive):
|
| 84 |
+
return 'Positive'
|
| 85 |
+
elif any(n in e for n in negative):
|
| 86 |
+
return 'Negative'
|
| 87 |
+
elif any(s in e for s in surprise):
|
| 88 |
+
return 'Surprise'
|
| 89 |
+
else:
|
| 90 |
+
return 'Others'
|
| 91 |
+
|
| 92 |
+
# --- Streamlit session state for emotion tracking ---
|
| 93 |
+
if 'emotion_history' not in st.session_state:
|
| 94 |
+
st.session_state['emotion_history'] = []
|
| 95 |
+
|
| 96 |
+
# Video Processing Class
|
| 97 |
+
class EmotionRecognitionProcessor(VideoProcessorBase):
|
| 98 |
+
def __init__(self):
|
| 99 |
+
self.last_class = None
|
| 100 |
+
self.rapid_change_count = 0
|
| 101 |
+
|
| 102 |
+
def recv(self, frame):
|
| 103 |
+
border_color = (255, 0, 0) # Rectangle color (blue in BGR)
|
| 104 |
+
font_color = (0, 0, 255) # Text color (red in BGR)
|
| 105 |
+
img = frame.to_ndarray(format="bgr24")
|
| 106 |
+
img_disp, img_gray, faces = face_detect(img)
|
| 107 |
+
current_class = None
|
| 108 |
+
|
| 109 |
+
if len(faces) == 0:
|
| 110 |
+
cv2.putText(
|
| 111 |
+
img_disp, 'No Face Detect.', (2, 20),
|
| 112 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 1
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
for (x, y, w, h) in faces:
|
| 116 |
+
x1, y1 = max(x - 10, 0), max(y - 10, 0)
|
| 117 |
+
x2 = min(x + w + 10, img_disp.shape[1])
|
| 118 |
+
y2 = min(y + h + 10, img_disp.shape[0])
|
| 119 |
+
|
| 120 |
+
face_img_gray = img_gray[y1:y2, x1:x2]
|
| 121 |
+
if face_img_gray.size == 0:
|
| 122 |
+
continue
|
| 123 |
+
face_img_pil = Image.fromarray(face_img_gray)
|
| 124 |
+
emotion, probability = classificator.detect_image(face_img_pil)
|
| 125 |
+
emotion_class = map_emotion_to_class(emotion)
|
| 126 |
+
|
| 127 |
+
cv2.rectangle(
|
| 128 |
+
img_disp,
|
| 129 |
+
(x1, y1),
|
| 130 |
+
(x2, y2),
|
| 131 |
+
border_color,
|
| 132 |
+
thickness=2
|
| 133 |
+
)
|
| 134 |
+
cv2.putText(
|
| 135 |
+
img_disp, emotion, (x + 30, y - 30),
|
| 136 |
+
cv2.FONT_HERSHEY_SIMPLEX, 1, font_color, 1
|
| 137 |
+
)
|
| 138 |
+
# Show probability
|
| 139 |
+
cv2.putText(
|
| 140 |
+
img_disp, str(round(probability, 3)), (x + 30, y - 50),
|
| 141 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.3, font_color, 1
|
| 142 |
+
)
|
| 143 |
+
current_class = emotion_class
|
| 144 |
+
|
| 145 |
+
# Track emotion class changes
|
| 146 |
+
if current_class:
|
| 147 |
+
history = st.session_state['emotion_history']
|
| 148 |
+
history.append(current_class)
|
| 149 |
+
if len(history) > 10:
|
| 150 |
+
history.pop(0)
|
| 151 |
+
# Detect rapid changes
|
| 152 |
+
if len(history) >= 3 and len(set(history[-3:])) > 1:
|
| 153 |
+
self.rapid_change_count += 1
|
| 154 |
+
else:
|
| 155 |
+
self.rapid_change_count = 0
|
| 156 |
+
|
| 157 |
+
return frame.from_ndarray(img_disp, format="bgr24")
|
| 158 |
+
|
| 159 |
+
webrtc_streamer(
|
| 160 |
+
key="emotion-detection",
|
| 161 |
+
video_processor_factory=EmotionRecognitionProcessor,
|
| 162 |
+
)
|
| 163 |
|
| 164 |
+
# --- Streamlit alert for rapid emotion changes ---
|
| 165 |
+
history = st.session_state['emotion_history']
|
| 166 |
+
if len(history) >= 3 and len(set(history[-3:])) > 1:
|
| 167 |
+
st.warning(
|
| 168 |
+
"⚠️ Rapid changes in your detected emotional state were observed. "
|
| 169 |
+
"Micro-expressions may not always reflect your true feelings. "
|
| 170 |
+
"If you feel emotionally unstable or distressed, " \
|
| 171 |
+
"consider reaching out to a mental health professional, "
|
| 172 |
+
"talking it over with a close person or taking a break."
|
| 173 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#
|
src/utils/backend/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .tensorflow_backend import * # noqa: F401,F403
|
src/utils/backend/tensorflow_backend.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def disable_tensorflow_v2_behavior():
|
| 5 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/compat/v1/disable_tensorflow_v2_behavior .
|
| 6 |
+
"""
|
| 7 |
+
tensorflow.compat.v1.disable_v2_behavior()
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def ones(*args, **kwargs):
|
| 11 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/ones .
|
| 12 |
+
"""
|
| 13 |
+
return tensorflow.ones(*args, **kwargs)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def transpose(*args, **kwargs):
|
| 17 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/transpose .
|
| 18 |
+
"""
|
| 19 |
+
return tensorflow.transpose(*args, **kwargs)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def map_fn(*args, **kwargs):
|
| 23 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/map_fn .
|
| 24 |
+
"""
|
| 25 |
+
return tensorflow.map_fn(*args, **kwargs)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def pad(*args, **kwargs):
|
| 29 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/pad .
|
| 30 |
+
"""
|
| 31 |
+
return tensorflow.pad(*args, **kwargs)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def top_k(*args, **kwargs):
|
| 35 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/nn/top_k .
|
| 36 |
+
"""
|
| 37 |
+
return tensorflow.nn.top_k(*args, **kwargs)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def clip_by_value(*args, **kwargs):
|
| 41 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/clip_by_value .
|
| 42 |
+
"""
|
| 43 |
+
return tensorflow.clip_by_value(*args, **kwargs)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def resize_images(images, size, method='bilinear', align_corners=False):
|
| 47 |
+
""" See https://www.tensorflow.org/versions/r1.14/api_docs/python/tf/image/resize_images .
|
| 48 |
+
|
| 49 |
+
Args
|
| 50 |
+
method: The method used for interpolation. One of ('bilinear', 'nearest', 'bicubic', 'area').
|
| 51 |
+
"""
|
| 52 |
+
methods = {
|
| 53 |
+
'bilinear': tensorflow.image.ResizeMethod.BILINEAR,
|
| 54 |
+
'nearest' : tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR,
|
| 55 |
+
'bicubic' : tensorflow.image.ResizeMethod.BICUBIC,
|
| 56 |
+
'area' : tensorflow.image.ResizeMethod.AREA,
|
| 57 |
+
}
|
| 58 |
+
return tensorflow.compat.v1.image.resize_images(images, size, tensorflow.image.ResizeMethod.NEAREST_NEIGHBOR, False)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def non_max_suppression(*args, **kwargs):
|
| 62 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/image/non_max_suppression .
|
| 63 |
+
"""
|
| 64 |
+
return tensorflow.image.non_max_suppression(*args, **kwargs)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def range(*args, **kwargs):
|
| 68 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/range .
|
| 69 |
+
"""
|
| 70 |
+
return tensorflow.range(*args, **kwargs)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def scatter_nd(*args, **kwargs):
|
| 74 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/scatter_nd .
|
| 75 |
+
"""
|
| 76 |
+
return tensorflow.scatter_nd(*args, **kwargs)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def gather_nd(*args, **kwargs):
|
| 80 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/gather_nd .
|
| 81 |
+
"""
|
| 82 |
+
return tensorflow.gather_nd(*args, **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def meshgrid(*args, **kwargs):
|
| 86 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/meshgrid .
|
| 87 |
+
"""
|
| 88 |
+
return tensorflow.meshgrid(*args, **kwargs)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def where(*args, **kwargs):
|
| 92 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/where .
|
| 93 |
+
"""
|
| 94 |
+
return tensorflow.where(*args, **kwargs)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def unstack(*args, **kwargs):
|
| 98 |
+
""" See https://www.tensorflow.org/api_docs/python/tf/unstack .
|
| 99 |
+
"""
|
| 100 |
+
return tensorflow.unstack(*args, **kwargs)
|
src/utils/callbacks.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import keras
|
| 4 |
+
import matplotlib
|
| 5 |
+
matplotlib.use('Agg')
|
| 6 |
+
from matplotlib import pyplot as plt
|
| 7 |
+
import scipy.signal
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LossHistory(keras.callbacks.Callback):
|
| 12 |
+
def __init__(self, log_dir):
|
| 13 |
+
import datetime
|
| 14 |
+
curr_time = datetime.datetime.now()
|
| 15 |
+
time_str = datetime.datetime.strftime(curr_time,'%Y_%m_%d_%H_%M_%S')
|
| 16 |
+
self.log_dir = log_dir
|
| 17 |
+
self.time_str = time_str
|
| 18 |
+
self.save_path = os.path.join(self.log_dir, "loss_" + str(self.time_str))
|
| 19 |
+
self.losses = []
|
| 20 |
+
self.val_loss = []
|
| 21 |
+
|
| 22 |
+
os.makedirs(self.save_path)
|
| 23 |
+
|
| 24 |
+
def on_epoch_end(self, batch, logs={}):
|
| 25 |
+
self.losses.append(logs.get('loss'))
|
| 26 |
+
self.val_loss.append(logs.get('val_loss'))
|
| 27 |
+
with open(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".txt"), 'a') as f:
|
| 28 |
+
f.write(str(logs.get('loss')))
|
| 29 |
+
f.write("\n")
|
| 30 |
+
with open(os.path.join(self.save_path, "epoch_val_loss_" + str(self.time_str) + ".txt"), 'a') as f:
|
| 31 |
+
f.write(str(logs.get('val_loss')))
|
| 32 |
+
f.write("\n")
|
| 33 |
+
self.loss_plot()
|
| 34 |
+
|
| 35 |
+
def loss_plot(self):
|
| 36 |
+
iters = range(len(self.losses))
|
| 37 |
+
|
| 38 |
+
plt.figure()
|
| 39 |
+
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
|
| 40 |
+
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
|
| 41 |
+
try:
|
| 42 |
+
if len(self.losses) < 25:
|
| 43 |
+
num = 5
|
| 44 |
+
else:
|
| 45 |
+
num = 15
|
| 46 |
+
|
| 47 |
+
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
|
| 48 |
+
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
|
| 49 |
+
except:
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
plt.grid(True)
|
| 53 |
+
plt.xlabel('Epoch')
|
| 54 |
+
plt.ylabel('Loss')
|
| 55 |
+
plt.title('A Loss Curve')
|
| 56 |
+
plt.legend(loc="upper right")
|
| 57 |
+
|
| 58 |
+
plt.savefig(os.path.join(self.save_path, "epoch_loss_" + str(self.time_str) + ".png"))
|
| 59 |
+
|
| 60 |
+
plt.cla()
|
| 61 |
+
plt.close("all")
|
| 62 |
+
|
| 63 |
+
class ExponentDecayScheduler(keras.callbacks.Callback):
|
| 64 |
+
def __init__(self,
|
| 65 |
+
decay_rate,
|
| 66 |
+
verbose=0):
|
| 67 |
+
super(ExponentDecayScheduler, self).__init__()
|
| 68 |
+
self.decay_rate = decay_rate
|
| 69 |
+
self.verbose = verbose
|
| 70 |
+
self.learning_rates = []
|
| 71 |
+
|
| 72 |
+
def on_epoch_end(self, batch, logs=None):
|
| 73 |
+
lr = self.model.optimizer.learning_rate
|
| 74 |
+
try:
|
| 75 |
+
current_lr = tf.keras.backend.get_value(lr)
|
| 76 |
+
except Exception:
|
| 77 |
+
current_lr = lr
|
| 78 |
+
|
| 79 |
+
new_lr = current_lr * self.decay_rate
|
| 80 |
+
try:
|
| 81 |
+
tf.keras.backend.set_value(lr, new_lr)
|
| 82 |
+
except Exception:
|
| 83 |
+
print("Warning: Could not set learning rate dynamically.")
|
| 84 |
+
|
| 85 |
+
if self.verbose > 0:
|
| 86 |
+
print('Setting learning rate to %s.' % (new_lr))
|
src/utils/dataloader.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from random import shuffle
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import keras
|
| 6 |
+
import numpy as np
|
| 7 |
+
from keras.utils import to_categorical
|
| 8 |
+
from PIL import Image
|
| 9 |
+
|
| 10 |
+
from .utils import cvtColor, preprocess_input
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ClsDatasets(keras.utils.Sequence):
|
| 14 |
+
def __init__(self, annotation_lines, input_shape, batch_size, num_classes, train, **kwargs):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.annotation_lines = annotation_lines
|
| 17 |
+
self.length = len(self.annotation_lines)
|
| 18 |
+
|
| 19 |
+
self.input_shape = input_shape
|
| 20 |
+
self.batch_size = batch_size
|
| 21 |
+
self.num_classes = num_classes
|
| 22 |
+
self.train = train
|
| 23 |
+
|
| 24 |
+
def __len__(self):
|
| 25 |
+
return int(math.ceil(self.length / float(self.batch_size)))
|
| 26 |
+
|
| 27 |
+
def __getitem__(self, index):
|
| 28 |
+
X_train = []
|
| 29 |
+
Y_train = []
|
| 30 |
+
start = index * self.batch_size
|
| 31 |
+
end = min((index + 1) * self.batch_size, self.length)
|
| 32 |
+
for i in range(start, end):
|
| 33 |
+
annotation_path = self.annotation_lines[i].split(';')[1].split()[0]
|
| 34 |
+
image = Image.open(annotation_path)
|
| 35 |
+
image = self.get_random_data(image, self.input_shape, random=self.train)
|
| 36 |
+
image = preprocess_input(np.array(image).astype(np.float32))
|
| 37 |
+
|
| 38 |
+
X_train.append(image)
|
| 39 |
+
Y_train.append(int(self.annotation_lines[i].split(';')[0]))
|
| 40 |
+
|
| 41 |
+
X_train = np.array(X_train)
|
| 42 |
+
Y_train = to_categorical(np.array(Y_train), num_classes = self.num_classes)
|
| 43 |
+
return X_train, Y_train
|
| 44 |
+
|
| 45 |
+
def on_epoch_end(self):
|
| 46 |
+
if self.train:
|
| 47 |
+
np.random.shuffle(self.annotation_lines)
|
| 48 |
+
|
| 49 |
+
def rand(self, a=0, b=1):
|
| 50 |
+
return np.random.rand()*(b-a) + a
|
| 51 |
+
|
| 52 |
+
def get_random_data(self, image, input_shape, jitter=.3, hue=.1, sat=1.5, val=1.5, random=True):
|
| 53 |
+
#------------------------------#
|
| 54 |
+
# 读取图像并转换成RGB图像
|
| 55 |
+
#------------------------------#
|
| 56 |
+
image = cvtColor(image)
|
| 57 |
+
#------------------------------#
|
| 58 |
+
# 获得图像的高宽与目标高宽
|
| 59 |
+
#------------------------------#
|
| 60 |
+
iw, ih = image.size
|
| 61 |
+
h, w = input_shape
|
| 62 |
+
|
| 63 |
+
if not random:
|
| 64 |
+
scale = min(w/iw, h/ih)
|
| 65 |
+
nw = int(iw*scale)
|
| 66 |
+
nh = int(ih*scale)
|
| 67 |
+
dx = (w-nw)//2
|
| 68 |
+
dy = (h-nh)//2
|
| 69 |
+
|
| 70 |
+
#---------------------------------#
|
| 71 |
+
# 将图像多余的部分加上灰条
|
| 72 |
+
#---------------------------------#
|
| 73 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
| 74 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
| 75 |
+
new_image.paste(image, (dx, dy))
|
| 76 |
+
image_data = np.array(new_image, np.float32)
|
| 77 |
+
|
| 78 |
+
return image_data
|
| 79 |
+
|
| 80 |
+
#------------------------------------------#
|
| 81 |
+
# 对图像进行缩放并且进行长和宽的扭曲
|
| 82 |
+
#------------------------------------------#
|
| 83 |
+
new_ar = w/h * self.rand(1-jitter,1+jitter)/self.rand(1-jitter,1+jitter)
|
| 84 |
+
scale = self.rand(.75, 1.25)
|
| 85 |
+
if new_ar < 1:
|
| 86 |
+
nh = int(scale*h)
|
| 87 |
+
nw = int(nh*new_ar)
|
| 88 |
+
else:
|
| 89 |
+
nw = int(scale*w)
|
| 90 |
+
nh = int(nw/new_ar)
|
| 91 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
| 92 |
+
|
| 93 |
+
#------------------------------------------#
|
| 94 |
+
# 将图像多余的部分加上灰条
|
| 95 |
+
#------------------------------------------#
|
| 96 |
+
dx = int(self.rand(0, w-nw))
|
| 97 |
+
dy = int(self.rand(0, h-nh))
|
| 98 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
| 99 |
+
new_image.paste(image, (dx, dy))
|
| 100 |
+
image = new_image
|
| 101 |
+
|
| 102 |
+
#------------------------------------------#
|
| 103 |
+
# 翻转图像
|
| 104 |
+
#------------------------------------------#
|
| 105 |
+
flip = self.rand()<.5
|
| 106 |
+
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
| 107 |
+
|
| 108 |
+
rotate = self.rand()<.5
|
| 109 |
+
if rotate:
|
| 110 |
+
angle = np.random.randint(-15,15)
|
| 111 |
+
a,b = w/2,h/2
|
| 112 |
+
M = cv2.getRotationMatrix2D((a,b),angle,1)
|
| 113 |
+
image = cv2.warpAffine(np.array(image), M, (w,h), borderValue=[128,128,128])
|
| 114 |
+
|
| 115 |
+
#------------------------------------------#
|
| 116 |
+
# 色域扭曲
|
| 117 |
+
#------------------------------------------#
|
| 118 |
+
hue = self.rand(-hue, hue)
|
| 119 |
+
sat = self.rand(1, sat) if self.rand()<.5 else 1/self.rand(1, sat)
|
| 120 |
+
val = self.rand(1, val) if self.rand()<.5 else 1/self.rand(1, val)
|
| 121 |
+
x = cv2.cvtColor(np.array(image,np.float32)/255, cv2.COLOR_RGB2HSV)
|
| 122 |
+
x[..., 1] *= sat
|
| 123 |
+
x[..., 2] *= val
|
| 124 |
+
x[x[:,:, 0]>360, 0] = 360
|
| 125 |
+
x[:, :, 1:][x[:, :, 1:]>1] = 1
|
| 126 |
+
x[x<0] = 0
|
| 127 |
+
image_data = cv2.cvtColor(x, cv2.COLOR_HSV2RGB)*255
|
| 128 |
+
return image_data
|
src/utils/utils.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from PIL import Image
|
| 3 |
+
|
| 4 |
+
#---------------------------------------------------#
|
| 5 |
+
# 不失真的resize
|
| 6 |
+
#---------------------------------------------------#
|
| 7 |
+
def letterbox_image(image, size):
|
| 8 |
+
iw, ih = image.size
|
| 9 |
+
w, h = size
|
| 10 |
+
|
| 11 |
+
scale = min(w/iw, h/ih)
|
| 12 |
+
nw = int(iw*scale)
|
| 13 |
+
nh = int(ih*scale)
|
| 14 |
+
|
| 15 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
| 16 |
+
new_image = Image.new('RGB', size, (128,128,128))
|
| 17 |
+
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
|
| 18 |
+
|
| 19 |
+
return new_image
|
| 20 |
+
|
| 21 |
+
#---------------------------------------------------#
|
| 22 |
+
# 获得类
|
| 23 |
+
#---------------------------------------------------#
|
| 24 |
+
def get_classes(classes_path):
|
| 25 |
+
with open(classes_path, encoding='utf-8') as f:
|
| 26 |
+
class_names = f.readlines()
|
| 27 |
+
class_names = [c.strip() for c in class_names]
|
| 28 |
+
return class_names, len(class_names)
|
| 29 |
+
|
| 30 |
+
#---------------------------------------------------------#
|
| 31 |
+
# 将图像转换成RGB图像,防止灰度图在预测时报错。
|
| 32 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
| 33 |
+
#---------------------------------------------------------#
|
| 34 |
+
def cvtColor(image):
|
| 35 |
+
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
| 36 |
+
return image
|
| 37 |
+
else:
|
| 38 |
+
image = image.convert('RGB')
|
| 39 |
+
return image
|
| 40 |
+
|
| 41 |
+
#----------------------------------------#
|
| 42 |
+
# 预处理训练图片
|
| 43 |
+
#----------------------------------------#
|
| 44 |
+
def preprocess_input(x):
|
| 45 |
+
x /= 127.5
|
| 46 |
+
x -= 1.
|
| 47 |
+
return x
|