Koyang commited on
Commit
12645de
·
1 Parent(s): c2682b2
Files changed (8) hide show
  1. app.py +31 -0
  2. configs/BlazeFace.yml +22 -0
  3. requirements.txt +10 -0
  4. src/NetWork.py +98 -0
  5. src/detection.py +123 -0
  6. src/download.py +218 -0
  7. src/preprocess.py +208 -0
  8. src/visualize.py +103 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from src.detection import Detector
4
+
5
+
6
+ # UGC: Define the inference fn() for your models
7
+ def model_inference(image):
8
+ image, json_out = Detector('BlazeFace')(image)
9
+ return image
10
+
11
+
12
+ def clear_all():
13
+ return None, None, None
14
+
15
+
16
+ # 下载模型
17
+ os.system("wget -c https://huggingface.co/yangcsu/facialdetection-vgg/resolve/main/vgg.pdparams -P ./configs")
18
+ os.system("wget -c https://huggingface.co/yangcsu/facialdetection-vgg/resolve/main/model.pdiparams -P ./configs")
19
+ os.system("wget -c https://huggingface.co/yangcsu/facialdetection-vgg/resolve/main/model.pdmodel -P ./configs")
20
+
21
+ examples = [
22
+ "https://s3.tebi.io/oss.haust.ml/images/face1.jpg"
23
+ ]
24
+ title = "人脸识别,表情分析"
25
+ description = "使用BlazeFace模型识别图片中的人脸,并使用VGG16模型分析其表情"
26
+
27
+ demo = gr.Interface(fn=model_inference, inputs="image", outputs="image", title=title, description=description,
28
+ examples=examples)
29
+
30
+ # 启动Gradio
31
+ demo.launch()
configs/BlazeFace.yml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mode: paddle
2
+ draw_threshold: 0.5
3
+ metric: WiderFace
4
+ use_dynamic_shape: true
5
+ arch: Face
6
+ min_subgraph_size: 3
7
+ param_path: configs/model.pdiparams
8
+ model_path: configs/model.pdmodel
9
+ Preprocess:
10
+ - is_scale: false
11
+ mean:
12
+ - 123
13
+ - 117
14
+ - 104
15
+ std:
16
+ - 127.502231
17
+ - 127.502231
18
+ - 127.502231
19
+ type: NormalizeImage
20
+ - type: Permute
21
+ label_list:
22
+ - face
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ opencv-python
3
+ paddlepaddle
4
+ PyYAML
5
+ shapely
6
+ scipy
7
+ Cython
8
+ numpy
9
+ setuptools
10
+ pillow
src/NetWork.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding:utf-8 -*-
2
+ # VGG模型代码
3
+ import numpy as np
4
+ import paddle
5
+ # from paddle.nn import Conv2D, MaxPool2D, BatchNorm, Linear
6
+ from paddle.nn import Conv2D, MaxPool2D, BatchNorm2D, Linear
7
+
8
+
9
+ # 定义vgg网络
10
+ class VGG(paddle.nn.Layer):
11
+ def __init__(self, num_class):
12
+ super(VGG, self).__init__()
13
+
14
+ in_channels = [3, 64, 128, 256, 512, 512]
15
+ # 定义第一个卷积块,包含两个卷积 输入通道数是图片通道数即3 输出通道数即out_channels=in_channels[1]=64
16
+ self.conv1_1 = Conv2D(in_channels=in_channels[0], out_channels=in_channels[1], kernel_size=3, padding=1,
17
+ stride=1)
18
+ self.conv1_2 = Conv2D(in_channels=in_channels[1], out_channels=in_channels[1], kernel_size=3, padding=1,
19
+ stride=1)
20
+ # 定义第二个卷积块,包含两个卷积 输入通道数是上一个卷积块的输出通道数即64 输出通道数即out_channels=in_channels[2]=128
21
+ self.conv2_1 = Conv2D(in_channels=in_channels[1], out_channels=in_channels[2], kernel_size=3, padding=1,
22
+ stride=1)
23
+ self.conv2_2 = Conv2D(in_channels=in_channels[2], out_channels=in_channels[2], kernel_size=3, padding=1,
24
+ stride=1)
25
+ # 定义第三个卷积块,包含三个卷积 输入通道数是上一个卷积块的输出通道数即128 输出通道数即out_channels=in_channels[3]=256
26
+ self.conv3_1 = Conv2D(in_channels=in_channels[2], out_channels=in_channels[3], kernel_size=3, padding=1,
27
+ stride=1)
28
+ self.conv3_2 = Conv2D(in_channels=in_channels[3], out_channels=in_channels[3], kernel_size=3, padding=1,
29
+ stride=1)
30
+ self.conv3_3 = Conv2D(in_channels=in_channels[3], out_channels=in_channels[3], kernel_size=3, padding=1,
31
+ stride=1)
32
+ # 定义第四个卷积块,包含三个卷积 输入通道数是上一个卷积块的输出通道数即256 输出通道数即out_channels=in_channels[4]=512
33
+ self.conv4_1 = Conv2D(in_channels=in_channels[3], out_channels=in_channels[4], kernel_size=3, padding=1,
34
+ stride=1)
35
+ self.conv4_2 = Conv2D(in_channels=in_channels[4], out_channels=in_channels[4], kernel_size=3, padding=1,
36
+ stride=1)
37
+ self.conv4_3 = Conv2D(in_channels=in_channels[4], out_channels=in_channels[4], kernel_size=3, padding=1,
38
+ stride=1)
39
+ # 定义第五个卷积块,包含三个卷积 输入通道数是上一个卷积块的输出通道数即512 输出通道数即out_channels=in_channels[5]=512
40
+ self.conv5_1 = Conv2D(in_channels=in_channels[4], out_channels=in_channels[5], kernel_size=3, padding=1,
41
+ stride=1)
42
+ self.conv5_2 = Conv2D(in_channels=in_channels[5], out_channels=in_channels[5], kernel_size=3, padding=1,
43
+ stride=1)
44
+ self.conv5_3 = Conv2D(in_channels=in_channels[5], out_channels=in_channels[5], kernel_size=3, padding=1,
45
+ stride=1)
46
+
47
+ # VGG网络的设计严格使用3*3的卷积层和池化层来提取特征,并在网络的最后面使用三层全连接层,将最后一层全连接层的输出作为分类的预测。
48
+ # 使用Sequential 将全连接层和relu组成一个线性结构(fc + relu)
49
+ # 当输入为224x224时,经过五个卷积块和池化层后,特征维度变为[512x7x7]
50
+ self.fc1 = paddle.nn.Sequential(paddle.nn.Linear(512 * 7 * 7, 4096), paddle.nn.ReLU())
51
+
52
+ self.drop1_ratio = 0.5
53
+ self.dropout1 = paddle.nn.Dropout(self.drop1_ratio, mode='upscale_in_train')
54
+ # 使用Sequential 将全连接层和relu组成一个线性结构(fc + relu)
55
+ self.fc2 = paddle.nn.Sequential(paddle.nn.Linear(4096, 4096), paddle.nn.ReLU())
56
+
57
+ self.drop2_ratio = 0.5
58
+ self.dropout2 = paddle.nn.Dropout(self.drop2_ratio, mode='upscale_in_train')
59
+
60
+ # 全连接层的输出
61
+ # paddle.nn.Linear(in_features, out_features, weight_attr=None, bias_attr=None, name=None)
62
+ # out_features 由输出标签的个数决定 本案例识别的7种表情,对应了3种标签。 因此 out_features = 3
63
+ self.fc3 = paddle.nn.Linear(4096, num_class)
64
+
65
+ self.relu = paddle.nn.ReLU()
66
+ self.pool = MaxPool2D(stride=2, kernel_size=2)
67
+
68
+ def forward(self, x):
69
+ # 激活函数用relu
70
+ x = self.relu(self.conv1_1(x))
71
+ x = self.relu(self.conv1_2(x))
72
+ x = self.pool(x)
73
+
74
+ x = self.relu(self.conv2_1(x))
75
+ x = self.relu(self.conv2_2(x))
76
+ x = self.pool(x)
77
+
78
+ x = self.relu(self.conv3_1(x))
79
+ x = self.relu(self.conv3_2(x))
80
+ x = self.relu(self.conv3_3(x))
81
+ x = self.pool(x)
82
+
83
+ x = self.relu(self.conv4_1(x))
84
+ x = self.relu(self.conv4_2(x))
85
+ x = self.relu(self.conv4_3(x))
86
+ x = self.pool(x)
87
+
88
+ x = self.relu(self.conv5_1(x))
89
+ x = self.relu(self.conv5_2(x))
90
+ x = self.relu(self.conv5_3(x))
91
+ x = self.pool(x)
92
+
93
+ x = paddle.flatten(x, 1, -1)
94
+ # 添加dropout抑制过拟合
95
+ x = self.dropout1(self.relu(self.fc1(x)))
96
+ x = self.dropout2(self.relu(self.fc2(x)))
97
+ x = self.fc3(x)
98
+ return x
src/detection.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ import numpy as np
4
+ import yaml
5
+ from paddle.inference import Config, create_predictor, PrecisionType
6
+ from PIL import Image
7
+
8
+ from .download import get_model_path
9
+ from .preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, decode_image
10
+ from .visualize import draw_det
11
+
12
+ class Detector(object):
13
+ def __init__(self, model_name):
14
+ parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
15
+ yml_file = os.path.join(parent_path, 'configs/{}.yml'.format(model_name))
16
+ with open(yml_file, 'r') as f:
17
+ yml_conf = yaml.safe_load(f)
18
+
19
+ infer_model = get_model_path(yml_conf['model_path'])
20
+ infer_params = get_model_path(yml_conf['param_path'])
21
+ config = Config(infer_model, infer_params)
22
+ device = yml_conf.get('device', 'CPU')
23
+ run_mode = yml_conf.get('mode', 'paddle')
24
+ cpu_threads = yml_conf.get('cpu_threads', 1)
25
+ if device == 'CPU':
26
+ config.disable_gpu()
27
+ config.set_cpu_math_library_num_threads(cpu_threads)
28
+ elif device == 'GPU':
29
+ # initial GPU memory(M), device ID
30
+ config.enable_use_gpu(200, 0)
31
+ # optimize graph and fuse op
32
+ config.switch_ir_optim(True)
33
+
34
+ precision_map = {
35
+ 'trt_int8': Config.Precision.Int8,
36
+ 'trt_fp32': Config.Precision.Float32,
37
+ 'trt_fp16': Config.Precision.Half
38
+ }
39
+
40
+
41
+ if run_mode in precision_map.keys():
42
+ config.enable_tensorrt_engine(
43
+ workspace_size=(1 << 25) * batch_size,
44
+ max_batch_size=batch_size,
45
+ min_subgraph_size=yml_conf['min_subgraph_size'],
46
+ precision_mode=precision_map[run_mode],
47
+ use_static=True,
48
+ use_calib_mode=False)
49
+
50
+ if yml_conf['use_dynamic_shape']:
51
+ min_input_shape = {
52
+ 'image': [batch_size, 3, 640, 640],
53
+ 'scale_factor': [batch_size, 2]
54
+ }
55
+ max_input_shape = {
56
+ 'image': [batch_size, 3, 1280, 1280],
57
+ 'scale_factor': [batch_size, 2]
58
+ }
59
+ opt_input_shape = {
60
+ 'image': [batch_size, 3, 1024, 1024],
61
+ 'scale_factor': [batch_size, 2]
62
+ }
63
+ config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
64
+ opt_input_shape)
65
+
66
+ # disable print log when predict
67
+ config.disable_glog_info()
68
+ # enable shared memory
69
+ config.enable_memory_optim()
70
+ # disable feed, fetch OP, needed by zero_copy_run
71
+ config.switch_use_feed_fetch_ops(False)
72
+ self.predictor = create_predictor(config)
73
+ self.yml_conf = yml_conf
74
+ self.preprocess_ops = self.create_preprocess_ops(yml_conf)
75
+ self.input_names = self.predictor.get_input_names()
76
+ self.output_names = self.predictor.get_output_names()
77
+ self.draw_threshold = yml_conf.get('draw_threshold', 0.5)
78
+ self.class_names = yml_conf['label_list']
79
+
80
+
81
+ def create_preprocess_ops(self, yml_conf):
82
+ preprocess_ops = []
83
+ for op_info in yml_conf['Preprocess']:
84
+ new_op_info = op_info.copy()
85
+ op_type = new_op_info.pop('type')
86
+ preprocess_ops.append(eval(op_type)(**new_op_info))
87
+ return preprocess_ops
88
+
89
+ def create_inputs(self, image_files):
90
+ inputs = dict()
91
+ im_list, im_info_list = [], []
92
+ for im_path in image_files:
93
+ im, im_info = preprocess(im_path, self.preprocess_ops)
94
+ im_list.append(im)
95
+ im_info_list.append(im_info)
96
+
97
+ inputs['im_shape'] = np.stack([e['im_shape'] for e in im_info_list], axis=0).astype('float32')
98
+ inputs['scale_factor'] = np.stack([e['scale_factor'] for e in im_info_list], axis=0).astype('float32')
99
+ inputs['image'] = np.stack(im_list, axis=0).astype('float32')
100
+ return inputs
101
+
102
+ def __call__(self, image_file):
103
+ inputs = self.create_inputs([image_file])
104
+ for name in self.input_names:
105
+ input_tensor = self.predictor.get_input_handle(name)
106
+ input_tensor.copy_from_cpu(inputs[name])
107
+
108
+ self.predictor.run()
109
+ boxes_tensor = self.predictor.get_output_handle(self.output_names[0])
110
+ np_boxes = boxes_tensor.copy_to_cpu()
111
+ boxes_num = self.predictor.get_output_handle(self.output_names[1])
112
+ np_boxes_num = boxes_num.copy_to_cpu()
113
+ if np_boxes_num.sum() <= 0:
114
+ np_boxes = np.zeros([0, 6])
115
+
116
+ if isinstance(image_file, str):
117
+ image = Image.open(image_file).convert('RGB')
118
+ elif isinstance(image_file, np.ndarray):
119
+ image = image_file
120
+ expect_boxes = (np_boxes[:, 1] > self.draw_threshold) & (np_boxes[:, 0] > -1)
121
+ np_boxes = np_boxes[expect_boxes, :]
122
+ image = draw_det(image, np_boxes, self.class_names)
123
+ return image, {'bboxes': np_boxes.tolist()}
src/download.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import os.path as osp
17
+ import sys
18
+ import yaml
19
+ import time
20
+ import shutil
21
+ import requests
22
+ import tqdm
23
+ import hashlib
24
+ import base64
25
+ import binascii
26
+ import tarfile
27
+ import zipfile
28
+
29
+ __all__ = [
30
+ 'get_model_path',
31
+ 'get_config_path',
32
+ 'get_dict_path',
33
+ ]
34
+
35
+ WEIGHTS_HOME = osp.expanduser("~/.cache/paddlecv/models")
36
+ CONFIGS_HOME = osp.expanduser("~/.cache/paddlecv/configs")
37
+ DICTS_HOME = osp.expanduser("~/.cache/paddlecv/dicts")
38
+
39
+ # dict of {dataset_name: (download_info, sub_dirs)}
40
+ # download info: [(url, md5sum)]
41
+
42
+ DOWNLOAD_RETRY_LIMIT = 3
43
+
44
+ PMP_DOWNLOAD_URL_PREFIX = 'https://bj.bcebos.com/v1/paddle-model-ecology/paddlecv/'
45
+
46
+
47
+ def is_url(path):
48
+ """
49
+ Whether path is URL.
50
+ Args:
51
+ path (string): URL string or not.
52
+ """
53
+ return path.startswith('http://') \
54
+ or path.startswith('https://') \
55
+ or path.startswith('paddlecv://')
56
+
57
+
58
+ def parse_url(url):
59
+ url = url.replace("paddlecv://", PMP_DOWNLOAD_URL_PREFIX)
60
+ return url
61
+
62
+
63
+ def get_model_path(path):
64
+ """Get model path from WEIGHTS_HOME, if not exists,
65
+ download it from url.
66
+ """
67
+ if not is_url(path):
68
+ return path
69
+ url = parse_url(path)
70
+ path, _ = get_path(url, WEIGHTS_HOME, path_depth=2)
71
+ return path
72
+
73
+
74
+ def get_config_path(path):
75
+ """Get config path from CONFIGS_HOME, if not exists,
76
+ download it from url.
77
+ """
78
+ if not is_url(path):
79
+ return path
80
+ url = parse_url(path)
81
+ path, _ = get_path(url, CONFIGS_HOME)
82
+ return path
83
+
84
+
85
+ def get_dict_path(path):
86
+ """Get config path from CONFIGS_HOME, if not exists,
87
+ download it from url.
88
+ """
89
+ if not is_url(path):
90
+ return path
91
+ url = parse_url(path)
92
+ path, _ = get_path(url, DICTS_HOME)
93
+ return path
94
+
95
+
96
+ def map_path(url, root_dir, path_depth=1):
97
+ # parse path after download to decompress under root_dir
98
+ assert path_depth > 0, "path_depth should be a positive integer"
99
+ dirname = url
100
+ for _ in range(path_depth):
101
+ dirname = osp.dirname(dirname)
102
+ fpath = osp.relpath(url, dirname)
103
+ path = osp.join(root_dir, fpath)
104
+ dirname = osp.dirname(path)
105
+ return path, dirname
106
+
107
+
108
+ def get_path(url, root_dir, md5sum=None, check_exist=True, path_depth=1):
109
+ """ Download from given url to root_dir.
110
+ if file or directory specified by url is exists under
111
+ root_dir, return the path directly, otherwise download
112
+ from url, return the path.
113
+ url (str): download url
114
+ root_dir (str): root dir for downloading, it should be
115
+ WEIGHTS_HOME
116
+ md5sum (str): md5 sum of download package
117
+ """
118
+ # parse path after download to decompress under root_dir
119
+ fullpath, dirname = map_path(url, root_dir, path_depth)
120
+
121
+ if osp.exists(fullpath) and check_exist:
122
+ if not osp.isfile(fullpath) or \
123
+ _check_exist_file_md5(fullpath, md5sum, url):
124
+ return fullpath, True
125
+ else:
126
+ os.remove(fullpath)
127
+
128
+ fullname = _download(url, dirname, md5sum)
129
+ return fullpath, False
130
+
131
+
132
+ def _download(url, path, md5sum=None):
133
+ """
134
+ Download from url, save to path.
135
+ url (str): download url
136
+ path (str): download to given path
137
+ """
138
+ if not osp.exists(path):
139
+ os.makedirs(path)
140
+
141
+ fname = osp.split(url)[-1]
142
+ fullname = osp.join(path, fname)
143
+ retry_cnt = 0
144
+
145
+ while not (osp.exists(fullname) and _check_exist_file_md5(fullname, md5sum,
146
+ url)):
147
+ if retry_cnt < DOWNLOAD_RETRY_LIMIT:
148
+ retry_cnt += 1
149
+ else:
150
+ raise RuntimeError("Download from {} failed. "
151
+ "Retry limit reached".format(url))
152
+
153
+
154
+ # NOTE: windows path join may incur \, which is invalid in url
155
+ if sys.platform == "win32":
156
+ url = url.replace('\\', '/')
157
+
158
+ req = requests.get(url, stream=True)
159
+ if req.status_code != 200:
160
+ raise RuntimeError("Downloading from {} failed with code "
161
+ "{}!".format(url, req.status_code))
162
+
163
+ # For protecting download interupted, download to
164
+ # tmp_fullname firstly, move tmp_fullname to fullname
165
+ # after download finished
166
+ tmp_fullname = fullname + "_tmp"
167
+ total_size = req.headers.get('content-length')
168
+ with open(tmp_fullname, 'wb') as f:
169
+ if total_size:
170
+ for chunk in tqdm.tqdm(
171
+ req.iter_content(chunk_size=1024),
172
+ total=(int(total_size) + 1023) // 1024,
173
+ unit='KB'):
174
+ f.write(chunk)
175
+ else:
176
+ for chunk in req.iter_content(chunk_size=1024):
177
+ if chunk:
178
+ f.write(chunk)
179
+ shutil.move(tmp_fullname, fullname)
180
+ return fullname
181
+
182
+
183
+ def _check_exist_file_md5(filename, md5sum, url):
184
+ # if md5sum is None, and file to check is model file,
185
+ # read md5um from url and check, else check md5sum directly
186
+ return _md5check_from_url(filename, url) if md5sum is None \
187
+ and filename.endswith('pdparams') \
188
+ else _md5check(filename, md5sum)
189
+
190
+
191
+ def _md5check_from_url(filename, url):
192
+ # For model in bcebos URLs, MD5 value is contained
193
+ # in request header as 'content_md5'
194
+ req = requests.get(url, stream=True)
195
+ content_md5 = req.headers.get('content-md5')
196
+ req.close()
197
+ if not content_md5 or _md5check(
198
+ filename,
199
+ binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode(
200
+ )):
201
+ return True
202
+ else:
203
+ return False
204
+
205
+
206
+ def _md5check(fullname, md5sum=None):
207
+ if md5sum is None:
208
+ return True
209
+
210
+ md5 = hashlib.md5()
211
+ with open(fullname, 'rb') as f:
212
+ for chunk in iter(lambda: f.read(4096), b""):
213
+ md5.update(chunk)
214
+ calc_md5sum = md5.hexdigest()
215
+
216
+ if calc_md5sum != md5sum:
217
+ return False
218
+ return True
src/preprocess.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import cv2
16
+ import numpy as np
17
+
18
+
19
+ def decode_image(im_file, im_info):
20
+ """read rgb image
21
+ Args:
22
+ im_file (str|np.ndarray): input can be image path or np.ndarray
23
+ im_info (dict): info of image
24
+ Returns:
25
+ im (np.ndarray): processed image (np.ndarray)
26
+ im_info (dict): info of processed image
27
+ """
28
+ if isinstance(im_file, str):
29
+ with open(im_file, 'rb') as f:
30
+ im_read = f.read()
31
+ data = np.frombuffer(im_read, dtype='uint8')
32
+ im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
33
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
34
+ else:
35
+ im = im_file
36
+ im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
37
+ im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
38
+ return im, im_info
39
+
40
+
41
+ class Resize(object):
42
+ """resize image by target_size and max_size
43
+ Args:
44
+ target_size (int): the target size of image
45
+ keep_ratio (bool): whether keep_ratio or not, default true
46
+ interp (int): method of resize
47
+ """
48
+
49
+ def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
50
+ if isinstance(target_size, int):
51
+ target_size = [target_size, target_size]
52
+ self.target_size = target_size
53
+ self.keep_ratio = keep_ratio
54
+ self.interp = interp
55
+
56
+ def __call__(self, im, im_info):
57
+ """
58
+ Args:
59
+ im (np.ndarray): image (np.ndarray)
60
+ im_info (dict): info of image
61
+ Returns:
62
+ im (np.ndarray): processed image (np.ndarray)
63
+ im_info (dict): info of processed image
64
+ """
65
+ assert len(self.target_size) == 2
66
+ assert self.target_size[0] > 0 and self.target_size[1] > 0
67
+ im_channel = im.shape[2]
68
+ im_scale_y, im_scale_x = self.generate_scale(im)
69
+ im = cv2.resize(
70
+ im,
71
+ None,
72
+ None,
73
+ fx=im_scale_x,
74
+ fy=im_scale_y,
75
+ interpolation=self.interp)
76
+ im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
77
+ im_info['scale_factor'] = np.array(
78
+ [im_scale_y, im_scale_x]).astype('float32')
79
+ return im, im_info
80
+
81
+ def generate_scale(self, im):
82
+ """
83
+ Args:
84
+ im (np.ndarray): image (np.ndarray)
85
+ Returns:
86
+ im_scale_x: the resize ratio of X
87
+ im_scale_y: the resize ratio of Y
88
+ """
89
+ origin_shape = im.shape[:2]
90
+ im_c = im.shape[2]
91
+ if self.keep_ratio:
92
+ im_size_min = np.min(origin_shape)
93
+ im_size_max = np.max(origin_shape)
94
+ target_size_min = np.min(self.target_size)
95
+ target_size_max = np.max(self.target_size)
96
+ im_scale = float(target_size_min) / float(im_size_min)
97
+ if np.round(im_scale * im_size_max) > target_size_max:
98
+ im_scale = float(target_size_max) / float(im_size_max)
99
+ im_scale_x = im_scale
100
+ im_scale_y = im_scale
101
+ else:
102
+ resize_h, resize_w = self.target_size
103
+ im_scale_y = resize_h / float(origin_shape[0])
104
+ im_scale_x = resize_w / float(origin_shape[1])
105
+ return im_scale_y, im_scale_x
106
+
107
+
108
+
109
+ class NormalizeImage(object):
110
+ """normalize image
111
+ Args:
112
+ mean (list): im - mean
113
+ std (list): im / std
114
+ is_scale (bool): whether need im / 255
115
+ norm_type (str): type in ['mean_std', 'none']
116
+ """
117
+
118
+ def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
119
+ self.mean = mean
120
+ self.std = std
121
+ self.is_scale = is_scale
122
+ self.norm_type = norm_type
123
+
124
+ def __call__(self, im, im_info):
125
+ """
126
+ Args:
127
+ im (np.ndarray): image (np.ndarray)
128
+ im_info (dict): info of image
129
+ Returns:
130
+ im (np.ndarray): processed image (np.ndarray)
131
+ im_info (dict): info of processed image
132
+ """
133
+ im = im.astype(np.float32, copy=False)
134
+ if self.is_scale:
135
+ scale = 1.0 / 255.0
136
+ im *= scale
137
+
138
+ if self.norm_type == 'mean_std':
139
+ mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
140
+ std = np.array(self.std)[np.newaxis, np.newaxis, :]
141
+ im -= mean
142
+ im /= std
143
+ return im, im_info
144
+
145
+
146
+ class Permute(object):
147
+ """permute image
148
+ Args:
149
+ to_bgr (bool): whether convert RGB to BGR
150
+ channel_first (bool): whether convert HWC to CHW
151
+ """
152
+
153
+ def __init__(self, ):
154
+ super(Permute, self).__init__()
155
+
156
+ def __call__(self, im, im_info):
157
+ """
158
+ Args:
159
+ im (np.ndarray): image (np.ndarray)
160
+ im_info (dict): info of image
161
+ Returns:
162
+ im (np.ndarray): processed image (np.ndarray)
163
+ im_info (dict): info of processed image
164
+ """
165
+ im = im.transpose((2, 0, 1)).copy()
166
+ return im, im_info
167
+
168
+
169
+ class PadStride(object):
170
+ """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
171
+ Args:
172
+ stride (bool): model with FPN need image shape % stride == 0
173
+ """
174
+
175
+ def __init__(self, stride=0):
176
+ self.coarsest_stride = stride
177
+
178
+ def __call__(self, im, im_info):
179
+ """
180
+ Args:
181
+ im (np.ndarray): image (np.ndarray)
182
+ im_info (dict): info of image
183
+ Returns:
184
+ im (np.ndarray): processed image (np.ndarray)
185
+ im_info (dict): info of processed image
186
+ """
187
+ coarsest_stride = self.coarsest_stride
188
+ if coarsest_stride <= 0:
189
+ return im, im_info
190
+ im_c, im_h, im_w = im.shape
191
+ pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
192
+ pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
193
+ padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
194
+ padding_im[:, :im_h, :im_w] = im
195
+ return padding_im, im_info
196
+
197
+
198
+ def preprocess(im, preprocess_ops):
199
+ # process image by preprocess_ops
200
+ im_info = {
201
+ 'scale_factor': np.array(
202
+ [1., 1.], dtype=np.float32),
203
+ 'im_shape': None,
204
+ }
205
+ im, im_info = decode_image(im, im_info)
206
+ for operator in preprocess_ops:
207
+ im, im_info = operator(im, im_info)
208
+ return im, im_info
src/visualize.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image, ImageDraw, ImageFile
3
+ from .NetWork import VGG
4
+ import paddle
5
+ import cv2
6
+
7
+ def get_color_map_list(num_classes):
8
+ """
9
+ Args:
10
+ num_classes (int): number of class
11
+ Returns:
12
+ color_map (list): RGB color list
13
+ """
14
+ color_map = num_classes * [0, 0, 0]
15
+ for i in range(0, num_classes):
16
+ j = 0
17
+ lab = i
18
+ while lab:
19
+ color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
20
+ color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
21
+ color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
22
+ j += 1
23
+ lab >>= 3
24
+ color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
25
+ return color_map
26
+
27
+
28
+ def draw_det(image, dt_bboxes, name_set):
29
+ im = Image.fromarray(image)
30
+ draw_thickness = min(im.size) // 320
31
+ draw = ImageDraw.Draw(im)
32
+ clsid2color = {}
33
+ color_list = get_color_map_list(len(name_set))
34
+
35
+ for (cls_id, score, xmin, ymin, xmax, ymax) in dt_bboxes:
36
+ image_box = im.crop(tuple([xmin, ymin, xmax, ymax]))
37
+ label = emotic(image_box)
38
+ cls_id = int(cls_id)
39
+ color = tuple(color_list[cls_id])
40
+ # draw bbox
41
+ draw.line(
42
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
43
+ (xmin, ymin)],
44
+ width=draw_thickness,
45
+ fill=color)
46
+
47
+ # draw label
48
+ text = "{} {:.4f}".format(label, score)
49
+ box = draw.textbbox((xmin, ymin), text, anchor='lt')
50
+ draw.rectangle(box, fill=color)
51
+ draw.text((box[0], box[1]), text, fill=(255, 255, 255))
52
+ image = np.array(im)
53
+ return image
54
+
55
+
56
+ def emotic(image):
57
+ def load_image(img):
58
+ # 将图片尺寸缩放道 224x224
59
+ img = cv2.resize(img, (224, 224))
60
+ # 读入的图像数据格式是[H, W, C]
61
+ # 使用转置操作将其变成[C, H, W]
62
+ img = np.transpose(img, (2, 0, 1))
63
+ img = img.astype('float32')
64
+ # 将数据范围调整到[-1.0, 1.0]之间
65
+ img = img / 255.
66
+ img = img * 2.0 - 1.0
67
+ return img
68
+
69
+ model = VGG(num_class=7)
70
+ params_file_path = r'configs/vgg.pdparams'
71
+ img = np.array(image)
72
+ # plt.imshow(img)
73
+ # plt.axis('off')
74
+ # plt.show()
75
+
76
+ param_dict = paddle.load(params_file_path)
77
+ model.load_dict(param_dict)
78
+ # 灌入数据
79
+ # model.eval()
80
+ tensor_img = load_image(img)
81
+ tensor_img = np.expand_dims(tensor_img, 0)
82
+
83
+ results = model(paddle.to_tensor(tensor_img))
84
+ # 取概率最大的标签作为预测输出
85
+ lab = np.argsort(results.numpy())
86
+ tap = lab[0][-1]
87
+
88
+ if tap == 0:
89
+ return 'SAD'
90
+ elif tap == 1:
91
+ return 'DISGUST'
92
+ elif tap == 2:
93
+ return 'HAPPY'
94
+ elif tap == 3:
95
+ return 'FEAR'
96
+ elif tap == 4:
97
+ return 'SUPERISE'
98
+ elif tap == 5:
99
+ return 'NATUREAL'
100
+ elif tap == 6:
101
+ return 'ANGRY'
102
+ else:
103
+ raise ('Not excepted file name')