Spaces:
Paused
Paused
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFile | |
| from .NetWork import VGG | |
| import paddle | |
| import cv2 | |
| def get_color_map_list(num_classes): | |
| """ | |
| Args: | |
| num_classes (int): number of class | |
| Returns: | |
| color_map (list): RGB color list | |
| """ | |
| color_map = num_classes * [0, 0, 0] | |
| for i in range(0, num_classes): | |
| j = 0 | |
| lab = i | |
| while lab: | |
| color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j)) | |
| color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j)) | |
| color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j)) | |
| j += 1 | |
| lab >>= 3 | |
| color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)] | |
| return color_map | |
| def draw_det(image, dt_bboxes, name_set): | |
| im = Image.fromarray(image) | |
| draw_thickness = min(im.size) // 320 | |
| draw = ImageDraw.Draw(im) | |
| clsid2color = {} | |
| color_list = get_color_map_list(len(name_set)) | |
| for (cls_id, score, xmin, ymin, xmax, ymax) in dt_bboxes: | |
| image_box = im.crop(tuple([xmin, ymin, xmax, ymax])) | |
| label = emotic(image_box) | |
| cls_id = int(cls_id) | |
| color = tuple(color_list[cls_id]) | |
| # draw bbox | |
| draw.line( | |
| [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin), | |
| (xmin, ymin)], | |
| width=draw_thickness, | |
| fill=color) | |
| # draw label | |
| text = "{} {:.4f}".format(label, score) | |
| box = draw.textbbox((xmin, ymin), text, anchor='lt') | |
| draw.rectangle(box, fill=color) | |
| draw.text((box[0], box[1]), text, fill=(255, 255, 255)) | |
| image = np.array(im) | |
| return image | |
| def emotic(image): | |
| def load_image(img): | |
| # 将图片尺寸缩放道 224x224 | |
| img = cv2.resize(img, (224, 224)) | |
| # 读入的图像数据格式是[H, W, C] | |
| # 使用转置操作将其变成[C, H, W] | |
| img = np.transpose(img, (2, 0, 1)) | |
| img = img.astype('float32') | |
| # 将数据范围调整到[-1.0, 1.0]之间 | |
| img = img / 255. | |
| img = img * 2.0 - 1.0 | |
| return img | |
| model = VGG(num_class=7) | |
| params_file_path = r'configs/vgg.pdparams' | |
| img = np.array(image) | |
| # plt.imshow(img) | |
| # plt.axis('off') | |
| # plt.show() | |
| param_dict = paddle.load(params_file_path) | |
| model.load_dict(param_dict) | |
| # 灌入数据 | |
| # model.eval() | |
| tensor_img = load_image(img) | |
| tensor_img = np.expand_dims(tensor_img, 0) | |
| results = model(paddle.to_tensor(tensor_img)) | |
| # 取概率最大的标签作为预测输出 | |
| lab = np.argsort(results.numpy()) | |
| tap = lab[0][-1] | |
| if tap == 0: | |
| return 'SAD' | |
| elif tap == 1: | |
| return 'DISGUST' | |
| elif tap == 2: | |
| return 'HAPPY' | |
| elif tap == 3: | |
| return 'FEAR' | |
| elif tap == 4: | |
| return 'SUPERISE' | |
| elif tap == 5: | |
| return 'NATUREAL' | |
| elif tap == 6: | |
| return 'ANGRY' | |
| else: | |
| raise ('Not excepted file name') | |