Spaces:
Runtime error
Runtime error
| #copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve. | |
| # | |
| #Licensed under the Apache License, Version 2.0 (the "License"); | |
| #you may not use this file except in compliance with the License. | |
| #You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| #Unless required by applicable law or agreed to in writing, software | |
| #distributed under the License is distributed on an "AS IS" BASIS, | |
| #WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| #See the License for the specific language governing permissions and | |
| #limitations under the License. | |
| # This code is refer from: https://github.com/KaiyangZhou/pytorch-center-loss | |
| from __future__ import absolute_import | |
| from __future__ import division | |
| from __future__ import print_function | |
| import os | |
| import pickle | |
| import paddle | |
| import paddle.nn as nn | |
| import paddle.nn.functional as F | |
| class CenterLoss(nn.Layer): | |
| """ | |
| Reference: Wen et al. A Discriminative Feature Learning Approach for Deep Face Recognition. ECCV 2016. | |
| """ | |
| def __init__(self, num_classes=6625, feat_dim=96, center_file_path=None): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.feat_dim = feat_dim | |
| self.centers = paddle.randn( | |
| shape=[self.num_classes, self.feat_dim]).astype("float64") | |
| if center_file_path is not None: | |
| assert os.path.exists( | |
| center_file_path | |
| ), f"center path({center_file_path}) must exist when it is not None." | |
| with open(center_file_path, 'rb') as f: | |
| char_dict = pickle.load(f) | |
| for key in char_dict.keys(): | |
| self.centers[key] = paddle.to_tensor(char_dict[key]) | |
| def __call__(self, predicts, batch): | |
| assert isinstance(predicts, (list, tuple)) | |
| features, predicts = predicts | |
| feats_reshape = paddle.reshape( | |
| features, [-1, features.shape[-1]]).astype("float64") | |
| label = paddle.argmax(predicts, axis=2) | |
| label = paddle.reshape(label, [label.shape[0] * label.shape[1]]) | |
| batch_size = feats_reshape.shape[0] | |
| #calc l2 distance between feats and centers | |
| square_feat = paddle.sum(paddle.square(feats_reshape), | |
| axis=1, | |
| keepdim=True) | |
| square_feat = paddle.expand(square_feat, [batch_size, self.num_classes]) | |
| square_center = paddle.sum(paddle.square(self.centers), | |
| axis=1, | |
| keepdim=True) | |
| square_center = paddle.expand( | |
| square_center, [self.num_classes, batch_size]).astype("float64") | |
| square_center = paddle.transpose(square_center, [1, 0]) | |
| distmat = paddle.add(square_feat, square_center) | |
| feat_dot_center = paddle.matmul(feats_reshape, | |
| paddle.transpose(self.centers, [1, 0])) | |
| distmat = distmat - 2.0 * feat_dot_center | |
| #generate the mask | |
| classes = paddle.arange(self.num_classes).astype("int64") | |
| label = paddle.expand( | |
| paddle.unsqueeze(label, 1), (batch_size, self.num_classes)) | |
| mask = paddle.equal( | |
| paddle.expand(classes, [batch_size, self.num_classes]), | |
| label).astype("float64") | |
| dist = paddle.multiply(distmat, mask) | |
| loss = paddle.sum(paddle.clip(dist, min=1e-12, max=1e+12)) / batch_size | |
| return {'loss_center': loss} | |