Upload CVLFace experiment code
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- cvlface/research/recognition/code/run_v1/README.md +0 -0
- cvlface/research/recognition/code/run_v1/aligners/__init__.py +25 -0
- cvlface/research/recognition/code/run_v1/aligners/base/__init__.py +60 -0
- cvlface/research/recognition/code/run_v1/aligners/base/utils.py +91 -0
- cvlface/research/recognition/code/run_v1/aligners/configs/dfa.yaml +10 -0
- cvlface/research/recognition/code/run_v1/aligners/configs/none.yaml +3 -0
- cvlface/research/recognition/code/run_v1/aligners/configs/retinaface.yaml +3 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/__init__.py +117 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/aligner_helper.py +97 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/__init__.py +27 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/config.py +18 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/__init__.py +2 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/functions/prior_box.py +140 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/__init__.py +3 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/multibox_loss.py +144 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/__init__.py +0 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/net.py +132 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/retinaface.py +142 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/preprocessor.py +93 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/box_utils.py +239 -0
- cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/model_utils.py +36 -0
- cvlface/research/recognition/code/run_v1/aligners/none/__init__.py +20 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/__init__.py +246 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/aligner_helper.py +97 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/__init__.py +28 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/config.py +18 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/__init__.py +2 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/functions/prior_box.py +140 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/__init__.py +3 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/multibox_loss.py +144 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/__init__.py +0 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/net.py +132 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/retinaface.py +123 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/preprocessor.py +93 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/box_utils.py +239 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/model_utils.py +36 -0
- cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface_pipeline.py +247 -0
- cvlface/research/recognition/code/run_v1/base.yaml +12 -0
- cvlface/research/recognition/code/run_v1/classifiers/__init__.py +31 -0
- cvlface/research/recognition/code/run_v1/classifiers/base/__init__.py +87 -0
- cvlface/research/recognition/code/run_v1/classifiers/base/utils.py +91 -0
- cvlface/research/recognition/code/run_v1/classifiers/configs/fc.yaml +4 -0
- cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc.yaml +4 -0
- cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_freeze.yaml +4 -0
- cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10.yaml +4 -0
- cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10_freeze.yaml +4 -0
- cvlface/research/recognition/code/run_v1/classifiers/fc/__init__.py +55 -0
- cvlface/research/recognition/code/run_v1/classifiers/fc/fc.py +67 -0
- cvlface/research/recognition/code/run_v1/classifiers/partial_fc/__init__.py +39 -0
- cvlface/research/recognition/code/run_v1/classifiers/partial_fc/partial_fc.py +289 -0
cvlface/research/recognition/code/run_v1/README.md
ADDED
|
File without changes
|
cvlface/research/recognition/code/run_v1/aligners/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import BaseAligner
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_aligner(aligner_cfg):
|
| 5 |
+
|
| 6 |
+
if aligner_cfg.name == 'none':
|
| 7 |
+
from .none import NoneAligner
|
| 8 |
+
aligner = NoneAligner.from_config(aligner_cfg)
|
| 9 |
+
elif aligner_cfg.name == 'retinaface_aligner':
|
| 10 |
+
from .retinaface_aligner import RetinaFaceAligner
|
| 11 |
+
aligner = RetinaFaceAligner.from_config(aligner_cfg)
|
| 12 |
+
elif aligner_cfg.name == 'differentiable_face_aligner':
|
| 13 |
+
from .differentiable_face_aligner import DifferentiableFaceAligner
|
| 14 |
+
aligner = DifferentiableFaceAligner.from_config(aligner_cfg)
|
| 15 |
+
else:
|
| 16 |
+
raise ValueError(f"Unknown classifier: {aligner_cfg.name}")
|
| 17 |
+
|
| 18 |
+
if aligner_cfg.start_from:
|
| 19 |
+
aligner.load_state_dict_from_path(aligner_cfg.start_from)
|
| 20 |
+
|
| 21 |
+
if aligner_cfg.freeze:
|
| 22 |
+
for param in aligner.parameters():
|
| 23 |
+
param.requires_grad = False
|
| 24 |
+
return aligner
|
| 25 |
+
|
cvlface/research/recognition/code/run_v1/aligners/base/__init__.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Union
|
| 3 |
+
import torch
|
| 4 |
+
from torch import device
|
| 5 |
+
from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path
|
| 6 |
+
|
| 7 |
+
class BaseAligner(torch.nn.Module):
|
| 8 |
+
|
| 9 |
+
def __init__(self, config=None):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.config = config
|
| 12 |
+
|
| 13 |
+
@classmethod
|
| 14 |
+
def from_config(cls, config) -> "BaseAligner":
|
| 15 |
+
raise NotImplementedError('from_config must be implemented in subclass')
|
| 16 |
+
|
| 17 |
+
def make_train_transform(self):
|
| 18 |
+
raise NotImplementedError('from_config must be implemented in subclass')
|
| 19 |
+
|
| 20 |
+
def make_test_transform(self):
|
| 21 |
+
raise NotImplementedError('from_config must be implemented in subclass')
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
raise NotImplementedError('from_config must be implemented in subclass')
|
| 25 |
+
|
| 26 |
+
def save_pretrained(
|
| 27 |
+
self,
|
| 28 |
+
save_dir: Union[str, os.PathLike],
|
| 29 |
+
name: str = 'model.pt',
|
| 30 |
+
rank: int = 0,
|
| 31 |
+
):
|
| 32 |
+
save_path = os.path.join(save_dir, name)
|
| 33 |
+
if rank == 0:
|
| 34 |
+
save_state_dict_and_config(self.state_dict(), self.config, save_path)
|
| 35 |
+
|
| 36 |
+
def load_state_dict_from_path(self, pretrained_model_path):
|
| 37 |
+
state_dict = load_state_dict_from_path(pretrained_model_path)
|
| 38 |
+
result = self.load_state_dict(state_dict)
|
| 39 |
+
print(f"Loaded pretrained aligner from {pretrained_model_path}")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def device(self) -> device:
|
| 44 |
+
return get_parameter_device(self)
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def dtype(self) -> torch.dtype:
|
| 48 |
+
return get_parameter_dtype(self)
|
| 49 |
+
|
| 50 |
+
def num_parameters(self, only_trainable: bool = False) -> int:
|
| 51 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
| 52 |
+
|
| 53 |
+
def has_trainable_params(self):
|
| 54 |
+
for param in self.parameters():
|
| 55 |
+
if param.requires_grad:
|
| 56 |
+
return True
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
def has_params(self):
|
| 60 |
+
return len(list(self.parameters())) > 0
|
cvlface/research/recognition/code/run_v1/aligners/base/utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
import safetensors
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from omegaconf import DictConfig, OmegaConf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
| 12 |
+
try:
|
| 13 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
| 14 |
+
return next(parameters_and_buffers).device
|
| 15 |
+
except StopIteration:
|
| 16 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
| 17 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
| 18 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 19 |
+
return tuples
|
| 20 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
| 21 |
+
first_tuple = next(gen)
|
| 22 |
+
return first_tuple[1].device
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
| 26 |
+
try:
|
| 27 |
+
params = tuple(parameter.parameters())
|
| 28 |
+
if len(params) > 0:
|
| 29 |
+
return params[0].dtype
|
| 30 |
+
|
| 31 |
+
buffers = tuple(parameter.buffers())
|
| 32 |
+
if len(buffers) > 0:
|
| 33 |
+
return buffers[0].dtype
|
| 34 |
+
|
| 35 |
+
except StopIteration:
|
| 36 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
| 37 |
+
|
| 38 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
| 39 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 40 |
+
return tuples
|
| 41 |
+
|
| 42 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
| 43 |
+
first_tuple = next(gen)
|
| 44 |
+
return first_tuple[1].dtype
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path:
|
| 48 |
+
path_obj = Path(save_path)
|
| 49 |
+
return path_obj.parent
|
| 50 |
+
|
| 51 |
+
def get_base_name(save_path: Union[str, os.PathLike]) -> str:
|
| 52 |
+
path_obj = Path(save_path)
|
| 53 |
+
return path_obj.name
|
| 54 |
+
|
| 55 |
+
def load_state_dict_from_path(path: Union[str, os.PathLike]):
|
| 56 |
+
# Load a state dict from a path.
|
| 57 |
+
if 'safetensors' in path:
|
| 58 |
+
state_dict = safetensors.torch.load_file(path)
|
| 59 |
+
else:
|
| 60 |
+
state_dict = torch.load(path, map_location="cpu")
|
| 61 |
+
return state_dict
|
| 62 |
+
|
| 63 |
+
def replace_extension(path, new_extension):
|
| 64 |
+
if not new_extension.startswith('.'):
|
| 65 |
+
new_extension = '.' + new_extension
|
| 66 |
+
return os.path.splitext(path)[0] + new_extension
|
| 67 |
+
|
| 68 |
+
def make_config_path(save_path):
|
| 69 |
+
config_path = replace_extension(save_path, '.yaml')
|
| 70 |
+
return config_path
|
| 71 |
+
|
| 72 |
+
def save_config(config, config_path):
|
| 73 |
+
assert isinstance(config, dict) or isinstance(config, DictConfig)
|
| 74 |
+
os.makedirs(get_parent_directory(config_path), exist_ok=True)
|
| 75 |
+
if isinstance(config, dict):
|
| 76 |
+
config = OmegaConf.create(config)
|
| 77 |
+
OmegaConf.save(config, config_path)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def save_state_dict_and_config(state_dict, config, save_path):
|
| 81 |
+
os.makedirs(get_parent_directory(save_path), exist_ok=True)
|
| 82 |
+
|
| 83 |
+
# save config dict
|
| 84 |
+
config_path = make_config_path(save_path)
|
| 85 |
+
save_config(config, config_path)
|
| 86 |
+
|
| 87 |
+
# Save the model
|
| 88 |
+
if 'safetensors' in save_path:
|
| 89 |
+
safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"})
|
| 90 |
+
else:
|
| 91 |
+
torch.save(state_dict, save_path)
|
cvlface/research/recognition/code/run_v1/aligners/configs/dfa.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: differentiable_face_aligner
|
| 2 |
+
arch: 'mobile0.25'
|
| 3 |
+
start_from: '../../../../pretrained_models/alignment/dfa_mobilenet/mobilenet0.25.pth'
|
| 4 |
+
freeze: True
|
| 5 |
+
|
| 6 |
+
input_padding_ratio: 0 # pad the input to this size before resize
|
| 7 |
+
input_padding_val: 'zero'
|
| 8 |
+
input_size: 160 # resize the input to this size
|
| 9 |
+
output_size: 112 # size of the output of aligner
|
| 10 |
+
color_space: 'RGB' # color space of the input image
|
cvlface/research/recognition/code/run_v1/aligners/configs/none.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: none
|
| 2 |
+
start_from: ''
|
| 3 |
+
freeze: False
|
cvlface/research/recognition/code/run_v1/aligners/configs/retinaface.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: retinaface
|
| 2 |
+
start_from: ''
|
| 3 |
+
freeze: True
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/__init__.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..base import BaseAligner
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from .dfa import get_landmark_predictor, get_preprocessor
|
| 4 |
+
from . import aligner_helper
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DifferentiableFaceAligner(BaseAligner):
|
| 11 |
+
|
| 12 |
+
'''
|
| 13 |
+
A differentiable face aligner that aligns the image with one face to a canonical position.
|
| 14 |
+
The aligner is based on the following paper (check out supplementary material for more details):
|
| 15 |
+
@inproceedings{kim2024kprpe,
|
| 16 |
+
title={{KeyPoint Relative Position Encoding for Face Recognition},
|
| 17 |
+
author={Kim, Minchul and Su, Yiyang and Liu, Feng and Liu, Xiaoming},
|
| 18 |
+
booktitle={CVPR},
|
| 19 |
+
year={2024}
|
| 20 |
+
}
|
| 21 |
+
'''
|
| 22 |
+
|
| 23 |
+
def __init__(self, net, prior_box, preprocessor, config):
|
| 24 |
+
super(DifferentiableFaceAligner, self).__init__()
|
| 25 |
+
self.net = net
|
| 26 |
+
self.prior_box = prior_box
|
| 27 |
+
self.preprocessor = preprocessor
|
| 28 |
+
self.config = config
|
| 29 |
+
|
| 30 |
+
@classmethod
|
| 31 |
+
def from_config(cls, config):
|
| 32 |
+
net, prior_box = get_landmark_predictor(network=config.arch,
|
| 33 |
+
use_aggregator=True,
|
| 34 |
+
input_size=config.input_size)
|
| 35 |
+
|
| 36 |
+
preprocessor = get_preprocessor(output_size=config.input_size,
|
| 37 |
+
padding=config.input_padding_ratio,
|
| 38 |
+
padding_val=config.input_padding_val)
|
| 39 |
+
if config.freeze:
|
| 40 |
+
for param in net.parameters():
|
| 41 |
+
param.requires_grad = False
|
| 42 |
+
model = cls(net, prior_box, preprocessor, config)
|
| 43 |
+
model.eval()
|
| 44 |
+
return model
|
| 45 |
+
|
| 46 |
+
def forward(self, x, padding_ratio_override=None):
|
| 47 |
+
|
| 48 |
+
# input size check
|
| 49 |
+
assert x.shape[1] == 3
|
| 50 |
+
assert x.ndim == 4
|
| 51 |
+
assert isinstance(x, torch.Tensor)
|
| 52 |
+
is_square = x.shape[2] == x.shape[3]
|
| 53 |
+
|
| 54 |
+
x = self.preprocessor(x, padding_ratio_override=padding_ratio_override)
|
| 55 |
+
assert self.prior_box.image_size == x.shape[2:]
|
| 56 |
+
|
| 57 |
+
# make image into BGR
|
| 58 |
+
x_bgr = x.flip(1)
|
| 59 |
+
result = self.net(x_bgr, self.prior_box)
|
| 60 |
+
orig_pred_ldmks, bbox, cls = aligner_helper.split_network_output(result)
|
| 61 |
+
score = torch.nn.Softmax(dim=-1)(cls)[:,1:]
|
| 62 |
+
|
| 63 |
+
reference_ldmk = aligner_helper.reference_landmark()
|
| 64 |
+
input_size = self.config.input_size
|
| 65 |
+
output_size = self.config.output_size
|
| 66 |
+
cv2_tfms = aligner_helper.get_cv2_affine_from_landmark(orig_pred_ldmks, reference_ldmk, input_size, input_size)
|
| 67 |
+
thetas = aligner_helper.cv2_param_to_torch_theta(cv2_tfms, input_size, input_size, output_size, output_size)
|
| 68 |
+
thetas = thetas.to(orig_pred_ldmks.device)
|
| 69 |
+
|
| 70 |
+
output_size = torch.Size((len(thetas), 3, output_size, output_size))
|
| 71 |
+
grid = F.affine_grid(thetas, output_size, align_corners=True)
|
| 72 |
+
aligned_x = F.grid_sample(x + 1, grid, align_corners=True) - 1 # +1, -1 for making padding pixel 0
|
| 73 |
+
aligned_ldmks = aligner_helper.adjust_ldmks(orig_pred_ldmks.view(-1, 5, 2), thetas)
|
| 74 |
+
|
| 75 |
+
orig_pred_ldmks = orig_pred_ldmks.view(-1, 5, 2)
|
| 76 |
+
# bbox (xmin, ymin, xmax, ymax)
|
| 77 |
+
normalized_bbox = bbox / torch.tensor([[x_bgr.size(3), x_bgr.size(2)] * 2]).to(bbox.device)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
if padding_ratio_override is None:
|
| 81 |
+
padding_ratio = self.preprocessor.padding
|
| 82 |
+
else:
|
| 83 |
+
padding_ratio = padding_ratio_override
|
| 84 |
+
if padding_ratio > 0:
|
| 85 |
+
# unpad the landmark so that it is in the original image coordinate
|
| 86 |
+
scale = 1 / (1 + (2 * padding_ratio))
|
| 87 |
+
pad_inv_theta = torch.from_numpy(np.array([[1 / scale, 0, 0], [0, 1 / scale, 0]]))
|
| 88 |
+
pad_inv_theta = pad_inv_theta.unsqueeze(0).float().to(self.device).repeat(orig_pred_ldmks.size(0), 1, 1)
|
| 89 |
+
unpad_ldmk_pred = torch.concat([orig_pred_ldmks.view(-1, 5, 2),
|
| 90 |
+
torch.ones((orig_pred_ldmks.size(0), 5, 1)).to(self.device)], dim=-1)
|
| 91 |
+
unpad_ldmk_pred = (((unpad_ldmk_pred) * 2 - 1) @ pad_inv_theta.mT) / 2 + 0.5
|
| 92 |
+
unpad_ldmk_pred = unpad_ldmk_pred.view(orig_pred_ldmks.size(0), -1).detach()
|
| 93 |
+
unpad_ldmk_pred = unpad_ldmk_pred.view(-1, 5, 2)
|
| 94 |
+
if not is_square:
|
| 95 |
+
unpad_ldmk_pred = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 96 |
+
normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 97 |
+
return aligned_x, unpad_ldmk_pred, aligned_ldmks, score, thetas, normalized_bbox
|
| 98 |
+
|
| 99 |
+
if not is_square:
|
| 100 |
+
orig_pred_ldmks = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 101 |
+
normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 102 |
+
return aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox
|
| 103 |
+
|
| 104 |
+
def make_train_transform(self):
|
| 105 |
+
transform = transforms.Compose([
|
| 106 |
+
transforms.ToTensor(),
|
| 107 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 108 |
+
])
|
| 109 |
+
return transform
|
| 110 |
+
|
| 111 |
+
def make_test_transform(self):
|
| 112 |
+
transform = transforms.Compose([
|
| 113 |
+
transforms.ToTensor(),
|
| 114 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 115 |
+
])
|
| 116 |
+
return transform
|
| 117 |
+
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/aligner_helper.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
from skimage import transform as trans
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def split_network_output(align_out):
|
| 9 |
+
anchor_bbox_pred, anchor_cls_pred, anchor_ldmk_pred, merged, _ = align_out
|
| 10 |
+
bbox, cls, ldmk = torch.split(merged, [4, 2, 10], dim=1)
|
| 11 |
+
return ldmk, bbox, cls
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_cv2_affine_from_landmark(ldmks, reference_ldmk, image_width, image_height, ):
|
| 15 |
+
assert ldmks.ndim == 2 # batchdim
|
| 16 |
+
assert ldmks.shape[1] == 10
|
| 17 |
+
assert isinstance(ldmks, torch.Tensor)
|
| 18 |
+
|
| 19 |
+
assert reference_ldmk.ndim == 2
|
| 20 |
+
assert reference_ldmk.shape[0] == 5
|
| 21 |
+
assert reference_ldmk.shape[1] == 2
|
| 22 |
+
assert isinstance(reference_ldmk, np.ndarray)
|
| 23 |
+
|
| 24 |
+
to_img_size = np.array([[[image_width, image_height]]])
|
| 25 |
+
ldmks = ldmks.view(ldmks.shape[0], 5, 2).detach().cpu().numpy()
|
| 26 |
+
ldmks = ldmks * to_img_size
|
| 27 |
+
transforms = []
|
| 28 |
+
for ldmk in ldmks:
|
| 29 |
+
tform = trans.SimilarityTransform()
|
| 30 |
+
tform.estimate(ldmk, reference_ldmk)
|
| 31 |
+
M = tform.params[0:2, :]
|
| 32 |
+
transforms.append(M)
|
| 33 |
+
transforms = np.stack(transforms, axis=0)
|
| 34 |
+
return transforms
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def cv2_param_to_torch_theta(cv2_tfms, image_width, image_height, output_width, output_height):
|
| 38 |
+
# https://github.com/wuneng/WarpAffine2GridSample
|
| 39 |
+
"""4.Affine Transformation Matrix to theta"""
|
| 40 |
+
assert cv2_tfms.ndim == 3 # N, 2, 3
|
| 41 |
+
assert cv2_tfms.shape[1] == 2
|
| 42 |
+
assert cv2_tfms.shape[2] == 3
|
| 43 |
+
|
| 44 |
+
srcs = np.array([[0, 0], [0, 1], [1, 1]], dtype=np.float32)
|
| 45 |
+
srcs = np.expand_dims(srcs, axis=0).repeat(cv2_tfms.shape[0], axis=0)
|
| 46 |
+
dsts = np.matmul(srcs, cv2_tfms[:, :, :2].transpose(0, 2, 1)) + cv2_tfms[:, :, 2:3].transpose(0, 2, 1)
|
| 47 |
+
|
| 48 |
+
# normalize to [-1, 1]
|
| 49 |
+
srcs = srcs / np.array([[[image_width, image_height]]]) * 2 - 1
|
| 50 |
+
dsts = dsts / np.array([[[output_width, output_height]]]) * 2 - 1
|
| 51 |
+
|
| 52 |
+
thetas = []
|
| 53 |
+
for src, dst in zip(srcs, dsts):
|
| 54 |
+
theta = trans.estimate_transform("affine", src=dst, dst=src).params[:2]
|
| 55 |
+
thetas.append(theta)
|
| 56 |
+
thetas = np.stack(thetas, axis=0)
|
| 57 |
+
thetas = torch.from_numpy(thetas).float()
|
| 58 |
+
return thetas
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def adjust_ldmks(ldmks, thetas):
|
| 62 |
+
inv_thetas = inv_matrix(thetas).to(ldmks.device).float()
|
| 63 |
+
_ldmks = torch.cat([ldmks, torch.ones((ldmks.shape[0], 5, 1)).to(ldmks.device)], dim=2)
|
| 64 |
+
ldmk_aligned = (((_ldmks) * 2 - 1) @ inv_thetas.permute(0,2,1)) / 2 + 0.5
|
| 65 |
+
return ldmk_aligned
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def inv_matrix(theta):
|
| 69 |
+
# torch batched version
|
| 70 |
+
assert theta.ndim == 3
|
| 71 |
+
a, b, t1 = theta[:, 0,0], theta[:, 0,1], theta[:, 0,2]
|
| 72 |
+
c, d, t2 = theta[:, 1,0], theta[:, 1,1], theta[:, 1,2]
|
| 73 |
+
det = a * d - b * c
|
| 74 |
+
inv_det = 1.0 / det
|
| 75 |
+
inv_mat = torch.stack([
|
| 76 |
+
torch.stack([d * inv_det, -b * inv_det, (b * t2 - d * t1) * inv_det], dim=1),
|
| 77 |
+
torch.stack([-c * inv_det, a * inv_det, (c * t1 - a * t2) * inv_det], dim=1)
|
| 78 |
+
], dim=1)
|
| 79 |
+
return inv_mat
|
| 80 |
+
|
| 81 |
+
def reference_landmark():
|
| 82 |
+
return np.array([[38.29459953, 51.69630051],
|
| 83 |
+
[73.53179932, 51.50139999],
|
| 84 |
+
[56.02519989, 71.73660278],
|
| 85 |
+
[41.54930115, 92.3655014],
|
| 86 |
+
[70.72990036, 92.20410156]])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def draw_ldmk(img, ldmk):
|
| 90 |
+
if ldmk is None:
|
| 91 |
+
return img
|
| 92 |
+
colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)]
|
| 93 |
+
img = img.copy()
|
| 94 |
+
for i in range(5):
|
| 95 |
+
color = colors[i]
|
| 96 |
+
cv2.circle(img, (int(ldmk[i*2] * img.shape[1]), int(ldmk[i*2+1] * img.shape[0])), 1, color, 4)
|
| 97 |
+
return img
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/__init__.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models.retinaface import RetinaFace
|
| 2 |
+
from .utils.model_utils import load_model
|
| 3 |
+
from .config import cfg_mnet, cfg_re50
|
| 4 |
+
from .layers.functions.prior_box import PriorBox
|
| 5 |
+
from .preprocessor import Preprocessor
|
| 6 |
+
|
| 7 |
+
def get_landmark_predictor(network='mobile0.25', use_aggregator=True, input_size=160):
|
| 8 |
+
|
| 9 |
+
cfg = None
|
| 10 |
+
if network == "mobile0.25":
|
| 11 |
+
cfg = cfg_mnet
|
| 12 |
+
elif network == "resnet50":
|
| 13 |
+
cfg = cfg_re50
|
| 14 |
+
net = RetinaFace(cfg=cfg, phase = 'test', use_aggregator=use_aggregator)
|
| 15 |
+
priorbox = PriorBox(image_size=(input_size, input_size),
|
| 16 |
+
min_sizes=[[64, 80], [96, 112], [128, 144]],
|
| 17 |
+
steps=[8, 16, 32],
|
| 18 |
+
clip=False,
|
| 19 |
+
variances=[0.1, 0.2],)
|
| 20 |
+
|
| 21 |
+
# aligner = Aligner(net, priorbox, input_size, output_size=output_size)
|
| 22 |
+
# return aligner
|
| 23 |
+
return net, priorbox
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_preprocessor(output_size=160, padding=0.0, padding_val='zero'):
|
| 27 |
+
return Preprocessor(output_size=output_size, padding=padding, padding_val=padding_val)
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/config.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
|
| 3 |
+
cfg_mnet = {
|
| 4 |
+
'name': 'mobilenet0.25',
|
| 5 |
+
'pretrain': True,
|
| 6 |
+
'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
|
| 7 |
+
'in_channel': 32,
|
| 8 |
+
'out_channel': 64
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
cfg_re50 = {
|
| 12 |
+
'name': 'Resnet50',
|
| 13 |
+
'pretrain': True,
|
| 14 |
+
'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
|
| 15 |
+
'in_channel': 256,
|
| 16 |
+
'out_channel': 256
|
| 17 |
+
}
|
| 18 |
+
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .functions import *
|
| 2 |
+
from .modules import *
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/functions/prior_box.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from itertools import product as product
|
| 3 |
+
from math import ceil
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PriorBox(object):
|
| 7 |
+
|
| 8 |
+
def __init__(self,
|
| 9 |
+
image_size,
|
| 10 |
+
min_sizes=[[64, 80], [96, 112], [128, 144]],
|
| 11 |
+
steps=[8,16,32],
|
| 12 |
+
clip=False,
|
| 13 |
+
variances=[0.1, 0.2],
|
| 14 |
+
):
|
| 15 |
+
super(PriorBox, self).__init__()
|
| 16 |
+
self.min_sizes = min_sizes
|
| 17 |
+
self.steps = steps
|
| 18 |
+
self.clip = clip
|
| 19 |
+
self.variances = variances
|
| 20 |
+
self.image_size = image_size
|
| 21 |
+
self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
self.priors = self.forward()
|
| 24 |
+
|
| 25 |
+
def forward(self):
|
| 26 |
+
anchors = []
|
| 27 |
+
for k, f in enumerate(self.feature_maps):
|
| 28 |
+
min_sizes = self.min_sizes[k]
|
| 29 |
+
for i, j in product(range(f[0]), range(f[1])):
|
| 30 |
+
for min_size in min_sizes:
|
| 31 |
+
s_kx = min_size / self.image_size[1]
|
| 32 |
+
s_ky = min_size / self.image_size[0]
|
| 33 |
+
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
|
| 34 |
+
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
|
| 35 |
+
for cy, cx in product(dense_cy, dense_cx):
|
| 36 |
+
anchors += [cx, cy, s_kx, s_ky]
|
| 37 |
+
|
| 38 |
+
# back to torch land
|
| 39 |
+
output = torch.Tensor(anchors).view(-1, 4)
|
| 40 |
+
# import pandas as pd
|
| 41 |
+
# pd.DataFrame(output.numpy()).to_csv('/mckim/temp/temp.csv')
|
| 42 |
+
if self.clip:
|
| 43 |
+
output.clamp_(max=1, min=0)
|
| 44 |
+
return output
|
| 45 |
+
|
| 46 |
+
def encode(self, matched):
|
| 47 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
| 48 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
| 49 |
+
"""
|
| 50 |
+
self.priors = self.priors.to(matched.device)
|
| 51 |
+
|
| 52 |
+
# dist b/t match center and prior's center
|
| 53 |
+
g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - self.priors[:, :2]
|
| 54 |
+
# encode variance
|
| 55 |
+
g_cxcy /= (self.variances[0] * self.priors[:, 2:])
|
| 56 |
+
# match wh / prior wh
|
| 57 |
+
g_wh = (matched[:, 2:] - matched[:, :2]) / self.priors[:, 2:]
|
| 58 |
+
g_wh = torch.log(g_wh) / self.variances[1]
|
| 59 |
+
# return target for smooth_l1_loss
|
| 60 |
+
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
| 61 |
+
|
| 62 |
+
def encode_landm(self, matched):
|
| 63 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
| 64 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
| 65 |
+
"""
|
| 66 |
+
self.priors = self.priors.to(matched.device)
|
| 67 |
+
|
| 68 |
+
# dist b/t match center and prior's center
|
| 69 |
+
matched = torch.reshape(matched, (matched.size(0), 5, 2))
|
| 70 |
+
priors_cx = self.priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 71 |
+
priors_cy = self.priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 72 |
+
priors_w = self.priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 73 |
+
priors_h = self.priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 74 |
+
priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
|
| 75 |
+
g_cxcy = matched[:, :, :2] - priors[:, :, :2]
|
| 76 |
+
# encode variance
|
| 77 |
+
g_cxcy /= (self.variances[0] * priors[:, :, 2:])
|
| 78 |
+
# g_cxcy /= priors[:, :, 2:]
|
| 79 |
+
g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
|
| 80 |
+
# return target for smooth_l1_loss
|
| 81 |
+
return g_cxcy
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
| 85 |
+
def decode(self, loc):
|
| 86 |
+
"""Decode locations from predictions using priors to undo
|
| 87 |
+
the encoding we did for offset regression at train time.
|
| 88 |
+
"""
|
| 89 |
+
self.priors = self.priors.to(loc.device)
|
| 90 |
+
|
| 91 |
+
boxes = torch.cat((
|
| 92 |
+
self.priors[:, :2] + loc[:, :2] * self.variances[0] * self.priors[:, 2:],
|
| 93 |
+
self.priors[:, 2:] * torch.exp(loc[:, 2:] * self.variances[1])), 1)
|
| 94 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
| 95 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 96 |
+
return boxes
|
| 97 |
+
|
| 98 |
+
def decode_landm(self, pre):
|
| 99 |
+
"""Decode landm from predictions using priors to undo
|
| 100 |
+
the encoding we did for offset regression at train time.
|
| 101 |
+
"""
|
| 102 |
+
self.priors = self.priors.to(pre.device)
|
| 103 |
+
landms = torch.cat((self.priors[:, :2] + pre[:, :2] * self.variances[0] * self.priors[:, 2:],
|
| 104 |
+
self.priors[:, :2] + pre[:, 2:4] * self.variances[0] * self.priors[:, 2:],
|
| 105 |
+
self.priors[:, :2] + pre[:, 4:6] * self.variances[0] * self.priors[:, 2:],
|
| 106 |
+
self.priors[:, :2] + pre[:, 6:8] * self.variances[0] * self.priors[:, 2:],
|
| 107 |
+
self.priors[:, :2] + pre[:, 8:10] * self.variances[0] * self.priors[:, 2:],
|
| 108 |
+
), dim=1)
|
| 109 |
+
return landms
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def decode_batch(self, loc):
|
| 113 |
+
"""Decode locations from predictions using priors to undo
|
| 114 |
+
the encoding we did for offset regression at train time.
|
| 115 |
+
"""
|
| 116 |
+
self.priors = self.priors.to(loc.device)
|
| 117 |
+
assert loc.ndim == 3
|
| 118 |
+
priors = self.priors.unsqueeze(0).expand(loc.size(0), -1, -1)
|
| 119 |
+
boxes = torch.cat((
|
| 120 |
+
priors[:, :, :2] + loc[:, :, :2] * self.variances[0] * priors[:, :, 2:],
|
| 121 |
+
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * self.variances[1])), -1)
|
| 122 |
+
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
| 123 |
+
boxes[:, :, 2:] += boxes[:, :, :2]
|
| 124 |
+
return boxes
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def decode_landm_batch(self, prediction):
|
| 128 |
+
"""Decode landm from prediction using priors to undo
|
| 129 |
+
the encoding we did for offset regression at train time.
|
| 130 |
+
"""
|
| 131 |
+
assert prediction.ndim == 3
|
| 132 |
+
self.priors = self.priors.to(prediction.device)
|
| 133 |
+
priors = self.priors.unsqueeze(0).expand(prediction.size(0), -1, -1)
|
| 134 |
+
landms = torch.cat((priors[:, :, :2] + prediction[:, :, :2] * self.variances[0] * priors[:, :, 2:],
|
| 135 |
+
priors[:, :, :2] + prediction[:, :, 2:4] * self.variances[0] * priors[:, :, 2:],
|
| 136 |
+
priors[:, :, :2] + prediction[:, :, 4:6] * self.variances[0] * priors[:, :, 2:],
|
| 137 |
+
priors[:, :, :2] + prediction[:, :, 6:8] * self.variances[0] * priors[:, :, 2:],
|
| 138 |
+
priors[:, :, :2] + prediction[:, :, 8:10] * self.variances[0] * priors[:, :, 2:],
|
| 139 |
+
), dim=-1)
|
| 140 |
+
return landms
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .multibox_loss import MultiBoxLoss
|
| 2 |
+
|
| 3 |
+
__all__ = ['MultiBoxLoss']
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/multibox_loss.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from ...utils.box_utils import match, log_sum_exp
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MultiBoxLoss(nn.Module):
|
| 8 |
+
"""SSD Weighted Loss Function
|
| 9 |
+
Compute Targets:
|
| 10 |
+
1) Produce Confidence Target Indices by matching ground truth boxes
|
| 11 |
+
with (default) 'priorboxes' that have jaccard index > threshold parameter
|
| 12 |
+
(default threshold: 0.5).
|
| 13 |
+
2) Produce localization target by 'encoding' variance into offsets of ground
|
| 14 |
+
truth boxes and their matched 'priorboxes'.
|
| 15 |
+
3) Hard negative mining to filter the excessive number of negative examples
|
| 16 |
+
that comes with using a large number of default bounding boxes.
|
| 17 |
+
(default negative:positive ratio 3:1)
|
| 18 |
+
Objective Loss:
|
| 19 |
+
$L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N$
|
| 20 |
+
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
|
| 21 |
+
weighted by α which is set to 1 by cross val.
|
| 22 |
+
Args:
|
| 23 |
+
c: class confidences,
|
| 24 |
+
l: predicted boxes,
|
| 25 |
+
g: ground truth boxes
|
| 26 |
+
N: number of matched default boxes
|
| 27 |
+
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
|
| 31 |
+
super(MultiBoxLoss, self).__init__()
|
| 32 |
+
self.num_classes = num_classes
|
| 33 |
+
self.threshold = overlap_thresh
|
| 34 |
+
self.background_label = bkg_label
|
| 35 |
+
self.encode_target = encode_target
|
| 36 |
+
self.use_prior_for_matching = prior_for_matching
|
| 37 |
+
self.do_neg_mining = neg_mining
|
| 38 |
+
self.negpos_ratio = neg_pos
|
| 39 |
+
self.neg_overlap = neg_overlap
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def forward(self, predictions, priorbox, targets):
|
| 43 |
+
"""Multibox Loss
|
| 44 |
+
Args:
|
| 45 |
+
predictions (tuple): A tuple containing loc preds, conf preds,
|
| 46 |
+
and prior boxes from SSD net.
|
| 47 |
+
conf shape: torch.size(batch_size,num_priors,num_classes)
|
| 48 |
+
loc shape: torch.size(batch_size,num_priors,4)
|
| 49 |
+
priors shape: torch.size(num_priors,4)
|
| 50 |
+
|
| 51 |
+
ground_truth (tensor): Ground truth boxes and labels for a batch,
|
| 52 |
+
shape: [batch_size,num_objs,5] (last idx is the label).
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
loc_data, conf_data, landm_data, aggs, thetas = predictions
|
| 56 |
+
num = loc_data.size(0)
|
| 57 |
+
num_priors = (priorbox.priors.size(0))
|
| 58 |
+
|
| 59 |
+
if aggs is not None:
|
| 60 |
+
stacked_target = torch.stack(targets, dim=0).squeeze(1)
|
| 61 |
+
|
| 62 |
+
pos_idx = stacked_target[:, -1] > 0
|
| 63 |
+
agg_ldmk = aggs[:, 6:][pos_idx]
|
| 64 |
+
tgt_ldmk = stacked_target[:, 4:14][pos_idx]
|
| 65 |
+
agg_loss_landm = F.smooth_l1_loss(agg_ldmk, tgt_ldmk, reduction='sum') / len(tgt_ldmk)
|
| 66 |
+
|
| 67 |
+
pos_idx = stacked_target[:, -1] != 0
|
| 68 |
+
agg_bbox = aggs[:, :4][pos_idx]
|
| 69 |
+
tgt_bbox = stacked_target[:, :4][pos_idx]
|
| 70 |
+
agg_loss_box = F.smooth_l1_loss(agg_bbox, tgt_bbox, reduction='sum') / len(tgt_bbox)
|
| 71 |
+
|
| 72 |
+
agg_cls = aggs[:, 4:6]
|
| 73 |
+
tgt_cls = (stacked_target[:, -1] > 0).long()
|
| 74 |
+
agg_loss_cls = F.cross_entropy(agg_cls, tgt_cls, reduction='sum') / len(tgt_cls)
|
| 75 |
+
aux_loss_dict = {
|
| 76 |
+
'agg_loss_landm': agg_loss_landm,
|
| 77 |
+
'agg_loss_box': agg_loss_box,
|
| 78 |
+
'agg_loss_cls': agg_loss_cls
|
| 79 |
+
}
|
| 80 |
+
else:
|
| 81 |
+
aux_loss_dict = None
|
| 82 |
+
|
| 83 |
+
# match priors (default boxes) and ground truth boxes
|
| 84 |
+
loc_t = torch.Tensor(num, num_priors, 4)
|
| 85 |
+
landm_t = torch.Tensor(num, num_priors, 10)
|
| 86 |
+
conf_t = torch.LongTensor(num, num_priors)
|
| 87 |
+
for idx in range(num):
|
| 88 |
+
truths = targets[idx][:, :4].data
|
| 89 |
+
labels = targets[idx][:, -1].data
|
| 90 |
+
landms = targets[idx][:, 4:14].data
|
| 91 |
+
match(self.threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx)
|
| 92 |
+
|
| 93 |
+
loc_t = loc_t.cuda()
|
| 94 |
+
conf_t = conf_t.cuda()
|
| 95 |
+
landm_t = landm_t.cuda()
|
| 96 |
+
zeros = torch.tensor(0).cuda()
|
| 97 |
+
# landm Loss (Smooth L1)
|
| 98 |
+
# Shape: [batch,num_priors,10]
|
| 99 |
+
pos1 = conf_t > zeros
|
| 100 |
+
num_pos_landm = pos1.long().sum(1, keepdim=True)
|
| 101 |
+
N1 = max(num_pos_landm.data.sum().float(), 1)
|
| 102 |
+
pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
|
| 103 |
+
landm_p = landm_data[pos_idx1].view(-1, 10)
|
| 104 |
+
landm_t = landm_t[pos_idx1].view(-1, 10)
|
| 105 |
+
loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
pos = conf_t != zeros
|
| 109 |
+
conf_t[pos] = 1
|
| 110 |
+
|
| 111 |
+
# Localization Loss (Smooth L1)
|
| 112 |
+
# Shape: [batch,num_priors,4]
|
| 113 |
+
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
|
| 114 |
+
loc_p = loc_data[pos_idx].view(-1, 4)
|
| 115 |
+
loc_t = loc_t[pos_idx].view(-1, 4)
|
| 116 |
+
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
|
| 117 |
+
|
| 118 |
+
# Compute max conf across batch for hard negative mining
|
| 119 |
+
batch_conf = conf_data.view(-1, self.num_classes)
|
| 120 |
+
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
|
| 121 |
+
|
| 122 |
+
# Hard Negative Mining
|
| 123 |
+
loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
|
| 124 |
+
loss_c = loss_c.view(num, -1)
|
| 125 |
+
_, loss_idx = loss_c.sort(1, descending=True)
|
| 126 |
+
_, idx_rank = loss_idx.sort(1)
|
| 127 |
+
num_pos = pos.long().sum(1, keepdim=True)
|
| 128 |
+
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
|
| 129 |
+
neg = idx_rank < num_neg.expand_as(idx_rank)
|
| 130 |
+
|
| 131 |
+
# Confidence Loss Including Positive and Negative Examples
|
| 132 |
+
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
|
| 133 |
+
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
|
| 134 |
+
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
|
| 135 |
+
targets_weighted = conf_t[(pos+neg).gt(0)]
|
| 136 |
+
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
|
| 137 |
+
|
| 138 |
+
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
| 139 |
+
N = max(num_pos.data.sum().float(), 1)
|
| 140 |
+
loss_l /= N
|
| 141 |
+
loss_c /= N
|
| 142 |
+
loss_landm /= N1
|
| 143 |
+
|
| 144 |
+
return loss_l, loss_c, loss_landm, aux_loss_dict
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/__init__.py
ADDED
|
File without changes
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/net.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchvision.models._utils as _utils
|
| 5 |
+
import torchvision.models as models
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.autograd import Variable
|
| 8 |
+
|
| 9 |
+
def conv_bn(inp, oup, stride = 1, leaky = 0):
|
| 10 |
+
return nn.Sequential(
|
| 11 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
| 12 |
+
nn.BatchNorm2d(oup),
|
| 13 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True)
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
def conv_bn_no_relu(inp, oup, stride):
|
| 17 |
+
return nn.Sequential(
|
| 18 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
| 19 |
+
nn.BatchNorm2d(oup),
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def conv_bn1X1(inp, oup, stride, leaky=0):
|
| 23 |
+
return nn.Sequential(
|
| 24 |
+
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
|
| 25 |
+
nn.BatchNorm2d(oup),
|
| 26 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def conv_dw(inp, oup, stride, leaky=0.1):
|
| 30 |
+
return nn.Sequential(
|
| 31 |
+
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
| 32 |
+
nn.BatchNorm2d(inp),
|
| 33 |
+
nn.LeakyReLU(negative_slope= leaky,inplace=True),
|
| 34 |
+
|
| 35 |
+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
| 36 |
+
nn.BatchNorm2d(oup),
|
| 37 |
+
nn.LeakyReLU(negative_slope= leaky,inplace=True),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
class SSH(nn.Module):
|
| 41 |
+
def __init__(self, in_channel, out_channel):
|
| 42 |
+
super(SSH, self).__init__()
|
| 43 |
+
assert out_channel % 4 == 0
|
| 44 |
+
leaky = 0
|
| 45 |
+
if (out_channel <= 64):
|
| 46 |
+
leaky = 0.1
|
| 47 |
+
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)
|
| 48 |
+
|
| 49 |
+
self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
|
| 50 |
+
self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
|
| 51 |
+
|
| 52 |
+
self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
|
| 53 |
+
self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
|
| 54 |
+
|
| 55 |
+
def forward(self, input):
|
| 56 |
+
conv3X3 = self.conv3X3(input)
|
| 57 |
+
|
| 58 |
+
conv5X5_1 = self.conv5X5_1(input)
|
| 59 |
+
conv5X5 = self.conv5X5_2(conv5X5_1)
|
| 60 |
+
|
| 61 |
+
conv7X7_2 = self.conv7X7_2(conv5X5_1)
|
| 62 |
+
conv7X7 = self.conv7x7_3(conv7X7_2)
|
| 63 |
+
|
| 64 |
+
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
|
| 65 |
+
out = F.relu(out)
|
| 66 |
+
return out
|
| 67 |
+
|
| 68 |
+
class FPN(nn.Module):
|
| 69 |
+
def __init__(self,in_channels_list,out_channels):
|
| 70 |
+
super(FPN,self).__init__()
|
| 71 |
+
leaky = 0
|
| 72 |
+
if (out_channels <= 64):
|
| 73 |
+
leaky = 0.1
|
| 74 |
+
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
|
| 75 |
+
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
|
| 76 |
+
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)
|
| 77 |
+
|
| 78 |
+
self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
|
| 79 |
+
self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)
|
| 80 |
+
|
| 81 |
+
def forward(self, input):
|
| 82 |
+
# names = list(input.keys())
|
| 83 |
+
input = list(input.values())
|
| 84 |
+
|
| 85 |
+
output1 = self.output1(input[0])
|
| 86 |
+
output2 = self.output2(input[1])
|
| 87 |
+
output3 = self.output3(input[2])
|
| 88 |
+
|
| 89 |
+
up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
|
| 90 |
+
output2 = output2 + up3
|
| 91 |
+
output2 = self.merge2(output2)
|
| 92 |
+
|
| 93 |
+
up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
|
| 94 |
+
output1 = output1 + up2
|
| 95 |
+
output1 = self.merge1(output1)
|
| 96 |
+
|
| 97 |
+
out = [output1, output2, output3]
|
| 98 |
+
return out
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MobileNetV1(nn.Module):
|
| 103 |
+
def __init__(self):
|
| 104 |
+
super(MobileNetV1, self).__init__()
|
| 105 |
+
self.stage1 = nn.Sequential(
|
| 106 |
+
conv_bn(3, 8, 2, leaky = 0.1), # 3
|
| 107 |
+
conv_dw(8, 16, 1), # 7
|
| 108 |
+
conv_dw(16, 32, 2), # 11
|
| 109 |
+
conv_dw(32, 32, 1), # 19
|
| 110 |
+
conv_dw(32, 64, 2), # 27
|
| 111 |
+
conv_dw(64, 64, 1), # 43
|
| 112 |
+
)
|
| 113 |
+
self.stage2 = nn.Sequential(
|
| 114 |
+
conv_dw(64, 128, 2), # 43 + 16 = 59
|
| 115 |
+
conv_dw(128, 128, 1), # 59 + 32 = 91
|
| 116 |
+
conv_dw(128, 128, 1), # 91 + 32 = 123
|
| 117 |
+
conv_dw(128, 128, 1), # 123 + 32 = 155
|
| 118 |
+
conv_dw(128, 128, 1), # 155 + 32 = 187
|
| 119 |
+
conv_dw(128, 128, 1), # 187 + 32 = 219
|
| 120 |
+
)
|
| 121 |
+
self.stage3 = nn.Sequential(
|
| 122 |
+
conv_dw(128, 256, 2), # 219 +3 2 = 241
|
| 123 |
+
conv_dw(256, 256, 1), # 241 + 64 = 301
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
x = self.stage1(x)
|
| 128 |
+
x = self.stage2(x)
|
| 129 |
+
x = self.stage3(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/retinaface.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision.models._utils as _utils
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from .net import MobileNetV1 as MobileNetV1
|
| 6 |
+
from .net import FPN as FPN
|
| 7 |
+
from .net import SSH as SSH
|
| 8 |
+
|
| 9 |
+
from timm.models import mlp_mixer
|
| 10 |
+
|
| 11 |
+
class ClassHead(nn.Module):
|
| 12 |
+
def __init__(self,inchannels=512,num_anchors=3):
|
| 13 |
+
super(ClassHead,self).__init__()
|
| 14 |
+
self.num_anchors = num_anchors
|
| 15 |
+
self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)
|
| 16 |
+
|
| 17 |
+
def forward(self,x):
|
| 18 |
+
out = self.conv1x1(x)
|
| 19 |
+
out = out.permute(0,2,3,1).contiguous()
|
| 20 |
+
|
| 21 |
+
return out.view(out.shape[0], -1, 2)
|
| 22 |
+
|
| 23 |
+
class BboxHead(nn.Module):
|
| 24 |
+
def __init__(self,inchannels=512,num_anchors=3):
|
| 25 |
+
super(BboxHead,self).__init__()
|
| 26 |
+
self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
|
| 27 |
+
|
| 28 |
+
def forward(self,x):
|
| 29 |
+
out = self.conv1x1(x)
|
| 30 |
+
out = out.permute(0,2,3,1).contiguous()
|
| 31 |
+
|
| 32 |
+
return out.view(out.shape[0], -1, 4)
|
| 33 |
+
|
| 34 |
+
class LandmarkHead(nn.Module):
|
| 35 |
+
def __init__(self,inchannels=512,num_anchors=3):
|
| 36 |
+
super(LandmarkHead,self).__init__()
|
| 37 |
+
self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
|
| 38 |
+
|
| 39 |
+
def forward(self,x):
|
| 40 |
+
out = self.conv1x1(x)
|
| 41 |
+
out = out.permute(0,2,3,1).contiguous()
|
| 42 |
+
|
| 43 |
+
return out.view(out.shape[0], -1, 10)
|
| 44 |
+
|
| 45 |
+
class RetinaFace(nn.Module):
|
| 46 |
+
def __init__(self, cfg = None, phase = 'train', use_aggregator=False):
|
| 47 |
+
"""
|
| 48 |
+
:param cfg: Network related settings.
|
| 49 |
+
:param phase: train or test.
|
| 50 |
+
"""
|
| 51 |
+
super(RetinaFace,self).__init__()
|
| 52 |
+
self.phase = phase
|
| 53 |
+
backbone = None
|
| 54 |
+
if cfg['name'] == 'mobilenet0.25':
|
| 55 |
+
backbone = MobileNetV1()
|
| 56 |
+
# if cfg['pretrain']:
|
| 57 |
+
# checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
|
| 58 |
+
# from collections import OrderedDict
|
| 59 |
+
# new_state_dict = OrderedDict()
|
| 60 |
+
# for k, v in checkpoint['state_dict'].items():
|
| 61 |
+
# name = k[7:] # remove module.
|
| 62 |
+
# new_state_dict[name] = v
|
| 63 |
+
# load params
|
| 64 |
+
# backbone.load_state_dict(new_state_dict)
|
| 65 |
+
elif cfg['name'] == 'Resnet50':
|
| 66 |
+
import torchvision.models as models
|
| 67 |
+
backbone = models.resnet50(pretrained=cfg['pretrain'])
|
| 68 |
+
|
| 69 |
+
self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
|
| 70 |
+
in_channels_stage2 = cfg['in_channel']
|
| 71 |
+
in_channels_list = [
|
| 72 |
+
in_channels_stage2 * 2,
|
| 73 |
+
in_channels_stage2 * 4,
|
| 74 |
+
in_channels_stage2 * 8,
|
| 75 |
+
]
|
| 76 |
+
out_channels = cfg['out_channel']
|
| 77 |
+
self.fpn = FPN(in_channels_list,out_channels)
|
| 78 |
+
self.ssh1 = SSH(out_channels, out_channels)
|
| 79 |
+
self.ssh2 = SSH(out_channels, out_channels)
|
| 80 |
+
self.ssh3 = SSH(out_channels, out_channels)
|
| 81 |
+
|
| 82 |
+
self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
|
| 83 |
+
self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
|
| 84 |
+
self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
|
| 85 |
+
|
| 86 |
+
self.use_aggregator = use_aggregator
|
| 87 |
+
if self.use_aggregator:
|
| 88 |
+
modules = [mlp_mixer.MixerBlock(16, 1050) for _ in range(3)]
|
| 89 |
+
modules.append(nn.Linear(16, 1))
|
| 90 |
+
self.aggregator = nn.Sequential(*modules)
|
| 91 |
+
else:
|
| 92 |
+
self.aggregator = None
|
| 93 |
+
|
| 94 |
+
def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
|
| 95 |
+
classhead = nn.ModuleList()
|
| 96 |
+
for i in range(fpn_num):
|
| 97 |
+
classhead.append(ClassHead(inchannels,anchor_num))
|
| 98 |
+
return classhead
|
| 99 |
+
|
| 100 |
+
def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
|
| 101 |
+
bboxhead = nn.ModuleList()
|
| 102 |
+
for i in range(fpn_num):
|
| 103 |
+
bboxhead.append(BboxHead(inchannels,anchor_num))
|
| 104 |
+
return bboxhead
|
| 105 |
+
|
| 106 |
+
def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
|
| 107 |
+
landmarkhead = nn.ModuleList()
|
| 108 |
+
for i in range(fpn_num):
|
| 109 |
+
landmarkhead.append(LandmarkHead(inchannels,anchor_num))
|
| 110 |
+
return landmarkhead
|
| 111 |
+
|
| 112 |
+
def forward(self, inputs, priorbox):
|
| 113 |
+
out = self.body(inputs)
|
| 114 |
+
|
| 115 |
+
# FPN
|
| 116 |
+
fpn = self.fpn(out)
|
| 117 |
+
|
| 118 |
+
# SSH
|
| 119 |
+
feature1 = self.ssh1(fpn[0])
|
| 120 |
+
feature2 = self.ssh2(fpn[1])
|
| 121 |
+
feature3 = self.ssh3(fpn[2])
|
| 122 |
+
features = [feature1, feature2, feature3]
|
| 123 |
+
|
| 124 |
+
bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
|
| 125 |
+
classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
|
| 126 |
+
ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
|
| 127 |
+
if self.use_aggregator:
|
| 128 |
+
decoded_bbox = priorbox.decode_batch(bbox_regressions)
|
| 129 |
+
decoded_ldmk = priorbox.decode_landm_batch(ldm_regressions)
|
| 130 |
+
combined = torch.cat([decoded_bbox, classifications, decoded_ldmk], dim=2)
|
| 131 |
+
weight = self.aggregator(combined)
|
| 132 |
+
weight = F.softmax(weight, dim=1)
|
| 133 |
+
agg = torch.sum(weight * combined, dim=1)
|
| 134 |
+
theta = None
|
| 135 |
+
else:
|
| 136 |
+
agg = None
|
| 137 |
+
theta = None
|
| 138 |
+
if self.phase == 'train':
|
| 139 |
+
output = (bbox_regressions, classifications, ldm_regressions, agg, theta)
|
| 140 |
+
else:
|
| 141 |
+
output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions, agg, theta)
|
| 142 |
+
return output
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/preprocessor.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
class Preprocessor():
|
| 5 |
+
|
| 6 |
+
def __init__(self, output_size=160, padding=0.0, padding_val='zero'):
|
| 7 |
+
self.output_size = output_size
|
| 8 |
+
self.padding = padding
|
| 9 |
+
self.padding_val = padding_val
|
| 10 |
+
|
| 11 |
+
def preprocess_batched(self, imgs, padding_ratio_override=None):
|
| 12 |
+
|
| 13 |
+
# check img is of float
|
| 14 |
+
if imgs.dtype == torch.float32:
|
| 15 |
+
if self.padding_val == 'zero':
|
| 16 |
+
padding_val = -1.0
|
| 17 |
+
elif self.padding_val == 'mean':
|
| 18 |
+
padding_val = imgs.mean()
|
| 19 |
+
else:
|
| 20 |
+
raise ValueError('padding_val must be "zero" or "mean"')
|
| 21 |
+
elif imgs.dtype == torch.uint8:
|
| 22 |
+
if self.padding_val == 'zero':
|
| 23 |
+
padding_val = 0
|
| 24 |
+
elif self.padding_val == 'mean':
|
| 25 |
+
padding_val = imgs.mean()
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError('padding_val must be "zero" or "mean"')
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
|
| 30 |
+
|
| 31 |
+
square_imgs = self.make_square_img_batched(imgs, padding_val=padding_val)
|
| 32 |
+
|
| 33 |
+
if padding_ratio_override is not None:
|
| 34 |
+
padding = padding_ratio_override
|
| 35 |
+
else:
|
| 36 |
+
padding = self.padding
|
| 37 |
+
padded_imgs = self.make_padded_img_batched(square_imgs, padding=padding, padding_val=padding_val)
|
| 38 |
+
|
| 39 |
+
size=(self.output_size, self.output_size)
|
| 40 |
+
if imgs.dtype == torch.float32:
|
| 41 |
+
resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
|
| 42 |
+
elif imgs.dtype == torch.uint8:
|
| 43 |
+
padded_imgs = padded_imgs.to(torch.float32)
|
| 44 |
+
resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
|
| 45 |
+
resized_imgs = torch.clip(resized_imgs, 0, 255)
|
| 46 |
+
resized_imgs = resized_imgs.to(torch.uint8)
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
|
| 49 |
+
return resized_imgs
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def make_square_img_batched(self, imgs, padding_val):
|
| 53 |
+
assert imgs.ndim == 4
|
| 54 |
+
# squarify the image
|
| 55 |
+
h, w = imgs.shape[2:]
|
| 56 |
+
if h > w:
|
| 57 |
+
diff = (h - w)
|
| 58 |
+
pad_left = diff // 2
|
| 59 |
+
pad_right = diff - pad_left
|
| 60 |
+
imgs = F.pad(imgs, (pad_left, pad_right, 0, 0), value=padding_val)
|
| 61 |
+
elif w > h:
|
| 62 |
+
diff = (w - h)
|
| 63 |
+
pad_top = diff // 2
|
| 64 |
+
pad_bottom = diff - pad_top
|
| 65 |
+
imgs = F.pad(imgs, (0, 0, pad_top, pad_bottom), value=padding_val)
|
| 66 |
+
assert imgs.shape[2] == imgs.shape[3]
|
| 67 |
+
return imgs
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def make_padded_img_batched(self, imgs, padding, padding_val):
|
| 71 |
+
if padding == 0:
|
| 72 |
+
return imgs
|
| 73 |
+
assert imgs.ndim == 4
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# pad the image
|
| 77 |
+
h, w = imgs.shape[2:]
|
| 78 |
+
pad_h = int(h * padding)
|
| 79 |
+
pad_w = int(w * padding)
|
| 80 |
+
imgs = F.pad(imgs, (pad_w, pad_w, pad_h, pad_h), value=padding_val)
|
| 81 |
+
return imgs
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def __call__(self, input, padding_ratio_override=None):
|
| 85 |
+
if input.ndim == 3:
|
| 86 |
+
assert input.shape[0] == 3
|
| 87 |
+
batch_input = input.unsqueeze(0)
|
| 88 |
+
return self.preprocess_batched(batch_input, padding_ratio_override=padding_ratio_override)[0]
|
| 89 |
+
elif input.ndim == 4:
|
| 90 |
+
assert input.shape[1] == 3
|
| 91 |
+
return self.preprocess_batched(input, padding_ratio_override=padding_ratio_override)
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f'Invalid input shape: {input.shape}')
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/box_utils.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def point_form(boxes):
|
| 6 |
+
""" Convert prior_boxes to (xmin, ymin, xmax, ymax)
|
| 7 |
+
representation for comparison to point form ground truth data.
|
| 8 |
+
Args:
|
| 9 |
+
boxes: (tensor) center-size default boxes from priorbox layers.
|
| 10 |
+
Return:
|
| 11 |
+
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
| 12 |
+
"""
|
| 13 |
+
return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
|
| 14 |
+
boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def center_size(boxes):
|
| 18 |
+
""" Convert prior_boxes to (cx, cy, w, h)
|
| 19 |
+
representation for comparison to center-size form ground truth data.
|
| 20 |
+
Args:
|
| 21 |
+
boxes: (tensor) point_form boxes
|
| 22 |
+
Return:
|
| 23 |
+
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
| 24 |
+
"""
|
| 25 |
+
return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
|
| 26 |
+
boxes[:, 2:] - boxes[:, :2], 1) # w, h
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def intersect(box_a, box_b):
|
| 30 |
+
""" We resize both tensors to [A,B,2] without new malloc:
|
| 31 |
+
[A,2] -> [A,1,2] -> [A,B,2]
|
| 32 |
+
[B,2] -> [1,B,2] -> [A,B,2]
|
| 33 |
+
Then we compute the area of intersect between box_a and box_b.
|
| 34 |
+
Args:
|
| 35 |
+
box_a: (tensor) bounding boxes, Shape: [A,4].
|
| 36 |
+
box_b: (tensor) bounding boxes, Shape: [B,4].
|
| 37 |
+
Return:
|
| 38 |
+
(tensor) intersection area, Shape: [A,B].
|
| 39 |
+
"""
|
| 40 |
+
A = box_a.size(0)
|
| 41 |
+
B = box_b.size(0)
|
| 42 |
+
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
|
| 43 |
+
box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
|
| 44 |
+
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
|
| 45 |
+
box_b[:, :2].unsqueeze(0).expand(A, B, 2))
|
| 46 |
+
inter = torch.clamp((max_xy - min_xy), min=0)
|
| 47 |
+
return inter[:, :, 0] * inter[:, :, 1]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def jaccard(box_a, box_b):
|
| 51 |
+
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
|
| 52 |
+
is simply the intersection over union of two boxes. Here we operate on
|
| 53 |
+
ground truth boxes and default boxes.
|
| 54 |
+
E.g.:
|
| 55 |
+
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
| 56 |
+
Args:
|
| 57 |
+
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
|
| 58 |
+
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
|
| 59 |
+
Return:
|
| 60 |
+
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
|
| 61 |
+
"""
|
| 62 |
+
inter = intersect(box_a, box_b)
|
| 63 |
+
area_a = ((box_a[:, 2]-box_a[:, 0]) *
|
| 64 |
+
(box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
|
| 65 |
+
area_b = ((box_b[:, 2]-box_b[:, 0]) *
|
| 66 |
+
(box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
|
| 67 |
+
union = area_a + area_b - inter
|
| 68 |
+
return inter / union # [A,B]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def matrix_iou(a, b):
|
| 72 |
+
"""
|
| 73 |
+
return iou of a and b, numpy version for data augenmentation
|
| 74 |
+
"""
|
| 75 |
+
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
| 76 |
+
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
| 77 |
+
|
| 78 |
+
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
| 79 |
+
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
| 80 |
+
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
|
| 81 |
+
return area_i / (area_a[:, np.newaxis] + area_b - area_i)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def matrix_iof(a, b):
|
| 85 |
+
"""
|
| 86 |
+
return iof of a and b, numpy version for data augenmentation
|
| 87 |
+
"""
|
| 88 |
+
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
| 89 |
+
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
| 90 |
+
|
| 91 |
+
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
| 92 |
+
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
| 93 |
+
return area_i / np.maximum(area_a[:, np.newaxis], 1)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def match(threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx):
|
| 97 |
+
"""Match each prior box with the ground truth box of the highest jaccard
|
| 98 |
+
overlap, encode the bounding boxes, then return the matched indices
|
| 99 |
+
corresponding to both confidence and location preds.
|
| 100 |
+
Args:
|
| 101 |
+
threshold: (float) The overlap threshold used when mathing boxes.
|
| 102 |
+
truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
|
| 103 |
+
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
|
| 104 |
+
variances: (tensor) Variances corresponding to each prior coord,
|
| 105 |
+
Shape: [num_priors, 4].
|
| 106 |
+
labels: (tensor) All the class labels for the image, Shape: [num_obj].
|
| 107 |
+
landms: (tensor) Ground truth landms, Shape [num_obj, 10].
|
| 108 |
+
loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
|
| 109 |
+
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
|
| 110 |
+
landm_t: (tensor) Tensor to be filled w/ endcoded landm targets.
|
| 111 |
+
idx: (int) current batch index
|
| 112 |
+
Return:
|
| 113 |
+
The matched indices corresponding to 1)location 2)confidence 3)landm preds.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# jaccard index
|
| 118 |
+
overlaps = jaccard(
|
| 119 |
+
truths,
|
| 120 |
+
point_form(priorbox.priors)
|
| 121 |
+
)
|
| 122 |
+
# (Bipartite Matching)
|
| 123 |
+
# [1,num_objects] best prior for each ground truth
|
| 124 |
+
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
|
| 125 |
+
|
| 126 |
+
# ignore hard gt
|
| 127 |
+
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
|
| 128 |
+
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
|
| 129 |
+
if best_prior_idx_filter.shape[0] <= 0:
|
| 130 |
+
loc_t[idx] = 0
|
| 131 |
+
conf_t[idx] = 0
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
# [1,num_priors] best ground truth for each prior
|
| 135 |
+
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
|
| 136 |
+
best_truth_idx.squeeze_(0)
|
| 137 |
+
best_truth_overlap.squeeze_(0)
|
| 138 |
+
best_prior_idx.squeeze_(1)
|
| 139 |
+
best_prior_idx_filter.squeeze_(1)
|
| 140 |
+
best_prior_overlap.squeeze_(1)
|
| 141 |
+
best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
|
| 142 |
+
# TODO refactor: index best_prior_idx with long tensor
|
| 143 |
+
# ensure every gt matches with its prior of max overlap
|
| 144 |
+
for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
|
| 145 |
+
best_truth_idx[best_prior_idx[j]] = j
|
| 146 |
+
matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
|
| 147 |
+
conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
|
| 148 |
+
conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
|
| 149 |
+
loc = priorbox.encode(matches)
|
| 150 |
+
|
| 151 |
+
matches_landm = landms[best_truth_idx]
|
| 152 |
+
landm = priorbox.encode_landm(matches_landm)
|
| 153 |
+
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
|
| 154 |
+
conf_t[idx] = conf # [num_priors] top class label for each prior
|
| 155 |
+
landm_t[idx] = landm
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def log_sum_exp(x):
|
| 160 |
+
"""Utility function for computing log_sum_exp while determining
|
| 161 |
+
This will be used to determine unaveraged confidence loss across
|
| 162 |
+
all examples in a batch.
|
| 163 |
+
Args:
|
| 164 |
+
x (Variable(tensor)): conf_preds from conf layers
|
| 165 |
+
"""
|
| 166 |
+
x_max = x.data.max()
|
| 167 |
+
return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Original author: Francisco Massa:
|
| 171 |
+
# https://github.com/fmassa/object-detection.torch
|
| 172 |
+
# Ported to PyTorch by Max deGroot (02/01/2017)
|
| 173 |
+
def nms(boxes, scores, overlap=0.5, top_k=200):
|
| 174 |
+
"""Apply non-maximum suppression at test time to avoid detecting too many
|
| 175 |
+
overlapping bounding boxes for a given object.
|
| 176 |
+
Args:
|
| 177 |
+
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
|
| 178 |
+
scores: (tensor) The class predscores for the img, Shape:[num_priors].
|
| 179 |
+
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
|
| 180 |
+
top_k: (int) The Maximum number of box preds to consider.
|
| 181 |
+
Return:
|
| 182 |
+
The indices of the kept boxes with respect to num_priors.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
keep = torch.Tensor(scores.size(0)).fill_(0).long()
|
| 186 |
+
if boxes.numel() == 0:
|
| 187 |
+
return keep
|
| 188 |
+
x1 = boxes[:, 0]
|
| 189 |
+
y1 = boxes[:, 1]
|
| 190 |
+
x2 = boxes[:, 2]
|
| 191 |
+
y2 = boxes[:, 3]
|
| 192 |
+
area = torch.mul(x2 - x1, y2 - y1)
|
| 193 |
+
v, idx = scores.sort(0) # sort in ascending order
|
| 194 |
+
# I = I[v >= 0.01]
|
| 195 |
+
idx = idx[-top_k:] # indices of the top-k largest vals
|
| 196 |
+
xx1 = boxes.new()
|
| 197 |
+
yy1 = boxes.new()
|
| 198 |
+
xx2 = boxes.new()
|
| 199 |
+
yy2 = boxes.new()
|
| 200 |
+
w = boxes.new()
|
| 201 |
+
h = boxes.new()
|
| 202 |
+
|
| 203 |
+
# keep = torch.Tensor()
|
| 204 |
+
count = 0
|
| 205 |
+
while idx.numel() > 0:
|
| 206 |
+
i = idx[-1] # index of current largest val
|
| 207 |
+
# keep.append(i)
|
| 208 |
+
keep[count] = i
|
| 209 |
+
count += 1
|
| 210 |
+
if idx.size(0) == 1:
|
| 211 |
+
break
|
| 212 |
+
idx = idx[:-1] # remove kept element from view
|
| 213 |
+
# load bboxes of next highest vals
|
| 214 |
+
torch.index_select(x1, 0, idx, out=xx1)
|
| 215 |
+
torch.index_select(y1, 0, idx, out=yy1)
|
| 216 |
+
torch.index_select(x2, 0, idx, out=xx2)
|
| 217 |
+
torch.index_select(y2, 0, idx, out=yy2)
|
| 218 |
+
# store element-wise max with next highest score
|
| 219 |
+
xx1 = torch.clamp(xx1, min=x1[i])
|
| 220 |
+
yy1 = torch.clamp(yy1, min=y1[i])
|
| 221 |
+
xx2 = torch.clamp(xx2, max=x2[i])
|
| 222 |
+
yy2 = torch.clamp(yy2, max=y2[i])
|
| 223 |
+
w.resize_as_(xx2)
|
| 224 |
+
h.resize_as_(yy2)
|
| 225 |
+
w = xx2 - xx1
|
| 226 |
+
h = yy2 - yy1
|
| 227 |
+
# check sizes of xx1 and xx2.. after each iteration
|
| 228 |
+
w = torch.clamp(w, min=0.0)
|
| 229 |
+
h = torch.clamp(h, min=0.0)
|
| 230 |
+
inter = w*h
|
| 231 |
+
# IoU = i / (area(a) + area(b) - i)
|
| 232 |
+
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
|
| 233 |
+
union = (rem_areas - inter) + area[i]
|
| 234 |
+
IoU = inter/union # store result in iou
|
| 235 |
+
# keep only elements with an IoU <= overlap
|
| 236 |
+
idx = idx[IoU.le(overlap)]
|
| 237 |
+
return keep, count
|
| 238 |
+
|
| 239 |
+
|
cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/model_utils.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def remove_prefix(state_dict, prefix):
|
| 4 |
+
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
|
| 5 |
+
print('remove prefix \'{}\''.format(prefix))
|
| 6 |
+
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
|
| 7 |
+
return {f(key): value for key, value in state_dict.items()}
|
| 8 |
+
|
| 9 |
+
def check_keys(model, pretrained_state_dict):
|
| 10 |
+
ckpt_keys = set(pretrained_state_dict.keys())
|
| 11 |
+
model_keys = set(model.state_dict().keys())
|
| 12 |
+
used_pretrained_keys = model_keys & ckpt_keys
|
| 13 |
+
unused_pretrained_keys = ckpt_keys - model_keys
|
| 14 |
+
missing_keys = model_keys - ckpt_keys
|
| 15 |
+
print('Missing keys:{}'.format(len(missing_keys)))
|
| 16 |
+
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
|
| 17 |
+
print('Used keys:{}'.format(len(used_pretrained_keys)))
|
| 18 |
+
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
|
| 19 |
+
return True
|
| 20 |
+
|
| 21 |
+
def load_model(model, pretrained_path, load_to_cpu):
|
| 22 |
+
print('Loading pretrained model from {}'.format(pretrained_path))
|
| 23 |
+
if load_to_cpu:
|
| 24 |
+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
|
| 25 |
+
else:
|
| 26 |
+
device = torch.cuda.current_device()
|
| 27 |
+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
|
| 28 |
+
if "state_dict" in pretrained_dict.keys():
|
| 29 |
+
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
|
| 30 |
+
else:
|
| 31 |
+
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
|
| 32 |
+
check_keys(model, pretrained_dict)
|
| 33 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
| 34 |
+
return model
|
| 35 |
+
|
| 36 |
+
|
cvlface/research/recognition/code/run_v1/aligners/none/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..base import BaseAligner
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class NoneAligner(BaseAligner):
|
| 5 |
+
def __init__(self, config):
|
| 6 |
+
super().__init__()
|
| 7 |
+
self.config = config
|
| 8 |
+
|
| 9 |
+
@classmethod
|
| 10 |
+
def from_config(cls, aligner_config):
|
| 11 |
+
return cls(aligner_config)
|
| 12 |
+
|
| 13 |
+
def make_train_transform(self):
|
| 14 |
+
return lambda x:x
|
| 15 |
+
|
| 16 |
+
def make_test_transform(self):
|
| 17 |
+
return lambda x:x
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
return x
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/__init__.py
ADDED
|
@@ -0,0 +1,246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..base import BaseAligner
|
| 2 |
+
from torchvision import transforms
|
| 3 |
+
from .retinaface import get_landmark_predictor, get_preprocessor
|
| 4 |
+
from . import aligner_helper
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
class RetinaFaceAligner(BaseAligner):
|
| 10 |
+
|
| 11 |
+
"""
|
| 12 |
+
A non-differentiable face aligner that aligns the image with one face to a canonical position.
|
| 13 |
+
The aligner is based on the following paper:
|
| 14 |
+
|
| 15 |
+
```
|
| 16 |
+
@inproceedings{deng2020retinaface,
|
| 17 |
+
title={Retinaface: Single-shot multi-level face localisation in the wild},
|
| 18 |
+
author={Deng, Jiankang and Guo, Jia and Ververas, Evangelos and Kotsia, Irene and Zafeiriou, Stefanos},
|
| 19 |
+
booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition},
|
| 20 |
+
pages={5203--5212},
|
| 21 |
+
year={2020}
|
| 22 |
+
}
|
| 23 |
+
```
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, net, prior_box, preprocessor, config):
|
| 27 |
+
super(RetinaFaceAligner, self).__init__()
|
| 28 |
+
self.net = net
|
| 29 |
+
self.prior_box = prior_box
|
| 30 |
+
self.preprocessor = preprocessor
|
| 31 |
+
self.config = config
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def from_config(cls, config):
|
| 35 |
+
net, prior_box = get_landmark_predictor(network=config.arch,
|
| 36 |
+
input_size=config.input_size)
|
| 37 |
+
|
| 38 |
+
preprocessor = get_preprocessor(output_size=config.input_size,
|
| 39 |
+
padding=config.input_padding_ratio,
|
| 40 |
+
padding_val=config.input_padding_val)
|
| 41 |
+
if config.freeze:
|
| 42 |
+
for param in net.parameters():
|
| 43 |
+
param.requires_grad = False
|
| 44 |
+
model = cls(net, prior_box, preprocessor, config)
|
| 45 |
+
model.eval()
|
| 46 |
+
return model
|
| 47 |
+
|
| 48 |
+
def forward(self, x, padding_ratio_override=None):
|
| 49 |
+
|
| 50 |
+
# input size check
|
| 51 |
+
assert x.shape[1] == 3
|
| 52 |
+
assert x.ndim == 4
|
| 53 |
+
assert isinstance(x, torch.Tensor)
|
| 54 |
+
is_square = x.shape[2] == x.shape[3]
|
| 55 |
+
|
| 56 |
+
x = self.preprocessor(x, padding_ratio_override=padding_ratio_override)
|
| 57 |
+
assert self.prior_box.image_size == x.shape[2:]
|
| 58 |
+
|
| 59 |
+
# make image into BGR
|
| 60 |
+
x_bgr = x.flip(1)
|
| 61 |
+
input_img = normalize_for_net(unnormalize(x_bgr))
|
| 62 |
+
|
| 63 |
+
result = self.net(input_img, self.prior_box)
|
| 64 |
+
batch_loc, batch_conf, batch_landms = result
|
| 65 |
+
batch_loc = torch.split(batch_loc, 1, dim=0)
|
| 66 |
+
batch_conf = torch.split(batch_conf, 1, dim=0)
|
| 67 |
+
batch_landms = torch.split(batch_landms, 1, dim=0)
|
| 68 |
+
|
| 69 |
+
nms_ldmks = []
|
| 70 |
+
nms_scores = []
|
| 71 |
+
nms_bbox = []
|
| 72 |
+
for loc, conf, landms, in zip(batch_loc, batch_conf, batch_landms):
|
| 73 |
+
dets = postprocess(self.prior_box, loc, conf, landms, confidence_threshold=0.0, nms_threshold=0.4)
|
| 74 |
+
bbox, score, ldmks = parse_one_det_result(dets)
|
| 75 |
+
ldmks = ldmks / np.array( [self.prior_box.image_size[0], self.prior_box.image_size[1]] * 5)
|
| 76 |
+
nms_ldmks.append(ldmks)
|
| 77 |
+
nms_scores.append(score)
|
| 78 |
+
nms_bbox.append(bbox)
|
| 79 |
+
|
| 80 |
+
orig_pred_ldmks = torch.from_numpy(np.array(nms_ldmks)).to(self.device).float()
|
| 81 |
+
score = torch.from_numpy(np.array(nms_scores)).to(self.device).float().unsqueeze(-1)
|
| 82 |
+
bbox = torch.from_numpy(np.array(nms_bbox)).to(self.device).float()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
reference_ldmk = aligner_helper.reference_landmark()
|
| 86 |
+
input_size = self.config.input_size
|
| 87 |
+
output_size = self.config.output_size
|
| 88 |
+
cv2_tfms = aligner_helper.get_cv2_affine_from_landmark(orig_pred_ldmks, reference_ldmk, input_size, input_size)
|
| 89 |
+
thetas = aligner_helper.cv2_param_to_torch_theta(cv2_tfms, input_size, input_size, output_size, output_size)
|
| 90 |
+
thetas = thetas.to(orig_pred_ldmks.device)
|
| 91 |
+
|
| 92 |
+
output_size = torch.Size((len(thetas), 3, output_size, output_size))
|
| 93 |
+
grid = F.affine_grid(thetas, output_size, align_corners=True)
|
| 94 |
+
aligned_x = F.grid_sample(x + 1, grid, align_corners=True) - 1 # +1, -1 for making padding pixel 0
|
| 95 |
+
aligned_ldmks = aligner_helper.adjust_ldmks(orig_pred_ldmks.view(-1, 5, 2), thetas)
|
| 96 |
+
|
| 97 |
+
orig_pred_ldmks = orig_pred_ldmks.view(-1, 5, 2)
|
| 98 |
+
# bbox (xmin, ymin, xmax, ymax)
|
| 99 |
+
normalized_bbox = bbox / torch.tensor([[input_img.size(3), input_img.size(2)] * 2]).to(bbox.device)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if padding_ratio_override is None:
|
| 103 |
+
padding_ratio = self.preprocessor.padding
|
| 104 |
+
else:
|
| 105 |
+
padding_ratio = padding_ratio_override
|
| 106 |
+
if padding_ratio > 0:
|
| 107 |
+
# unpad the landmark so that it is in the original image coordinate
|
| 108 |
+
scale = 1 / (1 + (2 * padding_ratio))
|
| 109 |
+
pad_inv_theta = torch.from_numpy(np.array([[1 / scale, 0, 0], [0, 1 / scale, 0]]))
|
| 110 |
+
pad_inv_theta = pad_inv_theta.unsqueeze(0).float().to(self.device).repeat(orig_pred_ldmks.size(0), 1, 1)
|
| 111 |
+
unpad_ldmk_pred = torch.concat([orig_pred_ldmks.view(-1, 5, 2),
|
| 112 |
+
torch.ones((orig_pred_ldmks.size(0), 5, 1)).to(self.device)], dim=-1)
|
| 113 |
+
unpad_ldmk_pred = (((unpad_ldmk_pred) * 2 - 1) @ pad_inv_theta.mT) / 2 + 0.5
|
| 114 |
+
unpad_ldmk_pred = unpad_ldmk_pred.view(orig_pred_ldmks.size(0), -1).detach()
|
| 115 |
+
unpad_ldmk_pred = unpad_ldmk_pred.view(-1, 5, 2)
|
| 116 |
+
if not is_square:
|
| 117 |
+
unpad_ldmk_pred = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 118 |
+
normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 119 |
+
return aligned_x, unpad_ldmk_pred, aligned_ldmks, score, thetas, normalized_bbox
|
| 120 |
+
|
| 121 |
+
if not is_square:
|
| 122 |
+
orig_pred_ldmks = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 123 |
+
normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input
|
| 124 |
+
return aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox
|
| 125 |
+
|
| 126 |
+
def make_train_transform(self):
|
| 127 |
+
transform = transforms.Compose([
|
| 128 |
+
transforms.ToTensor(),
|
| 129 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 130 |
+
])
|
| 131 |
+
return transform
|
| 132 |
+
|
| 133 |
+
def make_test_transform(self):
|
| 134 |
+
transform = transforms.Compose([
|
| 135 |
+
transforms.ToTensor(),
|
| 136 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
| 137 |
+
])
|
| 138 |
+
return transform
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def normalize(image):
|
| 142 |
+
image = image / 255.
|
| 143 |
+
image = (image - 0.5) / 0.5
|
| 144 |
+
return image
|
| 145 |
+
|
| 146 |
+
def unnormalize(image):
|
| 147 |
+
image = image * 0.5 + 0.5
|
| 148 |
+
image = image * 255.
|
| 149 |
+
return image
|
| 150 |
+
|
| 151 |
+
def normalize_for_net(bgr_image_0_255):
|
| 152 |
+
# bgr_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
| 153 |
+
return bgr_image_0_255 - torch.tensor([104, 117, 123])[None, :, None, None].to(bgr_image_0_255.device)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def postprocess(priorbox, loc, conf, landms, confidence_threshold, nms_threshold):
|
| 157 |
+
|
| 158 |
+
device = loc.device
|
| 159 |
+
im_height, im_width = priorbox.image_size
|
| 160 |
+
|
| 161 |
+
scale = torch.Tensor([im_width, im_height, im_width, im_height])
|
| 162 |
+
scale = scale.to(device)
|
| 163 |
+
|
| 164 |
+
boxes = priorbox.decode(loc.data.squeeze(0))
|
| 165 |
+
boxes = boxes * scale
|
| 166 |
+
boxes = boxes.cpu().numpy()
|
| 167 |
+
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
|
| 168 |
+
landms = priorbox.decode_landm(landms.data.squeeze(0))
|
| 169 |
+
scale1 = torch.Tensor([im_width, im_height, im_width, im_height,
|
| 170 |
+
im_width, im_height, im_width, im_height,
|
| 171 |
+
im_width, im_height])
|
| 172 |
+
scale1 = scale1.to(device)
|
| 173 |
+
landms = landms * scale1
|
| 174 |
+
landms = landms.cpu().numpy()
|
| 175 |
+
|
| 176 |
+
# ignore low scores
|
| 177 |
+
inds = np.where(scores > confidence_threshold)[0]
|
| 178 |
+
if len(inds) == 0:
|
| 179 |
+
inds = np.where(scores >= 0)[0]
|
| 180 |
+
boxes = boxes[inds]
|
| 181 |
+
landms = landms[inds]
|
| 182 |
+
scores = scores[inds]
|
| 183 |
+
|
| 184 |
+
# keep top-K before NMS
|
| 185 |
+
order = scores.argsort()[::-1]
|
| 186 |
+
# order = scores.argsort()[::-1][:args.top_k]
|
| 187 |
+
boxes = boxes[order]
|
| 188 |
+
landms = landms[order]
|
| 189 |
+
scores = scores[order]
|
| 190 |
+
|
| 191 |
+
# do NMS
|
| 192 |
+
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
| 193 |
+
keep = py_cpu_nms(dets, nms_threshold)
|
| 194 |
+
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
|
| 195 |
+
dets = dets[keep, :]
|
| 196 |
+
landms = landms[keep]
|
| 197 |
+
|
| 198 |
+
# keep top-K faster NMS
|
| 199 |
+
# dets = dets[:args.keep_top_k, :]
|
| 200 |
+
# landms = landms[:args.keep_top_k, :]
|
| 201 |
+
|
| 202 |
+
dets = np.concatenate((dets, landms), axis=1)
|
| 203 |
+
return dets
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def py_cpu_nms(dets,
|
| 207 |
+
thresh):
|
| 208 |
+
"""
|
| 209 |
+
Pure Python NMS baseline.
|
| 210 |
+
"""
|
| 211 |
+
x1 = dets[:, 0]
|
| 212 |
+
y1 = dets[:, 1]
|
| 213 |
+
x2 = dets[:, 2]
|
| 214 |
+
y2 = dets[:, 3]
|
| 215 |
+
scores = dets[:, 4]
|
| 216 |
+
|
| 217 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 218 |
+
order = scores.argsort()[::-1]
|
| 219 |
+
|
| 220 |
+
keep = []
|
| 221 |
+
while order.size > 0:
|
| 222 |
+
i = order[0]
|
| 223 |
+
keep.append(i)
|
| 224 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
| 225 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
| 226 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
| 227 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
| 228 |
+
|
| 229 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
| 230 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
| 231 |
+
inter = w * h
|
| 232 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
| 233 |
+
|
| 234 |
+
inds = np.where(ovr <= thresh)[0]
|
| 235 |
+
order = order[inds + 1]
|
| 236 |
+
|
| 237 |
+
return keep
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def parse_one_det_result(dets):
|
| 241 |
+
dets_sorted = dets[dets[:, 4].argsort()[::-1]]
|
| 242 |
+
result = dets_sorted[0]
|
| 243 |
+
bbox = result[:4]
|
| 244 |
+
score = result[4]
|
| 245 |
+
ldmks = result[5:]
|
| 246 |
+
return bbox, score, ldmks
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/aligner_helper.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
from skimage import transform as trans
|
| 5 |
+
import cv2
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def split_network_output(align_out):
|
| 9 |
+
anchor_bbox_pred, anchor_cls_pred, anchor_ldmk_pred, merged, _ = align_out
|
| 10 |
+
bbox, cls, ldmk = torch.split(merged, [4, 2, 10], dim=1)
|
| 11 |
+
return ldmk, bbox, cls
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_cv2_affine_from_landmark(ldmks, reference_ldmk, image_width, image_height, ):
|
| 15 |
+
assert ldmks.ndim == 2 # batchdim
|
| 16 |
+
assert ldmks.shape[1] == 10
|
| 17 |
+
assert isinstance(ldmks, torch.Tensor)
|
| 18 |
+
|
| 19 |
+
assert reference_ldmk.ndim == 2
|
| 20 |
+
assert reference_ldmk.shape[0] == 5
|
| 21 |
+
assert reference_ldmk.shape[1] == 2
|
| 22 |
+
assert isinstance(reference_ldmk, np.ndarray)
|
| 23 |
+
|
| 24 |
+
to_img_size = np.array([[[image_width, image_height]]])
|
| 25 |
+
ldmks = ldmks.view(ldmks.shape[0], 5, 2).detach().cpu().numpy()
|
| 26 |
+
ldmks = ldmks * to_img_size
|
| 27 |
+
transforms = []
|
| 28 |
+
for ldmk in ldmks:
|
| 29 |
+
tform = trans.SimilarityTransform()
|
| 30 |
+
tform.estimate(ldmk, reference_ldmk)
|
| 31 |
+
M = tform.params[0:2, :]
|
| 32 |
+
transforms.append(M)
|
| 33 |
+
transforms = np.stack(transforms, axis=0)
|
| 34 |
+
return transforms
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def cv2_param_to_torch_theta(cv2_tfms, image_width, image_height, output_width, output_height):
|
| 38 |
+
# https://github.com/wuneng/WarpAffine2GridSample
|
| 39 |
+
"""4.Affine Transformation Matrix to theta"""
|
| 40 |
+
assert cv2_tfms.ndim == 3 # N, 2, 3
|
| 41 |
+
assert cv2_tfms.shape[1] == 2
|
| 42 |
+
assert cv2_tfms.shape[2] == 3
|
| 43 |
+
|
| 44 |
+
srcs = np.array([[0, 0], [0, 1], [1, 1]], dtype=np.float32)
|
| 45 |
+
srcs = np.expand_dims(srcs, axis=0).repeat(cv2_tfms.shape[0], axis=0)
|
| 46 |
+
dsts = np.matmul(srcs, cv2_tfms[:, :, :2].transpose(0, 2, 1)) + cv2_tfms[:, :, 2:3].transpose(0, 2, 1)
|
| 47 |
+
|
| 48 |
+
# normalize to [-1, 1]
|
| 49 |
+
srcs = srcs / np.array([[[image_width, image_height]]]) * 2 - 1
|
| 50 |
+
dsts = dsts / np.array([[[output_width, output_height]]]) * 2 - 1
|
| 51 |
+
|
| 52 |
+
thetas = []
|
| 53 |
+
for src, dst in zip(srcs, dsts):
|
| 54 |
+
theta = trans.estimate_transform("affine", src=dst, dst=src).params[:2]
|
| 55 |
+
thetas.append(theta)
|
| 56 |
+
thetas = np.stack(thetas, axis=0)
|
| 57 |
+
thetas = torch.from_numpy(thetas).float()
|
| 58 |
+
return thetas
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def adjust_ldmks(ldmks, thetas):
|
| 62 |
+
inv_thetas = inv_matrix(thetas).to(ldmks.device).float()
|
| 63 |
+
_ldmks = torch.cat([ldmks, torch.ones((ldmks.shape[0], 5, 1)).to(ldmks.device)], dim=2)
|
| 64 |
+
ldmk_aligned = (((_ldmks) * 2 - 1) @ inv_thetas.permute(0,2,1)) / 2 + 0.5
|
| 65 |
+
return ldmk_aligned
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def inv_matrix(theta):
|
| 69 |
+
# torch batched version
|
| 70 |
+
assert theta.ndim == 3
|
| 71 |
+
a, b, t1 = theta[:, 0,0], theta[:, 0,1], theta[:, 0,2]
|
| 72 |
+
c, d, t2 = theta[:, 1,0], theta[:, 1,1], theta[:, 1,2]
|
| 73 |
+
det = a * d - b * c
|
| 74 |
+
inv_det = 1.0 / det
|
| 75 |
+
inv_mat = torch.stack([
|
| 76 |
+
torch.stack([d * inv_det, -b * inv_det, (b * t2 - d * t1) * inv_det], dim=1),
|
| 77 |
+
torch.stack([-c * inv_det, a * inv_det, (c * t1 - a * t2) * inv_det], dim=1)
|
| 78 |
+
], dim=1)
|
| 79 |
+
return inv_mat
|
| 80 |
+
|
| 81 |
+
def reference_landmark():
|
| 82 |
+
return np.array([[38.29459953, 51.69630051],
|
| 83 |
+
[73.53179932, 51.50139999],
|
| 84 |
+
[56.02519989, 71.73660278],
|
| 85 |
+
[41.54930115, 92.3655014],
|
| 86 |
+
[70.72990036, 92.20410156]])
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def draw_ldmk(img, ldmk):
|
| 90 |
+
if ldmk is None:
|
| 91 |
+
return img
|
| 92 |
+
colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)]
|
| 93 |
+
img = img.copy()
|
| 94 |
+
for i in range(5):
|
| 95 |
+
color = colors[i]
|
| 96 |
+
cv2.circle(img, (int(ldmk[i*2] * img.shape[1]), int(ldmk[i*2+1] * img.shape[0])), 1, color, 4)
|
| 97 |
+
return img
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models.retinaface import RetinaFace
|
| 2 |
+
from .utils.model_utils import load_model
|
| 3 |
+
from .config import cfg_mnet, cfg_re50
|
| 4 |
+
from .layers.functions.prior_box import PriorBox
|
| 5 |
+
from .preprocessor import Preprocessor
|
| 6 |
+
|
| 7 |
+
def get_landmark_predictor(network='mobile0.25', input_size=160):
|
| 8 |
+
|
| 9 |
+
cfg = None
|
| 10 |
+
if network == "mobile0.25":
|
| 11 |
+
cfg = cfg_mnet
|
| 12 |
+
elif network == "resnet50":
|
| 13 |
+
cfg = cfg_re50
|
| 14 |
+
net = RetinaFace(cfg=cfg, phase = 'test')
|
| 15 |
+
priorbox = PriorBox(image_size=(input_size, input_size),
|
| 16 |
+
# min_sizes=[[64, 80], [96, 112], [128, 144]],
|
| 17 |
+
min_sizes=[[16, 32], [64, 128], [256, 512]],
|
| 18 |
+
steps=[8, 16, 32],
|
| 19 |
+
clip=False,
|
| 20 |
+
variances=[0.1, 0.2],)
|
| 21 |
+
|
| 22 |
+
# aligner = Aligner(net, priorbox, input_size, output_size=output_size)
|
| 23 |
+
# return aligner
|
| 24 |
+
return net, priorbox
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_preprocessor(output_size=160, padding=0.0, padding_val='zero'):
|
| 28 |
+
return Preprocessor(output_size=output_size, padding=padding, padding_val=padding_val)
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/config.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# config.py
|
| 2 |
+
|
| 3 |
+
cfg_mnet = {
|
| 4 |
+
'name': 'mobilenet0.25',
|
| 5 |
+
'pretrain': True,
|
| 6 |
+
'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
|
| 7 |
+
'in_channel': 32,
|
| 8 |
+
'out_channel': 64
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
cfg_re50 = {
|
| 12 |
+
'name': 'Resnet50',
|
| 13 |
+
'pretrain': True,
|
| 14 |
+
'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
|
| 15 |
+
'in_channel': 256,
|
| 16 |
+
'out_channel': 256
|
| 17 |
+
}
|
| 18 |
+
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .functions import *
|
| 2 |
+
from .modules import *
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/functions/prior_box.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from itertools import product as product
|
| 3 |
+
from math import ceil
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PriorBox(object):
|
| 7 |
+
|
| 8 |
+
def __init__(self,
|
| 9 |
+
image_size,
|
| 10 |
+
min_sizes=[[64, 80], [96, 112], [128, 144]],
|
| 11 |
+
steps=[8,16,32],
|
| 12 |
+
clip=False,
|
| 13 |
+
variances=[0.1, 0.2],
|
| 14 |
+
):
|
| 15 |
+
super(PriorBox, self).__init__()
|
| 16 |
+
self.min_sizes = min_sizes
|
| 17 |
+
self.steps = steps
|
| 18 |
+
self.clip = clip
|
| 19 |
+
self.variances = variances
|
| 20 |
+
self.image_size = image_size
|
| 21 |
+
self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps]
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
self.priors = self.forward()
|
| 24 |
+
|
| 25 |
+
def forward(self):
|
| 26 |
+
anchors = []
|
| 27 |
+
for k, f in enumerate(self.feature_maps):
|
| 28 |
+
min_sizes = self.min_sizes[k]
|
| 29 |
+
for i, j in product(range(f[0]), range(f[1])):
|
| 30 |
+
for min_size in min_sizes:
|
| 31 |
+
s_kx = min_size / self.image_size[1]
|
| 32 |
+
s_ky = min_size / self.image_size[0]
|
| 33 |
+
dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]]
|
| 34 |
+
dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]]
|
| 35 |
+
for cy, cx in product(dense_cy, dense_cx):
|
| 36 |
+
anchors += [cx, cy, s_kx, s_ky]
|
| 37 |
+
|
| 38 |
+
# back to torch land
|
| 39 |
+
output = torch.Tensor(anchors).view(-1, 4)
|
| 40 |
+
# import pandas as pd
|
| 41 |
+
# pd.DataFrame(output.numpy()).to_csv('/mckim/temp/temp.csv')
|
| 42 |
+
if self.clip:
|
| 43 |
+
output.clamp_(max=1, min=0)
|
| 44 |
+
return output
|
| 45 |
+
|
| 46 |
+
def encode(self, matched):
|
| 47 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
| 48 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
| 49 |
+
"""
|
| 50 |
+
self.priors = self.priors.to(matched.device)
|
| 51 |
+
|
| 52 |
+
# dist b/t match center and prior's center
|
| 53 |
+
g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - self.priors[:, :2]
|
| 54 |
+
# encode variance
|
| 55 |
+
g_cxcy /= (self.variances[0] * self.priors[:, 2:])
|
| 56 |
+
# match wh / prior wh
|
| 57 |
+
g_wh = (matched[:, 2:] - matched[:, :2]) / self.priors[:, 2:]
|
| 58 |
+
g_wh = torch.log(g_wh) / self.variances[1]
|
| 59 |
+
# return target for smooth_l1_loss
|
| 60 |
+
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
| 61 |
+
|
| 62 |
+
def encode_landm(self, matched):
|
| 63 |
+
"""Encode the variances from the priorbox layers into the ground truth boxes
|
| 64 |
+
we have matched (based on jaccard overlap) with the prior boxes.
|
| 65 |
+
"""
|
| 66 |
+
self.priors = self.priors.to(matched.device)
|
| 67 |
+
|
| 68 |
+
# dist b/t match center and prior's center
|
| 69 |
+
matched = torch.reshape(matched, (matched.size(0), 5, 2))
|
| 70 |
+
priors_cx = self.priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 71 |
+
priors_cy = self.priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 72 |
+
priors_w = self.priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 73 |
+
priors_h = self.priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2)
|
| 74 |
+
priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2)
|
| 75 |
+
g_cxcy = matched[:, :, :2] - priors[:, :, :2]
|
| 76 |
+
# encode variance
|
| 77 |
+
g_cxcy /= (self.variances[0] * priors[:, :, 2:])
|
| 78 |
+
# g_cxcy /= priors[:, :, 2:]
|
| 79 |
+
g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1)
|
| 80 |
+
# return target for smooth_l1_loss
|
| 81 |
+
return g_cxcy
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Adapted from https://github.com/Hakuyume/chainer-ssd
|
| 85 |
+
def decode(self, loc):
|
| 86 |
+
"""Decode locations from predictions using priors to undo
|
| 87 |
+
the encoding we did for offset regression at train time.
|
| 88 |
+
"""
|
| 89 |
+
self.priors = self.priors.to(loc.device)
|
| 90 |
+
|
| 91 |
+
boxes = torch.cat((
|
| 92 |
+
self.priors[:, :2] + loc[:, :2] * self.variances[0] * self.priors[:, 2:],
|
| 93 |
+
self.priors[:, 2:] * torch.exp(loc[:, 2:] * self.variances[1])), 1)
|
| 94 |
+
boxes[:, :2] -= boxes[:, 2:] / 2
|
| 95 |
+
boxes[:, 2:] += boxes[:, :2]
|
| 96 |
+
return boxes
|
| 97 |
+
|
| 98 |
+
def decode_landm(self, pre):
|
| 99 |
+
"""Decode landm from predictions using priors to undo
|
| 100 |
+
the encoding we did for offset regression at train time.
|
| 101 |
+
"""
|
| 102 |
+
self.priors = self.priors.to(pre.device)
|
| 103 |
+
landms = torch.cat((self.priors[:, :2] + pre[:, :2] * self.variances[0] * self.priors[:, 2:],
|
| 104 |
+
self.priors[:, :2] + pre[:, 2:4] * self.variances[0] * self.priors[:, 2:],
|
| 105 |
+
self.priors[:, :2] + pre[:, 4:6] * self.variances[0] * self.priors[:, 2:],
|
| 106 |
+
self.priors[:, :2] + pre[:, 6:8] * self.variances[0] * self.priors[:, 2:],
|
| 107 |
+
self.priors[:, :2] + pre[:, 8:10] * self.variances[0] * self.priors[:, 2:],
|
| 108 |
+
), dim=1)
|
| 109 |
+
return landms
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def decode_batch(self, loc):
|
| 113 |
+
"""Decode locations from predictions using priors to undo
|
| 114 |
+
the encoding we did for offset regression at train time.
|
| 115 |
+
"""
|
| 116 |
+
self.priors = self.priors.to(loc.device)
|
| 117 |
+
assert loc.ndim == 3
|
| 118 |
+
priors = self.priors.unsqueeze(0).expand(loc.size(0), -1, -1)
|
| 119 |
+
boxes = torch.cat((
|
| 120 |
+
priors[:, :, :2] + loc[:, :, :2] * self.variances[0] * priors[:, :, 2:],
|
| 121 |
+
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * self.variances[1])), -1)
|
| 122 |
+
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
| 123 |
+
boxes[:, :, 2:] += boxes[:, :, :2]
|
| 124 |
+
return boxes
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def decode_landm_batch(self, prediction):
|
| 128 |
+
"""Decode landm from prediction using priors to undo
|
| 129 |
+
the encoding we did for offset regression at train time.
|
| 130 |
+
"""
|
| 131 |
+
assert prediction.ndim == 3
|
| 132 |
+
self.priors = self.priors.to(prediction.device)
|
| 133 |
+
priors = self.priors.unsqueeze(0).expand(prediction.size(0), -1, -1)
|
| 134 |
+
landms = torch.cat((priors[:, :, :2] + prediction[:, :, :2] * self.variances[0] * priors[:, :, 2:],
|
| 135 |
+
priors[:, :, :2] + prediction[:, :, 2:4] * self.variances[0] * priors[:, :, 2:],
|
| 136 |
+
priors[:, :, :2] + prediction[:, :, 4:6] * self.variances[0] * priors[:, :, 2:],
|
| 137 |
+
priors[:, :, :2] + prediction[:, :, 6:8] * self.variances[0] * priors[:, :, 2:],
|
| 138 |
+
priors[:, :, :2] + prediction[:, :, 8:10] * self.variances[0] * priors[:, :, 2:],
|
| 139 |
+
), dim=-1)
|
| 140 |
+
return landms
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .multibox_loss import MultiBoxLoss
|
| 2 |
+
|
| 3 |
+
__all__ = ['MultiBoxLoss']
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/multibox_loss.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from ...utils.box_utils import match, log_sum_exp
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MultiBoxLoss(nn.Module):
|
| 8 |
+
"""SSD Weighted Loss Function
|
| 9 |
+
Compute Targets:
|
| 10 |
+
1) Produce Confidence Target Indices by matching ground truth boxes
|
| 11 |
+
with (default) 'priorboxes' that have jaccard index > threshold parameter
|
| 12 |
+
(default threshold: 0.5).
|
| 13 |
+
2) Produce localization target by 'encoding' variance into offsets of ground
|
| 14 |
+
truth boxes and their matched 'priorboxes'.
|
| 15 |
+
3) Hard negative mining to filter the excessive number of negative examples
|
| 16 |
+
that comes with using a large number of default bounding boxes.
|
| 17 |
+
(default negative:positive ratio 3:1)
|
| 18 |
+
Objective Loss:
|
| 19 |
+
L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
| 20 |
+
Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss
|
| 21 |
+
weighted by α which is set to 1 by cross val.
|
| 22 |
+
Args:
|
| 23 |
+
c: class confidences,
|
| 24 |
+
l: predicted boxes,
|
| 25 |
+
g: ground truth boxes
|
| 26 |
+
N: number of matched default boxes
|
| 27 |
+
See: https://arxiv.org/pdf/1512.02325.pdf for more details.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target):
|
| 31 |
+
super(MultiBoxLoss, self).__init__()
|
| 32 |
+
self.num_classes = num_classes
|
| 33 |
+
self.threshold = overlap_thresh
|
| 34 |
+
self.background_label = bkg_label
|
| 35 |
+
self.encode_target = encode_target
|
| 36 |
+
self.use_prior_for_matching = prior_for_matching
|
| 37 |
+
self.do_neg_mining = neg_mining
|
| 38 |
+
self.negpos_ratio = neg_pos
|
| 39 |
+
self.neg_overlap = neg_overlap
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def forward(self, predictions, priorbox, targets):
|
| 43 |
+
"""Multibox Loss
|
| 44 |
+
Args:
|
| 45 |
+
predictions (tuple): A tuple containing loc preds, conf preds,
|
| 46 |
+
and prior boxes from SSD net.
|
| 47 |
+
conf shape: torch.size(batch_size,num_priors,num_classes)
|
| 48 |
+
loc shape: torch.size(batch_size,num_priors,4)
|
| 49 |
+
priors shape: torch.size(num_priors,4)
|
| 50 |
+
|
| 51 |
+
ground_truth (tensor): Ground truth boxes and labels for a batch,
|
| 52 |
+
shape: [batch_size,num_objs,5] (last idx is the label).
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
loc_data, conf_data, landm_data, aggs, thetas = predictions
|
| 56 |
+
num = loc_data.size(0)
|
| 57 |
+
num_priors = (priorbox.priors.size(0))
|
| 58 |
+
|
| 59 |
+
if aggs is not None:
|
| 60 |
+
stacked_target = torch.stack(targets, dim=0).squeeze(1)
|
| 61 |
+
|
| 62 |
+
pos_idx = stacked_target[:, -1] > 0
|
| 63 |
+
agg_ldmk = aggs[:, 6:][pos_idx]
|
| 64 |
+
tgt_ldmk = stacked_target[:, 4:14][pos_idx]
|
| 65 |
+
agg_loss_landm = F.smooth_l1_loss(agg_ldmk, tgt_ldmk, reduction='sum') / len(tgt_ldmk)
|
| 66 |
+
|
| 67 |
+
pos_idx = stacked_target[:, -1] != 0
|
| 68 |
+
agg_bbox = aggs[:, :4][pos_idx]
|
| 69 |
+
tgt_bbox = stacked_target[:, :4][pos_idx]
|
| 70 |
+
agg_loss_box = F.smooth_l1_loss(agg_bbox, tgt_bbox, reduction='sum') / len(tgt_bbox)
|
| 71 |
+
|
| 72 |
+
agg_cls = aggs[:, 4:6]
|
| 73 |
+
tgt_cls = (stacked_target[:, -1] > 0).long()
|
| 74 |
+
agg_loss_cls = F.cross_entropy(agg_cls, tgt_cls, reduction='sum') / len(tgt_cls)
|
| 75 |
+
aux_loss_dict = {
|
| 76 |
+
'agg_loss_landm': agg_loss_landm,
|
| 77 |
+
'agg_loss_box': agg_loss_box,
|
| 78 |
+
'agg_loss_cls': agg_loss_cls
|
| 79 |
+
}
|
| 80 |
+
else:
|
| 81 |
+
aux_loss_dict = None
|
| 82 |
+
|
| 83 |
+
# match priors (default boxes) and ground truth boxes
|
| 84 |
+
loc_t = torch.Tensor(num, num_priors, 4)
|
| 85 |
+
landm_t = torch.Tensor(num, num_priors, 10)
|
| 86 |
+
conf_t = torch.LongTensor(num, num_priors)
|
| 87 |
+
for idx in range(num):
|
| 88 |
+
truths = targets[idx][:, :4].data
|
| 89 |
+
labels = targets[idx][:, -1].data
|
| 90 |
+
landms = targets[idx][:, 4:14].data
|
| 91 |
+
match(self.threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx)
|
| 92 |
+
|
| 93 |
+
loc_t = loc_t.cuda()
|
| 94 |
+
conf_t = conf_t.cuda()
|
| 95 |
+
landm_t = landm_t.cuda()
|
| 96 |
+
zeros = torch.tensor(0).cuda()
|
| 97 |
+
# landm Loss (Smooth L1)
|
| 98 |
+
# Shape: [batch,num_priors,10]
|
| 99 |
+
pos1 = conf_t > zeros
|
| 100 |
+
num_pos_landm = pos1.long().sum(1, keepdim=True)
|
| 101 |
+
N1 = max(num_pos_landm.data.sum().float(), 1)
|
| 102 |
+
pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data)
|
| 103 |
+
landm_p = landm_data[pos_idx1].view(-1, 10)
|
| 104 |
+
landm_t = landm_t[pos_idx1].view(-1, 10)
|
| 105 |
+
loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum')
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
pos = conf_t != zeros
|
| 109 |
+
conf_t[pos] = 1
|
| 110 |
+
|
| 111 |
+
# Localization Loss (Smooth L1)
|
| 112 |
+
# Shape: [batch,num_priors,4]
|
| 113 |
+
pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data)
|
| 114 |
+
loc_p = loc_data[pos_idx].view(-1, 4)
|
| 115 |
+
loc_t = loc_t[pos_idx].view(-1, 4)
|
| 116 |
+
loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum')
|
| 117 |
+
|
| 118 |
+
# Compute max conf across batch for hard negative mining
|
| 119 |
+
batch_conf = conf_data.view(-1, self.num_classes)
|
| 120 |
+
loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1))
|
| 121 |
+
|
| 122 |
+
# Hard Negative Mining
|
| 123 |
+
loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now
|
| 124 |
+
loss_c = loss_c.view(num, -1)
|
| 125 |
+
_, loss_idx = loss_c.sort(1, descending=True)
|
| 126 |
+
_, idx_rank = loss_idx.sort(1)
|
| 127 |
+
num_pos = pos.long().sum(1, keepdim=True)
|
| 128 |
+
num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1)
|
| 129 |
+
neg = idx_rank < num_neg.expand_as(idx_rank)
|
| 130 |
+
|
| 131 |
+
# Confidence Loss Including Positive and Negative Examples
|
| 132 |
+
pos_idx = pos.unsqueeze(2).expand_as(conf_data)
|
| 133 |
+
neg_idx = neg.unsqueeze(2).expand_as(conf_data)
|
| 134 |
+
conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes)
|
| 135 |
+
targets_weighted = conf_t[(pos+neg).gt(0)]
|
| 136 |
+
loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum')
|
| 137 |
+
|
| 138 |
+
# Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N
|
| 139 |
+
N = max(num_pos.data.sum().float(), 1)
|
| 140 |
+
loss_l /= N
|
| 141 |
+
loss_c /= N
|
| 142 |
+
loss_landm /= N1
|
| 143 |
+
|
| 144 |
+
return loss_l, loss_c, loss_landm, aux_loss_dict
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/__init__.py
ADDED
|
File without changes
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/net.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torchvision.models._utils as _utils
|
| 5 |
+
import torchvision.models as models
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torch.autograd import Variable
|
| 8 |
+
|
| 9 |
+
def conv_bn(inp, oup, stride = 1, leaky = 0):
|
| 10 |
+
return nn.Sequential(
|
| 11 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
| 12 |
+
nn.BatchNorm2d(oup),
|
| 13 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True)
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
def conv_bn_no_relu(inp, oup, stride):
|
| 17 |
+
return nn.Sequential(
|
| 18 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
| 19 |
+
nn.BatchNorm2d(oup),
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
def conv_bn1X1(inp, oup, stride, leaky=0):
|
| 23 |
+
return nn.Sequential(
|
| 24 |
+
nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False),
|
| 25 |
+
nn.BatchNorm2d(oup),
|
| 26 |
+
nn.LeakyReLU(negative_slope=leaky, inplace=True)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
def conv_dw(inp, oup, stride, leaky=0.1):
|
| 30 |
+
return nn.Sequential(
|
| 31 |
+
nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False),
|
| 32 |
+
nn.BatchNorm2d(inp),
|
| 33 |
+
nn.LeakyReLU(negative_slope= leaky,inplace=True),
|
| 34 |
+
|
| 35 |
+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
| 36 |
+
nn.BatchNorm2d(oup),
|
| 37 |
+
nn.LeakyReLU(negative_slope= leaky,inplace=True),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
class SSH(nn.Module):
|
| 41 |
+
def __init__(self, in_channel, out_channel):
|
| 42 |
+
super(SSH, self).__init__()
|
| 43 |
+
assert out_channel % 4 == 0
|
| 44 |
+
leaky = 0
|
| 45 |
+
if (out_channel <= 64):
|
| 46 |
+
leaky = 0.1
|
| 47 |
+
self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1)
|
| 48 |
+
|
| 49 |
+
self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky)
|
| 50 |
+
self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
|
| 51 |
+
|
| 52 |
+
self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky)
|
| 53 |
+
self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1)
|
| 54 |
+
|
| 55 |
+
def forward(self, input):
|
| 56 |
+
conv3X3 = self.conv3X3(input)
|
| 57 |
+
|
| 58 |
+
conv5X5_1 = self.conv5X5_1(input)
|
| 59 |
+
conv5X5 = self.conv5X5_2(conv5X5_1)
|
| 60 |
+
|
| 61 |
+
conv7X7_2 = self.conv7X7_2(conv5X5_1)
|
| 62 |
+
conv7X7 = self.conv7x7_3(conv7X7_2)
|
| 63 |
+
|
| 64 |
+
out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1)
|
| 65 |
+
out = F.relu(out)
|
| 66 |
+
return out
|
| 67 |
+
|
| 68 |
+
class FPN(nn.Module):
|
| 69 |
+
def __init__(self,in_channels_list,out_channels):
|
| 70 |
+
super(FPN,self).__init__()
|
| 71 |
+
leaky = 0
|
| 72 |
+
if (out_channels <= 64):
|
| 73 |
+
leaky = 0.1
|
| 74 |
+
self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky)
|
| 75 |
+
self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky)
|
| 76 |
+
self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky)
|
| 77 |
+
|
| 78 |
+
self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky)
|
| 79 |
+
self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky)
|
| 80 |
+
|
| 81 |
+
def forward(self, input):
|
| 82 |
+
# names = list(input.keys())
|
| 83 |
+
input = list(input.values())
|
| 84 |
+
|
| 85 |
+
output1 = self.output1(input[0])
|
| 86 |
+
output2 = self.output2(input[1])
|
| 87 |
+
output3 = self.output3(input[2])
|
| 88 |
+
|
| 89 |
+
up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest")
|
| 90 |
+
output2 = output2 + up3
|
| 91 |
+
output2 = self.merge2(output2)
|
| 92 |
+
|
| 93 |
+
up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest")
|
| 94 |
+
output1 = output1 + up2
|
| 95 |
+
output1 = self.merge1(output1)
|
| 96 |
+
|
| 97 |
+
out = [output1, output2, output3]
|
| 98 |
+
return out
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class MobileNetV1(nn.Module):
|
| 103 |
+
def __init__(self):
|
| 104 |
+
super(MobileNetV1, self).__init__()
|
| 105 |
+
self.stage1 = nn.Sequential(
|
| 106 |
+
conv_bn(3, 8, 2, leaky = 0.1), # 3
|
| 107 |
+
conv_dw(8, 16, 1), # 7
|
| 108 |
+
conv_dw(16, 32, 2), # 11
|
| 109 |
+
conv_dw(32, 32, 1), # 19
|
| 110 |
+
conv_dw(32, 64, 2), # 27
|
| 111 |
+
conv_dw(64, 64, 1), # 43
|
| 112 |
+
)
|
| 113 |
+
self.stage2 = nn.Sequential(
|
| 114 |
+
conv_dw(64, 128, 2), # 43 + 16 = 59
|
| 115 |
+
conv_dw(128, 128, 1), # 59 + 32 = 91
|
| 116 |
+
conv_dw(128, 128, 1), # 91 + 32 = 123
|
| 117 |
+
conv_dw(128, 128, 1), # 123 + 32 = 155
|
| 118 |
+
conv_dw(128, 128, 1), # 155 + 32 = 187
|
| 119 |
+
conv_dw(128, 128, 1), # 187 + 32 = 219
|
| 120 |
+
)
|
| 121 |
+
self.stage3 = nn.Sequential(
|
| 122 |
+
conv_dw(128, 256, 2), # 219 +3 2 = 241
|
| 123 |
+
conv_dw(256, 256, 1), # 241 + 64 = 301
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
x = self.stage1(x)
|
| 128 |
+
x = self.stage2(x)
|
| 129 |
+
x = self.stage3(x)
|
| 130 |
+
return x
|
| 131 |
+
|
| 132 |
+
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/retinaface.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torchvision.models._utils as _utils
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from .net import MobileNetV1 as MobileNetV1
|
| 6 |
+
from .net import FPN as FPN
|
| 7 |
+
from .net import SSH as SSH
|
| 8 |
+
|
| 9 |
+
from timm.models import mlp_mixer
|
| 10 |
+
|
| 11 |
+
class ClassHead(nn.Module):
|
| 12 |
+
def __init__(self,inchannels=512,num_anchors=3):
|
| 13 |
+
super(ClassHead,self).__init__()
|
| 14 |
+
self.num_anchors = num_anchors
|
| 15 |
+
self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0)
|
| 16 |
+
|
| 17 |
+
def forward(self,x):
|
| 18 |
+
out = self.conv1x1(x)
|
| 19 |
+
out = out.permute(0,2,3,1).contiguous()
|
| 20 |
+
|
| 21 |
+
return out.view(out.shape[0], -1, 2)
|
| 22 |
+
|
| 23 |
+
class BboxHead(nn.Module):
|
| 24 |
+
def __init__(self,inchannels=512,num_anchors=3):
|
| 25 |
+
super(BboxHead,self).__init__()
|
| 26 |
+
self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0)
|
| 27 |
+
|
| 28 |
+
def forward(self,x):
|
| 29 |
+
out = self.conv1x1(x)
|
| 30 |
+
out = out.permute(0,2,3,1).contiguous()
|
| 31 |
+
|
| 32 |
+
return out.view(out.shape[0], -1, 4)
|
| 33 |
+
|
| 34 |
+
class LandmarkHead(nn.Module):
|
| 35 |
+
def __init__(self,inchannels=512,num_anchors=3):
|
| 36 |
+
super(LandmarkHead,self).__init__()
|
| 37 |
+
self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0)
|
| 38 |
+
|
| 39 |
+
def forward(self,x):
|
| 40 |
+
out = self.conv1x1(x)
|
| 41 |
+
out = out.permute(0,2,3,1).contiguous()
|
| 42 |
+
|
| 43 |
+
return out.view(out.shape[0], -1, 10)
|
| 44 |
+
|
| 45 |
+
class RetinaFace(nn.Module):
|
| 46 |
+
def __init__(self, cfg = None, phase = 'train'):
|
| 47 |
+
"""
|
| 48 |
+
:param cfg: Network related settings.
|
| 49 |
+
:param phase: train or test.
|
| 50 |
+
"""
|
| 51 |
+
super(RetinaFace,self).__init__()
|
| 52 |
+
self.phase = phase
|
| 53 |
+
backbone = None
|
| 54 |
+
if cfg['name'] == 'mobilenet0.25':
|
| 55 |
+
backbone = MobileNetV1()
|
| 56 |
+
# if cfg['pretrain']:
|
| 57 |
+
# checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu'))
|
| 58 |
+
# from collections import OrderedDict
|
| 59 |
+
# new_state_dict = OrderedDict()
|
| 60 |
+
# for k, v in checkpoint['state_dict'].items():
|
| 61 |
+
# name = k[7:] # remove module.
|
| 62 |
+
# new_state_dict[name] = v
|
| 63 |
+
# load params
|
| 64 |
+
# backbone.load_state_dict(new_state_dict)
|
| 65 |
+
elif cfg['name'] == 'Resnet50':
|
| 66 |
+
import torchvision.models as models
|
| 67 |
+
backbone = models.resnet50(pretrained=cfg['pretrain'])
|
| 68 |
+
|
| 69 |
+
self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers'])
|
| 70 |
+
in_channels_stage2 = cfg['in_channel']
|
| 71 |
+
in_channels_list = [
|
| 72 |
+
in_channels_stage2 * 2,
|
| 73 |
+
in_channels_stage2 * 4,
|
| 74 |
+
in_channels_stage2 * 8,
|
| 75 |
+
]
|
| 76 |
+
out_channels = cfg['out_channel']
|
| 77 |
+
self.fpn = FPN(in_channels_list,out_channels)
|
| 78 |
+
self.ssh1 = SSH(out_channels, out_channels)
|
| 79 |
+
self.ssh2 = SSH(out_channels, out_channels)
|
| 80 |
+
self.ssh3 = SSH(out_channels, out_channels)
|
| 81 |
+
|
| 82 |
+
self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel'])
|
| 83 |
+
self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel'])
|
| 84 |
+
self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel'])
|
| 85 |
+
|
| 86 |
+
def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2):
|
| 87 |
+
classhead = nn.ModuleList()
|
| 88 |
+
for i in range(fpn_num):
|
| 89 |
+
classhead.append(ClassHead(inchannels,anchor_num))
|
| 90 |
+
return classhead
|
| 91 |
+
|
| 92 |
+
def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2):
|
| 93 |
+
bboxhead = nn.ModuleList()
|
| 94 |
+
for i in range(fpn_num):
|
| 95 |
+
bboxhead.append(BboxHead(inchannels,anchor_num))
|
| 96 |
+
return bboxhead
|
| 97 |
+
|
| 98 |
+
def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2):
|
| 99 |
+
landmarkhead = nn.ModuleList()
|
| 100 |
+
for i in range(fpn_num):
|
| 101 |
+
landmarkhead.append(LandmarkHead(inchannels,anchor_num))
|
| 102 |
+
return landmarkhead
|
| 103 |
+
|
| 104 |
+
def forward(self, inputs, priorbox=None):
|
| 105 |
+
out = self.body(inputs)
|
| 106 |
+
|
| 107 |
+
# FPN
|
| 108 |
+
fpn = self.fpn(out)
|
| 109 |
+
|
| 110 |
+
# SSH
|
| 111 |
+
feature1 = self.ssh1(fpn[0])
|
| 112 |
+
feature2 = self.ssh2(fpn[1])
|
| 113 |
+
feature3 = self.ssh3(fpn[2])
|
| 114 |
+
features = [feature1, feature2, feature3]
|
| 115 |
+
|
| 116 |
+
bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1)
|
| 117 |
+
classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1)
|
| 118 |
+
ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1)
|
| 119 |
+
if self.phase == 'train':
|
| 120 |
+
output = (bbox_regressions, classifications, ldm_regressions)
|
| 121 |
+
else:
|
| 122 |
+
output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions)
|
| 123 |
+
return output
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/preprocessor.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
class Preprocessor():
|
| 5 |
+
|
| 6 |
+
def __init__(self, output_size=160, padding=0.0, padding_val='zero'):
|
| 7 |
+
self.output_size = output_size
|
| 8 |
+
self.padding = padding
|
| 9 |
+
self.padding_val = padding_val
|
| 10 |
+
|
| 11 |
+
def preprocess_batched(self, imgs, padding_ratio_override=None):
|
| 12 |
+
|
| 13 |
+
# check img is of float
|
| 14 |
+
if imgs.dtype == torch.float32:
|
| 15 |
+
if self.padding_val == 'zero':
|
| 16 |
+
padding_val = -1.0
|
| 17 |
+
elif self.padding_val == 'mean':
|
| 18 |
+
padding_val = imgs.mean()
|
| 19 |
+
else:
|
| 20 |
+
raise ValueError('padding_val must be "zero" or "mean"')
|
| 21 |
+
elif imgs.dtype == torch.uint8:
|
| 22 |
+
if self.padding_val == 'zero':
|
| 23 |
+
padding_val = 0
|
| 24 |
+
elif self.padding_val == 'mean':
|
| 25 |
+
padding_val = imgs.mean()
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError('padding_val must be "zero" or "mean"')
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
|
| 30 |
+
|
| 31 |
+
square_imgs = self.make_square_img_batched(imgs, padding_val=padding_val)
|
| 32 |
+
|
| 33 |
+
if padding_ratio_override is not None:
|
| 34 |
+
padding = padding_ratio_override
|
| 35 |
+
else:
|
| 36 |
+
padding = self.padding
|
| 37 |
+
padded_imgs = self.make_padded_img_batched(square_imgs, padding=padding, padding_val=padding_val)
|
| 38 |
+
|
| 39 |
+
size=(self.output_size, self.output_size)
|
| 40 |
+
if imgs.dtype == torch.float32:
|
| 41 |
+
resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
|
| 42 |
+
elif imgs.dtype == torch.uint8:
|
| 43 |
+
padded_imgs = padded_imgs.to(torch.float32)
|
| 44 |
+
resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True)
|
| 45 |
+
resized_imgs = torch.clip(resized_imgs, 0, 255)
|
| 46 |
+
resized_imgs = resized_imgs.to(torch.uint8)
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError('imgs.dtype must be torch.float32 or torch.uint8')
|
| 49 |
+
return resized_imgs
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def make_square_img_batched(self, imgs, padding_val):
|
| 53 |
+
assert imgs.ndim == 4
|
| 54 |
+
# squarify the image
|
| 55 |
+
h, w = imgs.shape[2:]
|
| 56 |
+
if h > w:
|
| 57 |
+
diff = (h - w)
|
| 58 |
+
pad_left = diff // 2
|
| 59 |
+
pad_right = diff - pad_left
|
| 60 |
+
imgs = F.pad(imgs, (pad_left, pad_right, 0, 0), value=padding_val)
|
| 61 |
+
elif w > h:
|
| 62 |
+
diff = (w - h)
|
| 63 |
+
pad_top = diff // 2
|
| 64 |
+
pad_bottom = diff - pad_top
|
| 65 |
+
imgs = F.pad(imgs, (0, 0, pad_top, pad_bottom), value=padding_val)
|
| 66 |
+
assert imgs.shape[2] == imgs.shape[3]
|
| 67 |
+
return imgs
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def make_padded_img_batched(self, imgs, padding, padding_val):
|
| 71 |
+
if padding == 0:
|
| 72 |
+
return imgs
|
| 73 |
+
assert imgs.ndim == 4
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# pad the image
|
| 77 |
+
h, w = imgs.shape[2:]
|
| 78 |
+
pad_h = int(h * padding)
|
| 79 |
+
pad_w = int(w * padding)
|
| 80 |
+
imgs = F.pad(imgs, (pad_w, pad_w, pad_h, pad_h), value=padding_val)
|
| 81 |
+
return imgs
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def __call__(self, input, padding_ratio_override=None):
|
| 85 |
+
if input.ndim == 3:
|
| 86 |
+
assert input.shape[0] == 3
|
| 87 |
+
batch_input = input.unsqueeze(0)
|
| 88 |
+
return self.preprocess_batched(batch_input, padding_ratio_override=padding_ratio_override)[0]
|
| 89 |
+
elif input.ndim == 4:
|
| 90 |
+
assert input.shape[1] == 3
|
| 91 |
+
return self.preprocess_batched(input, padding_ratio_override=padding_ratio_override)
|
| 92 |
+
else:
|
| 93 |
+
raise ValueError(f'Invalid input shape: {input.shape}')
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/box_utils.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def point_form(boxes):
|
| 6 |
+
""" Convert prior_boxes to (xmin, ymin, xmax, ymax)
|
| 7 |
+
representation for comparison to point form ground truth data.
|
| 8 |
+
Args:
|
| 9 |
+
boxes: (tensor) center-size default boxes from priorbox layers.
|
| 10 |
+
Return:
|
| 11 |
+
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
| 12 |
+
"""
|
| 13 |
+
return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin
|
| 14 |
+
boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def center_size(boxes):
|
| 18 |
+
""" Convert prior_boxes to (cx, cy, w, h)
|
| 19 |
+
representation for comparison to center-size form ground truth data.
|
| 20 |
+
Args:
|
| 21 |
+
boxes: (tensor) point_form boxes
|
| 22 |
+
Return:
|
| 23 |
+
boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes.
|
| 24 |
+
"""
|
| 25 |
+
return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy
|
| 26 |
+
boxes[:, 2:] - boxes[:, :2], 1) # w, h
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def intersect(box_a, box_b):
|
| 30 |
+
""" We resize both tensors to [A,B,2] without new malloc:
|
| 31 |
+
[A,2] -> [A,1,2] -> [A,B,2]
|
| 32 |
+
[B,2] -> [1,B,2] -> [A,B,2]
|
| 33 |
+
Then we compute the area of intersect between box_a and box_b.
|
| 34 |
+
Args:
|
| 35 |
+
box_a: (tensor) bounding boxes, Shape: [A,4].
|
| 36 |
+
box_b: (tensor) bounding boxes, Shape: [B,4].
|
| 37 |
+
Return:
|
| 38 |
+
(tensor) intersection area, Shape: [A,B].
|
| 39 |
+
"""
|
| 40 |
+
A = box_a.size(0)
|
| 41 |
+
B = box_b.size(0)
|
| 42 |
+
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2),
|
| 43 |
+
box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
|
| 44 |
+
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2),
|
| 45 |
+
box_b[:, :2].unsqueeze(0).expand(A, B, 2))
|
| 46 |
+
inter = torch.clamp((max_xy - min_xy), min=0)
|
| 47 |
+
return inter[:, :, 0] * inter[:, :, 1]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def jaccard(box_a, box_b):
|
| 51 |
+
"""Compute the jaccard overlap of two sets of boxes. The jaccard overlap
|
| 52 |
+
is simply the intersection over union of two boxes. Here we operate on
|
| 53 |
+
ground truth boxes and default boxes.
|
| 54 |
+
E.g.:
|
| 55 |
+
A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
|
| 56 |
+
Args:
|
| 57 |
+
box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4]
|
| 58 |
+
box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4]
|
| 59 |
+
Return:
|
| 60 |
+
jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)]
|
| 61 |
+
"""
|
| 62 |
+
inter = intersect(box_a, box_b)
|
| 63 |
+
area_a = ((box_a[:, 2]-box_a[:, 0]) *
|
| 64 |
+
(box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
|
| 65 |
+
area_b = ((box_b[:, 2]-box_b[:, 0]) *
|
| 66 |
+
(box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
|
| 67 |
+
union = area_a + area_b - inter
|
| 68 |
+
return inter / union # [A,B]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def matrix_iou(a, b):
|
| 72 |
+
"""
|
| 73 |
+
return iou of a and b, numpy version for data augenmentation
|
| 74 |
+
"""
|
| 75 |
+
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
| 76 |
+
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
| 77 |
+
|
| 78 |
+
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
| 79 |
+
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
| 80 |
+
area_b = np.prod(b[:, 2:] - b[:, :2], axis=1)
|
| 81 |
+
return area_i / (area_a[:, np.newaxis] + area_b - area_i)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def matrix_iof(a, b):
|
| 85 |
+
"""
|
| 86 |
+
return iof of a and b, numpy version for data augenmentation
|
| 87 |
+
"""
|
| 88 |
+
lt = np.maximum(a[:, np.newaxis, :2], b[:, :2])
|
| 89 |
+
rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:])
|
| 90 |
+
|
| 91 |
+
area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2)
|
| 92 |
+
area_a = np.prod(a[:, 2:] - a[:, :2], axis=1)
|
| 93 |
+
return area_i / np.maximum(area_a[:, np.newaxis], 1)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def match(threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx):
|
| 97 |
+
"""Match each prior box with the ground truth box of the highest jaccard
|
| 98 |
+
overlap, encode the bounding boxes, then return the matched indices
|
| 99 |
+
corresponding to both confidence and location preds.
|
| 100 |
+
Args:
|
| 101 |
+
threshold: (float) The overlap threshold used when mathing boxes.
|
| 102 |
+
truths: (tensor) Ground truth boxes, Shape: [num_obj, 4].
|
| 103 |
+
priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4].
|
| 104 |
+
variances: (tensor) Variances corresponding to each prior coord,
|
| 105 |
+
Shape: [num_priors, 4].
|
| 106 |
+
labels: (tensor) All the class labels for the image, Shape: [num_obj].
|
| 107 |
+
landms: (tensor) Ground truth landms, Shape [num_obj, 10].
|
| 108 |
+
loc_t: (tensor) Tensor to be filled w/ endcoded location targets.
|
| 109 |
+
conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds.
|
| 110 |
+
landm_t: (tensor) Tensor to be filled w/ endcoded landm targets.
|
| 111 |
+
idx: (int) current batch index
|
| 112 |
+
Return:
|
| 113 |
+
The matched indices corresponding to 1)location 2)confidence 3)landm preds.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# jaccard index
|
| 118 |
+
overlaps = jaccard(
|
| 119 |
+
truths,
|
| 120 |
+
point_form(priorbox.priors)
|
| 121 |
+
)
|
| 122 |
+
# (Bipartite Matching)
|
| 123 |
+
# [1,num_objects] best prior for each ground truth
|
| 124 |
+
best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True)
|
| 125 |
+
|
| 126 |
+
# ignore hard gt
|
| 127 |
+
valid_gt_idx = best_prior_overlap[:, 0] >= 0.2
|
| 128 |
+
best_prior_idx_filter = best_prior_idx[valid_gt_idx, :]
|
| 129 |
+
if best_prior_idx_filter.shape[0] <= 0:
|
| 130 |
+
loc_t[idx] = 0
|
| 131 |
+
conf_t[idx] = 0
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
# [1,num_priors] best ground truth for each prior
|
| 135 |
+
best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True)
|
| 136 |
+
best_truth_idx.squeeze_(0)
|
| 137 |
+
best_truth_overlap.squeeze_(0)
|
| 138 |
+
best_prior_idx.squeeze_(1)
|
| 139 |
+
best_prior_idx_filter.squeeze_(1)
|
| 140 |
+
best_prior_overlap.squeeze_(1)
|
| 141 |
+
best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior
|
| 142 |
+
# TODO refactor: index best_prior_idx with long tensor
|
| 143 |
+
# ensure every gt matches with its prior of max overlap
|
| 144 |
+
for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes
|
| 145 |
+
best_truth_idx[best_prior_idx[j]] = j
|
| 146 |
+
matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来
|
| 147 |
+
conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来
|
| 148 |
+
conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本
|
| 149 |
+
loc = priorbox.encode(matches)
|
| 150 |
+
|
| 151 |
+
matches_landm = landms[best_truth_idx]
|
| 152 |
+
landm = priorbox.encode_landm(matches_landm)
|
| 153 |
+
loc_t[idx] = loc # [num_priors,4] encoded offsets to learn
|
| 154 |
+
conf_t[idx] = conf # [num_priors] top class label for each prior
|
| 155 |
+
landm_t[idx] = landm
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def log_sum_exp(x):
|
| 160 |
+
"""Utility function for computing log_sum_exp while determining
|
| 161 |
+
This will be used to determine unaveraged confidence loss across
|
| 162 |
+
all examples in a batch.
|
| 163 |
+
Args:
|
| 164 |
+
x (Variable(tensor)): conf_preds from conf layers
|
| 165 |
+
"""
|
| 166 |
+
x_max = x.data.max()
|
| 167 |
+
return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# Original author: Francisco Massa:
|
| 171 |
+
# https://github.com/fmassa/object-detection.torch
|
| 172 |
+
# Ported to PyTorch by Max deGroot (02/01/2017)
|
| 173 |
+
def nms(boxes, scores, overlap=0.5, top_k=200):
|
| 174 |
+
"""Apply non-maximum suppression at test time to avoid detecting too many
|
| 175 |
+
overlapping bounding boxes for a given object.
|
| 176 |
+
Args:
|
| 177 |
+
boxes: (tensor) The location preds for the img, Shape: [num_priors,4].
|
| 178 |
+
scores: (tensor) The class predscores for the img, Shape:[num_priors].
|
| 179 |
+
overlap: (float) The overlap thresh for suppressing unnecessary boxes.
|
| 180 |
+
top_k: (int) The Maximum number of box preds to consider.
|
| 181 |
+
Return:
|
| 182 |
+
The indices of the kept boxes with respect to num_priors.
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
keep = torch.Tensor(scores.size(0)).fill_(0).long()
|
| 186 |
+
if boxes.numel() == 0:
|
| 187 |
+
return keep
|
| 188 |
+
x1 = boxes[:, 0]
|
| 189 |
+
y1 = boxes[:, 1]
|
| 190 |
+
x2 = boxes[:, 2]
|
| 191 |
+
y2 = boxes[:, 3]
|
| 192 |
+
area = torch.mul(x2 - x1, y2 - y1)
|
| 193 |
+
v, idx = scores.sort(0) # sort in ascending order
|
| 194 |
+
# I = I[v >= 0.01]
|
| 195 |
+
idx = idx[-top_k:] # indices of the top-k largest vals
|
| 196 |
+
xx1 = boxes.new()
|
| 197 |
+
yy1 = boxes.new()
|
| 198 |
+
xx2 = boxes.new()
|
| 199 |
+
yy2 = boxes.new()
|
| 200 |
+
w = boxes.new()
|
| 201 |
+
h = boxes.new()
|
| 202 |
+
|
| 203 |
+
# keep = torch.Tensor()
|
| 204 |
+
count = 0
|
| 205 |
+
while idx.numel() > 0:
|
| 206 |
+
i = idx[-1] # index of current largest val
|
| 207 |
+
# keep.append(i)
|
| 208 |
+
keep[count] = i
|
| 209 |
+
count += 1
|
| 210 |
+
if idx.size(0) == 1:
|
| 211 |
+
break
|
| 212 |
+
idx = idx[:-1] # remove kept element from view
|
| 213 |
+
# load bboxes of next highest vals
|
| 214 |
+
torch.index_select(x1, 0, idx, out=xx1)
|
| 215 |
+
torch.index_select(y1, 0, idx, out=yy1)
|
| 216 |
+
torch.index_select(x2, 0, idx, out=xx2)
|
| 217 |
+
torch.index_select(y2, 0, idx, out=yy2)
|
| 218 |
+
# store element-wise max with next highest score
|
| 219 |
+
xx1 = torch.clamp(xx1, min=x1[i])
|
| 220 |
+
yy1 = torch.clamp(yy1, min=y1[i])
|
| 221 |
+
xx2 = torch.clamp(xx2, max=x2[i])
|
| 222 |
+
yy2 = torch.clamp(yy2, max=y2[i])
|
| 223 |
+
w.resize_as_(xx2)
|
| 224 |
+
h.resize_as_(yy2)
|
| 225 |
+
w = xx2 - xx1
|
| 226 |
+
h = yy2 - yy1
|
| 227 |
+
# check sizes of xx1 and xx2.. after each iteration
|
| 228 |
+
w = torch.clamp(w, min=0.0)
|
| 229 |
+
h = torch.clamp(h, min=0.0)
|
| 230 |
+
inter = w*h
|
| 231 |
+
# IoU = i / (area(a) + area(b) - i)
|
| 232 |
+
rem_areas = torch.index_select(area, 0, idx) # load remaining areas)
|
| 233 |
+
union = (rem_areas - inter) + area[i]
|
| 234 |
+
IoU = inter/union # store result in iou
|
| 235 |
+
# keep only elements with an IoU <= overlap
|
| 236 |
+
idx = idx[IoU.le(overlap)]
|
| 237 |
+
return keep, count
|
| 238 |
+
|
| 239 |
+
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/model_utils.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
def remove_prefix(state_dict, prefix):
|
| 4 |
+
''' Old style model is stored with all names of parameters sharing common prefix 'module.' '''
|
| 5 |
+
print('remove prefix \'{}\''.format(prefix))
|
| 6 |
+
f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x
|
| 7 |
+
return {f(key): value for key, value in state_dict.items()}
|
| 8 |
+
|
| 9 |
+
def check_keys(model, pretrained_state_dict):
|
| 10 |
+
ckpt_keys = set(pretrained_state_dict.keys())
|
| 11 |
+
model_keys = set(model.state_dict().keys())
|
| 12 |
+
used_pretrained_keys = model_keys & ckpt_keys
|
| 13 |
+
unused_pretrained_keys = ckpt_keys - model_keys
|
| 14 |
+
missing_keys = model_keys - ckpt_keys
|
| 15 |
+
print('Missing keys:{}'.format(len(missing_keys)))
|
| 16 |
+
print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys)))
|
| 17 |
+
print('Used keys:{}'.format(len(used_pretrained_keys)))
|
| 18 |
+
assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint'
|
| 19 |
+
return True
|
| 20 |
+
|
| 21 |
+
def load_model(model, pretrained_path, load_to_cpu):
|
| 22 |
+
print('Loading pretrained model from {}'.format(pretrained_path))
|
| 23 |
+
if load_to_cpu:
|
| 24 |
+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage)
|
| 25 |
+
else:
|
| 26 |
+
device = torch.cuda.current_device()
|
| 27 |
+
pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device))
|
| 28 |
+
if "state_dict" in pretrained_dict.keys():
|
| 29 |
+
pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.')
|
| 30 |
+
else:
|
| 31 |
+
pretrained_dict = remove_prefix(pretrained_dict, 'module.')
|
| 32 |
+
check_keys(model, pretrained_dict)
|
| 33 |
+
model.load_state_dict(pretrained_dict, strict=False)
|
| 34 |
+
return model
|
| 35 |
+
|
| 36 |
+
|
cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface_pipeline.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
from .retinaface.utils.model_utils import load_model
|
| 5 |
+
from .retinaface.layers.functions.prior_box import PriorBox
|
| 6 |
+
from .retinaface.models.retinaface import RetinaFace
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
cfg_mnet = {
|
| 11 |
+
'name': 'mobilenet0.25',
|
| 12 |
+
'gpu_train': True,
|
| 13 |
+
'ngpu': 1,
|
| 14 |
+
'epoch': 250,
|
| 15 |
+
'decay1': 190,
|
| 16 |
+
'decay2': 220,
|
| 17 |
+
# 'image_size': 640,
|
| 18 |
+
'pretrain': True,
|
| 19 |
+
'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3},
|
| 20 |
+
'in_channel': 32,
|
| 21 |
+
'out_channel': 64
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
cfg_re50 = {
|
| 26 |
+
'name': 'Resnet50',
|
| 27 |
+
'gpu_train': True,
|
| 28 |
+
'ngpu': 4,
|
| 29 |
+
'epoch': 100,
|
| 30 |
+
'decay1': 70,
|
| 31 |
+
'decay2': 90,
|
| 32 |
+
# 'image_size': 840,
|
| 33 |
+
'pretrain': True,
|
| 34 |
+
'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3},
|
| 35 |
+
'in_channel': 256,
|
| 36 |
+
'out_channel': 256
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_retinface_model(network='resnet50', trained_model_path=''):
|
| 41 |
+
cfg = None
|
| 42 |
+
if network == "mobile0.25":
|
| 43 |
+
cfg = cfg_mnet
|
| 44 |
+
elif network == "resnet50":
|
| 45 |
+
cfg = cfg_re50
|
| 46 |
+
# net and model
|
| 47 |
+
net = RetinaFace(cfg=cfg, phase = 'test')
|
| 48 |
+
net = load_model(net, trained_model_path, True)
|
| 49 |
+
net.eval()
|
| 50 |
+
# freeze grad
|
| 51 |
+
for param in net.parameters():
|
| 52 |
+
param.requires_grad = False
|
| 53 |
+
|
| 54 |
+
return net
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RetinaFacePipeline(torch.nn.Module):
|
| 58 |
+
|
| 59 |
+
def __init__(self, net, priorbox, input_size, device='cuda'):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.net = net
|
| 62 |
+
self.priorbox = priorbox
|
| 63 |
+
self.input_size = input_size
|
| 64 |
+
self.output_size = 112
|
| 65 |
+
self.device = device
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def normalize(self, image):
|
| 69 |
+
image = image / 255.
|
| 70 |
+
image = (image - 0.5) / 0.5
|
| 71 |
+
return image
|
| 72 |
+
|
| 73 |
+
def unnormalize(self, image):
|
| 74 |
+
image = image * 0.5 + 0.5
|
| 75 |
+
image = image * 255.
|
| 76 |
+
return image
|
| 77 |
+
|
| 78 |
+
def normalize_for_net(self, bgr_image_0_255):
|
| 79 |
+
# bgr_image = cv2.imread(image_path, cv2.IMREAD_COLOR)
|
| 80 |
+
return bgr_image_0_255 - torch.tensor([104, 117, 123])[None, :, None, None].to(self.device)
|
| 81 |
+
|
| 82 |
+
def prealign_preprocess(self, images, value=0.0):
|
| 83 |
+
# pad to input_size
|
| 84 |
+
assert isinstance(images, torch.Tensor)
|
| 85 |
+
assert images.ndim == 4 or images.ndim == 3
|
| 86 |
+
input_size = self.input_size
|
| 87 |
+
|
| 88 |
+
data_width = images.shape[-1]
|
| 89 |
+
data_height = images.shape[-2]
|
| 90 |
+
if data_width > input_size or data_height > input_size:
|
| 91 |
+
# image is biggert than the input size
|
| 92 |
+
# resize such that the larger side becomes the input_size without changing the aspect ratio
|
| 93 |
+
if data_width > data_height:
|
| 94 |
+
scale = input_size / data_width
|
| 95 |
+
else:
|
| 96 |
+
scale = input_size / data_height
|
| 97 |
+
if images.ndim == 4:
|
| 98 |
+
images = F.interpolate(input=images, scale_factor=scale,
|
| 99 |
+
mode='bilinear', align_corners=False)
|
| 100 |
+
else:
|
| 101 |
+
images = F.interpolate(input=images.unsqueeze(0), scale_factor=scale,
|
| 102 |
+
mode='bilinear', align_corners=False).squeeze(0)
|
| 103 |
+
|
| 104 |
+
data_width = images.shape[-1]
|
| 105 |
+
data_height = images.shape[-2]
|
| 106 |
+
padding_width1 = (input_size - data_width) // 2
|
| 107 |
+
padding_width2 = (input_size - data_width) - padding_width1
|
| 108 |
+
padding_height1 = (input_size - data_height) // 2
|
| 109 |
+
padding_height2 = (input_size - data_height) - padding_height1
|
| 110 |
+
|
| 111 |
+
result = torch.nn.functional.pad(input=images,
|
| 112 |
+
pad=(padding_width1, padding_width2,
|
| 113 |
+
padding_height1, padding_height2),
|
| 114 |
+
value=value)
|
| 115 |
+
assert result.shape[-1] == input_size
|
| 116 |
+
assert result.shape[-2] == input_size
|
| 117 |
+
return result
|
| 118 |
+
|
| 119 |
+
def forward(self, rgb_images):
|
| 120 |
+
|
| 121 |
+
# cv2.imwrite('/mckim/temp/temp.jpg', self.unnormalize(rgb_images[0]).cpu().numpy().transpose(1,2,0))
|
| 122 |
+
|
| 123 |
+
assert rgb_images.shape[1] == 3
|
| 124 |
+
assert rgb_images.ndim == 4
|
| 125 |
+
assert isinstance(rgb_images, torch.Tensor)
|
| 126 |
+
assert self.priorbox.image_size == rgb_images.shape[2:]
|
| 127 |
+
rgb_images = rgb_images.to(self.device)
|
| 128 |
+
|
| 129 |
+
# make image into BGR
|
| 130 |
+
bgr_images = rgb_images.flip(1)
|
| 131 |
+
input_img = self.normalize_for_net(self.unnormalize(bgr_images))
|
| 132 |
+
batch_loc, batch_conf, batch_landms = self.net(input_img)
|
| 133 |
+
batch_loc = torch.split(batch_loc, 1, dim=0)
|
| 134 |
+
batch_conf = torch.split(batch_conf, 1, dim=0)
|
| 135 |
+
batch_landms = torch.split(batch_landms, 1, dim=0)
|
| 136 |
+
|
| 137 |
+
all_ldmks = []
|
| 138 |
+
for loc, conf, landms, in zip(batch_loc, batch_conf, batch_landms):
|
| 139 |
+
dets = postprocess(self.priorbox, loc, conf, landms, confidence_threshold=0.0, nms_threshold=0.4)
|
| 140 |
+
bbox, score, ldmks = parse_one_det_result(dets)
|
| 141 |
+
ldmks = ldmks / np.array( [self.priorbox.image_size[0], self.priorbox.image_size[1]] * 5)
|
| 142 |
+
all_ldmks.append(ldmks)
|
| 143 |
+
all_ldmks = torch.from_numpy(np.array(all_ldmks)).to(self.device).float()
|
| 144 |
+
return all_ldmks
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def postprocess(priorbox, loc, conf, landms, confidence_threshold, nms_threshold):
|
| 148 |
+
|
| 149 |
+
device = loc.device
|
| 150 |
+
im_height, im_width = priorbox.image_size
|
| 151 |
+
|
| 152 |
+
scale = torch.Tensor([im_width, im_height, im_width, im_height])
|
| 153 |
+
scale = scale.to(device)
|
| 154 |
+
|
| 155 |
+
boxes = priorbox.decode(loc.data.squeeze(0))
|
| 156 |
+
boxes = boxes * scale
|
| 157 |
+
boxes = boxes.cpu().numpy()
|
| 158 |
+
scores = conf.squeeze(0).data.cpu().numpy()[:, 1]
|
| 159 |
+
landms = priorbox.decode_landm(landms.data.squeeze(0))
|
| 160 |
+
scale1 = torch.Tensor([im_width, im_height, im_width, im_height,
|
| 161 |
+
im_width, im_height, im_width, im_height,
|
| 162 |
+
im_width, im_height])
|
| 163 |
+
scale1 = scale1.to(device)
|
| 164 |
+
landms = landms * scale1
|
| 165 |
+
landms = landms.cpu().numpy()
|
| 166 |
+
|
| 167 |
+
# ignore low scores
|
| 168 |
+
inds = np.where(scores > confidence_threshold)[0]
|
| 169 |
+
if len(inds) == 0:
|
| 170 |
+
inds = np.where(scores >= 0)[0]
|
| 171 |
+
boxes = boxes[inds]
|
| 172 |
+
landms = landms[inds]
|
| 173 |
+
scores = scores[inds]
|
| 174 |
+
|
| 175 |
+
# keep top-K before NMS
|
| 176 |
+
order = scores.argsort()[::-1]
|
| 177 |
+
# order = scores.argsort()[::-1][:args.top_k]
|
| 178 |
+
boxes = boxes[order]
|
| 179 |
+
landms = landms[order]
|
| 180 |
+
scores = scores[order]
|
| 181 |
+
|
| 182 |
+
# do NMS
|
| 183 |
+
dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False)
|
| 184 |
+
keep = py_cpu_nms(dets, nms_threshold)
|
| 185 |
+
# keep = nms(dets, args.nms_threshold,force_cpu=args.cpu)
|
| 186 |
+
dets = dets[keep, :]
|
| 187 |
+
landms = landms[keep]
|
| 188 |
+
|
| 189 |
+
# keep top-K faster NMS
|
| 190 |
+
# dets = dets[:args.keep_top_k, :]
|
| 191 |
+
# landms = landms[:args.keep_top_k, :]
|
| 192 |
+
|
| 193 |
+
dets = np.concatenate((dets, landms), axis=1)
|
| 194 |
+
return dets
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def py_cpu_nms(dets, thresh):
|
| 198 |
+
"""Pure Python NMS baseline."""
|
| 199 |
+
x1 = dets[:, 0]
|
| 200 |
+
y1 = dets[:, 1]
|
| 201 |
+
x2 = dets[:, 2]
|
| 202 |
+
y2 = dets[:, 3]
|
| 203 |
+
scores = dets[:, 4]
|
| 204 |
+
|
| 205 |
+
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
| 206 |
+
order = scores.argsort()[::-1]
|
| 207 |
+
|
| 208 |
+
keep = []
|
| 209 |
+
while order.size > 0:
|
| 210 |
+
i = order[0]
|
| 211 |
+
keep.append(i)
|
| 212 |
+
xx1 = np.maximum(x1[i], x1[order[1:]])
|
| 213 |
+
yy1 = np.maximum(y1[i], y1[order[1:]])
|
| 214 |
+
xx2 = np.minimum(x2[i], x2[order[1:]])
|
| 215 |
+
yy2 = np.minimum(y2[i], y2[order[1:]])
|
| 216 |
+
|
| 217 |
+
w = np.maximum(0.0, xx2 - xx1 + 1)
|
| 218 |
+
h = np.maximum(0.0, yy2 - yy1 + 1)
|
| 219 |
+
inter = w * h
|
| 220 |
+
ovr = inter / (areas[i] + areas[order[1:]] - inter)
|
| 221 |
+
|
| 222 |
+
inds = np.where(ovr <= thresh)[0]
|
| 223 |
+
order = order[inds + 1]
|
| 224 |
+
|
| 225 |
+
return keep
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def parse_one_det_result(dets):
|
| 229 |
+
dets_sorted = dets[dets[:, 4].argsort()[::-1]]
|
| 230 |
+
result = dets_sorted[0]
|
| 231 |
+
bbox = result[:4]
|
| 232 |
+
score = result[4]
|
| 233 |
+
ldmks = result[5:]
|
| 234 |
+
return bbox, score, ldmks
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def load_retinaface_pipeline(network, trained_model_path, input_size, device):
|
| 238 |
+
net = load_retinface_model(network='resnet50', trained_model_path=trained_model_path)
|
| 239 |
+
net = net.to(device)
|
| 240 |
+
priorbox = PriorBox(image_size=(input_size, input_size),
|
| 241 |
+
min_sizes=[[16, 32], [64, 128], [256, 512]],
|
| 242 |
+
steps=[8,16,32], clip=False,
|
| 243 |
+
variances=[0.1, 0.2],
|
| 244 |
+
device=device)
|
| 245 |
+
pipeline = RetinaFacePipeline(net, priorbox, input_size, device=device)
|
| 246 |
+
pipeline.cuda()
|
| 247 |
+
return pipeline
|
cvlface/research/recognition/code/run_v1/base.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- trainers : configs/default
|
| 3 |
+
- optims : configs/cosine
|
| 4 |
+
- pefts: configs/none
|
| 5 |
+
- models: vit/configs/v1_small
|
| 6 |
+
- classifiers: configs/partial_fc
|
| 7 |
+
- aligners: configs/none
|
| 8 |
+
- dataset: configs/casia
|
| 9 |
+
- data_augs: configs/v7
|
| 10 |
+
- losses: configs/adaface
|
| 11 |
+
- pipelines: configs/train_model_cls
|
| 12 |
+
- evaluations: configs/base
|
cvlface/research/recognition/code/run_v1/classifiers/__init__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import partial_fc
|
| 2 |
+
from . import fc
|
| 3 |
+
|
| 4 |
+
def get_classifier(classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size):
|
| 5 |
+
|
| 6 |
+
if margin_loss_fn is None:
|
| 7 |
+
classifier = None
|
| 8 |
+
print("No margin loss function provided, classifier will not be created")
|
| 9 |
+
return classifier
|
| 10 |
+
|
| 11 |
+
if classifier_cfg.name == 'partial_fc':
|
| 12 |
+
classifier = partial_fc.PartialFCClassifier.from_config(classifier_cfg, margin_loss_fn,
|
| 13 |
+
model_cfg, num_classes,
|
| 14 |
+
rank, world_size)
|
| 15 |
+
elif classifier_cfg.name == 'fc':
|
| 16 |
+
classifier = fc.FCClassifier.from_config(classifier_cfg, margin_loss_fn,
|
| 17 |
+
model_cfg, num_classes,
|
| 18 |
+
rank, world_size)
|
| 19 |
+
|
| 20 |
+
else:
|
| 21 |
+
raise ValueError(f"Unknown classifier: {classifier_cfg.name}")
|
| 22 |
+
|
| 23 |
+
if classifier_cfg.start_from:
|
| 24 |
+
classifier.load_state_dict_from_path(classifier_cfg.start_from)
|
| 25 |
+
|
| 26 |
+
if classifier_cfg.freeze:
|
| 27 |
+
for param in classifier.parameters():
|
| 28 |
+
param.requires_grad = False
|
| 29 |
+
|
| 30 |
+
return classifier
|
| 31 |
+
|
cvlface/research/recognition/code/run_v1/classifiers/base/__init__.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Union
|
| 3 |
+
import torch
|
| 4 |
+
from torch import device
|
| 5 |
+
from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path
|
| 6 |
+
from general_utils.os_utils import natural_sort
|
| 7 |
+
|
| 8 |
+
class BaseClassifier(torch.nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(self, config=None):
|
| 11 |
+
super(BaseClassifier, self).__init__()
|
| 12 |
+
self.config = config
|
| 13 |
+
|
| 14 |
+
@classmethod
|
| 15 |
+
def from_config(cls, classifier_cfg, margin_loss_fn, model_cfg, dataset_cfg, rank, world_size) -> "BaseClassifier":
|
| 16 |
+
raise NotImplementedError('from_config must be implemented in subclass')
|
| 17 |
+
|
| 18 |
+
def forward(self, local_embeddings, local_labels):
|
| 19 |
+
raise NotImplementedError('from_config must be implemented in subclass')
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@property
|
| 23 |
+
def device(self) -> device:
|
| 24 |
+
return get_parameter_device(self)
|
| 25 |
+
|
| 26 |
+
@property
|
| 27 |
+
def dtype(self) -> torch.dtype:
|
| 28 |
+
return get_parameter_dtype(self)
|
| 29 |
+
|
| 30 |
+
def num_parameters(self, only_trainable: bool = False) -> int:
|
| 31 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
| 32 |
+
|
| 33 |
+
def has_trainable_params(self):
|
| 34 |
+
for param in self.parameters():
|
| 35 |
+
if param.requires_grad:
|
| 36 |
+
return True
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
def save_pretrained(
|
| 40 |
+
self,
|
| 41 |
+
save_dir: Union[str, os.PathLike],
|
| 42 |
+
name: str = 'model.pt',
|
| 43 |
+
rank: int = 0,
|
| 44 |
+
):
|
| 45 |
+
rank_added_name = os.path.splitext(name)[0] + f'_rank{rank}' + os.path.splitext(name)[1]
|
| 46 |
+
save_path = os.path.join(save_dir, rank_added_name)
|
| 47 |
+
save_state_dict_and_config(self.state_dict(), self.config, save_path)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_state_dict_from_path(self, pretrained_model_path):
|
| 51 |
+
|
| 52 |
+
save_dir = os.path.dirname(pretrained_model_path)
|
| 53 |
+
save_name = os.path.basename(pretrained_model_path)
|
| 54 |
+
rank_added_name = os.path.splitext(save_name)[0] + f'_rank{self.rank}' + os.path.splitext(save_name)[1]
|
| 55 |
+
pretrained_model_path = os.path.join(save_dir, rank_added_name)
|
| 56 |
+
|
| 57 |
+
all_partitions = [name for name in os.listdir(save_dir) if '_rank' in name and '.pt' in name]
|
| 58 |
+
all_partitions = natural_sort(all_partitions)
|
| 59 |
+
ckpt_worldsize = len(all_partitions)
|
| 60 |
+
|
| 61 |
+
if self.world_size != ckpt_worldsize:
|
| 62 |
+
# we need to redistribute the partialfc weights
|
| 63 |
+
part_ckpts = [torch.load(os.path.join(save_dir, name), map_location='cpu') for name in all_partitions]
|
| 64 |
+
total_ckpt_num_subjects = sum([ckpt['partial_fc.weight'].shape[0] for ckpt in part_ckpts])
|
| 65 |
+
assert total_ckpt_num_subjects - self.partial_fc.num_classes < 10, \
|
| 66 |
+
(f"total_ckpt_num_subjects: {total_ckpt_num_subjects}, "
|
| 67 |
+
f"self.partial_fc.num_classes: {self.partial_fc.num_classes}"
|
| 68 |
+
f"The number can be slightly different due to the last partition.")
|
| 69 |
+
|
| 70 |
+
combined_weight = torch.cat([ckpt['partial_fc.weight'] for ckpt in part_ckpts], dim=0)
|
| 71 |
+
state_dict = part_ckpts[0]
|
| 72 |
+
|
| 73 |
+
class_start = self.partial_fc.class_start
|
| 74 |
+
num_sample = self.partial_fc.num_local
|
| 75 |
+
sub_center = combined_weight[class_start:class_start + num_sample, :]
|
| 76 |
+
if sub_center.shape[0] != num_sample:
|
| 77 |
+
# append zero
|
| 78 |
+
extra_center = torch.zeros(num_sample - sub_center.shape[0], sub_center.shape[1],
|
| 79 |
+
device=self.device, dtype=self.dtype)
|
| 80 |
+
sub_center = torch.cat([sub_center, extra_center], dim=0)
|
| 81 |
+
state_dict['partial_fc.weight'] = sub_center
|
| 82 |
+
|
| 83 |
+
else:
|
| 84 |
+
state_dict = load_state_dict_from_path(pretrained_model_path)
|
| 85 |
+
|
| 86 |
+
result = self.load_state_dict(state_dict, strict=False)
|
| 87 |
+
print(result)
|
cvlface/research/recognition/code/run_v1/classifiers/base/utils.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
from typing import List, Optional, Tuple, Union
|
| 3 |
+
import safetensors
|
| 4 |
+
import torch
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from omegaconf import DictConfig, OmegaConf
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_parameter_device(parameter: torch.nn.Module):
|
| 12 |
+
try:
|
| 13 |
+
parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers())
|
| 14 |
+
return next(parameters_and_buffers).device
|
| 15 |
+
except StopIteration:
|
| 16 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
| 17 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
| 18 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 19 |
+
return tuples
|
| 20 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
| 21 |
+
first_tuple = next(gen)
|
| 22 |
+
return first_tuple[1].device
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_parameter_dtype(parameter: torch.nn.Module):
|
| 26 |
+
try:
|
| 27 |
+
params = tuple(parameter.parameters())
|
| 28 |
+
if len(params) > 0:
|
| 29 |
+
return params[0].dtype
|
| 30 |
+
|
| 31 |
+
buffers = tuple(parameter.buffers())
|
| 32 |
+
if len(buffers) > 0:
|
| 33 |
+
return buffers[0].dtype
|
| 34 |
+
|
| 35 |
+
except StopIteration:
|
| 36 |
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
| 37 |
+
|
| 38 |
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
| 39 |
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
| 40 |
+
return tuples
|
| 41 |
+
|
| 42 |
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
| 43 |
+
first_tuple = next(gen)
|
| 44 |
+
return first_tuple[1].dtype
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path:
|
| 48 |
+
path_obj = Path(save_path)
|
| 49 |
+
return path_obj.parent
|
| 50 |
+
|
| 51 |
+
def get_base_name(save_path: Union[str, os.PathLike]) -> str:
|
| 52 |
+
path_obj = Path(save_path)
|
| 53 |
+
return path_obj.name
|
| 54 |
+
|
| 55 |
+
def load_state_dict_from_path(path: Union[str, os.PathLike]):
|
| 56 |
+
# Load a state dict from a path.
|
| 57 |
+
if 'safetensors' in path:
|
| 58 |
+
state_dict = safetensors.torch.load_file(path)
|
| 59 |
+
else:
|
| 60 |
+
state_dict = torch.load(path, map_location="cpu")
|
| 61 |
+
return state_dict
|
| 62 |
+
|
| 63 |
+
def replace_extension(path, new_extension):
|
| 64 |
+
if not new_extension.startswith('.'):
|
| 65 |
+
new_extension = '.' + new_extension
|
| 66 |
+
return os.path.splitext(path)[0] + new_extension
|
| 67 |
+
|
| 68 |
+
def make_config_path(save_path):
|
| 69 |
+
config_path = replace_extension(save_path, '.yaml')
|
| 70 |
+
return config_path
|
| 71 |
+
|
| 72 |
+
def save_config(config, config_path):
|
| 73 |
+
assert isinstance(config, dict) or isinstance(config, DictConfig)
|
| 74 |
+
os.makedirs(get_parent_directory(config_path), exist_ok=True)
|
| 75 |
+
if isinstance(config, dict):
|
| 76 |
+
config = OmegaConf.create(config)
|
| 77 |
+
OmegaConf.save(config, config_path)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def save_state_dict_and_config(state_dict, config, save_path):
|
| 81 |
+
os.makedirs(get_parent_directory(save_path), exist_ok=True)
|
| 82 |
+
|
| 83 |
+
# save config dict
|
| 84 |
+
config_path = make_config_path(save_path)
|
| 85 |
+
save_config(config, config_path)
|
| 86 |
+
|
| 87 |
+
# Save the model
|
| 88 |
+
if 'safetensors' in save_path:
|
| 89 |
+
safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"})
|
| 90 |
+
else:
|
| 91 |
+
torch.save(state_dict, save_path)
|
cvlface/research/recognition/code/run_v1/classifiers/configs/fc.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 'fc'
|
| 2 |
+
sample_rate: 1.0
|
| 3 |
+
start_from: ''
|
| 4 |
+
freeze: False
|
cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 'partial_fc'
|
| 2 |
+
sample_rate: 1.0
|
| 3 |
+
start_from: ''
|
| 4 |
+
freeze: False
|
cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_freeze.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 'partial_fc'
|
| 2 |
+
sample_rate: 1.0
|
| 3 |
+
start_from: ''
|
| 4 |
+
freeze: True
|
cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 'partial_fc'
|
| 2 |
+
sample_rate: 0.1
|
| 3 |
+
start_from: ''
|
| 4 |
+
freeze: False
|
cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10_freeze.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: 'partial_fc'
|
| 2 |
+
sample_rate: 0.1
|
| 3 |
+
start_from: ''
|
| 4 |
+
freeze: True
|
cvlface/research/recognition/code/run_v1/classifiers/fc/__init__.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..base import BaseClassifier, load_state_dict_from_path
|
| 2 |
+
from .fc import FC
|
| 3 |
+
from typing import Union
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FCClassifier(BaseClassifier):
|
| 8 |
+
|
| 9 |
+
def __init__(self, classifier, config, rank, world_size):
|
| 10 |
+
super(FCClassifier, self).__init__()
|
| 11 |
+
self.classifier = classifier
|
| 12 |
+
self.config = config
|
| 13 |
+
self.rank = rank
|
| 14 |
+
self.world_size = world_size
|
| 15 |
+
self.apply_ddp = True
|
| 16 |
+
|
| 17 |
+
@classmethod
|
| 18 |
+
def from_config(cls, classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size):
|
| 19 |
+
if classifier_cfg.name == 'fc':
|
| 20 |
+
classifier = FC(
|
| 21 |
+
margin_loss=margin_loss_fn,
|
| 22 |
+
embedding_size=model_cfg.output_dim,
|
| 23 |
+
num_classes=num_classes,
|
| 24 |
+
)
|
| 25 |
+
else:
|
| 26 |
+
raise NotImplementedError
|
| 27 |
+
|
| 28 |
+
model = cls(classifier, classifier_cfg, rank, world_size)
|
| 29 |
+
model.eval()
|
| 30 |
+
return model
|
| 31 |
+
|
| 32 |
+
def forward(self, local_embeddings, local_labels):
|
| 33 |
+
loss = self.classifier(local_embeddings, local_labels)
|
| 34 |
+
return loss
|
| 35 |
+
|
| 36 |
+
def save_pretrained(
|
| 37 |
+
self,
|
| 38 |
+
save_dir: Union[str, os.PathLike],
|
| 39 |
+
name: str = 'classifier.pt',
|
| 40 |
+
rank: int = 0,
|
| 41 |
+
):
|
| 42 |
+
if rank == 0:
|
| 43 |
+
super().save_pretrained(save_dir, name, rank)
|
| 44 |
+
|
| 45 |
+
def load_state_dict_from_path(self, pretrained_model_path):
|
| 46 |
+
save_dir = os.path.dirname(pretrained_model_path)
|
| 47 |
+
save_name = os.path.basename(pretrained_model_path)
|
| 48 |
+
rank_added_name = os.path.splitext(save_name)[0] + f'_rank0' + os.path.splitext(save_name)[1]
|
| 49 |
+
pretrained_model_path = os.path.join(save_dir, rank_added_name)
|
| 50 |
+
|
| 51 |
+
state_dict = load_state_dict_from_path(pretrained_model_path)
|
| 52 |
+
result = self.load_state_dict(state_dict, strict=False)
|
| 53 |
+
print('classifier loading result', result)
|
| 54 |
+
|
| 55 |
+
|
cvlface/research/recognition/code/run_v1/classifiers/fc/fc.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
import torch
|
| 3 |
+
from torch import distributed
|
| 4 |
+
from torch.nn.functional import linear, normalize
|
| 5 |
+
from losses.margin_loss import CombinedMarginLoss
|
| 6 |
+
from losses.adaface import AdaFaceLoss
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class FC(torch.nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
margin_loss: Callable,
|
| 15 |
+
embedding_size: int,
|
| 16 |
+
num_classes: int,
|
| 17 |
+
):
|
| 18 |
+
super(FC, self).__init__()
|
| 19 |
+
|
| 20 |
+
self.cross_entropy = torch.nn.CrossEntropyLoss()
|
| 21 |
+
self.embedding_size = embedding_size
|
| 22 |
+
self.num_classes = num_classes
|
| 23 |
+
self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_classes, embedding_size)))
|
| 24 |
+
|
| 25 |
+
# margin_loss
|
| 26 |
+
if isinstance(margin_loss, Callable):
|
| 27 |
+
self.margin_softmax = margin_loss
|
| 28 |
+
if isinstance(margin_loss, AdaFaceLoss):
|
| 29 |
+
self.register_buffer('batch_mean', torch.ones(1)*(20))
|
| 30 |
+
self.register_buffer('batch_std', torch.ones(1)*100)
|
| 31 |
+
else:
|
| 32 |
+
raise
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def forward(
|
| 36 |
+
self,
|
| 37 |
+
local_embeddings: torch.Tensor,
|
| 38 |
+
local_labels: torch.Tensor,
|
| 39 |
+
):
|
| 40 |
+
|
| 41 |
+
embeddings = local_embeddings
|
| 42 |
+
labels = local_labels
|
| 43 |
+
weight = self.weight
|
| 44 |
+
|
| 45 |
+
norms = embeddings.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8)
|
| 46 |
+
norm_embeddings = embeddings / norms
|
| 47 |
+
|
| 48 |
+
norm_weight_activated = normalize(weight)
|
| 49 |
+
logits = linear(norm_embeddings, norm_weight_activated)
|
| 50 |
+
logits = logits.clamp(-1, 1)
|
| 51 |
+
|
| 52 |
+
if isinstance(self.margin_softmax, CombinedMarginLoss):
|
| 53 |
+
logits = self.margin_softmax(logits=logits, labels=labels)
|
| 54 |
+
elif isinstance(self.margin_softmax, AdaFaceLoss):
|
| 55 |
+
logits, batch_mean, batch_std = self.margin_softmax(logits=logits, labels=labels, norms=norms,
|
| 56 |
+
batch_mean=self.batch_mean,
|
| 57 |
+
batch_std=self.batch_std)
|
| 58 |
+
self.batch_mean.data = batch_mean.data
|
| 59 |
+
self.batch_std.data = batch_std.data
|
| 60 |
+
else:
|
| 61 |
+
raise ValueError('parital FC margin_softmax not supported type')
|
| 62 |
+
|
| 63 |
+
loss = self.cross_entropy(logits, labels)
|
| 64 |
+
return loss
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
cvlface/research/recognition/code/run_v1/classifiers/partial_fc/__init__.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..base import BaseClassifier
|
| 2 |
+
from .partial_fc import PartialFC_V2
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PartialFCClassifier(BaseClassifier):
|
| 6 |
+
|
| 7 |
+
def __init__(self, classifier, config, rank, world_size):
|
| 8 |
+
super(PartialFCClassifier, self).__init__()
|
| 9 |
+
self.partial_fc = classifier
|
| 10 |
+
self.config = config
|
| 11 |
+
self.rank = rank
|
| 12 |
+
self.world_size = world_size
|
| 13 |
+
self.apply_ddp = False
|
| 14 |
+
|
| 15 |
+
@classmethod
|
| 16 |
+
def from_config(cls, classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size):
|
| 17 |
+
if classifier_cfg.name == 'partial_fc':
|
| 18 |
+
classifier = PartialFC_V2(
|
| 19 |
+
rank=rank,
|
| 20 |
+
world_size=world_size,
|
| 21 |
+
margin_loss=margin_loss_fn,
|
| 22 |
+
embedding_size=model_cfg.output_dim,
|
| 23 |
+
num_classes=num_classes,
|
| 24 |
+
sample_rate=classifier_cfg.sample_rate,
|
| 25 |
+
)
|
| 26 |
+
else:
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
model = cls(classifier, classifier_cfg, rank, world_size)
|
| 30 |
+
model.eval()
|
| 31 |
+
return model
|
| 32 |
+
|
| 33 |
+
def forward(self, local_embeddings, local_labels):
|
| 34 |
+
loss = self.partial_fc(local_embeddings, local_labels)
|
| 35 |
+
return loss
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
|
cvlface/research/recognition/code/run_v1/classifiers/partial_fc/partial_fc.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
import torch
|
| 3 |
+
from torch import distributed
|
| 4 |
+
from torch.nn.functional import linear, normalize
|
| 5 |
+
from losses.margin_loss import CombinedMarginLoss
|
| 6 |
+
from losses.adaface import AdaFaceLoss
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class PartialFC_V2(torch.nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
https://arxiv.org/abs/2203.15565
|
| 13 |
+
A distributed sparsely updating variant of the FC layer, named Partial FC (PFC).
|
| 14 |
+
When sample rate less than 1, in each iteration, positive class centers and a random subset of
|
| 15 |
+
negative class centers are selected to compute the margin-based softmax loss, all class
|
| 16 |
+
centers are still maintained throughout the whole training process, but only a subset is
|
| 17 |
+
selected and updated in each iteration.
|
| 18 |
+
.. note::
|
| 19 |
+
When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1).
|
| 20 |
+
Example:
|
| 21 |
+
--------
|
| 22 |
+
>>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2)
|
| 23 |
+
>>> for img, labels in data_loader:
|
| 24 |
+
>>> embeddings = net(img)
|
| 25 |
+
>>> loss = module_pfc(embeddings, labels)
|
| 26 |
+
>>> loss.backward()
|
| 27 |
+
>>> optimizer.step()
|
| 28 |
+
"""
|
| 29 |
+
_version = 2
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
rank: int,
|
| 34 |
+
world_size: int,
|
| 35 |
+
margin_loss: Callable,
|
| 36 |
+
embedding_size: int,
|
| 37 |
+
num_classes: int,
|
| 38 |
+
sample_rate: float = 1.0,
|
| 39 |
+
):
|
| 40 |
+
"""
|
| 41 |
+
Paramenters:
|
| 42 |
+
-----------
|
| 43 |
+
embedding_size: int
|
| 44 |
+
The dimension of embedding, required
|
| 45 |
+
num_classes: int
|
| 46 |
+
Total number of classes, required
|
| 47 |
+
sample_rate: float
|
| 48 |
+
The rate of negative centers participating in the calculation, default is 1.0.
|
| 49 |
+
"""
|
| 50 |
+
super(PartialFC_V2, self).__init__()
|
| 51 |
+
assert (
|
| 52 |
+
distributed.is_initialized()
|
| 53 |
+
), "must initialize distributed before create this"
|
| 54 |
+
self.rank = rank
|
| 55 |
+
self.world_size = world_size
|
| 56 |
+
|
| 57 |
+
self.dist_cross_entropy = DistCrossEntropy()
|
| 58 |
+
self.embedding_size = embedding_size
|
| 59 |
+
self.sample_rate: float = sample_rate
|
| 60 |
+
|
| 61 |
+
# make num_class divisible by self.world_size for ddp
|
| 62 |
+
_num_classes = num_classes // self.world_size * self.world_size
|
| 63 |
+
if _num_classes < num_classes:
|
| 64 |
+
_num_classes = _num_classes + self.world_size
|
| 65 |
+
num_classes = _num_classes
|
| 66 |
+
self.num_classes: int = num_classes
|
| 67 |
+
|
| 68 |
+
self.num_local: int = num_classes // self.world_size + int(
|
| 69 |
+
self.rank < num_classes % self.world_size
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# for i in range(8):
|
| 73 |
+
# num_local = (num_classes // self.world_size + int( i < num_classes % self.world_size ))
|
| 74 |
+
# class_start = num_classes // self.world_size * i + min( i, num_classes % self.world_size )
|
| 75 |
+
# print(num_local, class_start)
|
| 76 |
+
|
| 77 |
+
self.class_start: int = num_classes // self.world_size * self.rank + min(
|
| 78 |
+
self.rank, num_classes % self.world_size
|
| 79 |
+
)
|
| 80 |
+
self.num_sample: int = int(self.sample_rate * self.num_local)
|
| 81 |
+
self.last_batch_size: int = 0
|
| 82 |
+
|
| 83 |
+
self.is_updated: bool = True
|
| 84 |
+
self.init_weight_update: bool = True
|
| 85 |
+
self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size)))
|
| 86 |
+
|
| 87 |
+
# margin_loss
|
| 88 |
+
if isinstance(margin_loss, Callable):
|
| 89 |
+
self.margin_softmax = margin_loss
|
| 90 |
+
if isinstance(margin_loss, AdaFaceLoss):
|
| 91 |
+
self.register_buffer('batch_mean', torch.ones(1)*(20))
|
| 92 |
+
self.register_buffer('batch_std', torch.ones(1)*100)
|
| 93 |
+
else:
|
| 94 |
+
raise
|
| 95 |
+
|
| 96 |
+
def sample(self, labels, index_positive):
|
| 97 |
+
"""
|
| 98 |
+
This functions will change the value of labels
|
| 99 |
+
Parameters:
|
| 100 |
+
-----------
|
| 101 |
+
labels: torch.Tensor
|
| 102 |
+
pass
|
| 103 |
+
index_positive: torch.Tensor
|
| 104 |
+
pass
|
| 105 |
+
optimizer: torch.optim.Optimizer
|
| 106 |
+
pass
|
| 107 |
+
"""
|
| 108 |
+
with torch.no_grad():
|
| 109 |
+
positive = torch.unique(labels[index_positive], sorted=True).cuda()
|
| 110 |
+
if self.num_sample - positive.size(0) >= 0:
|
| 111 |
+
perm = torch.rand(size=[self.num_local]).cuda()
|
| 112 |
+
perm[positive] = 2.0
|
| 113 |
+
index = torch.topk(perm, k=self.num_sample)[1].cuda()
|
| 114 |
+
index = index.sort()[0].cuda()
|
| 115 |
+
else:
|
| 116 |
+
index = positive
|
| 117 |
+
self.weight_index = index
|
| 118 |
+
|
| 119 |
+
labels[index_positive] = torch.searchsorted(index, labels[index_positive])
|
| 120 |
+
|
| 121 |
+
return self.weight[self.weight_index]
|
| 122 |
+
|
| 123 |
+
def forward(
|
| 124 |
+
self,
|
| 125 |
+
local_embeddings: torch.Tensor,
|
| 126 |
+
local_labels: torch.Tensor,
|
| 127 |
+
):
|
| 128 |
+
"""
|
| 129 |
+
Parameters:
|
| 130 |
+
----------
|
| 131 |
+
local_embeddings: torch.Tensor
|
| 132 |
+
feature embeddings on each GPU(Rank).
|
| 133 |
+
local_labels: torch.Tensor
|
| 134 |
+
labels on each GPU(Rank).
|
| 135 |
+
Returns:
|
| 136 |
+
-------
|
| 137 |
+
loss: torch.Tensor
|
| 138 |
+
pass
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
local_labels.squeeze_()
|
| 142 |
+
local_labels = local_labels.long()
|
| 143 |
+
|
| 144 |
+
batch_size = local_embeddings.size(0)
|
| 145 |
+
if self.last_batch_size == 0:
|
| 146 |
+
self.last_batch_size = batch_size
|
| 147 |
+
assert self.last_batch_size == batch_size, (
|
| 148 |
+
f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}")
|
| 149 |
+
|
| 150 |
+
_gather_embeddings = [
|
| 151 |
+
torch.zeros((batch_size, self.embedding_size), dtype=local_embeddings.dtype, device=local_embeddings.device)
|
| 152 |
+
for _ in range(self.world_size)
|
| 153 |
+
]
|
| 154 |
+
_gather_labels = [
|
| 155 |
+
torch.zeros(batch_size).long().cuda() for _ in range(self.world_size)
|
| 156 |
+
]
|
| 157 |
+
_list_embeddings = AllGather(local_embeddings, *_gather_embeddings)
|
| 158 |
+
distributed.all_gather(_gather_labels, local_labels)
|
| 159 |
+
|
| 160 |
+
embeddings = torch.cat(_list_embeddings)
|
| 161 |
+
labels = torch.cat(_gather_labels)
|
| 162 |
+
|
| 163 |
+
labels = labels.view(-1, 1)
|
| 164 |
+
index_positive = (self.class_start <= labels) & (
|
| 165 |
+
labels < self.class_start + self.num_local
|
| 166 |
+
)
|
| 167 |
+
labels[~index_positive] = -1
|
| 168 |
+
labels[index_positive] -= self.class_start
|
| 169 |
+
|
| 170 |
+
if self.sample_rate < 1:
|
| 171 |
+
weight = self.sample(labels, index_positive)
|
| 172 |
+
else:
|
| 173 |
+
weight = self.weight
|
| 174 |
+
|
| 175 |
+
# with torch.cuda.amp.autocast(self.fp16):
|
| 176 |
+
norms = embeddings.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8)
|
| 177 |
+
norm_embeddings = embeddings / norms
|
| 178 |
+
|
| 179 |
+
norm_weight_activated = normalize(weight)
|
| 180 |
+
logits = linear(norm_embeddings, norm_weight_activated)
|
| 181 |
+
|
| 182 |
+
logits = logits.clamp(-1, 1)
|
| 183 |
+
|
| 184 |
+
if isinstance(self.margin_softmax, CombinedMarginLoss):
|
| 185 |
+
logits = self.margin_softmax(logits=logits, labels=labels)
|
| 186 |
+
elif isinstance(self.margin_softmax, AdaFaceLoss):
|
| 187 |
+
logits, batch_mean, batch_std = self.margin_softmax(logits=logits, labels=labels, norms=norms,
|
| 188 |
+
batch_mean=self.batch_mean,
|
| 189 |
+
batch_std=self.batch_std)
|
| 190 |
+
self.batch_mean.data = batch_mean.data
|
| 191 |
+
self.batch_std.data = batch_std.data
|
| 192 |
+
else:
|
| 193 |
+
raise ValueError('parital FC margin_softmax not supported type')
|
| 194 |
+
|
| 195 |
+
loss = self.dist_cross_entropy(logits, labels)
|
| 196 |
+
return loss
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class DistCrossEntropyFunc(torch.autograd.Function):
|
| 200 |
+
"""
|
| 201 |
+
CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax.
|
| 202 |
+
Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
@staticmethod
|
| 206 |
+
def forward(ctx, logits: torch.Tensor, label: torch.Tensor):
|
| 207 |
+
""" """
|
| 208 |
+
batch_size = logits.size(0)
|
| 209 |
+
# for numerical stability
|
| 210 |
+
max_logits, _ = torch.max(logits, dim=1, keepdim=True)
|
| 211 |
+
# local to global
|
| 212 |
+
distributed.all_reduce(max_logits, distributed.ReduceOp.MAX)
|
| 213 |
+
logits.sub_(max_logits)
|
| 214 |
+
logits.exp_()
|
| 215 |
+
sum_logits_exp = torch.sum(logits, dim=1, keepdim=True)
|
| 216 |
+
# local to global
|
| 217 |
+
distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM)
|
| 218 |
+
logits.div_(sum_logits_exp)
|
| 219 |
+
index = torch.where(label != -1)[0]
|
| 220 |
+
# loss
|
| 221 |
+
loss = torch.zeros(batch_size, 1, device=logits.device, dtype=logits.dtype)
|
| 222 |
+
loss[index] = logits[index].gather(1, label[index])
|
| 223 |
+
distributed.all_reduce(loss, distributed.ReduceOp.SUM)
|
| 224 |
+
ctx.save_for_backward(index, logits, label)
|
| 225 |
+
return loss.clamp_min_(1e-30).log_().mean() * (-1)
|
| 226 |
+
|
| 227 |
+
@staticmethod
|
| 228 |
+
def backward(ctx, loss_gradient):
|
| 229 |
+
"""
|
| 230 |
+
Args:
|
| 231 |
+
loss_grad (torch.Tensor): gradient backward by last layer
|
| 232 |
+
Returns:
|
| 233 |
+
gradients for each input in forward function
|
| 234 |
+
`None` gradients for one-hot label
|
| 235 |
+
"""
|
| 236 |
+
(
|
| 237 |
+
index,
|
| 238 |
+
logits,
|
| 239 |
+
label,
|
| 240 |
+
) = ctx.saved_tensors
|
| 241 |
+
batch_size = logits.size(0)
|
| 242 |
+
one_hot = torch.zeros(
|
| 243 |
+
size=[index.size(0), logits.size(1)], device=logits.device
|
| 244 |
+
)
|
| 245 |
+
one_hot.scatter_(1, label[index], 1)
|
| 246 |
+
logits[index] -= one_hot
|
| 247 |
+
logits.div_(batch_size)
|
| 248 |
+
return logits * loss_gradient.item(), None
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class DistCrossEntropy(torch.nn.Module):
|
| 252 |
+
def __init__(self):
|
| 253 |
+
super(DistCrossEntropy, self).__init__()
|
| 254 |
+
|
| 255 |
+
def forward(self, logit_part, label_part):
|
| 256 |
+
return DistCrossEntropyFunc.apply(logit_part, label_part)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class AllGatherFunc(torch.autograd.Function):
|
| 260 |
+
"""AllGather op with gradient backward"""
|
| 261 |
+
|
| 262 |
+
@staticmethod
|
| 263 |
+
def forward(ctx, tensor, *gather_list):
|
| 264 |
+
gather_list = list(gather_list)
|
| 265 |
+
distributed.all_gather(gather_list, tensor)
|
| 266 |
+
return tuple(gather_list)
|
| 267 |
+
|
| 268 |
+
@staticmethod
|
| 269 |
+
def backward(ctx, *grads):
|
| 270 |
+
grad_list = list(grads)
|
| 271 |
+
rank = distributed.get_rank()
|
| 272 |
+
grad_out = grad_list[rank]
|
| 273 |
+
|
| 274 |
+
dist_ops = [
|
| 275 |
+
distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True)
|
| 276 |
+
if i == rank
|
| 277 |
+
else distributed.reduce(
|
| 278 |
+
grad_list[i], i, distributed.ReduceOp.SUM, async_op=True
|
| 279 |
+
)
|
| 280 |
+
for i in range(distributed.get_world_size())
|
| 281 |
+
]
|
| 282 |
+
for _op in dist_ops:
|
| 283 |
+
_op.wait()
|
| 284 |
+
|
| 285 |
+
grad_out *= len(grad_list) # cooperate with distributed loss function
|
| 286 |
+
return (grad_out, *[None for _ in range(len(grad_list))])
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
AllGather = AllGatherFunc.apply
|