diff --git a/cvlface/research/recognition/code/run_v1/README.md b/cvlface/research/recognition/code/run_v1/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cvlface/research/recognition/code/run_v1/aligners/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e352f88158811ca0e4a8f7a73c45d9fe8cf7946 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/__init__.py @@ -0,0 +1,25 @@ +from .base import BaseAligner + + +def get_aligner(aligner_cfg): + + if aligner_cfg.name == 'none': + from .none import NoneAligner + aligner = NoneAligner.from_config(aligner_cfg) + elif aligner_cfg.name == 'retinaface_aligner': + from .retinaface_aligner import RetinaFaceAligner + aligner = RetinaFaceAligner.from_config(aligner_cfg) + elif aligner_cfg.name == 'differentiable_face_aligner': + from .differentiable_face_aligner import DifferentiableFaceAligner + aligner = DifferentiableFaceAligner.from_config(aligner_cfg) + else: + raise ValueError(f"Unknown classifier: {aligner_cfg.name}") + + if aligner_cfg.start_from: + aligner.load_state_dict_from_path(aligner_cfg.start_from) + + if aligner_cfg.freeze: + for param in aligner.parameters(): + param.requires_grad = False + return aligner + diff --git a/cvlface/research/recognition/code/run_v1/aligners/base/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a3168a59f144411fc0d8587388053b9f9a39f94 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/base/__init__.py @@ -0,0 +1,60 @@ +import os +from typing import Union +import torch +from torch import device +from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path + +class BaseAligner(torch.nn.Module): + + def __init__(self, config=None): + super().__init__() + self.config = config + + @classmethod + def from_config(cls, config) -> "BaseAligner": + raise NotImplementedError('from_config must be implemented in subclass') + + def make_train_transform(self): + raise NotImplementedError('from_config must be implemented in subclass') + + def make_test_transform(self): + raise NotImplementedError('from_config must be implemented in subclass') + + def forward(self, x): + raise NotImplementedError('from_config must be implemented in subclass') + + def save_pretrained( + self, + save_dir: Union[str, os.PathLike], + name: str = 'model.pt', + rank: int = 0, + ): + save_path = os.path.join(save_dir, name) + if rank == 0: + save_state_dict_and_config(self.state_dict(), self.config, save_path) + + def load_state_dict_from_path(self, pretrained_model_path): + state_dict = load_state_dict_from_path(pretrained_model_path) + result = self.load_state_dict(state_dict) + print(f"Loaded pretrained aligner from {pretrained_model_path}") + + + @property + def device(self) -> device: + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False) -> int: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + def has_trainable_params(self): + for param in self.parameters(): + if param.requires_grad: + return True + return False + + def has_params(self): + return len(list(self.parameters())) > 0 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/base/utils.py b/cvlface/research/recognition/code/run_v1/aligners/base/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cb9f6fd290ede2f2ced6959ddb57b8ba56fc9fdd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/base/utils.py @@ -0,0 +1,91 @@ +import itertools +from typing import List, Optional, Tuple, Union +import safetensors +import torch +from torch import Tensor +import os +from pathlib import Path +from omegaconf import DictConfig, OmegaConf + + +def get_parameter_device(parameter: torch.nn.Module): + try: + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + params = tuple(parameter.parameters()) + if len(params) > 0: + return params[0].dtype + + buffers = tuple(parameter.buffers()) + if len(buffers) > 0: + return buffers[0].dtype + + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path: + path_obj = Path(save_path) + return path_obj.parent + +def get_base_name(save_path: Union[str, os.PathLike]) -> str: + path_obj = Path(save_path) + return path_obj.name + +def load_state_dict_from_path(path: Union[str, os.PathLike]): + # Load a state dict from a path. + if 'safetensors' in path: + state_dict = safetensors.torch.load_file(path) + else: + state_dict = torch.load(path, map_location="cpu") + return state_dict + +def replace_extension(path, new_extension): + if not new_extension.startswith('.'): + new_extension = '.' + new_extension + return os.path.splitext(path)[0] + new_extension + +def make_config_path(save_path): + config_path = replace_extension(save_path, '.yaml') + return config_path + +def save_config(config, config_path): + assert isinstance(config, dict) or isinstance(config, DictConfig) + os.makedirs(get_parent_directory(config_path), exist_ok=True) + if isinstance(config, dict): + config = OmegaConf.create(config) + OmegaConf.save(config, config_path) + + +def save_state_dict_and_config(state_dict, config, save_path): + os.makedirs(get_parent_directory(save_path), exist_ok=True) + + # save config dict + config_path = make_config_path(save_path) + save_config(config, config_path) + + # Save the model + if 'safetensors' in save_path: + safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"}) + else: + torch.save(state_dict, save_path) diff --git a/cvlface/research/recognition/code/run_v1/aligners/configs/dfa.yaml b/cvlface/research/recognition/code/run_v1/aligners/configs/dfa.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e7e156b1233b1f86f744734f3e419478c2048374 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/configs/dfa.yaml @@ -0,0 +1,10 @@ +name: differentiable_face_aligner +arch: 'mobile0.25' +start_from: '../../../../pretrained_models/alignment/dfa_mobilenet/mobilenet0.25.pth' +freeze: True + +input_padding_ratio: 0 # pad the input to this size before resize +input_padding_val: 'zero' +input_size: 160 # resize the input to this size +output_size: 112 # size of the output of aligner +color_space: 'RGB' # color space of the input image diff --git a/cvlface/research/recognition/code/run_v1/aligners/configs/none.yaml b/cvlface/research/recognition/code/run_v1/aligners/configs/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..032cd68ca0ca1f8ff7e919005bd8ad0f90c8cde0 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/configs/none.yaml @@ -0,0 +1,3 @@ +name: none +start_from: '' +freeze: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/configs/retinaface.yaml b/cvlface/research/recognition/code/run_v1/aligners/configs/retinaface.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8d1912fa2b68a598049d3d76e68106925e9a325d --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/configs/retinaface.yaml @@ -0,0 +1,3 @@ +name: retinaface +start_from: '' +freeze: True \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b911456123eff6d15225a68527b2b5eb79a1de63 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/__init__.py @@ -0,0 +1,117 @@ +from ..base import BaseAligner +from torchvision import transforms +from .dfa import get_landmark_predictor, get_preprocessor +from . import aligner_helper +import torch +import torch.nn.functional as F +import numpy as np + + +class DifferentiableFaceAligner(BaseAligner): + + ''' + A differentiable face aligner that aligns the image with one face to a canonical position. + The aligner is based on the following paper (check out supplementary material for more details): + @inproceedings{kim2024kprpe, + title={{KeyPoint Relative Position Encoding for Face Recognition}, + author={Kim, Minchul and Su, Yiyang and Liu, Feng and Liu, Xiaoming}, + booktitle={CVPR}, + year={2024} + } + ''' + + def __init__(self, net, prior_box, preprocessor, config): + super(DifferentiableFaceAligner, self).__init__() + self.net = net + self.prior_box = prior_box + self.preprocessor = preprocessor + self.config = config + + @classmethod + def from_config(cls, config): + net, prior_box = get_landmark_predictor(network=config.arch, + use_aggregator=True, + input_size=config.input_size) + + preprocessor = get_preprocessor(output_size=config.input_size, + padding=config.input_padding_ratio, + padding_val=config.input_padding_val) + if config.freeze: + for param in net.parameters(): + param.requires_grad = False + model = cls(net, prior_box, preprocessor, config) + model.eval() + return model + + def forward(self, x, padding_ratio_override=None): + + # input size check + assert x.shape[1] == 3 + assert x.ndim == 4 + assert isinstance(x, torch.Tensor) + is_square = x.shape[2] == x.shape[3] + + x = self.preprocessor(x, padding_ratio_override=padding_ratio_override) + assert self.prior_box.image_size == x.shape[2:] + + # make image into BGR + x_bgr = x.flip(1) + result = self.net(x_bgr, self.prior_box) + orig_pred_ldmks, bbox, cls = aligner_helper.split_network_output(result) + score = torch.nn.Softmax(dim=-1)(cls)[:,1:] + + reference_ldmk = aligner_helper.reference_landmark() + input_size = self.config.input_size + output_size = self.config.output_size + cv2_tfms = aligner_helper.get_cv2_affine_from_landmark(orig_pred_ldmks, reference_ldmk, input_size, input_size) + thetas = aligner_helper.cv2_param_to_torch_theta(cv2_tfms, input_size, input_size, output_size, output_size) + thetas = thetas.to(orig_pred_ldmks.device) + + output_size = torch.Size((len(thetas), 3, output_size, output_size)) + grid = F.affine_grid(thetas, output_size, align_corners=True) + aligned_x = F.grid_sample(x + 1, grid, align_corners=True) - 1 # +1, -1 for making padding pixel 0 + aligned_ldmks = aligner_helper.adjust_ldmks(orig_pred_ldmks.view(-1, 5, 2), thetas) + + orig_pred_ldmks = orig_pred_ldmks.view(-1, 5, 2) + # bbox (xmin, ymin, xmax, ymax) + normalized_bbox = bbox / torch.tensor([[x_bgr.size(3), x_bgr.size(2)] * 2]).to(bbox.device) + + + if padding_ratio_override is None: + padding_ratio = self.preprocessor.padding + else: + padding_ratio = padding_ratio_override + if padding_ratio > 0: + # unpad the landmark so that it is in the original image coordinate + scale = 1 / (1 + (2 * padding_ratio)) + pad_inv_theta = torch.from_numpy(np.array([[1 / scale, 0, 0], [0, 1 / scale, 0]])) + pad_inv_theta = pad_inv_theta.unsqueeze(0).float().to(self.device).repeat(orig_pred_ldmks.size(0), 1, 1) + unpad_ldmk_pred = torch.concat([orig_pred_ldmks.view(-1, 5, 2), + torch.ones((orig_pred_ldmks.size(0), 5, 1)).to(self.device)], dim=-1) + unpad_ldmk_pred = (((unpad_ldmk_pred) * 2 - 1) @ pad_inv_theta.mT) / 2 + 0.5 + unpad_ldmk_pred = unpad_ldmk_pred.view(orig_pred_ldmks.size(0), -1).detach() + unpad_ldmk_pred = unpad_ldmk_pred.view(-1, 5, 2) + if not is_square: + unpad_ldmk_pred = None # cannot use this if the input is not square becaouse preprocessor changes input + normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input + return aligned_x, unpad_ldmk_pred, aligned_ldmks, score, thetas, normalized_bbox + + if not is_square: + orig_pred_ldmks = None # cannot use this if the input is not square becaouse preprocessor changes input + normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input + return aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox + + def make_train_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + def make_test_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/aligner_helper.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/aligner_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..148f59aad7b34713af13b4fe1e34aeeb91d651c5 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/aligner_helper.py @@ -0,0 +1,97 @@ +import torch +import numpy as np +import cv2 +from skimage import transform as trans +import cv2 + + +def split_network_output(align_out): + anchor_bbox_pred, anchor_cls_pred, anchor_ldmk_pred, merged, _ = align_out + bbox, cls, ldmk = torch.split(merged, [4, 2, 10], dim=1) + return ldmk, bbox, cls + + +def get_cv2_affine_from_landmark(ldmks, reference_ldmk, image_width, image_height, ): + assert ldmks.ndim == 2 # batchdim + assert ldmks.shape[1] == 10 + assert isinstance(ldmks, torch.Tensor) + + assert reference_ldmk.ndim == 2 + assert reference_ldmk.shape[0] == 5 + assert reference_ldmk.shape[1] == 2 + assert isinstance(reference_ldmk, np.ndarray) + + to_img_size = np.array([[[image_width, image_height]]]) + ldmks = ldmks.view(ldmks.shape[0], 5, 2).detach().cpu().numpy() + ldmks = ldmks * to_img_size + transforms = [] + for ldmk in ldmks: + tform = trans.SimilarityTransform() + tform.estimate(ldmk, reference_ldmk) + M = tform.params[0:2, :] + transforms.append(M) + transforms = np.stack(transforms, axis=0) + return transforms + + +def cv2_param_to_torch_theta(cv2_tfms, image_width, image_height, output_width, output_height): + # https://github.com/wuneng/WarpAffine2GridSample + """4.Affine Transformation Matrix to theta""" + assert cv2_tfms.ndim == 3 # N, 2, 3 + assert cv2_tfms.shape[1] == 2 + assert cv2_tfms.shape[2] == 3 + + srcs = np.array([[0, 0], [0, 1], [1, 1]], dtype=np.float32) + srcs = np.expand_dims(srcs, axis=0).repeat(cv2_tfms.shape[0], axis=0) + dsts = np.matmul(srcs, cv2_tfms[:, :, :2].transpose(0, 2, 1)) + cv2_tfms[:, :, 2:3].transpose(0, 2, 1) + + # normalize to [-1, 1] + srcs = srcs / np.array([[[image_width, image_height]]]) * 2 - 1 + dsts = dsts / np.array([[[output_width, output_height]]]) * 2 - 1 + + thetas = [] + for src, dst in zip(srcs, dsts): + theta = trans.estimate_transform("affine", src=dst, dst=src).params[:2] + thetas.append(theta) + thetas = np.stack(thetas, axis=0) + thetas = torch.from_numpy(thetas).float() + return thetas + + +def adjust_ldmks(ldmks, thetas): + inv_thetas = inv_matrix(thetas).to(ldmks.device).float() + _ldmks = torch.cat([ldmks, torch.ones((ldmks.shape[0], 5, 1)).to(ldmks.device)], dim=2) + ldmk_aligned = (((_ldmks) * 2 - 1) @ inv_thetas.permute(0,2,1)) / 2 + 0.5 + return ldmk_aligned + + +def inv_matrix(theta): + # torch batched version + assert theta.ndim == 3 + a, b, t1 = theta[:, 0,0], theta[:, 0,1], theta[:, 0,2] + c, d, t2 = theta[:, 1,0], theta[:, 1,1], theta[:, 1,2] + det = a * d - b * c + inv_det = 1.0 / det + inv_mat = torch.stack([ + torch.stack([d * inv_det, -b * inv_det, (b * t2 - d * t1) * inv_det], dim=1), + torch.stack([-c * inv_det, a * inv_det, (c * t1 - a * t2) * inv_det], dim=1) + ], dim=1) + return inv_mat + +def reference_landmark(): + return np.array([[38.29459953, 51.69630051], + [73.53179932, 51.50139999], + [56.02519989, 71.73660278], + [41.54930115, 92.3655014], + [70.72990036, 92.20410156]]) + + +def draw_ldmk(img, ldmk): + if ldmk is None: + return img + colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)] + img = img.copy() + for i in range(5): + color = colors[i] + cv2.circle(img, (int(ldmk[i*2] * img.shape[1]), int(ldmk[i*2+1] * img.shape[0])), 1, color, 4) + return img \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2f2630c3998b076ba51c73d66de7853f075afef1 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/__init__.py @@ -0,0 +1,27 @@ +from .models.retinaface import RetinaFace +from .utils.model_utils import load_model +from .config import cfg_mnet, cfg_re50 +from .layers.functions.prior_box import PriorBox +from .preprocessor import Preprocessor + +def get_landmark_predictor(network='mobile0.25', use_aggregator=True, input_size=160): + + cfg = None + if network == "mobile0.25": + cfg = cfg_mnet + elif network == "resnet50": + cfg = cfg_re50 + net = RetinaFace(cfg=cfg, phase = 'test', use_aggregator=use_aggregator) + priorbox = PriorBox(image_size=(input_size, input_size), + min_sizes=[[64, 80], [96, 112], [128, 144]], + steps=[8, 16, 32], + clip=False, + variances=[0.1, 0.2],) + + # aligner = Aligner(net, priorbox, input_size, output_size=output_size) + # return aligner + return net, priorbox + + +def get_preprocessor(output_size=160, padding=0.0, padding_val='zero'): + return Preprocessor(output_size=output_size, padding=padding, padding_val=padding_val) \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/config.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/config.py new file mode 100644 index 0000000000000000000000000000000000000000..12a4267998c0818413aca90bb476c07bf6b33ac2 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/config.py @@ -0,0 +1,18 @@ +# config.py + +cfg_mnet = { + 'name': 'mobilenet0.25', + 'pretrain': True, + 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3}, + 'in_channel': 32, + 'out_channel': 64 +} + +cfg_re50 = { + 'name': 'Resnet50', + 'pretrain': True, + 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, + 'in_channel': 256, + 'out_channel': 256 +} + diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53a3f4b5160995d93bc7911e808b3045d74362c9 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/__init__.py @@ -0,0 +1,2 @@ +from .functions import * +from .modules import * diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/functions/prior_box.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/functions/prior_box.py new file mode 100644 index 0000000000000000000000000000000000000000..e8bb3b3352e976c5b7d2f3c9a7a24181841901fd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/functions/prior_box.py @@ -0,0 +1,140 @@ +import torch +from itertools import product as product +from math import ceil + + +class PriorBox(object): + + def __init__(self, + image_size, + min_sizes=[[64, 80], [96, 112], [128, 144]], + steps=[8,16,32], + clip=False, + variances=[0.1, 0.2], + ): + super(PriorBox, self).__init__() + self.min_sizes = min_sizes + self.steps = steps + self.clip = clip + self.variances = variances + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] + with torch.no_grad(): + self.priors = self.forward() + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + # import pandas as pd + # pd.DataFrame(output.numpy()).to_csv('/mckim/temp/temp.csv') + if self.clip: + output.clamp_(max=1, min=0) + return output + + def encode(self, matched): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + """ + self.priors = self.priors.to(matched.device) + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - self.priors[:, :2] + # encode variance + g_cxcy /= (self.variances[0] * self.priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / self.priors[:, 2:] + g_wh = torch.log(g_wh) / self.variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + def encode_landm(self, matched): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + """ + self.priors = self.priors.to(matched.device) + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = self.priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_cy = self.priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_w = self.priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_h = self.priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy /= (self.variances[0] * priors[:, :, 2:]) + # g_cxcy /= priors[:, :, 2:] + g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) + # return target for smooth_l1_loss + return g_cxcy + + + # Adapted from https://github.com/Hakuyume/chainer-ssd + def decode(self, loc): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + """ + self.priors = self.priors.to(loc.device) + + boxes = torch.cat(( + self.priors[:, :2] + loc[:, :2] * self.variances[0] * self.priors[:, 2:], + self.priors[:, 2:] * torch.exp(loc[:, 2:] * self.variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + def decode_landm(self, pre): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + """ + self.priors = self.priors.to(pre.device) + landms = torch.cat((self.priors[:, :2] + pre[:, :2] * self.variances[0] * self.priors[:, 2:], + self.priors[:, :2] + pre[:, 2:4] * self.variances[0] * self.priors[:, 2:], + self.priors[:, :2] + pre[:, 4:6] * self.variances[0] * self.priors[:, 2:], + self.priors[:, :2] + pre[:, 6:8] * self.variances[0] * self.priors[:, 2:], + self.priors[:, :2] + pre[:, 8:10] * self.variances[0] * self.priors[:, 2:], + ), dim=1) + return landms + + + def decode_batch(self, loc): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + """ + self.priors = self.priors.to(loc.device) + assert loc.ndim == 3 + priors = self.priors.unsqueeze(0).expand(loc.size(0), -1, -1) + boxes = torch.cat(( + priors[:, :, :2] + loc[:, :, :2] * self.variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * self.variances[1])), -1) + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + + def decode_landm_batch(self, prediction): + """Decode landm from prediction using priors to undo + the encoding we did for offset regression at train time. + """ + assert prediction.ndim == 3 + self.priors = self.priors.to(prediction.device) + priors = self.priors.unsqueeze(0).expand(prediction.size(0), -1, -1) + landms = torch.cat((priors[:, :, :2] + prediction[:, :, :2] * self.variances[0] * priors[:, :, 2:], + priors[:, :, :2] + prediction[:, :, 2:4] * self.variances[0] * priors[:, :, 2:], + priors[:, :, :2] + prediction[:, :, 4:6] * self.variances[0] * priors[:, :, 2:], + priors[:, :, :2] + prediction[:, :, 6:8] * self.variances[0] * priors[:, :, 2:], + priors[:, :, :2] + prediction[:, :, 8:10] * self.variances[0] * priors[:, :, 2:], + ), dim=-1) + return landms \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf24bddbf283f233d0b93fc074a2bac2f5c044a9 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/__init__.py @@ -0,0 +1,3 @@ +from .multibox_loss import MultiBoxLoss + +__all__ = ['MultiBoxLoss'] diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/multibox_loss.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/multibox_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..cc43c34ea453d7f3a70f10b71f6e73fbddc142ab --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/layers/modules/multibox_loss.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...utils.box_utils import match, log_sum_exp + + +class MultiBoxLoss(nn.Module): + """SSD Weighted Loss Function + Compute Targets: + 1) Produce Confidence Target Indices by matching ground truth boxes + with (default) 'priorboxes' that have jaccard index > threshold parameter + (default threshold: 0.5). + 2) Produce localization target by 'encoding' variance into offsets of ground + truth boxes and their matched 'priorboxes'. + 3) Hard negative mining to filter the excessive number of negative examples + that comes with using a large number of default bounding boxes. + (default negative:positive ratio 3:1) + Objective Loss: + $L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N$ + Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss + weighted by α which is set to 1 by cross val. + Args: + c: class confidences, + l: predicted boxes, + g: ground truth boxes + N: number of matched default boxes + See: https://arxiv.org/pdf/1512.02325.pdf for more details. + """ + + def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): + super(MultiBoxLoss, self).__init__() + self.num_classes = num_classes + self.threshold = overlap_thresh + self.background_label = bkg_label + self.encode_target = encode_target + self.use_prior_for_matching = prior_for_matching + self.do_neg_mining = neg_mining + self.negpos_ratio = neg_pos + self.neg_overlap = neg_overlap + + + def forward(self, predictions, priorbox, targets): + """Multibox Loss + Args: + predictions (tuple): A tuple containing loc preds, conf preds, + and prior boxes from SSD net. + conf shape: torch.size(batch_size,num_priors,num_classes) + loc shape: torch.size(batch_size,num_priors,4) + priors shape: torch.size(num_priors,4) + + ground_truth (tensor): Ground truth boxes and labels for a batch, + shape: [batch_size,num_objs,5] (last idx is the label). + """ + + loc_data, conf_data, landm_data, aggs, thetas = predictions + num = loc_data.size(0) + num_priors = (priorbox.priors.size(0)) + + if aggs is not None: + stacked_target = torch.stack(targets, dim=0).squeeze(1) + + pos_idx = stacked_target[:, -1] > 0 + agg_ldmk = aggs[:, 6:][pos_idx] + tgt_ldmk = stacked_target[:, 4:14][pos_idx] + agg_loss_landm = F.smooth_l1_loss(agg_ldmk, tgt_ldmk, reduction='sum') / len(tgt_ldmk) + + pos_idx = stacked_target[:, -1] != 0 + agg_bbox = aggs[:, :4][pos_idx] + tgt_bbox = stacked_target[:, :4][pos_idx] + agg_loss_box = F.smooth_l1_loss(agg_bbox, tgt_bbox, reduction='sum') / len(tgt_bbox) + + agg_cls = aggs[:, 4:6] + tgt_cls = (stacked_target[:, -1] > 0).long() + agg_loss_cls = F.cross_entropy(agg_cls, tgt_cls, reduction='sum') / len(tgt_cls) + aux_loss_dict = { + 'agg_loss_landm': agg_loss_landm, + 'agg_loss_box': agg_loss_box, + 'agg_loss_cls': agg_loss_cls + } + else: + aux_loss_dict = None + + # match priors (default boxes) and ground truth boxes + loc_t = torch.Tensor(num, num_priors, 4) + landm_t = torch.Tensor(num, num_priors, 10) + conf_t = torch.LongTensor(num, num_priors) + for idx in range(num): + truths = targets[idx][:, :4].data + labels = targets[idx][:, -1].data + landms = targets[idx][:, 4:14].data + match(self.threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx) + + loc_t = loc_t.cuda() + conf_t = conf_t.cuda() + landm_t = landm_t.cuda() + zeros = torch.tensor(0).cuda() + # landm Loss (Smooth L1) + # Shape: [batch,num_priors,10] + pos1 = conf_t > zeros + num_pos_landm = pos1.long().sum(1, keepdim=True) + N1 = max(num_pos_landm.data.sum().float(), 1) + pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) + landm_p = landm_data[pos_idx1].view(-1, 10) + landm_t = landm_t[pos_idx1].view(-1, 10) + loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') + + + pos = conf_t != zeros + conf_t[pos] = 1 + + # Localization Loss (Smooth L1) + # Shape: [batch,num_priors,4] + pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) + loc_p = loc_data[pos_idx].view(-1, 4) + loc_t = loc_t[pos_idx].view(-1, 4) + loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') + + # Compute max conf across batch for hard negative mining + batch_conf = conf_data.view(-1, self.num_classes) + loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) + + # Hard Negative Mining + loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now + loss_c = loss_c.view(num, -1) + _, loss_idx = loss_c.sort(1, descending=True) + _, idx_rank = loss_idx.sort(1) + num_pos = pos.long().sum(1, keepdim=True) + num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) + neg = idx_rank < num_neg.expand_as(idx_rank) + + # Confidence Loss Including Positive and Negative Examples + pos_idx = pos.unsqueeze(2).expand_as(conf_data) + neg_idx = neg.unsqueeze(2).expand_as(conf_data) + conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) + targets_weighted = conf_t[(pos+neg).gt(0)] + loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') + + # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + N = max(num_pos.data.sum().float(), 1) + loss_l /= N + loss_c /= N + loss_landm /= N1 + + return loss_l, loss_c, loss_landm, aux_loss_dict diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/net.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/net.py new file mode 100644 index 0000000000000000000000000000000000000000..af1a1d70c2a910f9d16bcb6b2318be154dd086ac --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/net.py @@ -0,0 +1,132 @@ +import time +import torch +import torch.nn as nn +import torchvision.models._utils as _utils +import torchvision.models as models +import torch.nn.functional as F +from torch.autograd import Variable + +def conv_bn(inp, oup, stride = 1, leaky = 0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True) + ) + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True) + ) + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope= leaky,inplace=True), + + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope= leaky,inplace=True), + ) + +class SSH(nn.Module): + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1) + + self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky) + self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) + + self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky) + self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + +class FPN(nn.Module): + def __init__(self,in_channels_list,out_channels): + super(FPN,self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky) + self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky) + self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky) + + def forward(self, input): + # names = list(input.keys()) + input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest") + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest") + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + + +class MobileNetV1(nn.Module): + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky = 0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + return x + + diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/retinaface.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6f493f06d5fc48a443f90efe67b34e973700bf --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/models/retinaface.py @@ -0,0 +1,142 @@ +import torch +import torch.nn as nn +import torchvision.models._utils as _utils +import torch.nn.functional as F +from .net import MobileNetV1 as MobileNetV1 +from .net import FPN as FPN +from .net import SSH as SSH + +from timm.models import mlp_mixer + +class ClassHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(ClassHead,self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + out = out.permute(0,2,3,1).contiguous() + + return out.view(out.shape[0], -1, 2) + +class BboxHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(BboxHead,self).__init__() + self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + out = out.permute(0,2,3,1).contiguous() + + return out.view(out.shape[0], -1, 4) + +class LandmarkHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(LandmarkHead,self).__init__() + self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + out = out.permute(0,2,3,1).contiguous() + + return out.view(out.shape[0], -1, 10) + +class RetinaFace(nn.Module): + def __init__(self, cfg = None, phase = 'train', use_aggregator=False): + """ + :param cfg: Network related settings. + :param phase: train or test. + """ + super(RetinaFace,self).__init__() + self.phase = phase + backbone = None + if cfg['name'] == 'mobilenet0.25': + backbone = MobileNetV1() + # if cfg['pretrain']: + # checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu')) + # from collections import OrderedDict + # new_state_dict = OrderedDict() + # for k, v in checkpoint['state_dict'].items(): + # name = k[7:] # remove module. + # new_state_dict[name] = v + # load params + # backbone.load_state_dict(new_state_dict) + elif cfg['name'] == 'Resnet50': + import torchvision.models as models + backbone = models.resnet50(pretrained=cfg['pretrain']) + + self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers']) + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list,out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) + + self.use_aggregator = use_aggregator + if self.use_aggregator: + modules = [mlp_mixer.MixerBlock(16, 1050) for _ in range(3)] + modules.append(nn.Linear(16, 1)) + self.aggregator = nn.Sequential(*modules) + else: + self.aggregator = None + + def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels,anchor_num)) + return classhead + + def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels,anchor_num)) + return bboxhead + + def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels,anchor_num)) + return landmarkhead + + def forward(self, inputs, priorbox): + out = self.body(inputs) + + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) + classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1) + ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1) + if self.use_aggregator: + decoded_bbox = priorbox.decode_batch(bbox_regressions) + decoded_ldmk = priorbox.decode_landm_batch(ldm_regressions) + combined = torch.cat([decoded_bbox, classifications, decoded_ldmk], dim=2) + weight = self.aggregator(combined) + weight = F.softmax(weight, dim=1) + agg = torch.sum(weight * combined, dim=1) + theta = None + else: + agg = None + theta = None + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions, agg, theta) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions, agg, theta) + return output \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/preprocessor.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..d7de83e8847b5aed6313b95dceb2d7c54285015e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/preprocessor.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F + +class Preprocessor(): + + def __init__(self, output_size=160, padding=0.0, padding_val='zero'): + self.output_size = output_size + self.padding = padding + self.padding_val = padding_val + + def preprocess_batched(self, imgs, padding_ratio_override=None): + + # check img is of float + if imgs.dtype == torch.float32: + if self.padding_val == 'zero': + padding_val = -1.0 + elif self.padding_val == 'mean': + padding_val = imgs.mean() + else: + raise ValueError('padding_val must be "zero" or "mean"') + elif imgs.dtype == torch.uint8: + if self.padding_val == 'zero': + padding_val = 0 + elif self.padding_val == 'mean': + padding_val = imgs.mean() + else: + raise ValueError('padding_val must be "zero" or "mean"') + else: + raise ValueError('imgs.dtype must be torch.float32 or torch.uint8') + + square_imgs = self.make_square_img_batched(imgs, padding_val=padding_val) + + if padding_ratio_override is not None: + padding = padding_ratio_override + else: + padding = self.padding + padded_imgs = self.make_padded_img_batched(square_imgs, padding=padding, padding_val=padding_val) + + size=(self.output_size, self.output_size) + if imgs.dtype == torch.float32: + resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True) + elif imgs.dtype == torch.uint8: + padded_imgs = padded_imgs.to(torch.float32) + resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True) + resized_imgs = torch.clip(resized_imgs, 0, 255) + resized_imgs = resized_imgs.to(torch.uint8) + else: + raise ValueError('imgs.dtype must be torch.float32 or torch.uint8') + return resized_imgs + + + def make_square_img_batched(self, imgs, padding_val): + assert imgs.ndim == 4 + # squarify the image + h, w = imgs.shape[2:] + if h > w: + diff = (h - w) + pad_left = diff // 2 + pad_right = diff - pad_left + imgs = F.pad(imgs, (pad_left, pad_right, 0, 0), value=padding_val) + elif w > h: + diff = (w - h) + pad_top = diff // 2 + pad_bottom = diff - pad_top + imgs = F.pad(imgs, (0, 0, pad_top, pad_bottom), value=padding_val) + assert imgs.shape[2] == imgs.shape[3] + return imgs + + + def make_padded_img_batched(self, imgs, padding, padding_val): + if padding == 0: + return imgs + assert imgs.ndim == 4 + + + # pad the image + h, w = imgs.shape[2:] + pad_h = int(h * padding) + pad_w = int(w * padding) + imgs = F.pad(imgs, (pad_w, pad_w, pad_h, pad_h), value=padding_val) + return imgs + + + def __call__(self, input, padding_ratio_override=None): + if input.ndim == 3: + assert input.shape[0] == 3 + batch_input = input.unsqueeze(0) + return self.preprocess_batched(batch_input, padding_ratio_override=padding_ratio_override)[0] + elif input.ndim == 4: + assert input.shape[1] == 3 + return self.preprocess_batched(input, padding_ratio_override=padding_ratio_override) + else: + raise ValueError(f'Invalid input shape: {input.shape}') \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/box_utils.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce4f79839c7bd53e9d946cb33394cce49d251ac --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/box_utils.py @@ -0,0 +1,239 @@ +import torch +import numpy as np + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy + boxes[:, 2:] - boxes[:, :2], 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), + box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), + box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2]-box_a[:, 0]) * + (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2]-box_b[:, 0]) * + (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when mathing boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + landms: (tensor) Ground truth landms, Shape [num_obj, 10]. + loc_t: (tensor) Tensor to be filled w/ endcoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + landm_t: (tensor) Tensor to be filled w/ endcoded landm targets. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location 2)confidence 3)landm preds. + """ + + + # jaccard index + overlaps = jaccard( + truths, + point_form(priorbox.priors) + ) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 + conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 + conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + loc = priorbox.encode(matches) + + matches_landm = landms[best_truth_idx] + landm = priorbox.encode_landm(matches_landm) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + landm_t[idx] = landm + + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w*h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter/union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count + + diff --git a/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/model_utils.py b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c89a9ec6ec11b536b319ff07a1cf2283d378283 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/differentiable_face_aligner/dfa/utils/model_utils.py @@ -0,0 +1,36 @@ +import torch + +def remove_prefix(state_dict, prefix): + ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' + print('remove prefix \'{}\''.format(prefix)) + f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x + return {f(key): value for key, value in state_dict.items()} + +def check_keys(model, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(model.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + unused_pretrained_keys = ckpt_keys - model_keys + missing_keys = model_keys - ckpt_keys + print('Missing keys:{}'.format(len(missing_keys))) + print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) + print('Used keys:{}'.format(len(used_pretrained_keys))) + assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' + return True + +def load_model(model, pretrained_path, load_to_cpu): + print('Loading pretrained model from {}'.format(pretrained_path)) + if load_to_cpu: + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) + else: + device = torch.cuda.current_device() + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) + if "state_dict" in pretrained_dict.keys(): + pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') + else: + pretrained_dict = remove_prefix(pretrained_dict, 'module.') + check_keys(model, pretrained_dict) + model.load_state_dict(pretrained_dict, strict=False) + return model + + diff --git a/cvlface/research/recognition/code/run_v1/aligners/none/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/none/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..693c3e9c1de4b3edf98a554a4fa091fc5b04c760 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/none/__init__.py @@ -0,0 +1,20 @@ +from ..base import BaseAligner + + +class NoneAligner(BaseAligner): + def __init__(self, config): + super().__init__() + self.config = config + + @classmethod + def from_config(cls, aligner_config): + return cls(aligner_config) + + def make_train_transform(self): + return lambda x:x + + def make_test_transform(self): + return lambda x:x + + def forward(self, x): + return x diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b4f10b9ca2ed011810b4560d2b2a3f4d32fecc8 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/__init__.py @@ -0,0 +1,246 @@ +from ..base import BaseAligner +from torchvision import transforms +from .retinaface import get_landmark_predictor, get_preprocessor +from . import aligner_helper +import torch +import torch.nn.functional as F +import numpy as np + +class RetinaFaceAligner(BaseAligner): + + """ + A non-differentiable face aligner that aligns the image with one face to a canonical position. + The aligner is based on the following paper: + + ``` + @inproceedings{deng2020retinaface, + title={Retinaface: Single-shot multi-level face localisation in the wild}, + author={Deng, Jiankang and Guo, Jia and Ververas, Evangelos and Kotsia, Irene and Zafeiriou, Stefanos}, + booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition}, + pages={5203--5212}, + year={2020} + } + ``` + """ + + def __init__(self, net, prior_box, preprocessor, config): + super(RetinaFaceAligner, self).__init__() + self.net = net + self.prior_box = prior_box + self.preprocessor = preprocessor + self.config = config + + @classmethod + def from_config(cls, config): + net, prior_box = get_landmark_predictor(network=config.arch, + input_size=config.input_size) + + preprocessor = get_preprocessor(output_size=config.input_size, + padding=config.input_padding_ratio, + padding_val=config.input_padding_val) + if config.freeze: + for param in net.parameters(): + param.requires_grad = False + model = cls(net, prior_box, preprocessor, config) + model.eval() + return model + + def forward(self, x, padding_ratio_override=None): + + # input size check + assert x.shape[1] == 3 + assert x.ndim == 4 + assert isinstance(x, torch.Tensor) + is_square = x.shape[2] == x.shape[3] + + x = self.preprocessor(x, padding_ratio_override=padding_ratio_override) + assert self.prior_box.image_size == x.shape[2:] + + # make image into BGR + x_bgr = x.flip(1) + input_img = normalize_for_net(unnormalize(x_bgr)) + + result = self.net(input_img, self.prior_box) + batch_loc, batch_conf, batch_landms = result + batch_loc = torch.split(batch_loc, 1, dim=0) + batch_conf = torch.split(batch_conf, 1, dim=0) + batch_landms = torch.split(batch_landms, 1, dim=0) + + nms_ldmks = [] + nms_scores = [] + nms_bbox = [] + for loc, conf, landms, in zip(batch_loc, batch_conf, batch_landms): + dets = postprocess(self.prior_box, loc, conf, landms, confidence_threshold=0.0, nms_threshold=0.4) + bbox, score, ldmks = parse_one_det_result(dets) + ldmks = ldmks / np.array( [self.prior_box.image_size[0], self.prior_box.image_size[1]] * 5) + nms_ldmks.append(ldmks) + nms_scores.append(score) + nms_bbox.append(bbox) + + orig_pred_ldmks = torch.from_numpy(np.array(nms_ldmks)).to(self.device).float() + score = torch.from_numpy(np.array(nms_scores)).to(self.device).float().unsqueeze(-1) + bbox = torch.from_numpy(np.array(nms_bbox)).to(self.device).float() + + + reference_ldmk = aligner_helper.reference_landmark() + input_size = self.config.input_size + output_size = self.config.output_size + cv2_tfms = aligner_helper.get_cv2_affine_from_landmark(orig_pred_ldmks, reference_ldmk, input_size, input_size) + thetas = aligner_helper.cv2_param_to_torch_theta(cv2_tfms, input_size, input_size, output_size, output_size) + thetas = thetas.to(orig_pred_ldmks.device) + + output_size = torch.Size((len(thetas), 3, output_size, output_size)) + grid = F.affine_grid(thetas, output_size, align_corners=True) + aligned_x = F.grid_sample(x + 1, grid, align_corners=True) - 1 # +1, -1 for making padding pixel 0 + aligned_ldmks = aligner_helper.adjust_ldmks(orig_pred_ldmks.view(-1, 5, 2), thetas) + + orig_pred_ldmks = orig_pred_ldmks.view(-1, 5, 2) + # bbox (xmin, ymin, xmax, ymax) + normalized_bbox = bbox / torch.tensor([[input_img.size(3), input_img.size(2)] * 2]).to(bbox.device) + + + if padding_ratio_override is None: + padding_ratio = self.preprocessor.padding + else: + padding_ratio = padding_ratio_override + if padding_ratio > 0: + # unpad the landmark so that it is in the original image coordinate + scale = 1 / (1 + (2 * padding_ratio)) + pad_inv_theta = torch.from_numpy(np.array([[1 / scale, 0, 0], [0, 1 / scale, 0]])) + pad_inv_theta = pad_inv_theta.unsqueeze(0).float().to(self.device).repeat(orig_pred_ldmks.size(0), 1, 1) + unpad_ldmk_pred = torch.concat([orig_pred_ldmks.view(-1, 5, 2), + torch.ones((orig_pred_ldmks.size(0), 5, 1)).to(self.device)], dim=-1) + unpad_ldmk_pred = (((unpad_ldmk_pred) * 2 - 1) @ pad_inv_theta.mT) / 2 + 0.5 + unpad_ldmk_pred = unpad_ldmk_pred.view(orig_pred_ldmks.size(0), -1).detach() + unpad_ldmk_pred = unpad_ldmk_pred.view(-1, 5, 2) + if not is_square: + unpad_ldmk_pred = None # cannot use this if the input is not square becaouse preprocessor changes input + normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input + return aligned_x, unpad_ldmk_pred, aligned_ldmks, score, thetas, normalized_bbox + + if not is_square: + orig_pred_ldmks = None # cannot use this if the input is not square becaouse preprocessor changes input + normalized_bbox = None # cannot use this if the input is not square becaouse preprocessor changes input + return aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox + + def make_train_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + def make_test_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + +def normalize(image): + image = image / 255. + image = (image - 0.5) / 0.5 + return image + +def unnormalize(image): + image = image * 0.5 + 0.5 + image = image * 255. + return image + +def normalize_for_net(bgr_image_0_255): + # bgr_image = cv2.imread(image_path, cv2.IMREAD_COLOR) + return bgr_image_0_255 - torch.tensor([104, 117, 123])[None, :, None, None].to(bgr_image_0_255.device) + + +def postprocess(priorbox, loc, conf, landms, confidence_threshold, nms_threshold): + + device = loc.device + im_height, im_width = priorbox.image_size + + scale = torch.Tensor([im_width, im_height, im_width, im_height]) + scale = scale.to(device) + + boxes = priorbox.decode(loc.data.squeeze(0)) + boxes = boxes * scale + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + landms = priorbox.decode_landm(landms.data.squeeze(0)) + scale1 = torch.Tensor([im_width, im_height, im_width, im_height, + im_width, im_height, im_width, im_height, + im_width, im_height]) + scale1 = scale1.to(device) + landms = landms * scale1 + landms = landms.cpu().numpy() + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + if len(inds) == 0: + inds = np.where(scores >= 0)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1] + # order = scores.argsort()[::-1][:args.top_k] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + # dets = dets[:args.keep_top_k, :] + # landms = landms[:args.keep_top_k, :] + + dets = np.concatenate((dets, landms), axis=1) + return dets + + +def py_cpu_nms(dets, + thresh): + """ + Pure Python NMS baseline. + """ + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def parse_one_det_result(dets): + dets_sorted = dets[dets[:, 4].argsort()[::-1]] + result = dets_sorted[0] + bbox = result[:4] + score = result[4] + ldmks = result[5:] + return bbox, score, ldmks diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/aligner_helper.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/aligner_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..148f59aad7b34713af13b4fe1e34aeeb91d651c5 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/aligner_helper.py @@ -0,0 +1,97 @@ +import torch +import numpy as np +import cv2 +from skimage import transform as trans +import cv2 + + +def split_network_output(align_out): + anchor_bbox_pred, anchor_cls_pred, anchor_ldmk_pred, merged, _ = align_out + bbox, cls, ldmk = torch.split(merged, [4, 2, 10], dim=1) + return ldmk, bbox, cls + + +def get_cv2_affine_from_landmark(ldmks, reference_ldmk, image_width, image_height, ): + assert ldmks.ndim == 2 # batchdim + assert ldmks.shape[1] == 10 + assert isinstance(ldmks, torch.Tensor) + + assert reference_ldmk.ndim == 2 + assert reference_ldmk.shape[0] == 5 + assert reference_ldmk.shape[1] == 2 + assert isinstance(reference_ldmk, np.ndarray) + + to_img_size = np.array([[[image_width, image_height]]]) + ldmks = ldmks.view(ldmks.shape[0], 5, 2).detach().cpu().numpy() + ldmks = ldmks * to_img_size + transforms = [] + for ldmk in ldmks: + tform = trans.SimilarityTransform() + tform.estimate(ldmk, reference_ldmk) + M = tform.params[0:2, :] + transforms.append(M) + transforms = np.stack(transforms, axis=0) + return transforms + + +def cv2_param_to_torch_theta(cv2_tfms, image_width, image_height, output_width, output_height): + # https://github.com/wuneng/WarpAffine2GridSample + """4.Affine Transformation Matrix to theta""" + assert cv2_tfms.ndim == 3 # N, 2, 3 + assert cv2_tfms.shape[1] == 2 + assert cv2_tfms.shape[2] == 3 + + srcs = np.array([[0, 0], [0, 1], [1, 1]], dtype=np.float32) + srcs = np.expand_dims(srcs, axis=0).repeat(cv2_tfms.shape[0], axis=0) + dsts = np.matmul(srcs, cv2_tfms[:, :, :2].transpose(0, 2, 1)) + cv2_tfms[:, :, 2:3].transpose(0, 2, 1) + + # normalize to [-1, 1] + srcs = srcs / np.array([[[image_width, image_height]]]) * 2 - 1 + dsts = dsts / np.array([[[output_width, output_height]]]) * 2 - 1 + + thetas = [] + for src, dst in zip(srcs, dsts): + theta = trans.estimate_transform("affine", src=dst, dst=src).params[:2] + thetas.append(theta) + thetas = np.stack(thetas, axis=0) + thetas = torch.from_numpy(thetas).float() + return thetas + + +def adjust_ldmks(ldmks, thetas): + inv_thetas = inv_matrix(thetas).to(ldmks.device).float() + _ldmks = torch.cat([ldmks, torch.ones((ldmks.shape[0], 5, 1)).to(ldmks.device)], dim=2) + ldmk_aligned = (((_ldmks) * 2 - 1) @ inv_thetas.permute(0,2,1)) / 2 + 0.5 + return ldmk_aligned + + +def inv_matrix(theta): + # torch batched version + assert theta.ndim == 3 + a, b, t1 = theta[:, 0,0], theta[:, 0,1], theta[:, 0,2] + c, d, t2 = theta[:, 1,0], theta[:, 1,1], theta[:, 1,2] + det = a * d - b * c + inv_det = 1.0 / det + inv_mat = torch.stack([ + torch.stack([d * inv_det, -b * inv_det, (b * t2 - d * t1) * inv_det], dim=1), + torch.stack([-c * inv_det, a * inv_det, (c * t1 - a * t2) * inv_det], dim=1) + ], dim=1) + return inv_mat + +def reference_landmark(): + return np.array([[38.29459953, 51.69630051], + [73.53179932, 51.50139999], + [56.02519989, 71.73660278], + [41.54930115, 92.3655014], + [70.72990036, 92.20410156]]) + + +def draw_ldmk(img, ldmk): + if ldmk is None: + return img + colors = [(0, 255, 0), (255, 0, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255), (255, 0, 255)] + img = img.copy() + for i in range(5): + color = colors[i] + cv2.circle(img, (int(ldmk[i*2] * img.shape[1]), int(ldmk[i*2+1] * img.shape[0])), 1, color, 4) + return img \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6642577229c6ea88db280726a74419c3d41941f6 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/__init__.py @@ -0,0 +1,28 @@ +from .models.retinaface import RetinaFace +from .utils.model_utils import load_model +from .config import cfg_mnet, cfg_re50 +from .layers.functions.prior_box import PriorBox +from .preprocessor import Preprocessor + +def get_landmark_predictor(network='mobile0.25', input_size=160): + + cfg = None + if network == "mobile0.25": + cfg = cfg_mnet + elif network == "resnet50": + cfg = cfg_re50 + net = RetinaFace(cfg=cfg, phase = 'test') + priorbox = PriorBox(image_size=(input_size, input_size), + # min_sizes=[[64, 80], [96, 112], [128, 144]], + min_sizes=[[16, 32], [64, 128], [256, 512]], + steps=[8, 16, 32], + clip=False, + variances=[0.1, 0.2],) + + # aligner = Aligner(net, priorbox, input_size, output_size=output_size) + # return aligner + return net, priorbox + + +def get_preprocessor(output_size=160, padding=0.0, padding_val='zero'): + return Preprocessor(output_size=output_size, padding=padding, padding_val=padding_val) \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/config.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/config.py new file mode 100644 index 0000000000000000000000000000000000000000..12a4267998c0818413aca90bb476c07bf6b33ac2 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/config.py @@ -0,0 +1,18 @@ +# config.py + +cfg_mnet = { + 'name': 'mobilenet0.25', + 'pretrain': True, + 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3}, + 'in_channel': 32, + 'out_channel': 64 +} + +cfg_re50 = { + 'name': 'Resnet50', + 'pretrain': True, + 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, + 'in_channel': 256, + 'out_channel': 256 +} + diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53a3f4b5160995d93bc7911e808b3045d74362c9 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/__init__.py @@ -0,0 +1,2 @@ +from .functions import * +from .modules import * diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/functions/prior_box.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/functions/prior_box.py new file mode 100644 index 0000000000000000000000000000000000000000..e8bb3b3352e976c5b7d2f3c9a7a24181841901fd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/functions/prior_box.py @@ -0,0 +1,140 @@ +import torch +from itertools import product as product +from math import ceil + + +class PriorBox(object): + + def __init__(self, + image_size, + min_sizes=[[64, 80], [96, 112], [128, 144]], + steps=[8,16,32], + clip=False, + variances=[0.1, 0.2], + ): + super(PriorBox, self).__init__() + self.min_sizes = min_sizes + self.steps = steps + self.clip = clip + self.variances = variances + self.image_size = image_size + self.feature_maps = [[ceil(self.image_size[0]/step), ceil(self.image_size[1]/step)] for step in self.steps] + with torch.no_grad(): + self.priors = self.forward() + + def forward(self): + anchors = [] + for k, f in enumerate(self.feature_maps): + min_sizes = self.min_sizes[k] + for i, j in product(range(f[0]), range(f[1])): + for min_size in min_sizes: + s_kx = min_size / self.image_size[1] + s_ky = min_size / self.image_size[0] + dense_cx = [x * self.steps[k] / self.image_size[1] for x in [j + 0.5]] + dense_cy = [y * self.steps[k] / self.image_size[0] for y in [i + 0.5]] + for cy, cx in product(dense_cy, dense_cx): + anchors += [cx, cy, s_kx, s_ky] + + # back to torch land + output = torch.Tensor(anchors).view(-1, 4) + # import pandas as pd + # pd.DataFrame(output.numpy()).to_csv('/mckim/temp/temp.csv') + if self.clip: + output.clamp_(max=1, min=0) + return output + + def encode(self, matched): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + """ + self.priors = self.priors.to(matched.device) + + # dist b/t match center and prior's center + g_cxcy = (matched[:, :2] + matched[:, 2:])/2 - self.priors[:, :2] + # encode variance + g_cxcy /= (self.variances[0] * self.priors[:, 2:]) + # match wh / prior wh + g_wh = (matched[:, 2:] - matched[:, :2]) / self.priors[:, 2:] + g_wh = torch.log(g_wh) / self.variances[1] + # return target for smooth_l1_loss + return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4] + + def encode_landm(self, matched): + """Encode the variances from the priorbox layers into the ground truth boxes + we have matched (based on jaccard overlap) with the prior boxes. + """ + self.priors = self.priors.to(matched.device) + + # dist b/t match center and prior's center + matched = torch.reshape(matched, (matched.size(0), 5, 2)) + priors_cx = self.priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_cy = self.priors[:, 1].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_w = self.priors[:, 2].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors_h = self.priors[:, 3].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) + priors = torch.cat([priors_cx, priors_cy, priors_w, priors_h], dim=2) + g_cxcy = matched[:, :, :2] - priors[:, :, :2] + # encode variance + g_cxcy /= (self.variances[0] * priors[:, :, 2:]) + # g_cxcy /= priors[:, :, 2:] + g_cxcy = g_cxcy.reshape(g_cxcy.size(0), -1) + # return target for smooth_l1_loss + return g_cxcy + + + # Adapted from https://github.com/Hakuyume/chainer-ssd + def decode(self, loc): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + """ + self.priors = self.priors.to(loc.device) + + boxes = torch.cat(( + self.priors[:, :2] + loc[:, :2] * self.variances[0] * self.priors[:, 2:], + self.priors[:, 2:] * torch.exp(loc[:, 2:] * self.variances[1])), 1) + boxes[:, :2] -= boxes[:, 2:] / 2 + boxes[:, 2:] += boxes[:, :2] + return boxes + + def decode_landm(self, pre): + """Decode landm from predictions using priors to undo + the encoding we did for offset regression at train time. + """ + self.priors = self.priors.to(pre.device) + landms = torch.cat((self.priors[:, :2] + pre[:, :2] * self.variances[0] * self.priors[:, 2:], + self.priors[:, :2] + pre[:, 2:4] * self.variances[0] * self.priors[:, 2:], + self.priors[:, :2] + pre[:, 4:6] * self.variances[0] * self.priors[:, 2:], + self.priors[:, :2] + pre[:, 6:8] * self.variances[0] * self.priors[:, 2:], + self.priors[:, :2] + pre[:, 8:10] * self.variances[0] * self.priors[:, 2:], + ), dim=1) + return landms + + + def decode_batch(self, loc): + """Decode locations from predictions using priors to undo + the encoding we did for offset regression at train time. + """ + self.priors = self.priors.to(loc.device) + assert loc.ndim == 3 + priors = self.priors.unsqueeze(0).expand(loc.size(0), -1, -1) + boxes = torch.cat(( + priors[:, :, :2] + loc[:, :, :2] * self.variances[0] * priors[:, :, 2:], + priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * self.variances[1])), -1) + boxes[:, :, :2] -= boxes[:, :, 2:] / 2 + boxes[:, :, 2:] += boxes[:, :, :2] + return boxes + + + def decode_landm_batch(self, prediction): + """Decode landm from prediction using priors to undo + the encoding we did for offset regression at train time. + """ + assert prediction.ndim == 3 + self.priors = self.priors.to(prediction.device) + priors = self.priors.unsqueeze(0).expand(prediction.size(0), -1, -1) + landms = torch.cat((priors[:, :, :2] + prediction[:, :, :2] * self.variances[0] * priors[:, :, 2:], + priors[:, :, :2] + prediction[:, :, 2:4] * self.variances[0] * priors[:, :, 2:], + priors[:, :, :2] + prediction[:, :, 4:6] * self.variances[0] * priors[:, :, 2:], + priors[:, :, :2] + prediction[:, :, 6:8] * self.variances[0] * priors[:, :, 2:], + priors[:, :, :2] + prediction[:, :, 8:10] * self.variances[0] * priors[:, :, 2:], + ), dim=-1) + return landms \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cf24bddbf283f233d0b93fc074a2bac2f5c044a9 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/__init__.py @@ -0,0 +1,3 @@ +from .multibox_loss import MultiBoxLoss + +__all__ = ['MultiBoxLoss'] diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/multibox_loss.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/multibox_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..4e7ef749f85999f5724ad46df561fb280023602b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/layers/modules/multibox_loss.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from ...utils.box_utils import match, log_sum_exp + + +class MultiBoxLoss(nn.Module): + """SSD Weighted Loss Function + Compute Targets: + 1) Produce Confidence Target Indices by matching ground truth boxes + with (default) 'priorboxes' that have jaccard index > threshold parameter + (default threshold: 0.5). + 2) Produce localization target by 'encoding' variance into offsets of ground + truth boxes and their matched 'priorboxes'. + 3) Hard negative mining to filter the excessive number of negative examples + that comes with using a large number of default bounding boxes. + (default negative:positive ratio 3:1) + Objective Loss: + L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss + weighted by α which is set to 1 by cross val. + Args: + c: class confidences, + l: predicted boxes, + g: ground truth boxes + N: number of matched default boxes + See: https://arxiv.org/pdf/1512.02325.pdf for more details. + """ + + def __init__(self, num_classes, overlap_thresh, prior_for_matching, bkg_label, neg_mining, neg_pos, neg_overlap, encode_target): + super(MultiBoxLoss, self).__init__() + self.num_classes = num_classes + self.threshold = overlap_thresh + self.background_label = bkg_label + self.encode_target = encode_target + self.use_prior_for_matching = prior_for_matching + self.do_neg_mining = neg_mining + self.negpos_ratio = neg_pos + self.neg_overlap = neg_overlap + + + def forward(self, predictions, priorbox, targets): + """Multibox Loss + Args: + predictions (tuple): A tuple containing loc preds, conf preds, + and prior boxes from SSD net. + conf shape: torch.size(batch_size,num_priors,num_classes) + loc shape: torch.size(batch_size,num_priors,4) + priors shape: torch.size(num_priors,4) + + ground_truth (tensor): Ground truth boxes and labels for a batch, + shape: [batch_size,num_objs,5] (last idx is the label). + """ + + loc_data, conf_data, landm_data, aggs, thetas = predictions + num = loc_data.size(0) + num_priors = (priorbox.priors.size(0)) + + if aggs is not None: + stacked_target = torch.stack(targets, dim=0).squeeze(1) + + pos_idx = stacked_target[:, -1] > 0 + agg_ldmk = aggs[:, 6:][pos_idx] + tgt_ldmk = stacked_target[:, 4:14][pos_idx] + agg_loss_landm = F.smooth_l1_loss(agg_ldmk, tgt_ldmk, reduction='sum') / len(tgt_ldmk) + + pos_idx = stacked_target[:, -1] != 0 + agg_bbox = aggs[:, :4][pos_idx] + tgt_bbox = stacked_target[:, :4][pos_idx] + agg_loss_box = F.smooth_l1_loss(agg_bbox, tgt_bbox, reduction='sum') / len(tgt_bbox) + + agg_cls = aggs[:, 4:6] + tgt_cls = (stacked_target[:, -1] > 0).long() + agg_loss_cls = F.cross_entropy(agg_cls, tgt_cls, reduction='sum') / len(tgt_cls) + aux_loss_dict = { + 'agg_loss_landm': agg_loss_landm, + 'agg_loss_box': agg_loss_box, + 'agg_loss_cls': agg_loss_cls + } + else: + aux_loss_dict = None + + # match priors (default boxes) and ground truth boxes + loc_t = torch.Tensor(num, num_priors, 4) + landm_t = torch.Tensor(num, num_priors, 10) + conf_t = torch.LongTensor(num, num_priors) + for idx in range(num): + truths = targets[idx][:, :4].data + labels = targets[idx][:, -1].data + landms = targets[idx][:, 4:14].data + match(self.threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx) + + loc_t = loc_t.cuda() + conf_t = conf_t.cuda() + landm_t = landm_t.cuda() + zeros = torch.tensor(0).cuda() + # landm Loss (Smooth L1) + # Shape: [batch,num_priors,10] + pos1 = conf_t > zeros + num_pos_landm = pos1.long().sum(1, keepdim=True) + N1 = max(num_pos_landm.data.sum().float(), 1) + pos_idx1 = pos1.unsqueeze(pos1.dim()).expand_as(landm_data) + landm_p = landm_data[pos_idx1].view(-1, 10) + landm_t = landm_t[pos_idx1].view(-1, 10) + loss_landm = F.smooth_l1_loss(landm_p, landm_t, reduction='sum') + + + pos = conf_t != zeros + conf_t[pos] = 1 + + # Localization Loss (Smooth L1) + # Shape: [batch,num_priors,4] + pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) + loc_p = loc_data[pos_idx].view(-1, 4) + loc_t = loc_t[pos_idx].view(-1, 4) + loss_l = F.smooth_l1_loss(loc_p, loc_t, reduction='sum') + + # Compute max conf across batch for hard negative mining + batch_conf = conf_data.view(-1, self.num_classes) + loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1, 1)) + + # Hard Negative Mining + loss_c[pos.view(-1, 1)] = 0 # filter out pos boxes for now + loss_c = loss_c.view(num, -1) + _, loss_idx = loss_c.sort(1, descending=True) + _, idx_rank = loss_idx.sort(1) + num_pos = pos.long().sum(1, keepdim=True) + num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) + neg = idx_rank < num_neg.expand_as(idx_rank) + + # Confidence Loss Including Positive and Negative Examples + pos_idx = pos.unsqueeze(2).expand_as(conf_data) + neg_idx = neg.unsqueeze(2).expand_as(conf_data) + conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) + targets_weighted = conf_t[(pos+neg).gt(0)] + loss_c = F.cross_entropy(conf_p, targets_weighted, reduction='sum') + + # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N + N = max(num_pos.data.sum().float(), 1) + loss_l /= N + loss_c /= N + loss_landm /= N1 + + return loss_l, loss_c, loss_landm, aux_loss_dict diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/__init__.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/net.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/net.py new file mode 100644 index 0000000000000000000000000000000000000000..af1a1d70c2a910f9d16bcb6b2318be154dd086ac --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/net.py @@ -0,0 +1,132 @@ +import time +import torch +import torch.nn as nn +import torchvision.models._utils as _utils +import torchvision.models as models +import torch.nn.functional as F +from torch.autograd import Variable + +def conv_bn(inp, oup, stride = 1, leaky = 0): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True) + ) + +def conv_bn_no_relu(inp, oup, stride): + return nn.Sequential( + nn.Conv2d(inp, oup, 3, stride, 1, bias=False), + nn.BatchNorm2d(oup), + ) + +def conv_bn1X1(inp, oup, stride, leaky=0): + return nn.Sequential( + nn.Conv2d(inp, oup, 1, stride, padding=0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope=leaky, inplace=True) + ) + +def conv_dw(inp, oup, stride, leaky=0.1): + return nn.Sequential( + nn.Conv2d(inp, inp, 3, stride, 1, groups=inp, bias=False), + nn.BatchNorm2d(inp), + nn.LeakyReLU(negative_slope= leaky,inplace=True), + + nn.Conv2d(inp, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + nn.LeakyReLU(negative_slope= leaky,inplace=True), + ) + +class SSH(nn.Module): + def __init__(self, in_channel, out_channel): + super(SSH, self).__init__() + assert out_channel % 4 == 0 + leaky = 0 + if (out_channel <= 64): + leaky = 0.1 + self.conv3X3 = conv_bn_no_relu(in_channel, out_channel//2, stride=1) + + self.conv5X5_1 = conv_bn(in_channel, out_channel//4, stride=1, leaky = leaky) + self.conv5X5_2 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) + + self.conv7X7_2 = conv_bn(out_channel//4, out_channel//4, stride=1, leaky = leaky) + self.conv7x7_3 = conv_bn_no_relu(out_channel//4, out_channel//4, stride=1) + + def forward(self, input): + conv3X3 = self.conv3X3(input) + + conv5X5_1 = self.conv5X5_1(input) + conv5X5 = self.conv5X5_2(conv5X5_1) + + conv7X7_2 = self.conv7X7_2(conv5X5_1) + conv7X7 = self.conv7x7_3(conv7X7_2) + + out = torch.cat([conv3X3, conv5X5, conv7X7], dim=1) + out = F.relu(out) + return out + +class FPN(nn.Module): + def __init__(self,in_channels_list,out_channels): + super(FPN,self).__init__() + leaky = 0 + if (out_channels <= 64): + leaky = 0.1 + self.output1 = conv_bn1X1(in_channels_list[0], out_channels, stride = 1, leaky = leaky) + self.output2 = conv_bn1X1(in_channels_list[1], out_channels, stride = 1, leaky = leaky) + self.output3 = conv_bn1X1(in_channels_list[2], out_channels, stride = 1, leaky = leaky) + + self.merge1 = conv_bn(out_channels, out_channels, leaky = leaky) + self.merge2 = conv_bn(out_channels, out_channels, leaky = leaky) + + def forward(self, input): + # names = list(input.keys()) + input = list(input.values()) + + output1 = self.output1(input[0]) + output2 = self.output2(input[1]) + output3 = self.output3(input[2]) + + up3 = F.interpolate(output3, size=[output2.size(2), output2.size(3)], mode="nearest") + output2 = output2 + up3 + output2 = self.merge2(output2) + + up2 = F.interpolate(output2, size=[output1.size(2), output1.size(3)], mode="nearest") + output1 = output1 + up2 + output1 = self.merge1(output1) + + out = [output1, output2, output3] + return out + + + +class MobileNetV1(nn.Module): + def __init__(self): + super(MobileNetV1, self).__init__() + self.stage1 = nn.Sequential( + conv_bn(3, 8, 2, leaky = 0.1), # 3 + conv_dw(8, 16, 1), # 7 + conv_dw(16, 32, 2), # 11 + conv_dw(32, 32, 1), # 19 + conv_dw(32, 64, 2), # 27 + conv_dw(64, 64, 1), # 43 + ) + self.stage2 = nn.Sequential( + conv_dw(64, 128, 2), # 43 + 16 = 59 + conv_dw(128, 128, 1), # 59 + 32 = 91 + conv_dw(128, 128, 1), # 91 + 32 = 123 + conv_dw(128, 128, 1), # 123 + 32 = 155 + conv_dw(128, 128, 1), # 155 + 32 = 187 + conv_dw(128, 128, 1), # 187 + 32 = 219 + ) + self.stage3 = nn.Sequential( + conv_dw(128, 256, 2), # 219 +3 2 = 241 + conv_dw(256, 256, 1), # 241 + 64 = 301 + ) + + def forward(self, x): + x = self.stage1(x) + x = self.stage2(x) + x = self.stage3(x) + return x + + diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/retinaface.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/retinaface.py new file mode 100644 index 0000000000000000000000000000000000000000..90be889896b8c6bc444fa50aa5d4c5b91ab969ab --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/models/retinaface.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +import torchvision.models._utils as _utils +import torch.nn.functional as F +from .net import MobileNetV1 as MobileNetV1 +from .net import FPN as FPN +from .net import SSH as SSH + +from timm.models import mlp_mixer + +class ClassHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(ClassHead,self).__init__() + self.num_anchors = num_anchors + self.conv1x1 = nn.Conv2d(inchannels,self.num_anchors*2,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + out = out.permute(0,2,3,1).contiguous() + + return out.view(out.shape[0], -1, 2) + +class BboxHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(BboxHead,self).__init__() + self.conv1x1 = nn.Conv2d(inchannels,num_anchors*4,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + out = out.permute(0,2,3,1).contiguous() + + return out.view(out.shape[0], -1, 4) + +class LandmarkHead(nn.Module): + def __init__(self,inchannels=512,num_anchors=3): + super(LandmarkHead,self).__init__() + self.conv1x1 = nn.Conv2d(inchannels,num_anchors*10,kernel_size=(1,1),stride=1,padding=0) + + def forward(self,x): + out = self.conv1x1(x) + out = out.permute(0,2,3,1).contiguous() + + return out.view(out.shape[0], -1, 10) + +class RetinaFace(nn.Module): + def __init__(self, cfg = None, phase = 'train'): + """ + :param cfg: Network related settings. + :param phase: train or test. + """ + super(RetinaFace,self).__init__() + self.phase = phase + backbone = None + if cfg['name'] == 'mobilenet0.25': + backbone = MobileNetV1() + # if cfg['pretrain']: + # checkpoint = torch.load("./weights/mobilenetV1X0.25_pretrain.tar", map_location=torch.device('cpu')) + # from collections import OrderedDict + # new_state_dict = OrderedDict() + # for k, v in checkpoint['state_dict'].items(): + # name = k[7:] # remove module. + # new_state_dict[name] = v + # load params + # backbone.load_state_dict(new_state_dict) + elif cfg['name'] == 'Resnet50': + import torchvision.models as models + backbone = models.resnet50(pretrained=cfg['pretrain']) + + self.body = _utils.IntermediateLayerGetter(backbone, cfg['return_layers']) + in_channels_stage2 = cfg['in_channel'] + in_channels_list = [ + in_channels_stage2 * 2, + in_channels_stage2 * 4, + in_channels_stage2 * 8, + ] + out_channels = cfg['out_channel'] + self.fpn = FPN(in_channels_list,out_channels) + self.ssh1 = SSH(out_channels, out_channels) + self.ssh2 = SSH(out_channels, out_channels) + self.ssh3 = SSH(out_channels, out_channels) + + self.ClassHead = self._make_class_head(fpn_num=3, inchannels=cfg['out_channel']) + self.BboxHead = self._make_bbox_head(fpn_num=3, inchannels=cfg['out_channel']) + self.LandmarkHead = self._make_landmark_head(fpn_num=3, inchannels=cfg['out_channel']) + + def _make_class_head(self,fpn_num=3,inchannels=64,anchor_num=2): + classhead = nn.ModuleList() + for i in range(fpn_num): + classhead.append(ClassHead(inchannels,anchor_num)) + return classhead + + def _make_bbox_head(self,fpn_num=3,inchannels=64,anchor_num=2): + bboxhead = nn.ModuleList() + for i in range(fpn_num): + bboxhead.append(BboxHead(inchannels,anchor_num)) + return bboxhead + + def _make_landmark_head(self,fpn_num=3,inchannels=64,anchor_num=2): + landmarkhead = nn.ModuleList() + for i in range(fpn_num): + landmarkhead.append(LandmarkHead(inchannels,anchor_num)) + return landmarkhead + + def forward(self, inputs, priorbox=None): + out = self.body(inputs) + + # FPN + fpn = self.fpn(out) + + # SSH + feature1 = self.ssh1(fpn[0]) + feature2 = self.ssh2(fpn[1]) + feature3 = self.ssh3(fpn[2]) + features = [feature1, feature2, feature3] + + bbox_regressions = torch.cat([self.BboxHead[i](feature) for i, feature in enumerate(features)], dim=1) + classifications = torch.cat([self.ClassHead[i](feature) for i, feature in enumerate(features)],dim=1) + ldm_regressions = torch.cat([self.LandmarkHead[i](feature) for i, feature in enumerate(features)], dim=1) + if self.phase == 'train': + output = (bbox_regressions, classifications, ldm_regressions) + else: + output = (bbox_regressions, F.softmax(classifications, dim=-1), ldm_regressions) + return output \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/preprocessor.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..d7de83e8847b5aed6313b95dceb2d7c54285015e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/preprocessor.py @@ -0,0 +1,93 @@ +import torch +import torch.nn.functional as F + +class Preprocessor(): + + def __init__(self, output_size=160, padding=0.0, padding_val='zero'): + self.output_size = output_size + self.padding = padding + self.padding_val = padding_val + + def preprocess_batched(self, imgs, padding_ratio_override=None): + + # check img is of float + if imgs.dtype == torch.float32: + if self.padding_val == 'zero': + padding_val = -1.0 + elif self.padding_val == 'mean': + padding_val = imgs.mean() + else: + raise ValueError('padding_val must be "zero" or "mean"') + elif imgs.dtype == torch.uint8: + if self.padding_val == 'zero': + padding_val = 0 + elif self.padding_val == 'mean': + padding_val = imgs.mean() + else: + raise ValueError('padding_val must be "zero" or "mean"') + else: + raise ValueError('imgs.dtype must be torch.float32 or torch.uint8') + + square_imgs = self.make_square_img_batched(imgs, padding_val=padding_val) + + if padding_ratio_override is not None: + padding = padding_ratio_override + else: + padding = self.padding + padded_imgs = self.make_padded_img_batched(square_imgs, padding=padding, padding_val=padding_val) + + size=(self.output_size, self.output_size) + if imgs.dtype == torch.float32: + resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True) + elif imgs.dtype == torch.uint8: + padded_imgs = padded_imgs.to(torch.float32) + resized_imgs = F.interpolate(padded_imgs, size=size, mode='bilinear', align_corners=True) + resized_imgs = torch.clip(resized_imgs, 0, 255) + resized_imgs = resized_imgs.to(torch.uint8) + else: + raise ValueError('imgs.dtype must be torch.float32 or torch.uint8') + return resized_imgs + + + def make_square_img_batched(self, imgs, padding_val): + assert imgs.ndim == 4 + # squarify the image + h, w = imgs.shape[2:] + if h > w: + diff = (h - w) + pad_left = diff // 2 + pad_right = diff - pad_left + imgs = F.pad(imgs, (pad_left, pad_right, 0, 0), value=padding_val) + elif w > h: + diff = (w - h) + pad_top = diff // 2 + pad_bottom = diff - pad_top + imgs = F.pad(imgs, (0, 0, pad_top, pad_bottom), value=padding_val) + assert imgs.shape[2] == imgs.shape[3] + return imgs + + + def make_padded_img_batched(self, imgs, padding, padding_val): + if padding == 0: + return imgs + assert imgs.ndim == 4 + + + # pad the image + h, w = imgs.shape[2:] + pad_h = int(h * padding) + pad_w = int(w * padding) + imgs = F.pad(imgs, (pad_w, pad_w, pad_h, pad_h), value=padding_val) + return imgs + + + def __call__(self, input, padding_ratio_override=None): + if input.ndim == 3: + assert input.shape[0] == 3 + batch_input = input.unsqueeze(0) + return self.preprocess_batched(batch_input, padding_ratio_override=padding_ratio_override)[0] + elif input.ndim == 4: + assert input.shape[1] == 3 + return self.preprocess_batched(input, padding_ratio_override=padding_ratio_override) + else: + raise ValueError(f'Invalid input shape: {input.shape}') \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/box_utils.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/box_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce4f79839c7bd53e9d946cb33394cce49d251ac --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/box_utils.py @@ -0,0 +1,239 @@ +import torch +import numpy as np + + +def point_form(boxes): + """ Convert prior_boxes to (xmin, ymin, xmax, ymax) + representation for comparison to point form ground truth data. + Args: + boxes: (tensor) center-size default boxes from priorbox layers. + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, :2] - boxes[:, 2:]/2, # xmin, ymin + boxes[:, :2] + boxes[:, 2:]/2), 1) # xmax, ymax + + +def center_size(boxes): + """ Convert prior_boxes to (cx, cy, w, h) + representation for comparison to center-size form ground truth data. + Args: + boxes: (tensor) point_form boxes + Return: + boxes: (tensor) Converted xmin, ymin, xmax, ymax form of boxes. + """ + return torch.cat((boxes[:, 2:] + boxes[:, :2])/2, # cx, cy + boxes[:, 2:] - boxes[:, :2], 1) # w, h + + +def intersect(box_a, box_b): + """ We resize both tensors to [A,B,2] without new malloc: + [A,2] -> [A,1,2] -> [A,B,2] + [B,2] -> [1,B,2] -> [A,B,2] + Then we compute the area of intersect between box_a and box_b. + Args: + box_a: (tensor) bounding boxes, Shape: [A,4]. + box_b: (tensor) bounding boxes, Shape: [B,4]. + Return: + (tensor) intersection area, Shape: [A,B]. + """ + A = box_a.size(0) + B = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), + box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), + box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + inter = torch.clamp((max_xy - min_xy), min=0) + return inter[:, :, 0] * inter[:, :, 1] + + +def jaccard(box_a, box_b): + """Compute the jaccard overlap of two sets of boxes. The jaccard overlap + is simply the intersection over union of two boxes. Here we operate on + ground truth boxes and default boxes. + E.g.: + A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) + Args: + box_a: (tensor) Ground truth bounding boxes, Shape: [num_objects,4] + box_b: (tensor) Prior boxes from priorbox layers, Shape: [num_priors,4] + Return: + jaccard overlap: (tensor) Shape: [box_a.size(0), box_b.size(0)] + """ + inter = intersect(box_a, box_b) + area_a = ((box_a[:, 2]-box_a[:, 0]) * + (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B] + area_b = ((box_b[:, 2]-box_b[:, 0]) * + (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B] + union = area_a + area_b - inter + return inter / union # [A,B] + + +def matrix_iou(a, b): + """ + return iou of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + area_b = np.prod(b[:, 2:] - b[:, :2], axis=1) + return area_i / (area_a[:, np.newaxis] + area_b - area_i) + + +def matrix_iof(a, b): + """ + return iof of a and b, numpy version for data augenmentation + """ + lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) + rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) + + area_i = np.prod(rb - lt, axis=2) * (lt < rb).all(axis=2) + area_a = np.prod(a[:, 2:] - a[:, :2], axis=1) + return area_i / np.maximum(area_a[:, np.newaxis], 1) + + +def match(threshold, truths, priorbox, labels, landms, loc_t, conf_t, landm_t, idx): + """Match each prior box with the ground truth box of the highest jaccard + overlap, encode the bounding boxes, then return the matched indices + corresponding to both confidence and location preds. + Args: + threshold: (float) The overlap threshold used when mathing boxes. + truths: (tensor) Ground truth boxes, Shape: [num_obj, 4]. + priors: (tensor) Prior boxes from priorbox layers, Shape: [n_priors,4]. + variances: (tensor) Variances corresponding to each prior coord, + Shape: [num_priors, 4]. + labels: (tensor) All the class labels for the image, Shape: [num_obj]. + landms: (tensor) Ground truth landms, Shape [num_obj, 10]. + loc_t: (tensor) Tensor to be filled w/ endcoded location targets. + conf_t: (tensor) Tensor to be filled w/ matched indices for conf preds. + landm_t: (tensor) Tensor to be filled w/ endcoded landm targets. + idx: (int) current batch index + Return: + The matched indices corresponding to 1)location 2)confidence 3)landm preds. + """ + + + # jaccard index + overlaps = jaccard( + truths, + point_form(priorbox.priors) + ) + # (Bipartite Matching) + # [1,num_objects] best prior for each ground truth + best_prior_overlap, best_prior_idx = overlaps.max(1, keepdim=True) + + # ignore hard gt + valid_gt_idx = best_prior_overlap[:, 0] >= 0.2 + best_prior_idx_filter = best_prior_idx[valid_gt_idx, :] + if best_prior_idx_filter.shape[0] <= 0: + loc_t[idx] = 0 + conf_t[idx] = 0 + return + + # [1,num_priors] best ground truth for each prior + best_truth_overlap, best_truth_idx = overlaps.max(0, keepdim=True) + best_truth_idx.squeeze_(0) + best_truth_overlap.squeeze_(0) + best_prior_idx.squeeze_(1) + best_prior_idx_filter.squeeze_(1) + best_prior_overlap.squeeze_(1) + best_truth_overlap.index_fill_(0, best_prior_idx_filter, 2) # ensure best prior + # TODO refactor: index best_prior_idx with long tensor + # ensure every gt matches with its prior of max overlap + for j in range(best_prior_idx.size(0)): # 判别此anchor是预测哪一个boxes + best_truth_idx[best_prior_idx[j]] = j + matches = truths[best_truth_idx] # Shape: [num_priors,4] 此处为每一个anchor对应的bbox取出来 + conf = labels[best_truth_idx] # Shape: [num_priors] 此处为每一个anchor对应的label取出来 + conf[best_truth_overlap < threshold] = 0 # label as background overlap<0.35的全部作为负样本 + loc = priorbox.encode(matches) + + matches_landm = landms[best_truth_idx] + landm = priorbox.encode_landm(matches_landm) + loc_t[idx] = loc # [num_priors,4] encoded offsets to learn + conf_t[idx] = conf # [num_priors] top class label for each prior + landm_t[idx] = landm + + + +def log_sum_exp(x): + """Utility function for computing log_sum_exp while determining + This will be used to determine unaveraged confidence loss across + all examples in a batch. + Args: + x (Variable(tensor)): conf_preds from conf layers + """ + x_max = x.data.max() + return torch.log(torch.sum(torch.exp(x-x_max), 1, keepdim=True)) + x_max + + +# Original author: Francisco Massa: +# https://github.com/fmassa/object-detection.torch +# Ported to PyTorch by Max deGroot (02/01/2017) +def nms(boxes, scores, overlap=0.5, top_k=200): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + overlap: (float) The overlap thresh for suppressing unnecessary boxes. + top_k: (int) The Maximum number of box preds to consider. + Return: + The indices of the kept boxes with respect to num_priors. + """ + + keep = torch.Tensor(scores.size(0)).fill_(0).long() + if boxes.numel() == 0: + return keep + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + area = torch.mul(x2 - x1, y2 - y1) + v, idx = scores.sort(0) # sort in ascending order + # I = I[v >= 0.01] + idx = idx[-top_k:] # indices of the top-k largest vals + xx1 = boxes.new() + yy1 = boxes.new() + xx2 = boxes.new() + yy2 = boxes.new() + w = boxes.new() + h = boxes.new() + + # keep = torch.Tensor() + count = 0 + while idx.numel() > 0: + i = idx[-1] # index of current largest val + # keep.append(i) + keep[count] = i + count += 1 + if idx.size(0) == 1: + break + idx = idx[:-1] # remove kept element from view + # load bboxes of next highest vals + torch.index_select(x1, 0, idx, out=xx1) + torch.index_select(y1, 0, idx, out=yy1) + torch.index_select(x2, 0, idx, out=xx2) + torch.index_select(y2, 0, idx, out=yy2) + # store element-wise max with next highest score + xx1 = torch.clamp(xx1, min=x1[i]) + yy1 = torch.clamp(yy1, min=y1[i]) + xx2 = torch.clamp(xx2, max=x2[i]) + yy2 = torch.clamp(yy2, max=y2[i]) + w.resize_as_(xx2) + h.resize_as_(yy2) + w = xx2 - xx1 + h = yy2 - yy1 + # check sizes of xx1 and xx2.. after each iteration + w = torch.clamp(w, min=0.0) + h = torch.clamp(h, min=0.0) + inter = w*h + # IoU = i / (area(a) + area(b) - i) + rem_areas = torch.index_select(area, 0, idx) # load remaining areas) + union = (rem_areas - inter) + area[i] + IoU = inter/union # store result in iou + # keep only elements with an IoU <= overlap + idx = idx[IoU.le(overlap)] + return keep, count + + diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/model_utils.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c89a9ec6ec11b536b319ff07a1cf2283d378283 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface/utils/model_utils.py @@ -0,0 +1,36 @@ +import torch + +def remove_prefix(state_dict, prefix): + ''' Old style model is stored with all names of parameters sharing common prefix 'module.' ''' + print('remove prefix \'{}\''.format(prefix)) + f = lambda x: x.split(prefix, 1)[-1] if x.startswith(prefix) else x + return {f(key): value for key, value in state_dict.items()} + +def check_keys(model, pretrained_state_dict): + ckpt_keys = set(pretrained_state_dict.keys()) + model_keys = set(model.state_dict().keys()) + used_pretrained_keys = model_keys & ckpt_keys + unused_pretrained_keys = ckpt_keys - model_keys + missing_keys = model_keys - ckpt_keys + print('Missing keys:{}'.format(len(missing_keys))) + print('Unused checkpoint keys:{}'.format(len(unused_pretrained_keys))) + print('Used keys:{}'.format(len(used_pretrained_keys))) + assert len(used_pretrained_keys) > 0, 'load NONE from pretrained checkpoint' + return True + +def load_model(model, pretrained_path, load_to_cpu): + print('Loading pretrained model from {}'.format(pretrained_path)) + if load_to_cpu: + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage) + else: + device = torch.cuda.current_device() + pretrained_dict = torch.load(pretrained_path, map_location=lambda storage, loc: storage.cuda(device)) + if "state_dict" in pretrained_dict.keys(): + pretrained_dict = remove_prefix(pretrained_dict['state_dict'], 'module.') + else: + pretrained_dict = remove_prefix(pretrained_dict, 'module.') + check_keys(model, pretrained_dict) + model.load_state_dict(pretrained_dict, strict=False) + return model + + diff --git a/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface_pipeline.py b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..cc64ccdbce99d7f2f77ee727a5d20bcb87c90ce4 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/aligners/retinaface_aligner/retinaface_pipeline.py @@ -0,0 +1,247 @@ +import torch +import numpy as np +import cv2 +from .retinaface.utils.model_utils import load_model +from .retinaface.layers.functions.prior_box import PriorBox +from .retinaface.models.retinaface import RetinaFace +import torch.nn.functional as F + + +cfg_mnet = { + 'name': 'mobilenet0.25', + 'gpu_train': True, + 'ngpu': 1, + 'epoch': 250, + 'decay1': 190, + 'decay2': 220, + # 'image_size': 640, + 'pretrain': True, + 'return_layers': {'stage1': 1, 'stage2': 2, 'stage3': 3}, + 'in_channel': 32, + 'out_channel': 64 +} + + +cfg_re50 = { + 'name': 'Resnet50', + 'gpu_train': True, + 'ngpu': 4, + 'epoch': 100, + 'decay1': 70, + 'decay2': 90, + # 'image_size': 840, + 'pretrain': True, + 'return_layers': {'layer2': 1, 'layer3': 2, 'layer4': 3}, + 'in_channel': 256, + 'out_channel': 256 +} + + +def load_retinface_model(network='resnet50', trained_model_path=''): + cfg = None + if network == "mobile0.25": + cfg = cfg_mnet + elif network == "resnet50": + cfg = cfg_re50 + # net and model + net = RetinaFace(cfg=cfg, phase = 'test') + net = load_model(net, trained_model_path, True) + net.eval() + # freeze grad + for param in net.parameters(): + param.requires_grad = False + + return net + + +class RetinaFacePipeline(torch.nn.Module): + + def __init__(self, net, priorbox, input_size, device='cuda'): + super().__init__() + self.net = net + self.priorbox = priorbox + self.input_size = input_size + self.output_size = 112 + self.device = device + + + def normalize(self, image): + image = image / 255. + image = (image - 0.5) / 0.5 + return image + + def unnormalize(self, image): + image = image * 0.5 + 0.5 + image = image * 255. + return image + + def normalize_for_net(self, bgr_image_0_255): + # bgr_image = cv2.imread(image_path, cv2.IMREAD_COLOR) + return bgr_image_0_255 - torch.tensor([104, 117, 123])[None, :, None, None].to(self.device) + + def prealign_preprocess(self, images, value=0.0): + # pad to input_size + assert isinstance(images, torch.Tensor) + assert images.ndim == 4 or images.ndim == 3 + input_size = self.input_size + + data_width = images.shape[-1] + data_height = images.shape[-2] + if data_width > input_size or data_height > input_size: + # image is biggert than the input size + # resize such that the larger side becomes the input_size without changing the aspect ratio + if data_width > data_height: + scale = input_size / data_width + else: + scale = input_size / data_height + if images.ndim == 4: + images = F.interpolate(input=images, scale_factor=scale, + mode='bilinear', align_corners=False) + else: + images = F.interpolate(input=images.unsqueeze(0), scale_factor=scale, + mode='bilinear', align_corners=False).squeeze(0) + + data_width = images.shape[-1] + data_height = images.shape[-2] + padding_width1 = (input_size - data_width) // 2 + padding_width2 = (input_size - data_width) - padding_width1 + padding_height1 = (input_size - data_height) // 2 + padding_height2 = (input_size - data_height) - padding_height1 + + result = torch.nn.functional.pad(input=images, + pad=(padding_width1, padding_width2, + padding_height1, padding_height2), + value=value) + assert result.shape[-1] == input_size + assert result.shape[-2] == input_size + return result + + def forward(self, rgb_images): + + # cv2.imwrite('/mckim/temp/temp.jpg', self.unnormalize(rgb_images[0]).cpu().numpy().transpose(1,2,0)) + + assert rgb_images.shape[1] == 3 + assert rgb_images.ndim == 4 + assert isinstance(rgb_images, torch.Tensor) + assert self.priorbox.image_size == rgb_images.shape[2:] + rgb_images = rgb_images.to(self.device) + + # make image into BGR + bgr_images = rgb_images.flip(1) + input_img = self.normalize_for_net(self.unnormalize(bgr_images)) + batch_loc, batch_conf, batch_landms = self.net(input_img) + batch_loc = torch.split(batch_loc, 1, dim=0) + batch_conf = torch.split(batch_conf, 1, dim=0) + batch_landms = torch.split(batch_landms, 1, dim=0) + + all_ldmks = [] + for loc, conf, landms, in zip(batch_loc, batch_conf, batch_landms): + dets = postprocess(self.priorbox, loc, conf, landms, confidence_threshold=0.0, nms_threshold=0.4) + bbox, score, ldmks = parse_one_det_result(dets) + ldmks = ldmks / np.array( [self.priorbox.image_size[0], self.priorbox.image_size[1]] * 5) + all_ldmks.append(ldmks) + all_ldmks = torch.from_numpy(np.array(all_ldmks)).to(self.device).float() + return all_ldmks + + +def postprocess(priorbox, loc, conf, landms, confidence_threshold, nms_threshold): + + device = loc.device + im_height, im_width = priorbox.image_size + + scale = torch.Tensor([im_width, im_height, im_width, im_height]) + scale = scale.to(device) + + boxes = priorbox.decode(loc.data.squeeze(0)) + boxes = boxes * scale + boxes = boxes.cpu().numpy() + scores = conf.squeeze(0).data.cpu().numpy()[:, 1] + landms = priorbox.decode_landm(landms.data.squeeze(0)) + scale1 = torch.Tensor([im_width, im_height, im_width, im_height, + im_width, im_height, im_width, im_height, + im_width, im_height]) + scale1 = scale1.to(device) + landms = landms * scale1 + landms = landms.cpu().numpy() + + # ignore low scores + inds = np.where(scores > confidence_threshold)[0] + if len(inds) == 0: + inds = np.where(scores >= 0)[0] + boxes = boxes[inds] + landms = landms[inds] + scores = scores[inds] + + # keep top-K before NMS + order = scores.argsort()[::-1] + # order = scores.argsort()[::-1][:args.top_k] + boxes = boxes[order] + landms = landms[order] + scores = scores[order] + + # do NMS + dets = np.hstack((boxes, scores[:, np.newaxis])).astype(np.float32, copy=False) + keep = py_cpu_nms(dets, nms_threshold) + # keep = nms(dets, args.nms_threshold,force_cpu=args.cpu) + dets = dets[keep, :] + landms = landms[keep] + + # keep top-K faster NMS + # dets = dets[:args.keep_top_k, :] + # landms = landms[:args.keep_top_k, :] + + dets = np.concatenate((dets, landms), axis=1) + return dets + + +def py_cpu_nms(dets, thresh): + """Pure Python NMS baseline.""" + x1 = dets[:, 0] + y1 = dets[:, 1] + x2 = dets[:, 2] + y2 = dets[:, 3] + scores = dets[:, 4] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + + return keep + + +def parse_one_det_result(dets): + dets_sorted = dets[dets[:, 4].argsort()[::-1]] + result = dets_sorted[0] + bbox = result[:4] + score = result[4] + ldmks = result[5:] + return bbox, score, ldmks + + +def load_retinaface_pipeline(network, trained_model_path, input_size, device): + net = load_retinface_model(network='resnet50', trained_model_path=trained_model_path) + net = net.to(device) + priorbox = PriorBox(image_size=(input_size, input_size), + min_sizes=[[16, 32], [64, 128], [256, 512]], + steps=[8,16,32], clip=False, + variances=[0.1, 0.2], + device=device) + pipeline = RetinaFacePipeline(net, priorbox, input_size, device=device) + pipeline.cuda() + return pipeline diff --git a/cvlface/research/recognition/code/run_v1/base.yaml b/cvlface/research/recognition/code/run_v1/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4271f2b94a044c50de619325b886f716ea0bee97 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/base.yaml @@ -0,0 +1,12 @@ +defaults: + - trainers : configs/default + - optims : configs/cosine + - pefts: configs/none + - models: vit/configs/v1_small + - classifiers: configs/partial_fc + - aligners: configs/none + - dataset: configs/casia + - data_augs: configs/v7 + - losses: configs/adaface + - pipelines: configs/train_model_cls + - evaluations: configs/base \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/classifiers/__init__.py b/cvlface/research/recognition/code/run_v1/classifiers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bc7a4d74934ce9e0965ea087767d24ab11f79f8d --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/__init__.py @@ -0,0 +1,31 @@ +from . import partial_fc +from . import fc + +def get_classifier(classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size): + + if margin_loss_fn is None: + classifier = None + print("No margin loss function provided, classifier will not be created") + return classifier + + if classifier_cfg.name == 'partial_fc': + classifier = partial_fc.PartialFCClassifier.from_config(classifier_cfg, margin_loss_fn, + model_cfg, num_classes, + rank, world_size) + elif classifier_cfg.name == 'fc': + classifier = fc.FCClassifier.from_config(classifier_cfg, margin_loss_fn, + model_cfg, num_classes, + rank, world_size) + + else: + raise ValueError(f"Unknown classifier: {classifier_cfg.name}") + + if classifier_cfg.start_from: + classifier.load_state_dict_from_path(classifier_cfg.start_from) + + if classifier_cfg.freeze: + for param in classifier.parameters(): + param.requires_grad = False + + return classifier + diff --git a/cvlface/research/recognition/code/run_v1/classifiers/base/__init__.py b/cvlface/research/recognition/code/run_v1/classifiers/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4c94da4a562e300ead1c5535a57329ccbf74eca4 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/base/__init__.py @@ -0,0 +1,87 @@ +import os +from typing import Union +import torch +from torch import device +from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path +from general_utils.os_utils import natural_sort + +class BaseClassifier(torch.nn.Module): + + def __init__(self, config=None): + super(BaseClassifier, self).__init__() + self.config = config + + @classmethod + def from_config(cls, classifier_cfg, margin_loss_fn, model_cfg, dataset_cfg, rank, world_size) -> "BaseClassifier": + raise NotImplementedError('from_config must be implemented in subclass') + + def forward(self, local_embeddings, local_labels): + raise NotImplementedError('from_config must be implemented in subclass') + + + @property + def device(self) -> device: + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False) -> int: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + def has_trainable_params(self): + for param in self.parameters(): + if param.requires_grad: + return True + return False + + def save_pretrained( + self, + save_dir: Union[str, os.PathLike], + name: str = 'model.pt', + rank: int = 0, + ): + rank_added_name = os.path.splitext(name)[0] + f'_rank{rank}' + os.path.splitext(name)[1] + save_path = os.path.join(save_dir, rank_added_name) + save_state_dict_and_config(self.state_dict(), self.config, save_path) + + + def load_state_dict_from_path(self, pretrained_model_path): + + save_dir = os.path.dirname(pretrained_model_path) + save_name = os.path.basename(pretrained_model_path) + rank_added_name = os.path.splitext(save_name)[0] + f'_rank{self.rank}' + os.path.splitext(save_name)[1] + pretrained_model_path = os.path.join(save_dir, rank_added_name) + + all_partitions = [name for name in os.listdir(save_dir) if '_rank' in name and '.pt' in name] + all_partitions = natural_sort(all_partitions) + ckpt_worldsize = len(all_partitions) + + if self.world_size != ckpt_worldsize: + # we need to redistribute the partialfc weights + part_ckpts = [torch.load(os.path.join(save_dir, name), map_location='cpu') for name in all_partitions] + total_ckpt_num_subjects = sum([ckpt['partial_fc.weight'].shape[0] for ckpt in part_ckpts]) + assert total_ckpt_num_subjects - self.partial_fc.num_classes < 10, \ + (f"total_ckpt_num_subjects: {total_ckpt_num_subjects}, " + f"self.partial_fc.num_classes: {self.partial_fc.num_classes}" + f"The number can be slightly different due to the last partition.") + + combined_weight = torch.cat([ckpt['partial_fc.weight'] for ckpt in part_ckpts], dim=0) + state_dict = part_ckpts[0] + + class_start = self.partial_fc.class_start + num_sample = self.partial_fc.num_local + sub_center = combined_weight[class_start:class_start + num_sample, :] + if sub_center.shape[0] != num_sample: + # append zero + extra_center = torch.zeros(num_sample - sub_center.shape[0], sub_center.shape[1], + device=self.device, dtype=self.dtype) + sub_center = torch.cat([sub_center, extra_center], dim=0) + state_dict['partial_fc.weight'] = sub_center + + else: + state_dict = load_state_dict_from_path(pretrained_model_path) + + result = self.load_state_dict(state_dict, strict=False) + print(result) \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/classifiers/base/utils.py b/cvlface/research/recognition/code/run_v1/classifiers/base/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cb9f6fd290ede2f2ced6959ddb57b8ba56fc9fdd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/base/utils.py @@ -0,0 +1,91 @@ +import itertools +from typing import List, Optional, Tuple, Union +import safetensors +import torch +from torch import Tensor +import os +from pathlib import Path +from omegaconf import DictConfig, OmegaConf + + +def get_parameter_device(parameter: torch.nn.Module): + try: + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + params = tuple(parameter.parameters()) + if len(params) > 0: + return params[0].dtype + + buffers = tuple(parameter.buffers()) + if len(buffers) > 0: + return buffers[0].dtype + + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path: + path_obj = Path(save_path) + return path_obj.parent + +def get_base_name(save_path: Union[str, os.PathLike]) -> str: + path_obj = Path(save_path) + return path_obj.name + +def load_state_dict_from_path(path: Union[str, os.PathLike]): + # Load a state dict from a path. + if 'safetensors' in path: + state_dict = safetensors.torch.load_file(path) + else: + state_dict = torch.load(path, map_location="cpu") + return state_dict + +def replace_extension(path, new_extension): + if not new_extension.startswith('.'): + new_extension = '.' + new_extension + return os.path.splitext(path)[0] + new_extension + +def make_config_path(save_path): + config_path = replace_extension(save_path, '.yaml') + return config_path + +def save_config(config, config_path): + assert isinstance(config, dict) or isinstance(config, DictConfig) + os.makedirs(get_parent_directory(config_path), exist_ok=True) + if isinstance(config, dict): + config = OmegaConf.create(config) + OmegaConf.save(config, config_path) + + +def save_state_dict_and_config(state_dict, config, save_path): + os.makedirs(get_parent_directory(save_path), exist_ok=True) + + # save config dict + config_path = make_config_path(save_path) + save_config(config, config_path) + + # Save the model + if 'safetensors' in save_path: + safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"}) + else: + torch.save(state_dict, save_path) diff --git a/cvlface/research/recognition/code/run_v1/classifiers/configs/fc.yaml b/cvlface/research/recognition/code/run_v1/classifiers/configs/fc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2ad2f0f442269ee208d8fae2a68aecd63fd3091b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/configs/fc.yaml @@ -0,0 +1,4 @@ +name: 'fc' +sample_rate: 1.0 +start_from: '' +freeze: False diff --git a/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc.yaml b/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4765aee325e8fe7324740520d8823b227c4e7428 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc.yaml @@ -0,0 +1,4 @@ +name: 'partial_fc' +sample_rate: 1.0 +start_from: '' +freeze: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_freeze.yaml b/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_freeze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36f0ce7f7a8fca7f246f3e628cdea5182d288842 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_freeze.yaml @@ -0,0 +1,4 @@ +name: 'partial_fc' +sample_rate: 1.0 +start_from: '' +freeze: True \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10.yaml b/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..77403de2687114d6a32b3c33d35eff7b512eb6d4 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10.yaml @@ -0,0 +1,4 @@ +name: 'partial_fc' +sample_rate: 0.1 +start_from: '' +freeze: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10_freeze.yaml b/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10_freeze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12d6d6bca64e171bd537cf96125b405d7c776198 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/configs/partial_fc_sample10_freeze.yaml @@ -0,0 +1,4 @@ +name: 'partial_fc' +sample_rate: 0.1 +start_from: '' +freeze: True \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/classifiers/fc/__init__.py b/cvlface/research/recognition/code/run_v1/classifiers/fc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1bec7e554742505602d19c653c2714f42b718a26 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/fc/__init__.py @@ -0,0 +1,55 @@ +from ..base import BaseClassifier, load_state_dict_from_path +from .fc import FC +from typing import Union +import os + + +class FCClassifier(BaseClassifier): + + def __init__(self, classifier, config, rank, world_size): + super(FCClassifier, self).__init__() + self.classifier = classifier + self.config = config + self.rank = rank + self.world_size = world_size + self.apply_ddp = True + + @classmethod + def from_config(cls, classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size): + if classifier_cfg.name == 'fc': + classifier = FC( + margin_loss=margin_loss_fn, + embedding_size=model_cfg.output_dim, + num_classes=num_classes, + ) + else: + raise NotImplementedError + + model = cls(classifier, classifier_cfg, rank, world_size) + model.eval() + return model + + def forward(self, local_embeddings, local_labels): + loss = self.classifier(local_embeddings, local_labels) + return loss + + def save_pretrained( + self, + save_dir: Union[str, os.PathLike], + name: str = 'classifier.pt', + rank: int = 0, + ): + if rank == 0: + super().save_pretrained(save_dir, name, rank) + + def load_state_dict_from_path(self, pretrained_model_path): + save_dir = os.path.dirname(pretrained_model_path) + save_name = os.path.basename(pretrained_model_path) + rank_added_name = os.path.splitext(save_name)[0] + f'_rank0' + os.path.splitext(save_name)[1] + pretrained_model_path = os.path.join(save_dir, rank_added_name) + + state_dict = load_state_dict_from_path(pretrained_model_path) + result = self.load_state_dict(state_dict, strict=False) + print('classifier loading result', result) + + diff --git a/cvlface/research/recognition/code/run_v1/classifiers/fc/fc.py b/cvlface/research/recognition/code/run_v1/classifiers/fc/fc.py new file mode 100644 index 0000000000000000000000000000000000000000..0a0d506fb4bacc244491c99bda2ed9fb34dcc099 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/fc/fc.py @@ -0,0 +1,67 @@ +from typing import Callable +import torch +from torch import distributed +from torch.nn.functional import linear, normalize +from losses.margin_loss import CombinedMarginLoss +from losses.adaface import AdaFaceLoss + + + +class FC(torch.nn.Module): + + def __init__( + self, + margin_loss: Callable, + embedding_size: int, + num_classes: int, + ): + super(FC, self).__init__() + + self.cross_entropy = torch.nn.CrossEntropyLoss() + self.embedding_size = embedding_size + self.num_classes = num_classes + self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_classes, embedding_size))) + + # margin_loss + if isinstance(margin_loss, Callable): + self.margin_softmax = margin_loss + if isinstance(margin_loss, AdaFaceLoss): + self.register_buffer('batch_mean', torch.ones(1)*(20)) + self.register_buffer('batch_std', torch.ones(1)*100) + else: + raise + + + def forward( + self, + local_embeddings: torch.Tensor, + local_labels: torch.Tensor, + ): + + embeddings = local_embeddings + labels = local_labels + weight = self.weight + + norms = embeddings.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8) + norm_embeddings = embeddings / norms + + norm_weight_activated = normalize(weight) + logits = linear(norm_embeddings, norm_weight_activated) + logits = logits.clamp(-1, 1) + + if isinstance(self.margin_softmax, CombinedMarginLoss): + logits = self.margin_softmax(logits=logits, labels=labels) + elif isinstance(self.margin_softmax, AdaFaceLoss): + logits, batch_mean, batch_std = self.margin_softmax(logits=logits, labels=labels, norms=norms, + batch_mean=self.batch_mean, + batch_std=self.batch_std) + self.batch_mean.data = batch_mean.data + self.batch_std.data = batch_std.data + else: + raise ValueError('parital FC margin_softmax not supported type') + + loss = self.cross_entropy(logits, labels) + return loss + + + diff --git a/cvlface/research/recognition/code/run_v1/classifiers/partial_fc/__init__.py b/cvlface/research/recognition/code/run_v1/classifiers/partial_fc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..484c5008a8e07866b5df04f63bb80ae5b0a24110 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/partial_fc/__init__.py @@ -0,0 +1,39 @@ +from ..base import BaseClassifier +from .partial_fc import PartialFC_V2 + + +class PartialFCClassifier(BaseClassifier): + + def __init__(self, classifier, config, rank, world_size): + super(PartialFCClassifier, self).__init__() + self.partial_fc = classifier + self.config = config + self.rank = rank + self.world_size = world_size + self.apply_ddp = False + + @classmethod + def from_config(cls, classifier_cfg, margin_loss_fn, model_cfg, num_classes, rank, world_size): + if classifier_cfg.name == 'partial_fc': + classifier = PartialFC_V2( + rank=rank, + world_size=world_size, + margin_loss=margin_loss_fn, + embedding_size=model_cfg.output_dim, + num_classes=num_classes, + sample_rate=classifier_cfg.sample_rate, + ) + else: + raise NotImplementedError + + model = cls(classifier, classifier_cfg, rank, world_size) + model.eval() + return model + + def forward(self, local_embeddings, local_labels): + loss = self.partial_fc(local_embeddings, local_labels) + return loss + + + + diff --git a/cvlface/research/recognition/code/run_v1/classifiers/partial_fc/partial_fc.py b/cvlface/research/recognition/code/run_v1/classifiers/partial_fc/partial_fc.py new file mode 100644 index 0000000000000000000000000000000000000000..beaeb54cd7624f0555f3d59ca38b706aac4d5b1e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/classifiers/partial_fc/partial_fc.py @@ -0,0 +1,289 @@ +from typing import Callable +import torch +from torch import distributed +from torch.nn.functional import linear, normalize +from losses.margin_loss import CombinedMarginLoss +from losses.adaface import AdaFaceLoss + + + +class PartialFC_V2(torch.nn.Module): + """ + https://arxiv.org/abs/2203.15565 + A distributed sparsely updating variant of the FC layer, named Partial FC (PFC). + When sample rate less than 1, in each iteration, positive class centers and a random subset of + negative class centers are selected to compute the margin-based softmax loss, all class + centers are still maintained throughout the whole training process, but only a subset is + selected and updated in each iteration. + .. note:: + When sample rate equal to 1, Partial FC is equal to model parallelism(default sample rate is 1). + Example: + -------- + >>> module_pfc = PartialFC(embedding_size=512, num_classes=8000000, sample_rate=0.2) + >>> for img, labels in data_loader: + >>> embeddings = net(img) + >>> loss = module_pfc(embeddings, labels) + >>> loss.backward() + >>> optimizer.step() + """ + _version = 2 + + def __init__( + self, + rank: int, + world_size: int, + margin_loss: Callable, + embedding_size: int, + num_classes: int, + sample_rate: float = 1.0, + ): + """ + Paramenters: + ----------- + embedding_size: int + The dimension of embedding, required + num_classes: int + Total number of classes, required + sample_rate: float + The rate of negative centers participating in the calculation, default is 1.0. + """ + super(PartialFC_V2, self).__init__() + assert ( + distributed.is_initialized() + ), "must initialize distributed before create this" + self.rank = rank + self.world_size = world_size + + self.dist_cross_entropy = DistCrossEntropy() + self.embedding_size = embedding_size + self.sample_rate: float = sample_rate + + # make num_class divisible by self.world_size for ddp + _num_classes = num_classes // self.world_size * self.world_size + if _num_classes < num_classes: + _num_classes = _num_classes + self.world_size + num_classes = _num_classes + self.num_classes: int = num_classes + + self.num_local: int = num_classes // self.world_size + int( + self.rank < num_classes % self.world_size + ) + + # for i in range(8): + # num_local = (num_classes // self.world_size + int( i < num_classes % self.world_size )) + # class_start = num_classes // self.world_size * i + min( i, num_classes % self.world_size ) + # print(num_local, class_start) + + self.class_start: int = num_classes // self.world_size * self.rank + min( + self.rank, num_classes % self.world_size + ) + self.num_sample: int = int(self.sample_rate * self.num_local) + self.last_batch_size: int = 0 + + self.is_updated: bool = True + self.init_weight_update: bool = True + self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (self.num_local, embedding_size))) + + # margin_loss + if isinstance(margin_loss, Callable): + self.margin_softmax = margin_loss + if isinstance(margin_loss, AdaFaceLoss): + self.register_buffer('batch_mean', torch.ones(1)*(20)) + self.register_buffer('batch_std', torch.ones(1)*100) + else: + raise + + def sample(self, labels, index_positive): + """ + This functions will change the value of labels + Parameters: + ----------- + labels: torch.Tensor + pass + index_positive: torch.Tensor + pass + optimizer: torch.optim.Optimizer + pass + """ + with torch.no_grad(): + positive = torch.unique(labels[index_positive], sorted=True).cuda() + if self.num_sample - positive.size(0) >= 0: + perm = torch.rand(size=[self.num_local]).cuda() + perm[positive] = 2.0 + index = torch.topk(perm, k=self.num_sample)[1].cuda() + index = index.sort()[0].cuda() + else: + index = positive + self.weight_index = index + + labels[index_positive] = torch.searchsorted(index, labels[index_positive]) + + return self.weight[self.weight_index] + + def forward( + self, + local_embeddings: torch.Tensor, + local_labels: torch.Tensor, + ): + """ + Parameters: + ---------- + local_embeddings: torch.Tensor + feature embeddings on each GPU(Rank). + local_labels: torch.Tensor + labels on each GPU(Rank). + Returns: + ------- + loss: torch.Tensor + pass + """ + + local_labels.squeeze_() + local_labels = local_labels.long() + + batch_size = local_embeddings.size(0) + if self.last_batch_size == 0: + self.last_batch_size = batch_size + assert self.last_batch_size == batch_size, ( + f"last batch size do not equal current batch size: {self.last_batch_size} vs {batch_size}") + + _gather_embeddings = [ + torch.zeros((batch_size, self.embedding_size), dtype=local_embeddings.dtype, device=local_embeddings.device) + for _ in range(self.world_size) + ] + _gather_labels = [ + torch.zeros(batch_size).long().cuda() for _ in range(self.world_size) + ] + _list_embeddings = AllGather(local_embeddings, *_gather_embeddings) + distributed.all_gather(_gather_labels, local_labels) + + embeddings = torch.cat(_list_embeddings) + labels = torch.cat(_gather_labels) + + labels = labels.view(-1, 1) + index_positive = (self.class_start <= labels) & ( + labels < self.class_start + self.num_local + ) + labels[~index_positive] = -1 + labels[index_positive] -= self.class_start + + if self.sample_rate < 1: + weight = self.sample(labels, index_positive) + else: + weight = self.weight + + # with torch.cuda.amp.autocast(self.fp16): + norms = embeddings.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8) + norm_embeddings = embeddings / norms + + norm_weight_activated = normalize(weight) + logits = linear(norm_embeddings, norm_weight_activated) + + logits = logits.clamp(-1, 1) + + if isinstance(self.margin_softmax, CombinedMarginLoss): + logits = self.margin_softmax(logits=logits, labels=labels) + elif isinstance(self.margin_softmax, AdaFaceLoss): + logits, batch_mean, batch_std = self.margin_softmax(logits=logits, labels=labels, norms=norms, + batch_mean=self.batch_mean, + batch_std=self.batch_std) + self.batch_mean.data = batch_mean.data + self.batch_std.data = batch_std.data + else: + raise ValueError('parital FC margin_softmax not supported type') + + loss = self.dist_cross_entropy(logits, labels) + return loss + + +class DistCrossEntropyFunc(torch.autograd.Function): + """ + CrossEntropy loss is calculated in parallel, allreduce denominator into single gpu and calculate softmax. + Implemented of ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): + """ + + @staticmethod + def forward(ctx, logits: torch.Tensor, label: torch.Tensor): + """ """ + batch_size = logits.size(0) + # for numerical stability + max_logits, _ = torch.max(logits, dim=1, keepdim=True) + # local to global + distributed.all_reduce(max_logits, distributed.ReduceOp.MAX) + logits.sub_(max_logits) + logits.exp_() + sum_logits_exp = torch.sum(logits, dim=1, keepdim=True) + # local to global + distributed.all_reduce(sum_logits_exp, distributed.ReduceOp.SUM) + logits.div_(sum_logits_exp) + index = torch.where(label != -1)[0] + # loss + loss = torch.zeros(batch_size, 1, device=logits.device, dtype=logits.dtype) + loss[index] = logits[index].gather(1, label[index]) + distributed.all_reduce(loss, distributed.ReduceOp.SUM) + ctx.save_for_backward(index, logits, label) + return loss.clamp_min_(1e-30).log_().mean() * (-1) + + @staticmethod + def backward(ctx, loss_gradient): + """ + Args: + loss_grad (torch.Tensor): gradient backward by last layer + Returns: + gradients for each input in forward function + `None` gradients for one-hot label + """ + ( + index, + logits, + label, + ) = ctx.saved_tensors + batch_size = logits.size(0) + one_hot = torch.zeros( + size=[index.size(0), logits.size(1)], device=logits.device + ) + one_hot.scatter_(1, label[index], 1) + logits[index] -= one_hot + logits.div_(batch_size) + return logits * loss_gradient.item(), None + + +class DistCrossEntropy(torch.nn.Module): + def __init__(self): + super(DistCrossEntropy, self).__init__() + + def forward(self, logit_part, label_part): + return DistCrossEntropyFunc.apply(logit_part, label_part) + + +class AllGatherFunc(torch.autograd.Function): + """AllGather op with gradient backward""" + + @staticmethod + def forward(ctx, tensor, *gather_list): + gather_list = list(gather_list) + distributed.all_gather(gather_list, tensor) + return tuple(gather_list) + + @staticmethod + def backward(ctx, *grads): + grad_list = list(grads) + rank = distributed.get_rank() + grad_out = grad_list[rank] + + dist_ops = [ + distributed.reduce(grad_out, rank, distributed.ReduceOp.SUM, async_op=True) + if i == rank + else distributed.reduce( + grad_list[i], i, distributed.ReduceOp.SUM, async_op=True + ) + for i in range(distributed.get_world_size()) + ] + for _op in dist_ops: + _op.wait() + + grad_out *= len(grad_list) # cooperate with distributed loss function + return (grad_out, *[None for _ in range(len(grad_list))]) + + +AllGather = AllGatherFunc.apply diff --git a/cvlface/research/recognition/code/run_v1/config.py b/cvlface/research/recognition/code/run_v1/config.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ebf5410b07c88279d19d2b492cbd1d67012459 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/config.py @@ -0,0 +1,100 @@ +from hydra import compose, initialize +import hydra +import omegaconf +from omegaconf import OmegaConf +import os +from time import gmtime, strftime +from dataclasses import dataclass + + +@dataclass +class Config: + trainers: omegaconf.dictconfig.DictConfig + optims: omegaconf.dictconfig.DictConfig + models: omegaconf.dictconfig.DictConfig + dataset: omegaconf.dictconfig.DictConfig + data_augs: omegaconf.dictconfig.DictConfig + losses: omegaconf.dictconfig.DictConfig + classifiers: omegaconf.dictconfig.DictConfig + aligners: omegaconf.dictconfig.DictConfig + pipelines: omegaconf.dictconfig.DictConfig + evaluations: omegaconf.dictconfig.DictConfig + pefts: omegaconf.dictconfig.DictConfig + + +def init(root): + initialize(config_path="./", job_name="configs") + args_parser = hydra._internal.utils.get_args_parser() + args = args_parser.parse_args() + cfg = compose(config_name="base", overrides=args.overrides) + + + # writing config's path + OmegaConf.set_struct(cfg, False) + base_cfg = {} + for row in OmegaConf.load('base.yaml')['defaults']: + for key, val in row.items(): + base_cfg[key] = val + for key in Config.__dataclass_fields__.keys(): + has_override = [ + x.split('=')[1] if '.yaml' in x.split('=')[1] else x.split('=')[1] + '.yaml' + for x in args.overrides if key + '=' in x + ] + if has_override: + print('has_override', has_override[0]) + getattr(cfg, key).yaml_path = '/'+has_override[0] + else: + getattr(cfg, key).yaml_path = '/'+base_cfg[key] + OmegaConf.set_struct(cfg, True) + + # make output_dir + output_dir = prepare_output_dir(cfg, root) + cfg.trainers.output_dir = output_dir + os.makedirs(cfg.trainers.output_dir, exist_ok=True) + + return cfg + + +def parse_config_string(config_string): + lst = config_string.split('.') + assert len(lst) <= 2 + if len(lst) == 1: + return 'configs/' + lst[0] + '.yaml' + else: + return lst[0] + '/configs/' + lst[1] + '.yaml' + +def load_yaml(config_string, directory='models'): + yaml_path = os.path.join(directory, parse_config_string(config_string)) + assert os.path.exists(yaml_path), yaml_path + cfg = OmegaConf.load(yaml_path) + cfg.yaml_path = yaml_path + return cfg + +def is_used_directory(directory): + return os.path.isdir(directory) and os.path.exists(os.path.join(directory, 'train.py')) + +def prepare_output_dir(cfg, root): + # set working dir + cur_time = strftime("%m-%d_0", gmtime()) + task = os.path.basename(os.path.dirname(__file__)) + cfg.trainers.task = task + output_dir = os.path.join(root, 'research/recognition/experiments', cfg.trainers.task, cfg.trainers.prefix + "_" + cur_time) + if is_used_directory(output_dir): + while True: + cur_exp_number = int(output_dir[-2:].replace('_', "")) + output_dir = output_dir[:-2] + "_{}".format(cur_exp_number+1) + # replace repeating _ with _ + output_dir = output_dir.replace('__', '_') + if not is_used_directory(output_dir): + break + + if cfg.pipelines.resume: + print('resume ', cfg.pipelines.resume) + assert os.path.isdir(cfg.pipelines.resume) + if '/checkpoints/' in cfg.pipelines.resume: + output_dir = cfg.pipelines.resume.split('/checkpoints/')[0] + else: + output_dir = cfg.pipelines.resume + + os.makedirs(output_dir, exist_ok=True) + return output_dir diff --git a/cvlface/research/recognition/code/run_v1/data_augs/__init__.py b/cvlface/research/recognition/code/run_v1/data_augs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8ce6ad00d89521810735b91b9b09887248741f59 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/data_augs/__init__.py @@ -0,0 +1,14 @@ +from .basic_augmenter import BasicAugmenter + +def make_augmenter(augmentation_version, aug_params): + if augmentation_version == 'basic': + augmenter = BasicAugmenter(crop_augmentation_prob=aug_params.crop_augmentation_prob, + photometric_augmentation_prob=aug_params.photometric_augmentation_prob, + low_res_augmentation_prob=aug_params.low_res_augmentation_prob, + ) + elif augmentation_version == 'gridsample': + from .gridsample_augmenter import GridSampleAugmenter + augmenter = GridSampleAugmenter(aug_params, input_size=112) + else: + raise ValueError('not correct augmentation version') + return augmenter diff --git a/cvlface/research/recognition/code/run_v1/data_augs/aug_utils/sanity_check.py b/cvlface/research/recognition/code/run_v1/data_augs/aug_utils/sanity_check.py new file mode 100644 index 0000000000000000000000000000000000000000..da6d41bb7a52cd9efe0a6eded70b4e0a86574e0a --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/data_augs/aug_utils/sanity_check.py @@ -0,0 +1,58 @@ +from data_augs.aug_utils import transform_cv2 +from data_augs.aug_utils import transform_torch +import os +import numpy as np +from PIL import Image +import random + +if __name__ == '__main__': + + # set seed + np.random.seed(0) + random.seed(0) + + root = '/data/data/faces/casia_webface/raw/CASIA-WebFace_raw_aligned_mtcnn/3010926' + images = [os.path.join(root, name) for name in os.listdir(root)] + sample_images = [np.array(Image.open(image_path)) for image_path in images] + + # original shape + os.makedirs('/mckim/temp/temp_aug_v12', exist_ok=True) + os.makedirs('/mckim/temp/temp_aug_v12_torch', exist_ok=True) + output_width, output_height = 160, 160 + + os.makedirs('/mckim/temp/temp_aug_v12_determ', exist_ok=True) + for i, image in enumerate(sample_images): + + params = transform_cv2.sample_param(scale_min=0.8, scale_max=1.2, rot_prob=1.0, max_rot=45, hflip_prob=0.5, extra_offset=0.5) + + mat_cv = transform_cv2.generate_transform_cv2(image, output_width, output_height, **params) + output_cv = transform_cv2.augment_cv2_deterministic(image.copy(), mat_cv, output_width, output_height) + output_cv.save('/mckim/temp/temp_aug_v12_determ/{}_cv.png'.format(i)) + + mat_torch = transform_torch.generate_transform_torch(image, output_width, output_height, **params) + output_torch = transform_torch.augment_torch_deterministic(image.copy(), mat_torch, output_width, output_height) + output_torch.save('/mckim/temp/temp_aug_v12_determ/{}_torch.png'.format(i)) + + + # time it (torch is slower) + import time + + start = time.time() + for _ in range(1000): + for i, image in enumerate(sample_images): + params = transform_cv2.sample_param(scale_min=0.8, scale_max=1.2, rot_prob=1.0, max_rot=45, hflip_prob=0.5, extra_offset=0.5) + mat_cv = transform_cv2.generate_transform_cv2(image, output_width, output_height, **params) + output_cv = transform_cv2.augment_cv2_deterministic(image.copy(), mat_cv, output_width, output_height) + + end = time.time() + print('cv2 time: {}'.format(end - start)) + + start = time.time() + for _ in range(1000): + for i, image in enumerate(sample_images): + params = transform_cv2.sample_param(scale_min=0.8, scale_max=1.2, rot_prob=1.0, max_rot=45, hflip_prob=0.5, extra_offset=0.5) + mat_torch = transform_torch.generate_transform_torch(image, output_width, output_height, **params) + output_torch = transform_torch.augment_torch_deterministic(image.copy(), mat_torch, output_width, output_height) + + end = time.time() + print('torch time: {}'.format(end - start)) diff --git a/cvlface/research/recognition/code/run_v1/data_augs/aug_utils/transform_cv2.py b/cvlface/research/recognition/code/run_v1/data_augs/aug_utils/transform_cv2.py new file mode 100644 index 0000000000000000000000000000000000000000..6a9c73f6a01e00e0db8ca3040c557c5bf8c04ada --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/data_augs/aug_utils/transform_cv2.py @@ -0,0 +1,137 @@ +from PIL import Image +import numpy as np +import os +import cv2 +import torch +import random + + +def sample_param_debug(): + result = { + 'px': 0.5, + 'py': 0.5, + 'signx': 1, + 'signy': 1, + 'scale': 1, + 'angle': 0, + 'hflip': 0, + 'extra_offset_val': 0, + } + return result + + +def sample_param(scale_min=1.0, scale_max=1.0, rot_prob=0.0, max_rot=0, hflip_prob=0.0, extra_offset=0.0): + px = np.random.uniform(0, 1) + py = np.random.uniform(0, 1) + signx = (-1) ** np.random.randint(0, 2) + signy = (-1) ** np.random.randint(0, 2) + scale = np.random.uniform(scale_min, scale_max) + + if random.random() < rot_prob: + angle = np.random.uniform(-max_rot, max_rot) + else: + angle = 0 + + if random.random() < hflip_prob: + hflip = 1 + else: + hflip = 0 + + extra_offset_val = np.random.uniform(0, extra_offset) + result = { + 'px': px, + 'py': py, + 'signx': signx, + 'signy': signy, + 'scale': scale, + 'angle': angle, + 'hflip': hflip, + 'extra_offset_val': extra_offset_val, + } + return result + + +def generate_transform_cv2(image_np, output_width, output_height, px, py, signx, signy, scale, angle, hflip, extra_offset_val): + assert image_np.ndim == 3 + orig_shape = image_np.shape + + transforms = [] + + # origin + center_x = orig_shape[1] // 2 + center_y = orig_shape[0] // 2 + + # rotation + if angle != 0: + translation = np.asarray([[1, 0, -center_x], + [0, 1, -center_y], + [0, 0, 1]], dtype=np.float32) + transforms.append(translation) + + angle = np.deg2rad(angle) + rotation = np.asarray([[np.cos(-angle), -np.sin(-angle), 0], + [np.sin(-angle), np.cos(-angle), 0], + [0, 0, 1]], dtype=np.float32) + transforms.append(rotation) + translation = np.asarray([[1, 0, center_x], + [0, 1, center_y], + [0, 0, 1]], dtype=np.float32) + transforms.append(translation) + ######################### + + # possible offset without going out of bounds + padx = (output_width // 2 / scale - orig_shape[1] // 2) + pady = (output_height // 2 / scale - orig_shape[0] // 2) + # padx = np.abs(padx) + # pady = np.abs(pady) + padx = padx + extra_offset_val * center_x + pady = pady + extra_offset_val * center_y + + padx = padx * signx + pady = pady * signy + + # patch center + tx = center_x * (1-px) + (center_x + padx) * px + ty = center_y * (1-py) + (center_y + pady) * py + + translation = np.asarray([[1, 0, -tx], + [0, 1, -ty], + [0, 0, 1]], dtype=np.float32) + transforms.append(translation) + + # scale + scale_matrix = np.asarray([[scale, 0, 0], + [0, scale, 0], + [0, 0, 1]], dtype=np.float32) + transforms.append(scale_matrix) + + + # horizontal flip + if hflip: + transforms.append(np.asarray([[-1, 0, 0], + [0, 1, 0], + [0, 0, 1]], dtype=np.float32)) + + + # move to center + reverse_translation = np.asarray([[1, 0, output_width / 2], + [0, 1, output_height / 2], + [0, 0, 1]], dtype=np.float32) + transforms.append(reverse_translation) + + # aggregate + final_transform = np.eye(3) + for t in transforms: + final_transform = np.matmul(t, final_transform) + + return final_transform + + +def augment_cv2_deterministic(image_np, transform, output_width, output_height): + image_t = cv2.warpPerspective(image_np, + transform, + (output_width, output_height), + borderValue=0) + image_t = Image.fromarray(image_t) + return image_t + diff --git a/cvlface/research/recognition/code/run_v1/data_augs/aug_utils/transform_torch.py b/cvlface/research/recognition/code/run_v1/data_augs/aug_utils/transform_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b0dd24837f4281ebd24fb58d30083dd4d15ffd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/data_augs/aug_utils/transform_torch.py @@ -0,0 +1,121 @@ +from PIL import Image +import numpy as np +import os +import cv2 +import torch +import random + + +def sample_param_debug(): + result = { + 'px': 0.5, + 'py': 0.5, + 'signx': 1, + 'signy': 1, + 'scale': 1, + 'angle': 0, + 'hflip': 0, + 'extra_offset_val': 0, + } + return result + + +def sample_param(scale_min=1.0, scale_max=1.0, rot_prob=0.0, max_rot=0, hflip_prob=0.0, extra_offset=0.0, **kwargs, ): + px = np.random.uniform(0, 1) + py = np.random.uniform(0, 1) + signx = (-1) ** np.random.randint(0, 2) + signy = (-1) ** np.random.randint(0, 2) + scale = np.random.uniform(scale_min, scale_max) + + if random.random() < rot_prob: + angle = np.random.uniform(-max_rot, max_rot) + else: + angle = 0 + + if random.random() < hflip_prob: + hflip = 1 + else: + hflip = 0 + + extra_offset_val = np.random.uniform(0, extra_offset) + result = { + 'px': px, + 'py': py, + 'signx': signx, + 'signy': signy, + 'scale': scale, + 'angle': angle, + 'hflip': hflip, + 'extra_offset_val': extra_offset_val, + } + return result + + + +def generate_transform_torch(image_np, output_width, output_height, px, py, signx, signy, scale, angle, hflip, extra_offset_val): + assert image_np.ndim == 3 + orig_shape = image_np.shape + orig_width = orig_shape[1] + orig_height = orig_shape[0] + + # origin + center_x = 0 + center_y = 0 + + # extreme + extreme_x = 1 - (scale * 1 * orig_width / output_width) + extreme_y = 1 - (scale * 1 * orig_height / output_height) + + extreme_x = extreme_x + extra_offset_val * (scale * 1 * orig_width / output_width) + extreme_y = extreme_y + extra_offset_val * (scale * 1 * orig_height / output_height) + + + tx = center_x * (1 - px) + extreme_x * px + ty = center_y * (1 - py) + extreme_y * py + tx = tx * signx + ty = ty * signy + + transforms = [] + + # horizontal flip + if hflip: + transforms.append(np.asarray([[-1, 0, 0], + [0, 1, 0], + [0, 0, 1]], dtype=np.float32)) + + translation = np.asarray([[1, 0, +tx], # -1 to 1 + [0, 1, +ty], # -1 to 1 + [0, 0, 1]], dtype=np.float32) + transforms.append(translation) + + # scale + scale_x = output_width / orig_width / scale + scale_y = output_height / orig_height / scale + scale_matrix = np.asarray([[scale_x, 0, 0], + [0, scale_y, 0], + [0, 0, 1]], dtype=np.float32) + transforms.append(scale_matrix) + + # rotation + angle = np.deg2rad(angle) + rotation = np.asarray([[np.cos(angle), -np.sin(angle), 0], + [np.sin(angle), np.cos(angle), 0], + [0, 0, 1]], dtype=np.float32) + transforms.append(rotation) + + # aggregate + final_transform = np.eye(3) + for t in transforms: + final_transform = np.matmul(t, final_transform) + final_transform = torch.from_numpy(final_transform).float()[None, :2, :] + return final_transform + + +def augment_torch_deterministic(image_np, transform, output_width, output_height): + grid = torch.nn.functional.affine_grid(transform, + [1, image_np.shape[2], output_height, output_width], align_corners=True) + image_torch = torch.from_numpy(image_np).float().permute(2, 0, 1)[None, :, :, :] + image_t = torch.nn.functional.grid_sample(image_torch, grid, align_corners=True) + image_t = image_t.permute(0, 2, 3, 1).numpy()[0].astype(np.uint8) + image_t = Image.fromarray(image_t) + return image_t diff --git a/cvlface/research/recognition/code/run_v1/data_augs/basic_augmenter.py b/cvlface/research/recognition/code/run_v1/data_augs/basic_augmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..a4d6142bae464f8fe99618154b62924216408d8a --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/data_augs/basic_augmenter.py @@ -0,0 +1,116 @@ +import numpy as np +import cv2 +from torchvision.transforms import functional as F +from PIL import Image +from torchvision import transforms + + +class BasicAugmenter(): + + def __init__(self, crop_augmentation_prob, photometric_augmentation_prob, low_res_augmentation_prob): + self.crop_augmentation_prob = crop_augmentation_prob + self.photometric_augmentation_prob = photometric_augmentation_prob + self.low_res_augmentation_prob = low_res_augmentation_prob + self.random_resized_crop = transforms.RandomResizedCrop(size=(112, 112), + scale=(0.2, 1.0), + ratio=(0.75, 1.3333333333333333)) + self.photometric = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0) + + def augment(self, sample): + + # crop with zero padding augmentation + if np.random.random() < self.crop_augmentation_prob: + # RandomResizedCrop augmentation + sample, crop_ratio = self.crop_augment(sample) + + # low resolution augmentation + if np.random.random() < self.low_res_augmentation_prob: + # low res augmentation + img_np, resize_ratio = self.low_res_augmentation(np.array(sample)) + sample = Image.fromarray(img_np.astype(np.uint8)) + + # photometric augmentation + if np.random.random() < self.photometric_augmentation_prob: + sample = self.photometric_augmentation(sample) + + # random flip + if np.random.random() < 0.5: + sample = F.hflip(sample) + + return sample + + def crop_augment(self, sample): + new = np.zeros_like(np.array(sample)) + if hasattr(F, '_get_image_size'): + orig_W, orig_H = F._get_image_size(sample) + else: + # torchvision 0.11.0 and above + orig_W, orig_H = F.get_image_size(sample) + i, j, h, w = self.random_resized_crop.get_params(sample, + self.random_resized_crop.scale, + self.random_resized_crop.ratio) + cropped = F.crop(sample, i, j, h, w) + new[i:i+h,j:j+w, :] = np.array(cropped) + sample = Image.fromarray(new.astype(np.uint8)) + crop_ratio = min(h, w) / max(orig_H, orig_W) + return sample, crop_ratio + + def low_res_augmentation(self, img): + # resize the image to a small size and enlarge it back + img_shape = img.shape + side_ratio = np.random.uniform(0.2, 1.0) + small_side = int(side_ratio * img_shape[0]) + interpolation = np.random.choice( + [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) + small_img = cv2.resize(img, (small_side, small_side), interpolation=interpolation) + interpolation = np.random.choice( + [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) + aug_img = cv2.resize(small_img, (img_shape[1], img_shape[0]), interpolation=interpolation) + + return aug_img, side_ratio + + def photometric_augmentation(self, sample): + fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \ + self.photometric.get_params(self.photometric.brightness, self.photometric.contrast, + self.photometric.saturation, self.photometric.hue) + for fn_id in fn_idx: + if fn_id == 0 and brightness_factor is not None: + sample = F.adjust_brightness(sample, brightness_factor) + elif fn_id == 1 and contrast_factor is not None: + sample = F.adjust_contrast(sample, contrast_factor) + elif fn_id == 2 and saturation_factor is not None: + sample = F.adjust_saturation(sample, saturation_factor) + + return sample + + + +def main(): + from PIL import Image, ImageDraw + import torch + + image = Image.open('/data/data/faces/ms1mv2_subset_images/84946/5770863.jpg') + # draw a square box on the image + image_draw = ImageDraw.Draw(image) + image_draw.rectangle((10, 10, 110, 110), outline='red') + image_draw.rectangle((0, 0, 120, 120), outline='blue') + + augmenter = BasicAugmenter(0.2, 0.2, 0.2) + # make a grid 10x10 + grids = [] + for i in range(10): + grid = [] + for j in range(10): + align_input_sample = augmenter.augment(image) + grid.append(align_input_sample) + grids.append(grid) + # save the grid + grid_image = Image.new('RGB', (1120, 1120)) + for i in range(10): + for j in range(10): + grid_image.paste(grids[i][j], (112 * j, 112 * i)) + grid_image.save(f'/mckim/temp/BasicAugmenter.jpg') + + +if __name__ == '__main__': + main() diff --git a/cvlface/research/recognition/code/run_v1/data_augs/configs/basic_v1.yaml b/cvlface/research/recognition/code/run_v1/data_augs/configs/basic_v1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7fdd868566af68b1d243db5582846e28d099ec00 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/data_augs/configs/basic_v1.yaml @@ -0,0 +1,5 @@ + +augmentation_version: 'basic' +aug_params: { 'crop_augmentation_prob': 0.2, + 'photometric_augmentation_prob': 0.2, + 'low_res_augmentation_prob': 0.2} diff --git a/cvlface/research/recognition/code/run_v1/data_augs/configs/gridsample_v1.yaml b/cvlface/research/recognition/code/run_v1/data_augs/configs/gridsample_v1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7aedc487d56fd9e873cebc76b22c103a66bd7c4d --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/data_augs/configs/gridsample_v1.yaml @@ -0,0 +1,15 @@ +augmentation_version: 'gridsample' +aug_params: { 'scale_min': 0.8, + 'scale_max': 1.2, + 'rot_prob': 0.2, + 'max_rot': 20, + 'hflip_prob': 0.5, + 'extra_offset': 0.1, + 'photometric_num_ops': 2, + 'photometric_magnitude': 14, + 'photometric_magnitude_offset': 9, + 'photometric_num_magnitude_bins': 31, + 'blur_magnitude': 1.0, + 'blur_prob': 0.2, + 'cutout_prob': 0.2,} +disable_aug_during_warmup: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/data_augs/gridsample_augmenter.py b/cvlface/research/recognition/code/run_v1/data_augs/gridsample_augmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..baa52bf682583fc9f57a975bb1965266d61c5788 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/data_augs/gridsample_augmenter.py @@ -0,0 +1,346 @@ +import numpy as np +from data_augs.aug_utils import transform_torch +from data_augs.aug_utils import transform_cv2 +from PIL import Image +from PIL import ImageDraw +import torch +from typing import Tuple, Dict +from torch import Tensor +from torchvision.transforms import functional as F +import imgaug.augmenters as iaa +import cv2 +import albumentations as A +from torchvision import transforms + + +class GridSampleAugmenter(): + + ''' + GridSampleAugmenter: + This class is used to augment the input image while keeping track of the corresponding theta for grid sampling. + Output is (image, theta) where theta can be used as + + >>>from torchvision.transforms import ToTensor + >>>image_tensor = ToTensor()(image_pil).unsqueeze(0) + >>>align_input_theta = theta.unsqueeze(0) + >>>b, c, h, w = image_tensor.shape + >>>sample_grid = torch.nn.functional.affine_grid(align_input_theta, [b, c, h, w], align_corners=True) + >>>image_tensor_aug = torch.nn.functional.grid_sample(image_tensor, sample_grid, align_corners=True) + ''' + + def __init__(self, aug_params, input_size=112): + + print('GridSampleAugmenter') + self.aug_params = aug_params + self.input_size = input_size + self.photo_aug = PhotometricRandAugment(num_ops=self.aug_params['photometric_num_ops'], + magnitude=self.aug_params['photometric_magnitude'], + magnitude_offset=self.aug_params['photometric_magnitude_offset'], + num_magnitude_bins=self.aug_params['photometric_num_magnitude_bins']) + self.blur_aug = BlurAugmenter(magnitude=self.aug_params['blur_magnitude'], prob=self.aug_params['blur_prob']) + self.cutout = CutoutAugment(aug_params['cutout_prob']) + + def augment(self, sample): + image_np = np.array(sample) + + # augment + params = transform_torch.sample_param( + scale_min=self.aug_params['scale_min'], + scale_max=self.aug_params['scale_max'], + rot_prob=self.aug_params['rot_prob'], + max_rot=self.aug_params['max_rot'], + hflip_prob=self.aug_params['hflip_prob'], + extra_offset=self.aug_params['extra_offset'], + ) + mat = transform_cv2.generate_transform_cv2(image_np, self.input_size, self.input_size, **params) + aug_sample = transform_cv2.augment_cv2_deterministic(image_np, mat, self.input_size, self.input_size) + + # corresponding theta + align_input_theta = transform_torch.generate_transform_torch(image_np, self.input_size, self.input_size, **params) + align_input_theta = align_input_theta.squeeze(0) + + # cutout + aug_sample = self.cutout.augment(aug_sample) + + # blur + blur_params = self.blur_aug.sample_param() + aug_sample = self.blur_aug.augment(aug_sample, param=blur_params) + + # photometric + photo_params = self.photo_aug.sample_param() + aug_sample = self.photo_aug.augment(aug_sample, param=photo_params) + + return aug_sample, align_input_theta + + +class CutoutAugment(): + + def __init__(self, cutout_prob): + self.cutout_prob = cutout_prob + self.dropout = A.CoarseDropout(max_holes=20, # Maximum number of regions to zero out. (default: 8) + max_height=16, # Maximum height of the hole. (default: 8) + max_width=16, # Maximum width of the hole. (default: 8) + min_holes=12, # Maximum number of regions to zero out. (default: None, which equals max_holes) + min_height=None, # Maximum height of the hole. (default: None, which equals max_height) + min_width=None, # Maximum width of the hole. (default: None, which equals max_width) + fill_value=0, # value for dropped pixels. + mask_fill_value=None, # fill value for dropped pixels in mask. + always_apply=False, + p=1.0 + ) + self.random_resized_crop = transforms.RandomResizedCrop(size=(112, 112), + scale=(0.2, 1.0), + ratio=(0.75, 1.3333333333333333)) + + def augment(self, sample): + if np.random.random() < self.cutout_prob: + if np.random.random() < 0.05: + # not too natural + return Image.fromarray(self.dropout(image=np.array(sample))['image']) + else: + new = np.zeros_like(np.array(sample)) + i, j, h, w = self.random_resized_crop.get_params(sample, + self.random_resized_crop.scale, + self.random_resized_crop.ratio) + cropped = F.crop(sample, i, j, h, w) + new[i:i+h,j:j+w, :] = np.array(cropped) + sample = Image.fromarray(new.astype(np.uint8)) + return sample + else: + return sample + + +class PhotometricRandAugment(): + + def __init__(self, + num_ops: int = 2, + magnitude: int = 9, + magnitude_offset: int = 4, + num_magnitude_bins: int = 31) -> None: + self.num_ops = num_ops + self.magnitude = magnitude + self.magnitude_offset = magnitude_offset + self.num_magnitude_bins = num_magnitude_bins + self.op_names = list(self._augmentation_space(self.num_magnitude_bins).keys()) + self.op_meta = self._augmentation_space(self.num_magnitude_bins) + + def _augmentation_space(self, num_bins: int) -> Dict[str, Tuple[Tensor, bool]]: + return { + # op_name: (magnitudes, signed) + "Identity": (torch.tensor(0.0), False), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Saturate": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Equalize": (torch.tensor(0.0), False), + "Grayscale": (torch.tensor(0.0), False), + } + + def apply_op(self, img: Tensor, op_name: str, magnitude: float): + if op_name == "Brightness": + img = F.adjust_brightness(img, 1.0 + magnitude) + elif op_name == "Saturate": + img = F.adjust_saturation(img, 1.0 + magnitude) + elif op_name == "Contrast": + img = F.adjust_contrast(img, 1.0 + magnitude) + elif op_name == "Sharpness": + img = F.adjust_sharpness(img, 1.0 + magnitude) + elif op_name == "Equalize": + img = F.equalize(img) + elif op_name == 'Grayscale': + img = F.to_grayscale(img, num_output_channels=3) + elif op_name == "Identity": + pass + else: + raise ValueError("The provided operator {} is not recognized.".format(op_name)) + return img + + def sample_param(self): + ops = [] + for _ in range(self.num_ops): + # random sample op + op_name = np.random.choice(self.op_names) + # reduce probability of these two ops + if op_name in ['Equalize', 'Grayscale']: + op_name = np.random.choice(self.op_names) + if op_name in ['Equalize', 'Grayscale']: + op_name = np.random.choice(self.op_names) + + magnitudes, signed = self.op_meta[op_name] + # random sample magnitude + magnitude_idx = np.random.randint(self.magnitude-self.magnitude_offset, + self.magnitude+self.magnitude_offset) + magnitude_idx = np.clip(magnitude_idx, 0, self.num_magnitude_bins-1) + if magnitudes.ndim > 0: + magnitude = float(magnitudes[magnitude_idx].item()) + else: + magnitude = 0.0 + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + ops.append((op_name, magnitude)) + return ops + + def augment(self, img: Tensor, param=None) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + Returns: + PIL Image or Tensor: Transformed image. + """ + if param is None: + param = self.sample_param() + for op_name, magnitude in param: + img = self.apply_op(img, op_name, magnitude) + + return img + + + +class BlurAugmenter(): + + def __init__(self, magnitude=0.5, prob=0.2): + self.magnitude = magnitude + self.prob = prob + + def sample_param(self): + if np.random.random() < self.prob: + blur_method = np.random.choice(['avg', 'gaussian', + 'resize', 'resize', 'resize', 'resize', + 'resize', 'resize', 'resize', 'resize']) # more resizing aug, no motion + if blur_method == 'avg': + k = np.random.randint(1, int(10 * self.magnitude)) + param = [blur_method, k] + elif blur_method == 'gaussian': + sigma = np.random.random() * 4 * self.magnitude + param = [blur_method, sigma] + elif blur_method == 'motion': + k = np.random.randint(5, max(int(10 * self.magnitude), 6)) + angle = np.random.randint(-45, 45) + direction = np.random.random() * 2 - 1 + param = [blur_method, k, angle, direction] + elif blur_method == 'resize': + side_ratio = np.random.uniform(1.0 - 0.8 * self.magnitude, 1.0) + interpolation1 = np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, + cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) + interpolation2 = np.random.choice([cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_AREA, + cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]) + param = [blur_method, side_ratio, [interpolation1, interpolation2]] + else: + raise ValueError('not a correct blur') + else: + param = ['skip'] + + return param + + def augment(self, sample, param=None): + if param is None: + param = self.sample_param() + blur_method = param[0] + if blur_method == 'skip': + return sample + + if blur_method == 'avg': + blur_method, k = param + avg_blur = iaa.AverageBlur(k=k) # max 10 + blurred = avg_blur(image=np.array(sample)) + elif blur_method == 'gaussian': + blur_method, sigma = param + gaussian_blur = iaa.GaussianBlur(sigma=sigma) # 4 is max + blurred = gaussian_blur(image=np.array(sample)) + elif blur_method == 'motion': + blur_method, k, angle, direction = param + motion_blur = iaa.MotionBlur(k=k, angle=angle, direction=direction) # k 20 max angle:-45 45, dir:-1 1 + blurred = motion_blur(image=np.array(sample)) + elif blur_method == 'resize': + blur_method, side_ratio, interpolation = param + blurred = self.low_res_augmentation(np.array(sample), side_ratio, interpolation) + else: + raise ValueError('not a correct blur') + + sample = Image.fromarray(blurred.astype(np.uint8)) + + return sample + + def low_res_augmentation(self, img, side_ratio, interpolation): + # resize the image to a small size and enlarge it back + img_shape = img.shape + small_side = int(side_ratio * img_shape[0]) + small_img = cv2.resize(img, (small_side, small_side), interpolation=interpolation[0]) + aug_img = cv2.resize(small_img, (img_shape[1], img_shape[0]), interpolation=interpolation[1]) + return aug_img + + +def main(): + image = Image.open('/data/data/faces/ms1mv2_subset_images/84946/5770863.jpg') + # draw a square box on the image + image_draw = ImageDraw.Draw(image) + image_draw.rectangle((10, 10, 110, 110), outline='red') + image_draw.rectangle((0, 0, 120, 120), outline='blue') + + scale_min = 0.7 + scale_max = 2.0 + rot_prob = 0.2 + max_rot = 30 + hflip_prob = 0.5 + extra_offset = 0.15 + + photometric_num_ops = 2 + photometric_magnitude = 14 + photometric_magnitude_offset = 9 + photometric_num_magnitude_bins = 31 + + blur_magnitude = 1.0 + blur_prob = 0.3 + cutout_prob = 0.2 + + aug_params = { + 'scale_min': scale_min, + 'scale_max': scale_max, + 'rot_prob': rot_prob, + 'max_rot': max_rot, + 'hflip_prob': hflip_prob, + 'extra_offset': extra_offset, + 'photometric_num_ops': photometric_num_ops, + 'photometric_magnitude': photometric_magnitude, + 'photometric_magnitude_offset': photometric_magnitude_offset, + 'photometric_num_magnitude_bins': photometric_num_magnitude_bins, + 'blur_magnitude': blur_magnitude, + 'blur_prob': blur_prob, + 'cutout_prob': cutout_prob + } + align_input_size = 112 + augmenter = GridSampleAugmenter(aug_params, align_input_size) + # make a grid 10x10 + grids = [] + grids_theta = [] + for i in range(10): + grid = [] + grid_theta = [] + for j in range(10): + align_input_sample, align_input_theta = augmenter.augment(image) + grid.append(align_input_sample) + from torchvision.transforms import ToTensor + image_tensor = ToTensor()(image).unsqueeze(0) + align_input_theta = align_input_theta.unsqueeze(0) + b, c, h, w = image_tensor.shape + sample_grid = torch.nn.functional.affine_grid(align_input_theta, [b, c, h, w], align_corners=True) + image_tensor_aug = torch.nn.functional.grid_sample(image_tensor, sample_grid, align_corners=True) + from general_utils.img_utils import tensor_to_pil + grid_theta.append(tensor_to_pil(image_tensor_aug)[0]) + grids.append(grid) + grids_theta.append(grid_theta) + # save the grid + grid_image = Image.new('RGB', (1120, 1120)) + for i in range(10): + for j in range(10): + grid_image.paste(grids[i][j], (112 * j, 112 * i)) + grid_image.save(f'/mckim/temp/GridSampleAugmenter.jpg') + + grid_theta_image = Image.new('RGB', (1120, 1120)) + for i in range(10): + for j in range(10): + grid_theta_image.paste(grids_theta[i][j], (112 * j, 112 * i)) + grid_theta_image.save(f'/mckim/temp/GridSampleAugmenter_by_theta.jpg') + + +if __name__ == '__main__': + main() diff --git a/cvlface/research/recognition/code/run_v1/dataset/__init__.py b/cvlface/research/recognition/code/run_v1/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..423183e4f065f8da30afbf21b13aafbcb80f28b0 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/__init__.py @@ -0,0 +1,120 @@ +import os +from torchvision.datasets import ImageFolder +import numpy as np +from torchvision.utils import make_grid +import torch +import cv2 +from PIL import Image +import random + +def get_train_dataset(dataset_cfg, train_transform, aug_cfg, local_rank=0): + + # batch_size = cfg.trainers.batch_size + # num_workers = cfg.trainers.num_workers + # local_rank = cfg.trainers.local_rank + # world_size = cfg.trainers.world_size + + root_dir = os.path.join(dataset_cfg.data_root, dataset_cfg.rec) + rec = os.path.join(root_dir, 'train.rec') + idx = os.path.join(root_dir, 'train.idx') + + # Synthetic + if dataset_cfg.rec == "synthetic": + from .base_dataset import SyntheticDataset + train_set = SyntheticDataset(dataset_cfg.num_classes, dataset_cfg.num_image) + label_mapping = None + + # Mxnet RecordIO + elif os.path.exists(rec) and os.path.exists(idx): + if aug_cfg.augmentation_version == 'none': + from .base_dataset import MXFaceDataset + assert dataset_cfg.repeated_sampling_cfg is None + train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank) + else: + if dataset_cfg.repeated_sampling_cfg is not None: + if dataset_cfg.repeated_sampling_cfg.ldmk_path: + from .repeated_dataset_with_ldmk_theta import RepeatedWithLdmkThetaMXDataset + # repeated sampling + augmentation + ldmk + train_set = RepeatedWithLdmkThetaMXDataset(root_dir=root_dir, local_rank=local_rank, + augmentation_version=aug_cfg.augmentation_version, + aug_params=aug_cfg.aug_params, + repeated_sampling_cfg=dataset_cfg.repeated_sampling_cfg) + else: + from .repeated_dataset import RepeatedSamplingMXDataset + # repeated sampling + augmentation + train_set = RepeatedSamplingMXDataset(root_dir=root_dir, local_rank=local_rank, + augmentation_version=aug_cfg.augmentation_version, + aug_params=aug_cfg.aug_params, + repeated_sampling_cfg=dataset_cfg.repeated_sampling_cfg) + else: + from .augment_dataset import AugmentMXDataset + # augmentation + train_set = AugmentMXDataset(root_dir=root_dir, local_rank=local_rank, + augmentation_version=aug_cfg.augmentation_version, + aug_params=aug_cfg.aug_params) + + train_set.transform = train_transform + + # resample dataset if needed + if hasattr(dataset_cfg, 'resample_dataset') and dataset_cfg.resample_dataset: + from .subset_dataset import SubsetDataset + + if dataset_cfg.resample_dataset == 'one_half': + removing_index = list(set(range(0, len(train_set))) - set(range(0, len(train_set), 2))) + elif dataset_cfg.resample_dataset == 'one_fourth': + removing_index = list(set(range(0, len(train_set))) - set(range(0, len(train_set), 4))) + else: + removing_index = np.load(os.path.join(root_dir, dataset_cfg.resample_dataset)) + train_set = SubsetDataset(train_set, removing_index) + dataset_cfg.num_classes = len(train_set.unique_label) + label_mapping = train_set.label_mapping + else: + label_mapping = None + + elif dataset_cfg.rec == '': + raise ValueError('No dataset is provided') + + # Image Folder + else: + train_set = ImageFolder(root_dir, train_transform) + label_mapping = None + + train_set.color_space = dataset_cfg.color_space + + return train_set, label_mapping + + +def set_epoch(dataloader, epoch, cfg): + if hasattr(dataloader.sampler, 'set_epoch'): + if cfg.trainers.local_rank == 0: + print(f'Dataloader set epoch: {epoch}') + dataloader.sampler.set_epoch(epoch) + if hasattr(dataloader.dataset, 'set_augmentation'): + if hasattr(cfg.data_augs, 'disable_aug_during_warmup') and cfg.data_augs.disable_aug_during_warmup: + if cfg.trainers.local_rank == 0: + print(f'set augmentation, epoch: {epoch} : {epoch >= cfg.optims.warmup_epoch}') + dataloader.dataset.set_augmentation(epoch >= cfg.optims.warmup_epoch) + + +def visualize_dataset(dataloader, save_path): + batch = [dataloader.dataset[i] for i in range(4)] + batch_img = torch.stack([b[0] for b in batch], dim=0) + grid = make_grid(batch_img, nrow=4, padding=2, normalize=False) + grid = tensor_to_numpy_uin8(grid) + if dataloader.dataset.color_space == 'RGB': + Image.fromarray(grid).save(save_path) + else: + cv2.imwrite(save_path, grid) + +def tensor_to_numpy_uin8(tensor): + array = ((tensor * 0.5 + 0.5)*256).cpu().numpy().astype(np.uint8) + return np.transpose(array, (1, 2, 0)) + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) diff --git a/cvlface/research/recognition/code/run_v1/dataset/augment_dataset.py b/cvlface/research/recognition/code/run_v1/dataset/augment_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..34ed76053cd1e06950d345aad40feba80ba20d18 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/augment_dataset.py @@ -0,0 +1,35 @@ +from .base_dataset import MXFaceDataset +from data_augs import make_augmenter + +class AugmentMXDataset(MXFaceDataset): + def __init__(self, + root_dir, + local_rank, + augmentation_version='v1', + aug_params=None, + ): + super(AugmentMXDataset, self).__init__(root_dir, local_rank) + print('augmentation_version', augmentation_version) + self.augmenter = make_augmenter(augmentation_version, aug_params) + + def __getitem__(self, index, skip_augment=False): + sample, target = self.read_sample(index) + theta = None + if not skip_augment: + sample = self.augmenter.augment(sample) + if isinstance(sample, tuple): + sample, theta = sample + + if self.transform is not None: + sample = self.transform(sample) + + # import cv2 + # cv2.imwrite('/mckim/temp/temp.png',255*0.5*sample.transpose(0,1).transpose(1,2).numpy() + 0.5) + + if theta is not None: + placeholder = 0 + assert theta.shape == (2, 3) + return sample, placeholder, target, theta + else: + return sample, target + diff --git a/cvlface/research/recognition/code/run_v1/dataset/base_dataset.py b/cvlface/research/recognition/code/run_v1/dataset/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8d74d5cc9cf8e132aa94893644a24fcfb0cd53 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/base_dataset.py @@ -0,0 +1,116 @@ +import numbers +import os +import numpy as np +import torch +from torch.utils.data import Dataset +from torchvision import transforms +import random +import atexit +import pandas as pd +from tqdm import tqdm + +mx = None + + +def get_mxnet(): + global mx + if mx is None: + import mxnet as mxnet + mx = mxnet + return mx + +def iterate_record(imgidx, record): + mx = get_mxnet() + # make one yourself + record_info = [] + for idx in tqdm(imgidx, total=len(imgidx), desc='Iterating Dataset for extracting info (done only once)'): + s = record.read_idx(idx) + header, _ = mx.recordio.unpack(s) + label = header.label + if not isinstance(label, numbers.Number): + label = label[0] + label = int(label) + row = {'_idx':idx, 'path': f'{label}/{idx}.jpg', 'label': label} + record_info.append(row) + record_info = pd.DataFrame(record_info) + return record_info + + +class MXFaceDataset(Dataset): + def __init__(self, root_dir, local_rank): + super(MXFaceDataset, self).__init__() + mx = get_mxnet() + self.to_PIL = transforms.ToPILImage() + self.root_dir = root_dir + self.local_rank = local_rank + path_imgrec = os.path.join(root_dir, 'train.rec') + path_imgidx = os.path.join(root_dir, 'train.idx') + self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') + self.imgrec.close = lambda: None + s = self.imgrec.read_idx(0) + header, _ = mx.recordio.unpack(s) + if header.flag > 0: + self.header0 = (int(header.label[0]), int(header.label[1])) + self.imgidx = np.array(range(1, int(header.label[0]))) + else: + self.imgidx = np.array(list(self.imgrec.keys)) + + info_path = os.path.join(root_dir, 'train.tsv') + if os.path.isfile(info_path): + self.info = pd.read_csv(os.path.join(root_dir, 'train.tsv'), sep='\t', header=None) + self.info.columns = ['_idx', 'path', 'label'] + else: + self.info = iterate_record(self.imgidx, self.imgrec) + self.info.to_csv(info_path, sep='\t', header=False, index=False) + self.label_info = {k: v for k, v in self.info.groupby('label')} + + atexit.register(self.dispose) + + def __getitem__(self, index): + sample, label = self.read_sample(index) + sample = self.transform(sample) + return sample, label + + def __len__(self): + return len(self.info) + + def read_sample(self, index): + info_index = self.info.index[index] + idx = self.imgidx[info_index] + s = self.imgrec.read_idx(idx) + header, img = mx.recordio.unpack(s) + label = header.label + if not isinstance(label, numbers.Number): + label = label[0] + label = torch.tensor(label, dtype=torch.long) + sample = mx.image.imdecode(img).asnumpy() + sample = self.to_PIL(sample) + return sample, label + + + def dispose(self): + self.imgrec.close() + + def __del__(self): + self.dispose() + + + +class SyntheticDataset(Dataset): + def __init__(self, num_class, num_sample): + super(SyntheticDataset, self).__init__() + img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) + img = np.transpose(img, (2, 0, 1)) + img = torch.from_numpy(img).squeeze(0).float() + img = ((img / 255) - 0.5) / 0.5 + self.img = img + self.num_class = num_class + self.num_sample = num_sample + print("SyntheticDataset: num_class: {}, num_sample: {}".format(num_class, num_sample)) + + def __getitem__(self, index): + label = random.randint(0, self.num_class - 1) + return self.img, label + + def __len__(self): + return self.num_sample diff --git a/cvlface/research/recognition/code/run_v1/dataset/configs/agedb_80.yaml b/cvlface/research/recognition/code/run_v1/dataset/configs/agedb_80.yaml new file mode 100644 index 0000000000000000000000000000000000000000..430eed6c4072e606f189d767a63680931894b0dd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/configs/agedb_80.yaml @@ -0,0 +1,8 @@ +data_root: ${oc.env:AGEDB_PROTOCOL_ROOT} +rec: 'agedb_train_80' +color_space: 'RGB' +num_classes: 567 +num_image: 13200 + +repeated_sampling_cfg: null +semi_sampling_cfg: null diff --git a/cvlface/research/recognition/code/run_v1/dataset/configs/casia.yaml b/cvlface/research/recognition/code/run_v1/dataset/configs/casia.yaml new file mode 100644 index 0000000000000000000000000000000000000000..54ba829d5ed747bdb855875e8c72470e17a839db --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/configs/casia.yaml @@ -0,0 +1,9 @@ + +data_root: ${oc.env:DATA_ROOT} +rec: 'casia_webface' +color_space: 'RGB' +num_classes: 10572 +num_image: 490623 + +repeated_sampling_cfg: null +semi_sampling_cfg: null diff --git a/cvlface/research/recognition/code/run_v1/dataset/configs/synthetic.yaml b/cvlface/research/recognition/code/run_v1/dataset/configs/synthetic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c42793f58cdd5e1730b26ad212406aa657dbacf0 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/configs/synthetic.yaml @@ -0,0 +1,9 @@ + +data_root: ${oc.env:DATA_ROOT} +rec: 'synthetic' +color_space: 'RGB' +num_classes: 205990 +num_image: 4235242 + +repeated_sampling_cfg: null +semi_sampling_cfg: null diff --git a/cvlface/research/recognition/code/run_v1/dataset/configs/webface12m_ldmktheta_RA10.yaml b/cvlface/research/recognition/code/run_v1/dataset/configs/webface12m_ldmktheta_RA10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1f782f4d5be0760f75a7e020bead8814abff3592 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/configs/webface12m_ldmktheta_RA10.yaml @@ -0,0 +1,15 @@ + +data_root: ${oc.env:DATA_ROOT} +rec: 'webface260m/WebFace12M' +color_space: 'RGB' +num_classes: 617970 +num_image: 12720066 + +repeated_sampling_cfg: + use_same_image: False + second_img_augment: False + ldmk_path: 'ldmk_5points.csv' + disable_repeat: True # performs repeated aug without increasing batch size + skip_aug_prob_in_disable_repeat: 0.0 + repeated_augment_prob: 0.1 +semi_sampling_cfg: null diff --git a/cvlface/research/recognition/code/run_v1/dataset/configs/webface42m_ldmktheta_RA10.yaml b/cvlface/research/recognition/code/run_v1/dataset/configs/webface42m_ldmktheta_RA10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0d99cae23d20f4fc6e5dee369dc391a1b74715e3 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/configs/webface42m_ldmktheta_RA10.yaml @@ -0,0 +1,17 @@ + +data_root: ${oc.env:DATA_ROOT} +rec: 'webface260m/WebFace42M' +color_space: 'RGB' +num_classes: 2059906 +num_image: 42474557 + +resample_dataset: 'removing_indices.npy' + +repeated_sampling_cfg: + use_same_image: False + second_img_augment: False + ldmk_path: 'ldmk_5points.csv' + disable_repeat: True # performs repeated aug without increasing batch size + skip_aug_prob_in_disable_repeat: 0.0 + repeated_augment_prob: 0.1 +semi_sampling_cfg: null diff --git a/cvlface/research/recognition/code/run_v1/dataset/configs/webface4m.yaml b/cvlface/research/recognition/code/run_v1/dataset/configs/webface4m.yaml new file mode 100644 index 0000000000000000000000000000000000000000..58ff755ae3a20e41eca6bb5e097ecc9ed00f3e56 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/configs/webface4m.yaml @@ -0,0 +1,9 @@ + +data_root: ${oc.env:DATA_ROOT} +rec: 'webface260m/WebFace4M' +color_space: 'RGB' +num_classes: 205990 +num_image: 4235242 + +repeated_sampling_cfg: null +semi_sampling_cfg: null diff --git a/cvlface/research/recognition/code/run_v1/dataset/configs/webface4m_ldmktheta_RA10.yaml b/cvlface/research/recognition/code/run_v1/dataset/configs/webface4m_ldmktheta_RA10.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9bb863105fbe7eaf5a38bdb9da04c6c2c98a7b30 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/configs/webface4m_ldmktheta_RA10.yaml @@ -0,0 +1,16 @@ + +data_root: ${oc.env:DATA_ROOT} +rec: 'webface260m/WebFace4M' +color_space: 'RGB' +num_classes: 205990 +num_image: 4235242 + +repeated_sampling_cfg: + use_same_image: False + second_img_augment: False + ldmk_path: 'ldmk_5points.csv' + disable_repeat: True # performs repeated aug without increasing batch size + skip_aug_prob_in_disable_repeat: 0.0 + repeated_augment_prob: 0.1 +semi_sampling_cfg: null +resample_dataset: '' \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/dataset/configs/webface4m_ldmktheta_RA50.yaml b/cvlface/research/recognition/code/run_v1/dataset/configs/webface4m_ldmktheta_RA50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1ffb30929436e0d1dc24769e48f76af7e0d25ffa --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/configs/webface4m_ldmktheta_RA50.yaml @@ -0,0 +1,16 @@ + +data_root: ${oc.env:DATA_ROOT} +rec: 'webface260m/WebFace4M' +color_space: 'RGB' +num_classes: 205990 +num_image: 4235242 + +repeated_sampling_cfg: + use_same_image: False + second_img_augment: False + ldmk_path: 'ldmk_5points.csv' + disable_repeat: True # performs repeated aug without increasing batch size + skip_aug_prob_in_disable_repeat: 0.0 + repeated_augment_prob: 0.5 +semi_sampling_cfg: null +resample_dataset: '' \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/dataset/general_dataset.py b/cvlface/research/recognition/code/run_v1/dataset/general_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1d9e2ae6a7e1b64b1a64f4d258db23770f2e8cad --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/general_dataset.py @@ -0,0 +1,156 @@ +from tqdm import tqdm +import mxnet as mx +import numbers +import pandas as pd +import torch +import cv2 +import numpy as np +from torch.utils.data import Dataset +from PIL import Image + + + +class GeneralAugmentDataset(Dataset): + def __init__(self, dataset, augmenter, transform): + super(GeneralAugmentDataset, self).__init__() + self.dataset = dataset + self.augmenter = augmenter + self.transform = transform + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + batch = self.dataset[index] + sample, target = batch + assert isinstance(sample, Image.Image) + + sample = self.augmenter.augment(sample) + theta = None + if isinstance(sample, tuple): + sample, theta = sample + + if self.transform is not None: + sample = self.transform(sample) + + if theta is not None: + placeholder = 0 + assert theta.shape == (2, 3) + return sample, placeholder, target, theta + else: + return sample, target + + +class GeneralAugmentLmdkThetaDataset(Dataset): + def __init__(self, + dataset, + augmenter, + transform, + ): + super(GeneralAugmentLmdkThetaDataset, self).__init__() + self.dataset = dataset + self.augmenter = augmenter + self.transform = transform + + self.identity_theta = torch.zeros(2, 3) + self.identity_theta[0, 0] = 1 + self.identity_theta[1, 1] = 1 + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + placeholder = 0 + sample, target, ldmk = self.dataset[index] + assert isinstance(sample, Image.Image) + + theta = None + sample = self.augmenter.augment(sample) + if isinstance(sample, tuple): + sample, theta = sample + + if self.transform is not None: + sample = self.transform(sample) + + if theta is not None: + ldmk = self.transform_ldmk(ldmk, theta) + else: + theta = self.identity_theta.clone() + ldmk = ldmk.float() + return sample, target, ldmk, theta, placeholder, placeholder, placeholder + + + def transform_ldmk(self, ldmk, theta): + inv_theta = inv_matrix(theta.unsqueeze(0)).squeeze(0) + ldmk = torch.cat([ldmk, torch.ones(ldmk.shape[0], 1)], dim=1).float() + transformed_ldmk = (((ldmk) * 2 - 1) @ inv_theta.T) / 2 + 0.5 + if inv_theta[0, 0] < 0: + transformed_ldmk = self.mirror_ldmk(transformed_ldmk) + return transformed_ldmk + + def mirror_ldmk(self, ldmk): + if len(ldmk) == 5: + return self.mirror_ldmk_5(ldmk) + else: + return self.mirror_ldmk_34(ldmk) + + def mirror_ldmk_5(self, ldmk): + # landm + new_ldmk = ldmk.clone() + tmp = new_ldmk[1, :].clone() + new_ldmk[1, :] = new_ldmk[0, :] + new_ldmk[0, :] = tmp + tmp1 = new_ldmk[4, :].clone() + new_ldmk[4, :] = new_ldmk[3, :] + new_ldmk[3, :] = tmp1 + return new_ldmk + + def mirror_ldmk_34(self, ldmk): + raise NotImplementedError + + +def iterate_record(imgidx, record): + # make one yourself + record_info = [] + for idx in tqdm(imgidx, total=len(imgidx)): + s = record.read_idx(idx) + header, _ = mx.recordio.unpack(s) + label = header.label + if not isinstance(label, numbers.Number): + label = label[0] + label = int(label) + row = {'idx':idx, 'path': f'{label}/{idx}.jpg', 'label': label} + record_info.append(row) + record_info = pd.DataFrame(record_info) + return record_info + + +def visualize_landmark(img, landmark): + if isinstance(img, torch.Tensor): + img = img.clone() + # make it to numpy array + img = img.permute(1, 2, 0).numpy() + img = (img * 0.5 + 0.5) * 255 + img = img.astype(np.uint8) + img = img.copy() + else: + img = img.copy() + landmark = landmark.clone().reshape(-1, 2) * torch.tensor([[img.shape[1], img.shape[0]]]) + for i in range(landmark.shape[0]): + cv2.circle(img, (int(landmark[i][0]), int(landmark[i][1])), 2, (0, 0, 255), -1) + # put index on the landmark + cv2.putText(img, str(i), (int(landmark[i][0]), int(landmark[i][1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + return img + +def inv_matrix(theta): + # torch batched version + assert theta.ndim == 3 + a, b, t1 = theta[:, 0,0], theta[:, 0,1], theta[:, 0,2] + c, d, t2 = theta[:, 1,0], theta[:, 1,1], theta[:, 1,2] + det = a * d - b * c + inv_det = 1.0 / det + inv_mat = torch.stack([ + torch.stack([d * inv_det, -b * inv_det, (b * t2 - d * t1) * inv_det], dim=1), + torch.stack([-c * inv_det, a * inv_det, (c * t1 - a * t2) * inv_det], dim=1) + ], dim=1) + return inv_mat diff --git a/cvlface/research/recognition/code/run_v1/dataset/repeated_dataset.py b/cvlface/research/recognition/code/run_v1/dataset/repeated_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a407ca0bc28547db56d588d482e26c88a27f5396 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/repeated_dataset.py @@ -0,0 +1,63 @@ +from .augment_dataset import AugmentMXDataset +from tqdm import tqdm +import mxnet as mx +import numbers +import pandas as pd +import os +import torch + +class RepeatedSamplingMXDataset(AugmentMXDataset): + def __init__(self, + root_dir, + local_rank, + augmentation_version='v1', + aug_params=None, + repeated_sampling_cfg=None, + ): + super(RepeatedSamplingMXDataset, self).__init__(root_dir, local_rank, augmentation_version, aug_params) + # self.augmenter + # self.imgidx + assert repeated_sampling_cfg is not None + self.repeated_sampling_cfg = repeated_sampling_cfg + + + def __getitem__(self, index, skip_augment=False): + sample1, target = self.read_sample(index) + theta1 = torch.tensor([0]) + theta2 = torch.tensor([0]) + + if not skip_augment: + sample1 = self.augmenter.augment(sample1) + if isinstance(sample1, tuple): + sample1, theta1 = sample1 + + sample1 = self.transform(sample1) + + if self.repeated_sampling_cfg.use_same_image: + # same image + extra_index = index + else: + # same subject + extra_index = self.label_info[target.item()].sample(1).index.item() + + sample2, _target = self.read_sample(extra_index) + assert target == _target + if not skip_augment and self.repeated_sampling_cfg.second_img_augment: + sample2 = self.augmenter.augment(sample2) + if isinstance(sample2, tuple): + sample2, theta2 = sample2 + + if self.transform is not None: + sample2 = self.transform(sample2) + + # import cv2 + # cv2.imwrite('/mckim/temp/temp.png',255*0.5*sample.transpose(0,1).transpose(1,2).numpy() + 0.5) + + if theta1.ndim != 1: + return sample1, sample2, target, theta1, theta2 + else: + # dummy theta + return sample1, sample2, target + + + diff --git a/cvlface/research/recognition/code/run_v1/dataset/repeated_dataset_with_ldmk_theta.py b/cvlface/research/recognition/code/run_v1/dataset/repeated_dataset_with_ldmk_theta.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd0525a183c20717355dfe2a122423520e943e7 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/repeated_dataset_with_ldmk_theta.py @@ -0,0 +1,185 @@ +from .augment_dataset import AugmentMXDataset +from tqdm import tqdm +import mxnet as mx +import numbers +import pandas as pd +import os +import torch +import cv2 +import numpy as np + +class RepeatedWithLdmkThetaMXDataset(AugmentMXDataset): + def __init__(self, + root_dir, + local_rank, + augmentation_version='v1', + aug_params=None, + repeated_sampling_cfg=None, + ): + super(RepeatedWithLdmkThetaMXDataset, self).__init__(root_dir, local_rank, augmentation_version, aug_params) + # self.augmenter + # self.imgidx + assert repeated_sampling_cfg is not None + self.repeated_sampling_cfg = repeated_sampling_cfg + self.disable_repeat = repeated_sampling_cfg.disable_repeat + self.skip_aug_prob_in_disable_repeat = repeated_sampling_cfg.skip_aug_prob_in_disable_repeat + + self.ldmk_info = pd.read_csv(os.path.join(root_dir, repeated_sampling_cfg.ldmk_path), sep=',', index_col=0) + self.identity_theta = torch.zeros(2, 3) + self.identity_theta[0, 0] = 1 + self.identity_theta[1, 1] = 1 + + self.do_augment = True + self.prev_index = None + self.prev_label = None + self.repeated = False + + def set_augmentation(self, value): + print('set augmentation', value) + self.do_augment = value + + def get_one_sample(self, index, augment=True): + sample, target = self.read_sample(index) + + theta = None + + if augment: + sample = self.augmenter.augment(sample) + if isinstance(sample, tuple): + sample, theta = sample + + if self.transform is not None: + sample = self.transform(sample) + + # load landmark + ldmk = self.ldmk_info.loc[index].values + if len(ldmk) == 10: + ldmk = ldmk.reshape(-1, 2) + else: + ldmk = ldmk.reshape(-1, 3)[:, :2] + + ldmk = torch.from_numpy(ldmk) + if theta is not None: + ldmk = self.transform_ldmk(ldmk, theta) + else: + theta = self.identity_theta.clone() + + ldmk = ldmk.float() + return sample, target, ldmk, theta + + + def __getitem__(self, index, skip_augment=False): + + placeholder = 0 + augment = not skip_augment and self.do_augment + + if self.prev_index is not None and augment: + if self.repeated_sampling_cfg.repeated_augment_prob > 0: + if np.random.rand() < self.repeated_sampling_cfg.repeated_augment_prob and not self.repeated: + self.repeated = True + if self.repeated_sampling_cfg.use_same_image: + index = self.prev_index + else: + index = self.label_info[self.prev_label].sample(1).index.item() + else: + self.repeated = False + + + if self.disable_repeat: + if np.random.rand() < self.skip_aug_prob_in_disable_repeat: + augment = False + sample1, target, ldmk1, theta1 = self.get_one_sample(index, augment=augment) + if self.repeated: + if self.prev_label is not None: + if augment and self.prev_label != target.item(): + print('Warning repeated label different {} {}'.format(target.item(), self.prev_label)) + + self.prev_index = index + self.prev_label = target.item() + + # vis1 = visualize_landmark(sample1, ldmk1) + # cv2.imwrite('/mckim/temp/temp.png', vis1[:, :, ::-1]) + + + if self.disable_repeat: + return sample1, target, ldmk1, theta1, placeholder, placeholder, placeholder + + # get extra image index + if self.repeated_sampling_cfg.use_same_image: + extra_index = index + else: + extra_index = self.label_info[target.item()].sample(1).index.item() + + extra_augment = augment and self.repeated_sampling_cfg.second_img_augment + sample2, target2, ldmk2, theta2 = self.get_one_sample(extra_index, augment=extra_augment) + assert target == target2 + + # vis1 = visualize_landmark(sample1, ldmk1) + # vis2 = visualize_landmark(sample2, ldmk2) + # vis = np.concatenate([vis1, vis2], axis=1) + # cv2.imwrite('/mckim/temp/temp.png', vis[:, :, ::-1]) + + + + return sample1, target, ldmk1, theta1, sample2, ldmk2, theta2 + + def transform_ldmk(self, ldmk, theta): + inv_theta = inv_matrix(theta.unsqueeze(0)).squeeze(0) + ldmk = torch.cat([ldmk, torch.ones(ldmk.shape[0], 1)], dim=1).float() + transformed_ldmk = (((ldmk) * 2 - 1) @ inv_theta.T) / 2 + 0.5 + if inv_theta[0, 0] < 0: + transformed_ldmk = self.mirror_ldmk(transformed_ldmk) + return transformed_ldmk + + def mirror_ldmk(self, ldmk): + if len(ldmk) == 5: + return self.mirror_ldmk_5(ldmk) + else: + return self.mirror_ldmk_34(ldmk) + + def mirror_ldmk_5(self, ldmk): + # landm + new_ldmk = ldmk.clone() + tmp = new_ldmk[1, :].clone() + new_ldmk[1, :] = new_ldmk[0, :] + new_ldmk[0, :] = tmp + tmp1 = new_ldmk[4, :].clone() + new_ldmk[4, :] = new_ldmk[3, :] + new_ldmk[3, :] = tmp1 + return new_ldmk + + def mirror_ldmk_34(self, ldmk): + raise NotImplementedError + + + + +def visualize_landmark(img, landmark): + if isinstance(img, torch.Tensor): + img = img.clone() + # make it to numpy array + img = img.permute(1, 2, 0).numpy() + img = (img * 0.5 + 0.5) * 255 + img = img.astype(np.uint8) + img = img.copy() + else: + img = img.copy() + landmark = landmark.clone().reshape(-1, 2) * torch.tensor([[img.shape[1], img.shape[0]]]) + for i in range(landmark.shape[0]): + cv2.circle(img, (int(landmark[i][0]), int(landmark[i][1])), 2, (0, 0, 255), -1) + # put index on the landmark + cv2.putText(img, str(i), (int(landmark[i][0]), int(landmark[i][1])), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1) + return img + +def inv_matrix(theta): + # torch batched version + assert theta.ndim == 3 + a, b, t1 = theta[:, 0,0], theta[:, 0,1], theta[:, 0,2] + c, d, t2 = theta[:, 1,0], theta[:, 1,1], theta[:, 1,2] + det = a * d - b * c + inv_det = 1.0 / det + inv_mat = torch.stack([ + torch.stack([d * inv_det, -b * inv_det, (b * t2 - d * t1) * inv_det], dim=1), + torch.stack([-c * inv_det, a * inv_det, (c * t1 - a * t2) * inv_det], dim=1) + ], dim=1) + return inv_mat diff --git a/cvlface/research/recognition/code/run_v1/dataset/subset_dataset.py b/cvlface/research/recognition/code/run_v1/dataset/subset_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b6321512be03f27d302c7ad6f79e6dd666fc8547 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/dataset/subset_dataset.py @@ -0,0 +1,75 @@ +from torch.utils.data import Dataset +from .augment_dataset import AugmentMXDataset +from .repeated_dataset import RepeatedSamplingMXDataset +from .repeated_dataset_with_ldmk_theta import RepeatedWithLdmkThetaMXDataset +import numpy as np +import torch + +class SubsetDataset(Dataset): + def __init__(self, dataset, drop_indices): + """ + Initializes the SubsetDataset. + Args: + - dataset (Dataset): The original dataset. + - drop_indices (set of int): The indices to drop from the dataset. + """ + self.dataset = dataset + self.drop_indices = set(drop_indices) + + # Calculate the indices to keep + self.indices = [i for i in range(len(dataset)) if i not in self.drop_indices] + + print('original sample count :', len(self.dataset)) + print('original label count :', len(self.dataset.info['label'].unique())) + print(f'removing {len(self.drop_indices)} ({len(self.drop_indices) / len(self.dataset) * 100 :.2f}%) samples ') + dropped_info = self.dataset.info.copy() + dropped_info = dropped_info.drop(self.drop_indices) + + unique_label = dropped_info['label'].unique() + unique_label = np.sort(unique_label) + self.unique_label = unique_label + self.label_mapping = {k: i for i, k in enumerate(unique_label)} + print('new sample count :', len(dropped_info)) + print('new label count :', len(self.label_mapping)) + + # adjust self.label_info if there is one, so you don't repeat samples from drop indices + if hasattr(self.dataset, 'label_info'): + new_label_info = {k: v for k, v in dropped_info.groupby('label')} + self.dataset.label_info = new_label_info + + + def __len__(self): + return len(self.indices) + + def __getitem__(self, idx): + # Adjust the index to refer to the index in the original dataset + result = self.dataset[self.indices[idx]] + if isinstance(self.dataset, RepeatedWithLdmkThetaMXDataset): + target_index = 1 + elif isinstance(self.dataset, RepeatedSamplingMXDataset): + target_index = 2 + elif isinstance(self.dataset, AugmentMXDataset): + target_index = 2 if len(result) == 4 else 1 + else: + raise NotImplementedError + + target = result[target_index] + remapped_target = self.remap_target(target) + + result = list(result) + result[target_index] = remapped_target + result = tuple(result) + return result + + def remap_target(self, target): + if isinstance(target, int): + return self.label_mapping[target] + elif isinstance(target, np.ndarray): + return np.array([self.label_mapping[t] for t in target]) + elif isinstance(target, float): + return self.label_mapping[int(target)] + elif isinstance(target, torch.Tensor): + return torch.tensor(self.label_mapping[target.item()], dtype=torch.long) + else: + raise NotImplementedError + pass \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/eval.py b/cvlface/research/recognition/code/run_v1/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..184235ec56119863e5c9f3af94fd4ef7cd04b389 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/eval.py @@ -0,0 +1,146 @@ +import pyrootutils +root = pyrootutils.setup_root( + search_from=__file__, + indicator=["__root__.txt"], + pythonpath=True, + dotenv=True, +) +import os, sys +sys.path.append(os.path.join(root)) +import numpy as np +np.bool = np.bool_ # fix bug for mxnet 1.9.1 +np.object = np.object_ + +import pandas as pd +from models import get_model +from aligners import get_aligner +from evaluations import get_evaluator_by_name +from lightning.fabric.loggers import CSVLogger +from pipelines import pipeline_from_name +from lightning.pytorch.loggers import WandbLogger +from general_utils.config_utils import load_config +from evaluations import summary +from lightning.fabric import Fabric +from functools import partial +from fabric.fabric import setup_dataloader_from_dataset + +import lovely_tensors as lt +lt.monkey_patch() + + +def get_runname_and_task(ckpt_dir): + if 'pretrained_models' in ckpt_dir: + runname = ckpt_dir.split('/')[-1] + code_task = os.path.abspath(__file__).split('/')[-2] + save_dir_task = 'pretrained_models' + else: + runname = ckpt_dir.split('/')[-3] + code_task = os.path.abspath(__file__).split('/')[-2] + save_dir_task = code_task + return runname, save_dir_task, code_task + + + +if __name__ == '__main__': + + import argparse + parser = argparse.ArgumentParser() + parser.add_argument('--num_gpu', type=int, default=1) + parser.add_argument('--precision', type=str, default='32-true') + parser.add_argument('--eval_config_name', type=str, default='quick') + parser.add_argument('--pipeline_name', type=str, default='default') + parser.add_argument('--ckpt_dir', type=str, default="../../../../pretrained_models/recognition/adaface_ir101_webface12m") + args = parser.parse_args() + + # setup output dir + runname, save_dir_task, task = get_runname_and_task(args.ckpt_dir) + eval_config = load_config(f'evaluations/configs/{args.eval_config_name}.yaml') + output_dir = os.path.join(root, 'research/recognition/experiments', save_dir_task, 'eval_' + runname) + os.makedirs(output_dir, exist_ok=True) + + # load model + model_config = load_config(os.path.join(args.ckpt_dir, 'model.yaml')) + model = get_model(model_config, task) + model.load_state_dict_from_path(os.path.join(args.ckpt_dir, 'model.pt')) + train_transform = model.make_train_transform() + test_transform = model.make_test_transform() + + # maybe load aligner + if os.path.exists(os.path.join(args.ckpt_dir, 'aligner.yaml')): + aligner_config = load_config(os.path.join(args.ckpt_dir, 'aligner.yaml')) + aligner_config.start_from = os.path.join(args.ckpt_dir, 'aligner.pt') + aligner = get_aligner(aligner_config) + else: + aligner_config = load_config(os.path.join(root, 'research/recognition/code/', task, f'aligners/configs/none.yaml')) + aligner = get_aligner(aligner_config) + + + # load pipeline + if args.pipeline_name == 'default': + full_config_path = os.path.join(args.ckpt_dir, 'config.yaml') + assert os.path.isfile(full_config_path), f"config.yaml not found at {full_config_path}, try with pipeline name" + pipeline_name = load_config(full_config_path).pipelines.eval_pipeline_name + else: + pipeline_name = args.pipeline_name + + # launch fabric + csv_logger = CSVLogger(root_dir=output_dir, flush_logs_every_n_steps=1) + wandb_logger = WandbLogger(project=task, save_dir=output_dir, + name=os.path.basename(output_dir), + log_model=False) + fabric = Fabric(precision='32-true', + accelerator="auto", + strategy="ddp", + devices=args.num_gpu, + loggers=[csv_logger, wandb_logger], + ) + + if args.num_gpu == 1: + fabric.launch() + print(f"Fabric launched with {args.num_gpu} GPUS and {args.precision}") + fabric.setup_dataloader_from_dataset = partial(setup_dataloader_from_dataset, fabric=fabric, seed=2048) + + # prepare accelerator + model = fabric.setup(model) + if aligner.has_trainable_params(): + aligner = fabric.setup(aligner) + + # make inference pipe (after accelerator setup) + eval_pipeline = pipeline_from_name(pipeline_name, model, aligner) + eval_pipeline.integrity_check(dataset_color_space='RGB') + + # evaluation callbacks + evaluators = [] + for name, info in eval_config.per_epoch_evaluations.items(): + eval_data_path = os.path.join(eval_config.data_root, info.path) + eval_type = info.evaluation_type + eval_batch_size = info.batch_size + eval_num_workers = info.num_workers + evaluator = get_evaluator_by_name(eval_type=eval_type, name=name, eval_data_path=eval_data_path, + transform=eval_pipeline.make_test_transform(), + fabric=fabric, batch_size=eval_batch_size, num_workers=eval_num_workers) + evaluator.integrity_check(info.color_space, eval_pipeline.color_space) + evaluators.append(evaluator) + + # Evaluation + print('Evaluation Started') + all_result = {} + for evaluator in evaluators: + if fabric.local_rank == 0: + print(f"Evaluating {evaluator.name}") + result = evaluator.evaluate(eval_pipeline, epoch=0, step=0, n_images_seen=0) + if fabric.local_rank == 0: + print(f"{evaluator.name}") + print(result) + all_result.update({evaluator.name + "/" + k: v for k, v in result.items()}) + + if fabric.local_rank == 0: + os.makedirs(os.path.join(output_dir, 'result'), exist_ok=True) + save_result = pd.DataFrame(pd.Series(all_result), columns=['val']) + save_result.to_csv(os.path.join(output_dir, f'result/eval_final.csv')) + mean, summary_dict = summary(save_result, epoch=0, step=0, n_images_seen=0) + fabric.log_dict(summary_dict) + summary_result = pd.DataFrame(pd.Series(summary_dict), columns=['val']) + summary_result.to_csv(os.path.join(output_dir, f'result/eval_summary_final.csv')) + + print('Evaluation Finished') diff --git a/cvlface/research/recognition/code/run_v1/evaluations/__init__.py b/cvlface/research/recognition/code/run_v1/evaluations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..220dcfcd564a581da16f3cdf81c9ce1563e22d0b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/__init__.py @@ -0,0 +1,77 @@ +import os +import torch + +from .verification_evaluator import VerificationEvaluator + +def get_evaluator_by_name(eval_type, name, eval_data_path, transform, fabric, batch_size, num_workers): + + assert os.path.isdir(eval_data_path), ('Evaluation Dataset does not exist. Check that cvlface/.env file is set correctly ' + 'and the dataset is downloaded.') + + if eval_type == 'verification': + return VerificationEvaluator(name, eval_data_path, transform, fabric, batch_size, num_workers) + elif eval_type == 'ijbbc': + from .ijbbc_evaluator import IJBBCEvaluator + return IJBBCEvaluator(name, eval_data_path, transform, fabric, batch_size, num_workers) + elif eval_type == 'tinyface': + from .tinyface_evaluator import TinyFaceEvaluator + return TinyFaceEvaluator(name, eval_data_path, transform, fabric, batch_size, num_workers) + else: + raise ValueError('Unknown evaluation type: %s' % eval_type) + + +def summary(save_result, epoch, step, n_images_seen): + key_metrics = ['cfpfp/acc', 'agedb_30/acc', 'lfw/acc', + 'cplfw/acc', 'calfw/acc', + 'tinyface/rank-1', 'tinyface/rank-5', + 'IJBB_gt_aligned/Norm:False_Det:True_tpr_at_fpr_0.0001', + 'IJBC_gt_aligned/Norm:False_Det:True_tpr_at_fpr_0.0001'] + key_metrics_in_save_result = [k for k in key_metrics if k in save_result.index] + if key_metrics_in_save_result: + summary = save_result.loc[key_metrics_in_save_result] + summary.index = ['summary/'+k.replace('/', '_') for k in summary.index] + summary.index = [k.replace('Norm:False_Det:True_tpr_at_fpr_0.0001', 'TPR@FPR0.01') for k in summary.index] + summary.index = [k.replace('_gt_aligned', '') for k in summary.index] + mean = summary['val'].mean() + + summary_dict = summary['val'].to_dict() + summary_dict['epoch'] = epoch + summary_dict['step'] = step + summary_dict['n_images_seen'] = n_images_seen + summary_dict['trainer/global_step'] = step + summary_dict['trainer/epoch'] = epoch + + else: + mean = save_result['val'].mean() + summary_dict = save_result['val'].to_dict() + summary_dict['epoch'] = epoch + summary_dict['step'] = step + summary_dict['n_images_seen'] = n_images_seen + summary_dict['trainer/global_step'] = step + summary_dict['trainer/epoch'] = epoch + return mean, summary_dict + + +class IsBestTracker(): + + def __init__(self, fabric): + self._is_best = True + self.prev_best_metric = -1 + self.fabric = fabric + + + def set_is_best(self, metric): + metric_tensor = torch.tensor(metric, device=self.fabric.device) + self.fabric.barrier() + self.fabric.broadcast(metric_tensor, 0) + metric = metric_tensor.item() + + if metric > self.prev_best_metric: + self.prev_best_metric = metric + self._is_best = True + else: + self._is_best = False + + + def is_best(self): + return self._is_best diff --git a/cvlface/research/recognition/code/run_v1/evaluations/base_evaluator.py b/cvlface/research/recognition/code/run_v1/evaluations/base_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..8f1d61d9b243b68c0dc4ec84f01d937d35f2986e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/base_evaluator.py @@ -0,0 +1,177 @@ +import torch +import random +import string +import os +import numpy as np + + +def preprocess_transform(examples, image_transforms): + images = [image.convert("RGB") for image in examples['image']] + images = [image_transforms(image) for image in images] + examples["pixel_values"] = images + return examples + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + indexes = torch.tensor([example["index"] for example in examples], dtype=torch.int) + is_sames = torch.tensor([example["is_same"] for example in examples], dtype=torch.bool) + + return { + "pixel_values": pixel_values, + "index": indexes, + "is_same": is_sames, + } + +def repeat_tensor_along_dim(tensor, dim, repeats): + # Create the shape for repeating using list comprehension. + # For all dimensions other than the specified 'dim', it will just have a '1' (i.e., no repeat). + repeat_shape = [repeats if i == dim else 1 for i in range(tensor.dim())] + return tensor.repeat(*repeat_shape) + + +def flatten_first_two_dims(tensor): + flattened_shape = [-1] + list(tensor.shape[2:]) + # Reshape the tensor + flattened_tensor = tensor.reshape(*flattened_shape) + return flattened_tensor + + +def flatten_first_two_dims_numpy(array): + flattened_shape = (-1,) + array.shape[2:] + # Reshape the tensor + flattened_array = array.reshape(*flattened_shape) + return flattened_array + +def first_unique_index(array): + unique, idx, counts = torch.unique(array, dim=0, sorted=True, return_inverse=True, return_counts=True) + _, ind_sorted = torch.sort(idx, stable=True) + cum_sum = counts.cumsum(0) + cum_sum = torch.cat((torch.tensor([0], device=cum_sum.device, dtype=cum_sum.dtype), cum_sum[:-1])) + first_indicies = ind_sorted[cum_sum] + return first_indicies + +class BaseEvaluator(): + def __init__(self, name, fabric, batch_size): + self.name = name + self.fabric = fabric + self.batch_size = batch_size + + def integrity_check(self, eval_color_space, pipeline_color_space): + raise NotImplementedError('extract method must be implemented in subclass') + + def extract(self, pipeline): + raise NotImplementedError('extract method must be implemented in subclass') + + def compute_metric(self, gathered_collection): + raise NotImplementedError('extract method must be implemented in subclass') + + + def evaluate(self, pipeline, epoch=0, step=0, n_images_seen=0): + raise NotImplementedError('extract method must be implemented in subclass') + + def log(self, result, epoch, step, n_images_seen): + # append the name of the evaluation to the key + log_result = {f'val/{self.name}/{k}': v for k, v in result.items()} + log_result['epoch'] = epoch + log_result['step'] = step + log_result['n_images_seen'] = n_images_seen + # log + self.fabric.log_dict(log_result) + + def complete_batch(self, batch): + batch_keys = batch.keys() + for key in batch_keys: + if len(batch[key]) != self.batch_size: + if isinstance(batch[key], torch.Tensor): + num_missing = self.batch_size - len(batch[key]) + last_example = batch[key][-1].unsqueeze(0) + additional_examples = repeat_tensor_along_dim(last_example, dim=0, repeats=num_missing) + batch[key] = torch.cat([batch[key], additional_examples], dim=0) + elif isinstance(batch[key], list): + batch[key] = batch[key] + [batch[key][-1]] * (self.batch_size - len(batch[key])) + return batch + + + def gather_collection(self, method='cpu', per_gpu_collection={}): + # gathers dictionary across all gpus + + if method == 'cpu': + runname = os.getcwd().split('/')[-1] + if hasattr(self.fabric, 'loggers') and len(self.fabric.loggers) > 0 and hasattr(self.fabric.loggers[0], 'root_dir'): + runname = runname + '_' + self.fabric.loggers[0].root_dir.split('/')[-1] + + cache_dir = os.path.join(os.path.expanduser('~/.cache'), runname, 'temporary_cpu_communication') + os.makedirs(cache_dir, exist_ok=True) + torch.save(per_gpu_collection, os.path.join(cache_dir, f'per_gpu_collection_rank{self.fabric.local_rank}.pt')) + self.fabric.barrier() + + if self.fabric.local_rank == 0: + # load per gpu collection from cache + gathered_collection = [] + world_size = self.fabric.world_size + for rank in range(world_size): + per_gpu_collection = torch.load(os.path.join(cache_dir, f'per_gpu_collection_rank{rank}.pt')) + gathered_collection.append(per_gpu_collection) + collection = {} + for key in gathered_collection[0].keys(): + concat = [per_gpu_collection[key] for per_gpu_collection in gathered_collection] + if isinstance(concat[0], list): + stacked = np.array(concat).transpose(1, 0) + stacked = flatten_first_two_dims_numpy(stacked) + else: + # stacked = torch.stack(concat, dim=0).transpose(0, 1) + assert isinstance(concat[0], torch.Tensor) + stacked = torch.stack(concat, dim=1) + stacked = flatten_first_two_dims(stacked) + collection[key] = stacked + collection = self.remove_duplicates(collection) + self.check_index_order(collection) + # erase cache_dir + os.system(f'rm -rf {cache_dir}') + else: + collection = None + self.fabric.barrier() + else: + # gpu based gathering + gathered_collection = self.fabric.all_gather(per_gpu_collection) + collection = self.flatten_collection(gathered_collection) + collection = self.remove_duplicates(collection) + self.check_index_order(collection) + return collection + + def flatten_collection(self, gathered_collection): + + # flatten collection by sorting by index + # gathered_collection['index'] = torch.tensor([[2,3],[0,1],[4,5]]) + collection_order = torch.argsort(gathered_collection['index'].min(dim=1)[0]) + for key, val in gathered_collection.items(): + gathered_collection[key] = val[collection_order].transpose(0, 1) + flattened_collection = {k:flatten_first_two_dims(v) for k, v, in gathered_collection.items()} + + return flattened_collection + + def remove_duplicates(self, collection): + # find duplicate index and drop except the first one + unique_idx = first_unique_index(collection['index']) + for key, val in collection.items(): + collection[key] = val[unique_idx] + return collection + + + def check_index_order(self, collection): + index_to_check = collection['index'].to(self.fabric.device) + assert (index_to_check == torch.arange(index_to_check.shape[0], + dtype=index_to_check.dtype, + device=index_to_check.device)).all() + + def is_debug_run(self): + try: + if self.fabric.cfg.trainers.debug: + return True + else: + return False + except: + return False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/evaluations/configs/agedb_30_1to1.yaml b/cvlface/research/recognition/code/run_v1/evaluations/configs/agedb_30_1to1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e79f49472808567ccf4a771e5f74db62fb304825 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/configs/agedb_30_1to1.yaml @@ -0,0 +1,11 @@ +data_root: ${oc.env:AGEDB_PROTOCOL_ROOT} +eval_every_n_epochs: 5 +per_epoch_evaluations: { + "agedb_30": { + 'path': 'facerec_val/agedb_30_1to1', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, +} diff --git a/cvlface/research/recognition/code/run_v1/evaluations/configs/base.yaml b/cvlface/research/recognition/code/run_v1/evaluations/configs/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..9ca45cc4bd7c6a673c8624b8f854fb74c8bb2f5e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/configs/base.yaml @@ -0,0 +1,40 @@ +data_root: ${oc.env:DATA_ROOT} +eval_every_n_epochs: 1 +per_epoch_evaluations: { + "lfw": { + 'path': 'facerec_val/lfw', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "agedb_30": { + 'path': 'facerec_val/agedb_30', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "cfpfp": { + 'path': 'facerec_val/cfp_fp', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "cplfw": { + 'path': 'facerec_val/cplfw', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "calfw": { + 'path': 'facerec_val/calfw', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, +} + diff --git a/cvlface/research/recognition/code/run_v1/evaluations/configs/final.yaml b/cvlface/research/recognition/code/run_v1/evaluations/configs/final.yaml new file mode 100644 index 0000000000000000000000000000000000000000..88b415d9a600a4f4b899125d48eea0e360858e4e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/configs/final.yaml @@ -0,0 +1,11 @@ +data_root: ${oc.env:AGEDB_PROTOCOL_ROOT} +eval_every_n_epochs: 1 +per_epoch_evaluations: { + "agedb_30": { + 'path': 'facerec_val/agedb_30_1to1', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, +} diff --git a/cvlface/research/recognition/code/run_v1/evaluations/configs/full.yaml b/cvlface/research/recognition/code/run_v1/evaluations/configs/full.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ffe80c853b14824ad6ddc2e1897bef55828c7322 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/configs/full.yaml @@ -0,0 +1,61 @@ +data_root: ${oc.env:DATA_ROOT} +eval_every_n_epochs: 1 +per_epoch_evaluations: { + "lfw": { + 'path': 'facerec_val/lfw', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "agedb_30": { + 'path': 'facerec_val/agedb_30', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "cfpfp": { + 'path': 'facerec_val/cfp_fp', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "cplfw": { + 'path': 'facerec_val/cplfw', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "calfw": { + 'path': 'facerec_val/calfw', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "IJBB_gt_aligned": { + 'path': 'facerec_val/IJBB_gt_aligned', + 'evaluation_type': 'ijbbc', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "IJBC_gt_aligned": { + 'path': 'facerec_val/IJBC_gt_aligned', + 'evaluation_type': 'ijbbc', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, + "tinyface": { + 'path': 'facerec_val/tinyface_aligned_pad_0.1', + 'evaluation_type': 'tinyface', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, +} + diff --git a/cvlface/research/recognition/code/run_v1/evaluations/configs/quick.yaml b/cvlface/research/recognition/code/run_v1/evaluations/configs/quick.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e6495b89c5b490d107ae62ba1607134a3c95d33b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/configs/quick.yaml @@ -0,0 +1,12 @@ +data_root: ${oc.env:DATA_ROOT} +eval_every_n_epochs: 1 +per_epoch_evaluations: { + "lfw": { + 'path': 'facerec_val/lfw', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, +} + diff --git a/cvlface/research/recognition/code/run_v1/evaluations/configs/skip_eval.yaml b/cvlface/research/recognition/code/run_v1/evaluations/configs/skip_eval.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bb7eabb8e42d0efccce1eebf07a7d028485922ce --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/configs/skip_eval.yaml @@ -0,0 +1,5 @@ +data_root: ${oc.env:DATA_ROOT} +eval_every_n_epochs: -1 +per_epoch_evaluations: { +} + diff --git a/cvlface/research/recognition/code/run_v1/evaluations/ijbbc/evaluate.py b/cvlface/research/recognition/code/run_v1/evaluations/ijbbc/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..bc79c18620fd4cb1443855ec5032a157ed48bf0e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/ijbbc/evaluate.py @@ -0,0 +1,104 @@ +import numpy as np +from tqdm import tqdm +import sklearn +from sklearn.metrics import roc_curve + + +def image2template_feature(img_feats=None, templates=None, medias=None, dummy=False): + # ========================================================== + # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim] + # 2. compute media feature. + # 3. compute template feature. + # ========================================================== + unique_templates = np.unique(templates) + template_feats = np.zeros((len(unique_templates), img_feats.shape[1])) + + if dummy: + template_feats = np.random.randn(len(unique_templates), img_feats.shape[1]) + template_norm_feats = sklearn.preprocessing.normalize(template_feats) + return template_norm_feats, unique_templates + + for count_template, uqt in tqdm(enumerate(unique_templates), total=len(unique_templates), desc='image2template_feature'): + + (ind_t,) = np.where(templates == uqt) + face_norm_feats = img_feats[ind_t] + face_medias = medias[ind_t] + unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True) + media_norm_feats = [] + for u, ct in zip(unique_medias, unique_media_counts): + (ind_m,) = np.where(face_medias == u) + if ct == 1: + media_norm_feats += [face_norm_feats[ind_m]] + else: # image features from the same video will be aggregated into one feature + media_norm_feats += [ + np.mean(face_norm_feats[ind_m], axis=0, keepdims=True) + ] + media_norm_feats = np.array(media_norm_feats) + # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True)) + template_feats[count_template] = np.sum(media_norm_feats, axis=0) + # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True)) + template_norm_feats = sklearn.preprocessing.normalize(template_feats) + # print(template_norm_feats.shape) + return template_norm_feats, unique_templates + + +def verification(template_norm_feats=None, unique_templates=None, p1=None, p2=None): + # ========================================================== + # Compute set-to-set Similarity Score. + # ========================================================== + template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int) + for count_template, uqt in enumerate(unique_templates): + template2id[uqt] = count_template + + score = np.zeros((len(p1),)) # save cosine distance between pairs + + total_pairs = np.array(range(len(p1))) + batchsize = 100000 # small batchsize instead of all pairs in one batch due to the memory limiation + sublists = [ total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize) ] + for c, s in tqdm(enumerate(sublists), total=len(sublists), desc='verification'): + feat1 = template_norm_feats[template2id[p1[s]]] + feat2 = template_norm_feats[template2id[p2[s]]] + similarity_score = np.sum(feat1 * feat2, -1) + score[s] = similarity_score.flatten() + return score + + + +def evaluate(embeddings, faceness_scores, templates, medias, label, p1, p2, dummy=False): + + infernece_configs = [{'use_norm_score': True, 'use_detector_score': True}, + {'use_norm_score': True, 'use_detector_score': False}, + {'use_norm_score': False, 'use_detector_score': True}, ] + + scores = {} + for config in infernece_configs: + use_norm_score = config['use_norm_score'] + use_detector_score = config['use_detector_score'] + + img_input_feats = embeddings.copy() + if not use_norm_score: + # normalise features to remove norm information + img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True)) + if use_detector_score: + img_input_feats = img_input_feats * faceness_scores[:, np.newaxis] + + template_norm_feats, unique_templates = image2template_feature(img_input_feats, templates, medias, dummy=dummy) + score = verification(template_norm_feats, unique_templates, p1, p2) + method = f"Norm:{use_norm_score}_Det:{use_detector_score}" + scores[method] = score + + x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1] + result = {} + for method in scores.keys(): + fpr, tpr, thresholds = roc_curve(label, scores[method]) + fpr = np.flipud(fpr) + tpr = np.flipud(tpr) # select largest tpr at same fpr + thresholds = np.flipud(thresholds) + for fpr_iter in np.arange(len(x_labels)): + _, min_index = min(list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr))))) + best_thresh = thresholds[min_index] + _fpr_val = x_labels[fpr_iter] + _tpr_val = tpr[min_index] * 100 + result[f'{method}_tpr_at_fpr_{_fpr_val}'] = _tpr_val + result[f'{method}_thresh_at_fpr_{_fpr_val}'] = best_thresh + return result diff --git a/cvlface/research/recognition/code/run_v1/evaluations/ijbbc_evaluator.py b/cvlface/research/recognition/code/run_v1/evaluations/ijbbc_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..f9936ec7b98b29fec54974c4b4658b092d2ffcd3 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/ijbbc_evaluator.py @@ -0,0 +1,122 @@ +from datasets import Dataset +import torch +from functools import partial +from .base_evaluator import BaseEvaluator +from tqdm import tqdm +from .ijbbc.evaluate import evaluate +import os + +def preprocess_transform(examples, image_transforms): + images = [image.convert("RGB") for image in examples['image']] + images = [image_transforms(image) for image in images] + examples["pixel_values"] = images + return examples + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + indexes = torch.tensor([example["index"] for example in examples], dtype=torch.int) + + return { + "pixel_values": pixel_values, + "index": indexes, + } + + +class IJBBCEvaluator(BaseEvaluator): + def __init__(self, name, data_path, transform, fabric, batch_size, num_workers): + super().__init__(name, fabric, batch_size) + self.name = name + self.data_path = data_path + dataset = Dataset.load_from_disk(data_path) + preprocess = partial(preprocess_transform, image_transforms=transform) + dataset = dataset.with_transform(preprocess) + self.dataloader = fabric.setup_dataloader_from_dataset(dataset, + is_train=False, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn) + self.meta = torch.load(os.path.join(data_path, 'metadata.pt')) + + + def integrity_check(self, eval_color_space, pipeline_color_space): + assert eval_color_space == pipeline_color_space + + + @torch.no_grad() + def evaluate(self, pipeline, epoch=0, step=0, n_images_seen=0): + pipeline.eval() + collection = self.extract(pipeline) + collection_flip = self.extract(pipeline, flip_images=True) + if self.fabric.local_rank == 0: + result = self.compute_metric(collection, collection_flip) + self.log(result, epoch, step, n_images_seen) + else: + result = {} + return result + + def extract(self, pipeline, flip_images=False): + all_features = [] + all_index = [] + for batch_idx, batch in tqdm(enumerate(self.dataloader), total=len(self.dataloader), desc='IJB Feature', + disable=self.fabric.local_rank != 0): + batch = self.complete_batch(batch) # needed for last batch to be gather compatible + + if self.is_debug_run(): + if batch_idx > 10: + break + + images = batch['pixel_values'] + index = batch['index'] + + if flip_images: + images = torch.flip(images, dims=[3]) + features = pipeline(images) + all_features.append(features.cpu().detach()) + all_index.append(index.cpu().detach()) + + # aggregate across all gpus + per_gpu_collection = {"index": torch.cat(all_index, dim=0), + 'features': torch.cat(all_features, dim=0)} + + # cpu based gathering just in case we have a lot of data + collection = self.gather_collection(method='cpu', per_gpu_collection=per_gpu_collection) + return collection + + + def compute_metric(self, collection, collection_flip): + if self.is_debug_run(): + return dummy_result + + embeddings = (collection['features'] + collection_flip['features']).numpy() + + faceness_scores = self.meta['faceness_scores'] + templates = self.meta['templates'] + medias = self.meta['medias'] + label = self.meta['label'] + p1 = self.meta['p1'] + p2 = self.meta['p2'] + result = evaluate(embeddings, faceness_scores, templates, medias, label, p1, p2, dummy=False) + return result + + +dummy_result = {k: 0.0 for k in ['Norm:True_Det:True_tpr_at_fpr_1e-06', 'Norm:True_Det:True_thresh_at_fpr_1e-06', + 'Norm:True_Det:True_tpr_at_fpr_1e-05', 'Norm:True_Det:True_thresh_at_fpr_1e-05', + 'Norm:True_Det:True_tpr_at_fpr_0.0001', 'Norm:True_Det:True_thresh_at_fpr_0.0001', + 'Norm:True_Det:True_tpr_at_fpr_0.001', 'Norm:True_Det:True_thresh_at_fpr_0.001', + 'Norm:True_Det:True_tpr_at_fpr_0.01', 'Norm:True_Det:True_thresh_at_fpr_0.01', + 'Norm:True_Det:True_tpr_at_fpr_0.1', 'Norm:True_Det:True_thresh_at_fpr_0.1', + 'Norm:True_Det:False_tpr_at_fpr_1e-06', 'Norm:True_Det:False_thresh_at_fpr_1e-06', + 'Norm:True_Det:False_tpr_at_fpr_1e-05', 'Norm:True_Det:False_thresh_at_fpr_1e-05', + 'Norm:True_Det:False_tpr_at_fpr_0.0001', 'Norm:True_Det:False_thresh_at_fpr_0.0001', + 'Norm:True_Det:False_tpr_at_fpr_0.001', 'Norm:True_Det:False_thresh_at_fpr_0.001', + 'Norm:True_Det:False_tpr_at_fpr_0.01', 'Norm:True_Det:False_thresh_at_fpr_0.01', + 'Norm:True_Det:False_tpr_at_fpr_0.1', 'Norm:True_Det:False_thresh_at_fpr_0.1', + 'Norm:False_Det:True_tpr_at_fpr_1e-06', 'Norm:False_Det:True_thresh_at_fpr_1e-06', + 'Norm:False_Det:True_tpr_at_fpr_1e-05', 'Norm:False_Det:True_thresh_at_fpr_1e-05', + 'Norm:False_Det:True_tpr_at_fpr_0.0001', 'Norm:False_Det:True_thresh_at_fpr_0.0001', + 'Norm:False_Det:True_tpr_at_fpr_0.001', 'Norm:False_Det:True_thresh_at_fpr_0.001', + 'Norm:False_Det:True_tpr_at_fpr_0.01', 'Norm:False_Det:True_thresh_at_fpr_0.01', + 'Norm:False_Det:True_tpr_at_fpr_0.1', 'Norm:False_Det:True_thresh_at_fpr_0.1', ]} diff --git a/cvlface/research/recognition/code/run_v1/evaluations/tinyface/evaluate.py b/cvlface/research/recognition/code/run_v1/evaluations/tinyface/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..f2794499f48e7bb3c67b81557c6e1958de3769d1 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/tinyface/evaluate.py @@ -0,0 +1,71 @@ +import os +import numpy as np +from .metrics import DIR_FAR + + +def evaluate( + all_features, + image_paths, + meta, + ranks=[1, 5, 20] +): + evaluator = TinyFaceTest(meta) + results = evaluator.test_identification(all_features, image_paths, ranks) + results = {k: v for k, v in zip(['rank-{}'.format(r) for r in ranks], results)} + results = {k: v * 100 for k, v in results.items()} + return results + + +class TinyFaceTest: + def __init__(self, meta): + self.meta = meta + + def get_key(self, image_path): + return os.path.splitext(os.path.basename(image_path))[0] + + def get_label(self, image_path): + return int(os.path.basename(image_path).split('_')[0]) + + def init_proto(self, image_paths, probe_paths, match_paths, distractor_paths): + index_dict = {} + for i, image_path in enumerate(image_paths): + index_dict[self.get_key(image_path)] = i + + self.indices_probe = np.array([index_dict[self.get_key(img)] for img in probe_paths]) + self.indices_match = np.array([index_dict[self.get_key(img)] for img in match_paths]) + self.indices_distractor = np.array([index_dict[self.get_key(img)] for img in distractor_paths]) + + self.labels_probe = np.array([self.get_label(img) for img in probe_paths]) + self.labels_match = np.array([self.get_label(img) for img in match_paths]) + self.labels_distractor = np.array([-100 for img in distractor_paths]) + + self.indices_gallery = np.concatenate([self.indices_match, self.indices_distractor]) + self.labels_gallery = np.concatenate([self.labels_match, self.labels_distractor]) + + def test_identification(self, features, image_paths, ranks=[1, 5, 20]): + assert len(image_paths) == len(features) + assert len(image_paths) == len(self.meta['image_paths']) + self.init_proto(image_paths, + self.meta['probe_paths'], + self.meta['gallery_paths'], + self.meta['distractor_paths']) + + feat_probe = features[self.indices_probe] + feat_gallery = features[self.indices_gallery] + compare_func = inner_product + score_mat = compare_func(feat_probe, feat_gallery) + + label_mat = self.labels_probe[:, None] == self.labels_gallery[None, :] + + results, _, __ = DIR_FAR(score_mat, label_mat, ranks) + + return results + + +def inner_product(x1, x2): + + # normalize + x1 = x1 / np.linalg.norm(x1, axis=1, keepdims=True) + x2 = x2 / np.linalg.norm(x2, axis=1, keepdims=True) + + return np.dot(x1, x2.T) diff --git a/cvlface/research/recognition/code/run_v1/evaluations/tinyface/metrics.py b/cvlface/research/recognition/code/run_v1/evaluations/tinyface/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..cb12ced26fe8d7442c8e6fcda1da8c63b8f84c48 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/tinyface/metrics.py @@ -0,0 +1,251 @@ +"""Common metrics used for evaluate results +""" +# MIT License +# +# Copyright (c) 2017 Yichun Shi +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import numpy as np +from warnings import warn + +# Find thresholds given FARs +# but the real FARs using these thresholds could be different +# the exact FARs need to recomputed using calcROC +def find_thresholds_by_FAR(score_vec, label_vec, FARs=None, epsilon=1e-5): + assert len(score_vec.shape)==1 + assert score_vec.shape == label_vec.shape + assert label_vec.dtype == np.bool + score_neg = score_vec[~label_vec] + score_neg[::-1].sort() + # score_neg = np.sort(score_neg)[::-1] # score from high to low + num_neg = len(score_neg) + + assert num_neg >= 1 + + if FARs is None: + thresholds = np.unique(score_neg) + thresholds = np.insert(thresholds, 0, thresholds[0]+epsilon) + thresholds = np.insert(thresholds, thresholds.size, thresholds[-1]-epsilon) + else: + FARs = np.array(FARs) + num_false_alarms = np.round(num_neg * FARs).astype(np.int32) + + thresholds = [] + for num_false_alarm in num_false_alarms: + if num_false_alarm==0: + threshold = score_neg[0] + epsilon + else: + threshold = score_neg[num_false_alarm-1] + thresholds.append(threshold) + thresholds = np.array(thresholds) + + return thresholds + + + +def ROC(score_vec, label_vec, thresholds=None, FARs=None, get_false_indices=False): + ''' Compute Receiver operating characteristic (ROC) with a score and label vector. + ''' + assert score_vec.ndim == 1 + assert score_vec.shape == label_vec.shape + assert label_vec.dtype == np.bool + + if thresholds is None: + thresholds = find_thresholds_by_FAR(score_vec, label_vec, FARs=FARs) + + assert len(thresholds.shape)==1 + if np.size(thresholds) > 10000: + warn('number of thresholds (%d) very large, computation may take a long time!' % np.size(thresholds)) + + # FARs would be check again + TARs = np.zeros(thresholds.shape[0]) + FARs = np.zeros(thresholds.shape[0]) + false_accept_indices = [] + false_reject_indices = [] + for i,threshold in enumerate(thresholds): + accept = score_vec >= threshold + TARs[i] = np.mean(accept[label_vec]) + FARs[i] = np.mean(accept[~label_vec]) + if get_false_indices: + false_accept_indices.append(np.argwhere(accept & (~label_vec)).flatten()) + false_reject_indices.append(np.argwhere((~accept) & label_vec).flatten()) + + if get_false_indices: + return TARs, FARs, thresholds, false_accept_indices, false_reject_indices + else: + return TARs, FARs, thresholds + +def ROC_by_mat(score_mat, label_mat, thresholds=None, FARs=None, get_false_indices=False, triu_k=None): + ''' Compute ROC using a pairwise score matrix and a corresponding label matrix. + A wapper of ROC function. + ''' + assert score_mat.ndim == 2 + assert score_mat.shape == label_mat.shape + assert label_mat.dtype == np.bool + + # Convert into vectors + m,n = score_mat.shape + if triu_k is not None: + assert m==n, "If using triu for ROC, the score matrix must be a sqaure matrix!" + triu_indices = np.triu_indices(m, triu_k) + score_vec = score_mat[triu_indices] + label_vec = label_mat[triu_indices] + else: + score_vec = score_mat.flatten() + label_vec = label_mat.flatten() + + # Compute ROC + if get_false_indices: + TARs, FARs, thresholds, false_accept_indices, false_reject_indices = \ + ROC(score_vec, label_vec, thresholds, FARs, True) + else: + TARs, FARs, thresholds = ROC(score_vec, label_vec, thresholds, FARs, False) + + # Convert false accept/reject indices into [row, col] indices + if get_false_indices: + rows, cols = np.meshgrid(np.arange(m), np.arange(n), indexing='ij') + rc = np.stack([rows, cols], axis=2) + if triu_k is not None: + rc = rc[triu_indices,:] + else: + rc = rc.reshape([-1,2]) + + for i in range(len(FARs)): + false_accept_indices[i] = rc[false_accept_indices[i]] + false_reject_indices[i] = rc[false_reject_indices[i]] + return TARs, FARs, thresholds, false_accept_indices, false_reject_indices + else: + return TARs, FARs, thresholds + + +def DIR_FAR(score_mat, label_mat, ranks=[1], FARs=[1.0], get_false_indices=False): + ''' Closed/Open-set Identification. + A general case of Cummulative Match Characteristic (CMC) + where thresholding is allowed for open-set identification. + args: + score_mat: a P x G matrix, P is number of probes, G is size of gallery + label_mat: a P x G matrix, bool + ranks: a list of integers + FARs: false alarm rates, if 1.0, closed-set identification (CMC) + get_false_indices: not implemented yet + return: + DIRs: an F x R matrix, F is the number of FARs, R is the number of ranks, + flatten into a vector if F=1 or R=1. + FARs: an vector of length = F. + thredholds: an vector of length = F. + ''' + assert score_mat.shape==label_mat.shape + # assert np.all(label_mat.astype(np.float32).sum(axis=1) <=1 ) + # Split the matrix for match probes and non-match probes + # subfix _m: match, _nm: non-match + # For closed set, we only use the match probes + match_indices = label_mat.astype(np.bool).any(axis=1) + score_mat_m = score_mat[match_indices,:] + label_mat_m = label_mat[match_indices,:] + score_mat_nm = score_mat[np.logical_not(match_indices),:] + label_mat_nm = label_mat[np.logical_not(match_indices),:] + + print('mate probes: %d, non mate probes: %d' % (score_mat_m.shape[0], score_mat_nm.shape[0])) + + # Find the thresholds for different FARs + max_score_nm = np.max(score_mat_nm, axis=1) + label_temp = np.zeros(max_score_nm.shape, dtype=np.bool) + if len(FARs) == 1 and FARs[0] >= 1.0: + # If only testing closed-set identification, use the minimum score as threshold + # in case there is no non-mate probes + thresholds = [np.min(score_mat) - 1e-10] + openset = False + else: + # If there is open-set identification, find the thresholds by FARs. + assert score_mat_nm.shape[0] > 0, "For open-set identification (FAR<1.0), there should be at least one non-mate probe!" + thresholds = find_thresholds_by_FAR(max_score_nm, label_temp, FARs=FARs) + openset = True + + # Sort the labels row by row according to scores + sort_idx_mat_m = np.argsort(score_mat_m, axis=1) + sorted_label_mat_m = np.ndarray(label_mat_m.shape, dtype=np.bool) + for row in range(label_mat_m.shape[0]): + sort_idx = (sort_idx_mat_m[row, :])[::-1] + sorted_label_mat_m[row,:] = label_mat_m[row, sort_idx] + + # Calculate DIRs for different FARs and ranks + if openset: + gt_score_m = score_mat_m[label_mat_m] + assert gt_score_m.size == score_mat_m.shape[0] + + DIRs = np.zeros([len(FARs), len(ranks)], dtype=np.float32) + FARs = np.zeros([len(FARs)], dtype=np.float32) + if get_false_indices: + false_retrieval = np.zeros([len(FARs), len(ranks), score_mat_m.shape[0]], dtype=np.bool) + false_reject = np.zeros([len(FARs), len(ranks), score_mat_m.shape[0]], dtype=np.bool) + false_accept = np.zeros([len(FARs), len(ranks), score_mat_nm.shape[0]], dtype=np.bool) + for i, threshold in enumerate(thresholds): + for j, rank in enumerate(ranks): + success_retrieval = sorted_label_mat_m[:,0:rank].any(axis=1) + if openset: + success_threshold = gt_score_m >= threshold + DIRs[i,j] = (success_threshold & success_retrieval).astype(np.float32).mean() + else: + DIRs[i,j] = success_retrieval.astype(np.float32).mean() + if get_false_indices: + false_retrieval[i,j] = ~success_retrieval + false_accept[i,j] = score_mat_nm.max(1) >= threshold + if openset: + false_reject[i,j] = ~success_threshold + if score_mat_nm.shape[0] > 0: + FARs[i] = (max_score_nm >= threshold).astype(np.float32).mean() + + if DIRs.shape[0] == 1 or DIRs.shape[1] == 1: + DIRs = DIRs.flatten() + + if get_false_indices: + return DIRs, FARs, thresholds, match_indices, false_retrieval, false_reject, false_accept, sort_idx_mat_m + else: + return DIRs, FARs, thresholds + +def accuracy(score_vec, label_vec, thresholds=None): + assert len(score_vec.shape)==1 + assert len(label_vec.shape)==1 + assert score_vec.shape == label_vec.shape + assert label_vec.dtype==np.bool + # find thresholds by TAR + if thresholds is None: + score_pos = score_vec[label_vec==True] + thresholds = np.sort(score_pos)[::1] + + assert len(thresholds.shape)==1 + if np.size(thresholds) > 10000: + warn('number of thresholds (%d) very large, computation may take a long time!' % np.size(thresholds)) + + # Loop Computation + accuracies = np.zeros(np.size(thresholds)) + for i, threshold in enumerate(thresholds): + pred_vec = score_vec>=threshold + accuracies[i] = np.mean(pred_vec==label_vec) + + # Matrix Computation, Each column is a threshold + # predictions = score_vec[:,None] >= thresholds[None,:] + # accuracies = np.mean(predictions==label_vec[:,None], axis=0) + + argmax = np.argmax(accuracies) + accuracy = accuracies[argmax] + threshold = np.mean(thresholds[accuracies==accuracy]) + + return accuracy, threshold diff --git a/cvlface/research/recognition/code/run_v1/evaluations/tinyface_evaluator.py b/cvlface/research/recognition/code/run_v1/evaluations/tinyface_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..f147b751103315e15aa928c08a117ae739d289af --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/tinyface_evaluator.py @@ -0,0 +1,108 @@ +from datasets import Dataset +import torch +from functools import partial +from .base_evaluator import BaseEvaluator +from tqdm import tqdm +import os +from .tinyface.evaluate import evaluate + +def preprocess_transform(examples, image_transforms): + images = [image.convert("RGB") for image in examples['image']] + images = [image_transforms(image) for image in images] + examples["pixel_values"] = images + return examples + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + indexes = torch.tensor([example["index"] for example in examples], dtype=torch.int) + image_paths = [example["path"] for example in examples] + + return { + "pixel_values": pixel_values, + "index": indexes, + "image_paths": image_paths + } + + +class TinyFaceEvaluator(BaseEvaluator): + def __init__(self, name, data_path, transform, fabric, batch_size, num_workers): + super().__init__(name, fabric, batch_size) + self.name = name + self.data_path = data_path + dataset = Dataset.load_from_disk(data_path) + preprocess = partial(preprocess_transform, image_transforms=transform) + dataset = dataset.with_transform(preprocess) + self.dataloader = fabric.setup_dataloader_from_dataset(dataset, + is_train=False, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn) + self.meta = torch.load(os.path.join(data_path, 'metadata.pt')) + + + def integrity_check(self, eval_color_space, pipeline_color_space): + assert eval_color_space == pipeline_color_space + + + @torch.no_grad() + def evaluate(self, pipeline, epoch=0, step=0, n_images_seen=0): + pipeline.eval() + collection = self.extract(pipeline) + collection_flip = self.extract(pipeline, flip_images=True) + if self.fabric.local_rank == 0: + result = self.compute_metric(collection, collection_flip) + self.log(result, epoch, step, n_images_seen) + else: + result = {} + return result + + def extract(self, pipeline, flip_images=False): + all_features = [] + all_index = [] + all_image_paths = [] + for batch_idx, batch in tqdm(enumerate(self.dataloader), total=len(self.dataloader), desc='TinyFace Feature', + disable=self.fabric.local_rank != 0): + batch = self.complete_batch(batch) # needed for last batch to be gather compatible + + if self.is_debug_run(): + if batch_idx > 10: + break + + images = batch['pixel_values'] + index = batch['index'] + image_paths = batch['image_paths'] + + if flip_images: + images = torch.flip(images, dims=[3]) + features = pipeline(images) + all_features.append(features.cpu().detach()) + all_index.append(index.cpu().detach()) + all_image_paths.extend(image_paths) + + # aggregate across all gpus + per_gpu_collection = {"index": torch.cat(all_index, dim=0), + 'features': torch.cat(all_features, dim=0), + 'image_paths': all_image_paths} + + # cpu based gathering just in case we have a lot of data + collection = self.gather_collection(method='cpu', per_gpu_collection=per_gpu_collection) + return collection + + + def compute_metric(self, collection, collection_flip): + if self.is_debug_run(): + print('Debug run, skipping metric computation') + ranks = [1, 5, 20] + return {k: 0.0 for k in ['rank-{}'.format(r) for r in ranks]} + + embeddings = (collection['features'] + collection_flip['features']).numpy() + result = evaluate( + all_features=embeddings, + image_paths=collection['image_paths'], + meta=self.meta, + ) + return result + diff --git a/cvlface/research/recognition/code/run_v1/evaluations/verification_evaluator.py b/cvlface/research/recognition/code/run_v1/evaluations/verification_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..bd3485c2930f5a57611f32aa84ade08faef2efce --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/verification_evaluator.py @@ -0,0 +1,138 @@ +import torch +from functools import partial +from .base_evaluator import BaseEvaluator +from .verifications.verification import evaluate +import sklearn +from tqdm import tqdm +import numpy as np +import os +import pandas as pd +from PIL import Image +from torch.utils.data import Dataset as TorchDataset + +def preprocess_transform(examples, image_transforms): + images = [image.convert("RGB") for image in examples['image']] + images = [image_transforms(image) for image in images] + examples["pixel_values"] = images + return examples + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + indexes = torch.tensor([example["index"] for example in examples], dtype=torch.int) + is_sames = torch.tensor([example["is_same"] for example in examples], dtype=torch.bool) + + return { + "pixel_values": pixel_values, + "index": indexes, + "is_same": is_sames, + } + + +class VerificationEvaluator(BaseEvaluator): + def __init__(self, name, data_path, transform, fabric, batch_size, num_workers): + super().__init__(name, fabric, batch_size) + self.name = name + self.data_path = data_path + pairs_csv = os.path.join(data_path, 'pairs.csv') + if os.path.isfile(pairs_csv): + dataset = LocalVerificationPairDataset(pairs_csv, transform) + else: + from datasets import Dataset + dataset = Dataset.load_from_disk(data_path) + preprocess = partial(preprocess_transform, image_transforms=transform) + dataset = dataset.with_transform(preprocess) + self.dataloader = fabric.setup_dataloader_from_dataset(dataset, + is_train=False, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=collate_fn) + + + def integrity_check(self, eval_color_space, pipeline_color_space): + assert eval_color_space == pipeline_color_space + + + @torch.no_grad() + def evaluate(self, pipeline, epoch=0, step=0, n_images_seen=0): + pipeline.eval() + collection = self.extract(pipeline) + collection_flip = self.extract(pipeline, flip_images=True) + if self.fabric.local_rank == 0: + result = self.compute_metric(collection, collection_flip) + self.log(result, epoch, step, n_images_seen) + else: + result = {} + return result + + + def extract(self, pipeline, flip_images=False): + all_features = [] + all_is_sames = [] + all_index = [] + for batch_idx, batch in tqdm(enumerate(self.dataloader), total=len(self.dataloader), + desc=f'Verification {self.name}', + disable=self.fabric.local_rank != 0): + batch = self.complete_batch(batch) # needed for last batch to be gather compatible + + if self.is_debug_run(): + if batch_idx > 10: + break + + images = batch['pixel_values'] + is_sames = batch['is_same'] + index = batch['index'] + + if flip_images: + images = torch.flip(images, dims=[3]) + features = pipeline(images) + all_features.append(features.cpu().detach()) + all_is_sames.append(is_sames.cpu().detach()) + all_index.append(index.cpu().detach()) + + # aggregate across all gpus + per_gpu_collection = {"index": torch.cat(all_index, dim=0), + 'is_same': torch.cat(all_is_sames, dim=0), + 'features': torch.cat(all_features, dim=0)} + + # cpu based gathering just in case we have a lot of data + collection = self.gather_collection(method='cpu', per_gpu_collection=per_gpu_collection) + return collection + + + def compute_metric(self, collection, collection_flip): + if self.is_debug_run(): + print('Debug run, skipping metric computation') + return {'acc': 0, 'std': 0} + + embeddings = (collection['features'] + collection_flip['features']).numpy() + embeddings = sklearn.preprocessing.normalize(embeddings) + issame_list = collection['is_same'].numpy()[::2] + _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=10) + accuracy = accuracy * 100 + acc, std = np.mean(accuracy), np.std(accuracy) + result = {'acc': acc, 'std': std} + return result + + +class LocalVerificationPairDataset(TorchDataset): + def __init__(self, pairs_csv, transform): + self.rows = pd.read_csv(pairs_csv) + self.transform = transform + + def __len__(self): + return len(self.rows) + + def __getitem__(self, index): + row = self.rows.iloc[index] + image = Image.open(row['path']).convert('RGB') + is_same = row['is_same'] + if isinstance(is_same, str): + is_same = is_same.lower() == 'true' + return { + 'pixel_values': self.transform(image), + 'index': int(row['index']), + 'is_same': bool(is_same), + } diff --git a/cvlface/research/recognition/code/run_v1/evaluations/verifications/verification.py b/cvlface/research/recognition/code/run_v1/evaluations/verifications/verification.py new file mode 100644 index 0000000000000000000000000000000000000000..ad1d640a0c238741df1e77ca0f2dfa075e9dbc72 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/evaluations/verifications/verification.py @@ -0,0 +1,299 @@ +import datetime +import pickle + +import numpy as np +import sklearn +import torch +from scipy import interpolate +from sklearn.decomposition import PCA +from sklearn.model_selection import KFold +import matplotlib.pyplot as plt + +mx = None +nd = None + + +def get_mxnet(): + global mx, nd + if mx is None: + import mxnet as mxnet + from mxnet import ndarray as ndarray + mx = mxnet + nd = ndarray + return mx, nd + +class LFold: + def __init__(self, n_splits=2, shuffle=False): + self.n_splits = n_splits + if self.n_splits > 1: + self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle) + + def split(self, indices): + if self.n_splits > 1: + return self.k_fold.split(indices) + else: + return [(indices, indices)] + + +def calculate_roc(thresholds, + embeddings1, + embeddings2, + actual_issame, + nrof_folds=10, + pca=0): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + tprs = np.zeros((nrof_folds, nrof_thresholds)) + fprs = np.zeros((nrof_folds, nrof_thresholds)) + accuracy = np.zeros((nrof_folds)) + indices = np.arange(nrof_pairs) + + if pca == 0: + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + if pca > 0: + print('doing pca on', fold_idx) + embed1_train = embeddings1[train_set] + embed2_train = embeddings2[train_set] + _embed_train = np.concatenate((embed1_train, embed2_train), axis=0) + pca_model = PCA(n_components=pca) + pca_model.fit(_embed_train) + embed1 = pca_model.transform(embeddings1) + embed2 = pca_model.transform(embeddings2) + embed1 = sklearn.preprocessing.normalize(embed1) + embed2 = sklearn.preprocessing.normalize(embed2) + diff = np.subtract(embed1, embed2) + dist = np.sum(np.square(diff), 1) + + # Find the best threshold for the fold + acc_train = np.zeros((nrof_thresholds)) + for threshold_idx, threshold in enumerate(thresholds): + _, _, acc_train[threshold_idx] = calculate_accuracy( + threshold, dist[train_set], actual_issame[train_set]) + best_threshold_index = np.argmax(acc_train) + for threshold_idx, threshold in enumerate(thresholds): + tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy( + threshold, dist[test_set], + actual_issame[test_set]) + _, _, accuracy[fold_idx] = calculate_accuracy( + thresholds[best_threshold_index], dist[test_set], + actual_issame[test_set]) + + tpr = np.mean(tprs, 0) + fpr = np.mean(fprs, 0) + + # plot_roc_custom(fprs[4], tprs[4], '/mckim/temp/plot3.png', use_log_scale=True) + # plot_roc_custom(fprs[4], tprs[4], '/mckim/temp/plot4.png', use_log_scale=False) + # plot_roc_custom(fprs[4], tprs[4], '/mckim/temp/plot3.pdf', use_log_scale=True) + # plot_roc_custom(fprs[4], tprs[4], '/mckim/temp/plot4.pdf', use_log_scale=False) + # auc = sklearn.metrics.auc(fprs[4], tprs[4]) + # acc = np.mean(accuracy) + + return tpr, fpr, accuracy + +def plot_roc_custom(fprs, tprs, save_path, figsize=(6, 3), use_log_scale=True): + + plt.figure(figsize=figsize) + ax = plt.subplot(1, 1, 1) + if use_log_scale: + ax.set_title("Receiver Operating Characteristic on LFW Dataset (Log Scale)") + ax.set_xscale('log') + ax.set_xlim(1e-5, 1.0) + ax.set_ylim(0.9, 1.0) + else: + ax.set_title("Receiver Operating Characteristic on LFW Dataset") + ax.set_xlim(0, 1) + # ax.set_ylim(-0.05, 1.05) + + ax.set_xlabel('False Match Rate (FMR)') + ax.set_ylabel('True Match Rate (TMR)') + + plt.plot(fprs, tprs, color='blue', lw=1, label='ROC curve ViT (WebFace4M)') + plt.plot([0, 1], [0, 1], color='grey', lw=1, linestyle='--') + + + plt.legend() + plt.tight_layout() + plt.savefig(save_path) + plt.cla() + plt.clf() + + + +def calculate_accuracy(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + tp = np.sum(np.logical_and(predict_issame, actual_issame)) + fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame))) + tn = np.sum( + np.logical_and(np.logical_not(predict_issame), + np.logical_not(actual_issame))) + fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame)) + + tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn) + fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn) + acc = float(tp + tn) / dist.size + return tpr, fpr, acc + + +def calculate_val(thresholds, + embeddings1, + embeddings2, + actual_issame, + far_target, + nrof_folds=10): + assert (embeddings1.shape[0] == embeddings2.shape[0]) + assert (embeddings1.shape[1] == embeddings2.shape[1]) + nrof_pairs = min(len(actual_issame), embeddings1.shape[0]) + nrof_thresholds = len(thresholds) + k_fold = LFold(n_splits=nrof_folds, shuffle=False) + + val = np.zeros(nrof_folds) + far = np.zeros(nrof_folds) + + diff = np.subtract(embeddings1, embeddings2) + dist = np.sum(np.square(diff), 1) + indices = np.arange(nrof_pairs) + + for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)): + + # Find the threshold that gives FAR = far_target + far_train = np.zeros(nrof_thresholds) + for threshold_idx, threshold in enumerate(thresholds): + _, far_train[threshold_idx] = calculate_val_far( + threshold, dist[train_set], actual_issame[train_set]) + if np.max(far_train) >= far_target: + f = interpolate.interp1d(far_train, thresholds, kind='slinear') + threshold = f(far_target) + else: + threshold = 0.0 + + val[fold_idx], far[fold_idx] = calculate_val_far( + threshold, dist[test_set], actual_issame[test_set]) + + val_mean = np.mean(val) + far_mean = np.mean(far) + val_std = np.std(val) + return val_mean, val_std, far_mean + + +def calculate_val_far(threshold, dist, actual_issame): + predict_issame = np.less(dist, threshold) + true_accept = np.sum(np.logical_and(predict_issame, actual_issame)) + false_accept = np.sum( + np.logical_and(predict_issame, np.logical_not(actual_issame))) + n_same = np.sum(actual_issame) + n_diff = np.sum(np.logical_not(actual_issame)) + # print(true_accept, false_accept) + # print(n_same, n_diff) + val = float(true_accept) / float(n_same) + far = float(false_accept) / float(n_diff) + return val, far + + +def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): + # Calculate evaluation metrics + thresholds = np.arange(0, 4, 0.01) + embeddings1 = embeddings[0::2] + embeddings2 = embeddings[1::2] + tpr, fpr, accuracy = calculate_roc(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + nrof_folds=nrof_folds, + pca=pca) + thresholds = np.arange(0, 4, 0.001) + try: + val, val_std, far = calculate_val(thresholds, + embeddings1, + embeddings2, + np.asarray(actual_issame), + 1e-3, + nrof_folds=nrof_folds) + except ValueError as exc: + print(f'calculate_val failed: {exc}') + val, val_std, far = 0.0, 0.0, 0.0 + return tpr, fpr, accuracy, val, val_std, far + +@torch.no_grad() +def load_bin(path, image_size): + mx, nd = get_mxnet() + try: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f) # py2 + except UnicodeDecodeError as e: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f, encoding='bytes') # py3 + data_list = [] + for flip in [0, 1]: + data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1])) + data_list.append(data) + for idx in range(len(issame_list) * 2): + _bin = bins[idx] + img = mx.image.imdecode(_bin) + if img.shape[1] != image_size[0]: + img = mx.image.resize_short(img, image_size[0]) + img = nd.transpose(img, axes=(2, 0, 1)) + for flip in [0, 1]: + if flip == 1: + img = mx.ndarray.flip(data=img, axis=2) + data_list[flip][idx][:] = torch.from_numpy(img.asnumpy()) + if idx % 1000 == 0: + print('loading bin', idx) + print(data_list[0].shape) + return data_list, issame_list + +@torch.no_grad() +def test(data_set, backbone, batch_size, nfolds=10): + print('testing verification..') + data_list = data_set[0] + issame_list = data_set[1] + embeddings_list = [] + time_consumed = 0.0 + for i in range(len(data_list)): + data = data_list[i] + embeddings = None + ba = 0 + while ba < data.shape[0]: + bb = min(ba + batch_size, data.shape[0]) + count = bb - ba + _data = data[bb - batch_size: bb] + time0 = datetime.datetime.now() + img = ((_data / 255) - 0.5) / 0.5 + net_out: torch.Tensor = backbone(img) + _embeddings = net_out.detach().cpu().numpy() + time_now = datetime.datetime.now() + diff = time_now - time0 + time_consumed += diff.total_seconds() + if embeddings is None: + embeddings = np.zeros((data.shape[0], _embeddings.shape[1])) + embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :] + ba = bb + embeddings_list.append(embeddings) + + _xnorm = 0.0 + _xnorm_cnt = 0 + for embed in embeddings_list: + for i in range(embed.shape[0]): + _em = embed[i] + _norm = np.linalg.norm(_em) + _xnorm += _norm + _xnorm_cnt += 1 + _xnorm /= _xnorm_cnt + + embeddings = embeddings_list[0].copy() + embeddings = sklearn.preprocessing.normalize(embeddings) + acc1 = 0.0 + std1 = 0.0 + embeddings = embeddings_list[0] + embeddings_list[1] + embeddings = sklearn.preprocessing.normalize(embeddings) + print(embeddings.shape) + print('infer time', time_consumed) + _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) + acc2, std2 = np.mean(accuracy), np.std(accuracy) + return acc1, std1, acc2, std2, _xnorm, embeddings_list diff --git a/cvlface/research/recognition/code/run_v1/fabric/fabric.py b/cvlface/research/recognition/code/run_v1/fabric/fabric.py new file mode 100644 index 0000000000000000000000000000000000000000..d2d80ccb5c5cca0209795a35fbee001e2089ad24 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/fabric/fabric.py @@ -0,0 +1,183 @@ +from general_utils import dist_utils +import builtins as __builtin__ +builtin_print = __builtin__.print +import torch +from functools import partial +from .sampler import worker_init_fn, DDPWithAttribute +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data import DataLoader +import os, sys +import traceback + + +def setup_dataloader_from_dataset(dataset, + is_train, + batch_size, + num_workers, + seed, + fabric, + collate_fn=None): + + if seed is None: + init_fn = None + else: + init_fn = partial(worker_init_fn, num_workers=num_workers, rank=fabric.local_rank, seed=seed) + + if is_train: + sampler = DistributedSampler(dataset=dataset, num_replicas=fabric.world_size, + rank=fabric.local_rank, shuffle=True, drop_last=True, seed=seed) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, sampler=sampler, + collate_fn=collate_fn, worker_init_fn=init_fn, + drop_last=True, shuffle=(sampler is None), pin_memory=True, ) + + else: + sampler = DistributedSampler(dataset=dataset, num_replicas=fabric.world_size, + rank=fabric.local_rank, shuffle=False, drop_last=False, seed=seed) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, sampler=sampler, + collate_fn=collate_fn, worker_init_fn=init_fn, + drop_last=False, shuffle=(sampler is None), pin_memory=False, ) + dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False) + return dataloader + + # from functools import partial + # from torch.utils.data.distributed import DistributedSampler + # from torch.utils.data import DataLoader + # import random + # import numpy as np + # def worker_init_fn(worker_id, num_workers, rank, seed): + # # The seed of each worker equals to + # # num_worker * rank + worker_id + user_seed + # worker_seed = num_workers * rank + worker_id + seed + # np.random.seed(worker_seed) + # random.seed(worker_seed) + # torch.manual_seed(worker_seed) + # # needed for speeding up dataloader in torch 2.0.1 + # os.sched_setaffinity(0, range(os.cpu_count())) + # + # local_rank = fabric.local_rank + # world_size = fabric.world_size + # batch_size = cfg.trainers.batch_size + # seed = cfg.trainers.seed + # sampler = DistributedSampler(dataset=dataset, num_replicas=world_size, + # rank=local_rank, shuffle=True, drop_last=True, seed=seed) + # init_fn = partial(worker_init_fn, num_workers=num_workers, rank=local_rank, seed=seed) + # dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, sampler=sampler, + # worker_init_fn=init_fn, drop_last=True, shuffle=(sampler is None), pin_memory=True, + # collate_fn=collate_fn) + # dataloader = fabric.setup_dataloaders(dataloader, use_distributed_sampler=False) + + + +class Fabric(): + + def __init__(self, local_rank, world_size, precision, grad_max_norm=5, loggers=(), seed=None, cfg=None): + self.local_rank = local_rank + self.world_size = world_size + self.seed = seed + + + self.all_gather =dist_utils.all_gather + self.barrier = dist_utils.barrier + self.broadcast = dist_utils.broadcast + self.print = self.setup_print(local_rank, cfg) + + self.precision = precision + assert self.precision in ['16-mixed', '32-true'] + + if precision == '16-mixed': + self.amp = torch.cuda.amp.grad_scaler.GradScaler(growth_interval=100) + elif precision == '32-true': + self.amp = None + else: + raise ValueError(f'Unknown precision {precision}') + self.grad_max_norm = grad_max_norm + self.device = torch.device('cuda', local_rank) + self.loggers = loggers + self.cfg = cfg + + + def log_dict(self, x): + if self.local_rank == 0: + for logger in self.loggers: + logger.log_metrics(x) + + def backward(self, loss, optimizer, accumulate=False): + + param_gen = (param for param_group in optimizer.param_groups for param in param_group['params']) + + if not accumulate: + if self.amp is None: + loss.backward() + torch.nn.utils.clip_grad_norm_(param_gen, max_norm=self.grad_max_norm) + optimizer.step() + optimizer.zero_grad() + else: + self.amp.scale(loss).backward() + self.amp.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(param_gen, max_norm=self.grad_max_norm) + self.amp.step(optimizer) + self.amp.update() + optimizer.zero_grad() + + + def setup(self, model): + model = model.to(self.device) + if model.has_trainable_params(): + model_ddp = DDPWithAttribute( + module=model, broadcast_buffers=False, device_ids=[self.local_rank], bucket_cap_mb=16, + find_unused_parameters=False) + return model_ddp + else: + return model + + def setup_dataloader_from_dataset(self, dataset, is_train, batch_size, num_workers, collate_fn=None): + local_rank = self.local_rank + world_size = self.world_size + seed = self.seed + + if seed is None: + init_fn = None + else: + init_fn = partial(worker_init_fn, num_workers=num_workers, rank=local_rank, seed=seed) + + if is_train: + sampler = DistributedSampler(dataset=dataset, num_replicas=world_size, + rank=local_rank, shuffle=True, drop_last=True, seed=seed) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, sampler=sampler, + collate_fn=collate_fn, worker_init_fn=init_fn, + drop_last=True, shuffle=(sampler is None), pin_memory=True, ) + + else: + sampler = DistributedSampler(dataset=dataset, num_replicas=world_size, + rank=local_rank, shuffle=False, drop_last=False, seed=seed) + dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, sampler=sampler, + collate_fn=collate_fn, worker_init_fn=init_fn, + drop_last=False, shuffle=(sampler is None), pin_memory=False, ) + + return dataloader + + def launch(self): + pass + + def setup_print(self, rank, cfg): + is_master = rank == 0 + if cfg is not None: + save_path = os.path.join(cfg.trainers.output_dir, 'run_log.txt') if cfg is not None else './run_log.txt' + is_debug = cfg.trainers.debug + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + original_print = __builtin__.print + def custom_print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_master or force: + builtin_print(*args, **kwargs) + if not is_debug: + with open(save_path, 'a') as f: + original_print(*args, file=f, **kwargs) + + __builtin__.print = custom_print + else: + custom_print = __builtin__.print + return custom_print + + diff --git a/cvlface/research/recognition/code/run_v1/fabric/sampler.py b/cvlface/research/recognition/code/run_v1/fabric/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..77566297e7474b5d63cd960517da6e0065870b34 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/fabric/sampler.py @@ -0,0 +1,188 @@ +import math +import random +import torch.distributed as dist +from torch.utils.data import DistributedSampler as _DistributedSampler +import queue as Queue +import threading +import numpy as np +import torch +from torch.utils.data import DataLoader +import os + + +def worker_init_fn(worker_id, num_workers, rank, seed): + # The seed of each worker equals to + # num_worker * rank + worker_id + user_seed + worker_seed = num_workers * rank + worker_id + seed + np.random.seed(worker_seed) + random.seed(worker_seed) + torch.manual_seed(worker_seed) + # needed for speeding up dataloader in torch 2.0.1 + os.sched_setaffinity(0, range(os.cpu_count())) + + + +def get_dist_info(): + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + + return rank, world_size + + +def sync_random_seed(seed=None, device="cuda"): + """Make sure different ranks share the same seed. + All workers must call this function, otherwise it will deadlock. + This method is generally used in `DistributedSampler`, + because the seed should be identical across all processes + in the distributed group. + In distributed sampling, different ranks should sample non-overlapped + data in the dataset. Therefore, this function is used to make sure that + each rank shuffles the data indices in the same order based + on the same seed. Then different ranks could use different indices + to select non-overlapped data from the same data list. + Args: + seed (int, Optional): The seed. Default to None. + device (str): The device where the seed will be put on. + Default to 'cuda'. + Returns: + int: Seed to be used. + """ + if seed is None: + seed = np.random.randint(2**31) + assert isinstance(seed, int) + + rank, world_size = get_dist_info() + + if world_size == 1: + return seed + + if rank == 0: + random_num = torch.tensor(seed, dtype=torch.int32, device=device) + else: + random_num = torch.tensor(0, dtype=torch.int32, device=device) + + dist.broadcast(random_num, src=0) + + return random_num.item() + + +class DistributedSampler(_DistributedSampler): + def __init__( + self, + dataset, + num_replicas=None, # world_size + rank=None, # local_rank + shuffle=True, + seed=0, + ): + + super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) + + # In distributed sampling, different ranks should sample + # non-overlapped data in the dataset. Therefore, this function + # is used to make sure that each rank shuffles the data indices + # in the same order based on the same seed. Then different ranks + # could use different indices to select non-overlapped data from the + # same data list. + self.seed = sync_random_seed(seed) + + def __iter__(self): + # deterministically shuffle based on epoch + if self.shuffle: + g = torch.Generator() + # When :attr:`shuffle=True`, this ensures all replicas + # use a different random ordering for each epoch. + # Otherwise, the next iteration of this sampler will + # yield the same ordering. + g.manual_seed(self.epoch + self.seed) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = torch.arange(len(self.dataset)).tolist() + + # add extra samples to make it evenly divisible + # in case that indices is shorter than half of total_size + indices = (indices * math.ceil(self.total_size / len(indices)))[ + : self.total_size + ] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + + +class BackgroundGenerator(threading.Thread): + def __init__(self, generator, local_rank, max_prefetch=6): + super(BackgroundGenerator, self).__init__() + self.queue = Queue.Queue(max_prefetch) + self.generator = generator + self.local_rank = local_rank + self.daemon = True + self.start() + + def run(self): + torch.cuda.set_device(self.local_rank) + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def next(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __next__(self): + return self.next() + + def __iter__(self): + return self + + +class DataLoaderX(DataLoader): + + def __init__(self, local_rank, **kwargs): + super(DataLoaderX, self).__init__(**kwargs) + self.stream = torch.cuda.Stream(local_rank) + self.local_rank = local_rank + + def __iter__(self): + self.iter = super(DataLoaderX, self).__iter__() + self.iter = BackgroundGenerator(self.iter, self.local_rank) + self.preload() + return self + + def preload(self): + self.batch = next(self.iter, None) + if self.batch is None: + return None + with torch.cuda.stream(self.stream): + for k in range(len(self.batch)): + self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) + + def __next__(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + if batch is None: + raise StopIteration + self.preload() + return batch + + +class DDPWithAttribute(torch.nn.parallel.DistributedDataParallel): + def __getattr__(self, name): + # First, try to get the attribute from the DDP model itself + try: + return super().__getattr__(name) + except AttributeError: + pass + + # If not found, try to get it from the original model + return getattr(self.module, name) \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/losses/__init__.py b/cvlface/research/recognition/code/run_v1/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f61b64d3af2f8583dd68468855ed79b09e5be6 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/losses/__init__.py @@ -0,0 +1,26 @@ +from .margin_loss import CombinedMarginLoss +from .adaface import AdaFaceLoss + +def get_margin_loss(loss_config): + if loss_config.margin_loss_name == 'margin': + margin_loss = CombinedMarginLoss( + 64, + loss_config.margin_list[0], + loss_config.margin_list[1], + loss_config.margin_list[2], + loss_config.interclass_filtering_threshold + ) + elif loss_config.margin_loss_name == 'adaface': + margin_loss = AdaFaceLoss( + 64, + m=loss_config.m, + h=loss_config.h, + t_alpha=loss_config.t_alpha, + interclass_filtering_threshold=loss_config.interclass_filtering_threshold + ) + elif loss_config.margin_loss_name == 'none': + margin_loss = None + else: + raise ValueError("Not implemented loss margin_loss_name") + return margin_loss + diff --git a/cvlface/research/recognition/code/run_v1/losses/adaface.py b/cvlface/research/recognition/code/run_v1/losses/adaface.py new file mode 100644 index 0000000000000000000000000000000000000000..643c6c0b68df178e1ebc5a6c77e8c1e6668c4172 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/losses/adaface.py @@ -0,0 +1,66 @@ +import torch +import math + + +class AdaFaceLoss(torch.nn.Module): + def __init__(self, + s, + m, + h, + t_alpha, + interclass_filtering_threshold=0): + super().__init__() + self.s = s + self.m = m + self.h = h + self.t_alpha = t_alpha + self.interclass_filtering_threshold = interclass_filtering_threshold + self.eps = 1e-3 + + def forward(self, logits, labels, norms, batch_mean, batch_std): + index_positive = torch.where(labels != -1)[0] + + if self.interclass_filtering_threshold > 0: + with torch.no_grad(): + dirty = logits > self.interclass_filtering_threshold + dirty = dirty.float() + mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device) + mask.scatter_(1, labels[index_positive], 0) + dirty[index_positive] *= mask + tensor_mul = 1 - dirty + logits = tensor_mul * logits + + safe_norms = torch.clip(norms, min=0.001, max=100) # for stability + safe_norms = safe_norms.clone().detach() + + # update batchmean batchstd + with torch.no_grad(): + mean = safe_norms.mean().detach() + std = safe_norms.std().detach() + batch_mean = mean * self.t_alpha + (1 - self.t_alpha) * batch_mean + batch_std = std * self.t_alpha + (1 - self.t_alpha) * batch_std + + margin_scaler = (safe_norms - batch_mean) / (batch_std + self.eps) # 66% between -1, 1 + margin_scaler = margin_scaler * self.h # 68% between -0.333 ,0.333 when h:0.333 + margin_scaler = torch.clip(margin_scaler, -1, 1).view(-1) + margin_scaler = margin_scaler[index_positive] + + target_logit = logits[index_positive, labels[index_positive].view(-1)] + + + ######### + with torch.no_grad(): + # g_angular + target_logit.arccos_() + margin_final_logit = target_logit + (self.m * margin_scaler * -1) + margin_final_logit.cos_() + # g_additive + margin_final_logit = margin_final_logit - (self.m + (self.m * margin_scaler)) + # make margin_final_logit as same dtype as logits + margin_final_logit = margin_final_logit.type(logits.dtype) + logits[index_positive, labels[index_positive].view(-1)] = margin_final_logit + + # scale + logits = logits * self.s + + return logits, batch_mean, batch_std diff --git a/cvlface/research/recognition/code/run_v1/losses/configs/adaface.yaml b/cvlface/research/recognition/code/run_v1/losses/configs/adaface.yaml new file mode 100644 index 0000000000000000000000000000000000000000..09521ae034e2d3d68363782054918943ff9f9d72 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/losses/configs/adaface.yaml @@ -0,0 +1,10 @@ + +margin_loss_name: 'adaface' + +# partialFC +interclass_filtering_threshold: 0 + +# margin for combined CE +m: 0.4 +h: 0.333 +t_alpha: 0.01 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/losses/configs/arcface.yaml b/cvlface/research/recognition/code/run_v1/losses/configs/arcface.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a6cc1a5bbd1721cee04dfd13c06ecada7e8f9a1 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/losses/configs/arcface.yaml @@ -0,0 +1,8 @@ + +margin_loss_name: 'margin' + +# partialFC +interclass_filtering_threshold: 0 + +# margin for combined CE +margin_list: [ 1.0, 0.5, 0.0 ] \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/losses/configs/cosface.yaml b/cvlface/research/recognition/code/run_v1/losses/configs/cosface.yaml new file mode 100644 index 0000000000000000000000000000000000000000..535291fe8457b3d845e349de0424f42a9ceaf638 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/losses/configs/cosface.yaml @@ -0,0 +1,8 @@ + +margin_loss_name: 'margin' + +# partialFC +interclass_filtering_threshold: 0 + +# margin for combined CE +margin_list: [ 1.0, 0.0, 0.4 ] \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/losses/margin_loss.py b/cvlface/research/recognition/code/run_v1/losses/margin_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..59d43d742efc4e4216eea03c0eef022bf6316b2a --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/losses/margin_loss.py @@ -0,0 +1,102 @@ +import torch +import math + + + +class CombinedMarginLoss(torch.nn.Module): + def __init__(self, + s, + m1, + m2, + m3, + interclass_filtering_threshold=0): + super().__init__() + self.s = s + self.m1 = m1 + self.m2 = m2 + self.m3 = m3 + self.interclass_filtering_threshold = interclass_filtering_threshold + + # For ArcFace + self.cos_m = math.cos(self.m2) + self.sin_m = math.sin(self.m2) + self.theta = math.cos(math.pi - self.m2) + self.sinmm = math.sin(math.pi - self.m2) * self.m2 + self.easy_margin = False + + def forward(self, logits, labels): + index_positive = torch.where(labels != -1)[0] + + if self.interclass_filtering_threshold > 0: + with torch.no_grad(): + dirty = logits > self.interclass_filtering_threshold + dirty = dirty.float() + mask = torch.ones([index_positive.size(0), logits.size(1)], device=logits.device) + mask.scatter_(1, labels[index_positive], 0) + dirty[index_positive] *= mask + tensor_mul = 1 - dirty + logits = tensor_mul * logits + + target_logit = logits[index_positive, labels[index_positive].view(-1)] + + if self.m1 == 1.0 and self.m3 == 0.0: + with torch.no_grad(): + target_logit.arccos_() + logits.arccos_() + final_target_logit = target_logit + self.m2 + logits[index_positive, labels[index_positive].view(-1)] = final_target_logit + logits.cos_() + logits = logits * self.s + + elif self.m3 > 0: + final_target_logit = target_logit - self.m3 + logits[index_positive, labels[index_positive].view(-1)] = final_target_logit + logits = logits * self.s + else: + raise + + return logits + + + +class ArcFace(torch.nn.Module): + """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf): + """ + + def __init__(self, s=64.0, margin=0.5): + super(ArcFace, self).__init__() + self.scale = s + self.margin = margin + self.cos_m = math.cos(margin) + self.sin_m = math.sin(margin) + self.theta = math.cos(math.pi - margin) + self.sinmm = math.sin(math.pi - margin) * margin + self.easy_margin = False + + def forward(self, logits: torch.Tensor, labels: torch.Tensor): + index = torch.where(labels != -1)[0] + target_logit = logits[index, labels[index].view(-1)] + + with torch.no_grad(): + target_logit.arccos_() + logits.arccos_() + final_target_logit = target_logit + self.margin + logits[index, labels[index].view(-1)] = final_target_logit + logits.cos_() + logits = logits * self.s + return logits + + +class CosFace(torch.nn.Module): + def __init__(self, s=64.0, m=0.40): + super(CosFace, self).__init__() + self.s = s + self.m = m + + def forward(self, logits: torch.Tensor, labels: torch.Tensor): + index = torch.where(labels != -1)[0] + target_logit = logits[index, labels[index].view(-1)] + final_target_logit = target_logit - self.m + logits[index, labels[index].view(-1)] = final_target_logit + logits = logits * self.s + return logits diff --git a/cvlface/research/recognition/code/run_v1/models/__init__.py b/cvlface/research/recognition/code/run_v1/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2933b597048d2e72336f33badf0e299abbf0e7 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/__init__.py @@ -0,0 +1,46 @@ + + +def get_model(model_config, task=''): + + if '/vit/' in model_config.yaml_path: + from .vit import load_model as load_vit_model + model = load_vit_model(model_config) + print('Loaded ViT model') + elif '/vit_irpe/' in model_config.yaml_path: + from .vit_irpe import load_model as load_vit_irpe_model + model = load_vit_irpe_model(model_config) + print('Loaded ViT model with iRPE') + elif '/vit_kprpe/' in model_config.yaml_path: + from .vit_kprpe import load_model as load_vit_kprpe_model + model = load_vit_kprpe_model(model_config) + print('Loaded ViT model with KPRPE') + elif '/iresnet/' in model_config.yaml_path: + from .iresnet import load_model as load_iresnet_model + model = load_iresnet_model(model_config) + print('Loaded iResNet model') + elif '/iresnet_insightface/' in model_config.yaml_path: + from .iresnet_insightface import load_model as load_iresnet_insightface_model + model = load_iresnet_insightface_model(model_config) + print('Loaded iResNet model') + elif '/part_fvit/' in model_config.yaml_path: + from .part_fvit import load_model as load_part_fvit_model + model = load_part_fvit_model(model_config) + print('Loaded PartFVIT model') + elif '/swin/' in model_config.yaml_path: + from .swin import load_model as load_swin_model + model = load_swin_model(model_config) + print('Loaded Swin model') + elif '/swin_kprpe/' in model_config.yaml_path: + from .swin_kprpe import load_model as load_swin_kprpe_model + model = load_swin_kprpe_model(model_config) + print('Loaded Swin model with KPRPE') + else: + raise NotImplementedError(f"Model {model_config.yaml_path} not implemented") + if model_config.start_from: + model.load_state_dict_from_path(model_config.start_from) + + if model_config.freeze: + for param in model.parameters(): + param.requires_grad = False + + return model diff --git a/cvlface/research/recognition/code/run_v1/models/base/__init__.py b/cvlface/research/recognition/code/run_v1/models/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a77a8e8110737e9c57ea1d433cb84a3de7577f4a --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/base/__init__.py @@ -0,0 +1,209 @@ +import os +import math +from typing import Union +import torch +import torch.nn.functional as F +from torch import device +from .utils import get_parameter_device, get_parameter_dtype, save_state_dict_and_config, load_state_dict_from_path + +class BaseModel(torch.nn.Module): + """ + A base model class that provides a template for implementing models. It includes methods for + loading, saving, and managing model configurations and states. This class is designed to be + extended by specific model implementations. + + Attributes: + config (object): Configuration object containing model settings. + input_color_flip (bool): Whether to flip the color channels from BGR to RGB. + """ + + def __init__(self, config=None): + """ + Initializes the BaseModel class. + + Parameters: + config (object, optional): Configuration object containing model settings. + """ + super(BaseModel, self).__init__() + self.config = config + if self.config.color_space == 'BGR': + self.input_color_flip = True + self._config_color_space = 'BGR' + self.config.color_space = 'RGB' + else: + self.input_color_flip = False + + def forward(self, x): + """ + Forward pass of the model. Needs to be implemented in subclass. + + Parameters: + x (torch.Tensor): Input tensor. + + Raises: + NotImplementedError: If the subclass does not implement this method. + """ + raise NotImplementedError('forward must be implemented in subclass') + + @classmethod + def from_config(cls, config) -> "BaseModel": + """ + Creates an instance of this class from a configuration object. Needs to be implemented in subclass. + + Parameters: + config (object): Configuration object. + + Returns: + BaseModel: An instance of the subclass. + + Raises: + NotImplementedError: If the subclass does not implement this method. + """ + raise NotImplementedError('from_config must be implemented in subclass') + + def make_train_transform(self): + """ + Creates training data transformations. Needs to be implemented in subclass. + + Raises: + NotImplementedError: If the subclass does not implement this method. + """ + raise NotImplementedError('make_train_transform must be implemented in subclass') + + def make_test_transform(self): + """ + Creates testing data transformations. Needs to be implemented in subclass. + + Raises: + NotImplementedError: If the subclass does not implement this method. + """ + raise NotImplementedError('make_test_transform must be implemented in subclass') + + def save_pretrained( + self, + save_dir: Union[str, os.PathLike], + name: str = 'model.pt', + rank: int = 0, + ): + """ + Saves the model's state_dict and configuration to the specified directory. + + Parameters: + save_dir (Union[str, os.PathLike]): The directory to save the model. + name (str, optional): The name of the file to save the model as. Default is 'model.pt'. + rank (int, optional): The rank of the process (used in distributed training). Default is 0. + """ + save_path = os.path.join(save_dir, name) + if rank == 0: + save_state_dict_and_config(self.state_dict(), self.config, save_path) + + def load_state_dict_from_path(self, pretrained_model_path): + state_dict = load_state_dict_from_path(pretrained_model_path) + if 'net.vit' in list(self.state_dict().keys())[-1] and 'pretrained_models' in pretrained_model_path: + state_dict = {k.replace('net', 'net.vit'): v for k, v in state_dict.items()} + + current_state_dict = self.state_dict() + if not any(key in current_state_dict for key in state_dict) and any(key.startswith('net.') for key in current_state_dict): + state_dict = {f'net.{key}': value for key, value in state_dict.items()} + filtered_state_dict = {} + skipped_shape = [] + resized_relative_position_bias = [] + for key, value in state_dict.items(): + if key not in current_state_dict: + continue + if current_state_dict[key].shape != value.shape: + if key.endswith('relative_position_bias_table'): + resized_value = self.resize_relative_position_bias_table( + value, + tuple(current_state_dict[key].shape), + ) + if resized_value is not None: + filtered_state_dict[key] = resized_value + resized_relative_position_bias.append( + (key, tuple(value.shape), tuple(resized_value.shape)) + ) + continue + skipped_shape.append((key, tuple(value.shape), tuple(current_state_dict[key].shape))) + continue + filtered_state_dict[key] = value + + st_keys = list(state_dict.keys()) + self_keys = list(current_state_dict.keys()) + print('compatible keys in state_dict', len(filtered_state_dict), '/', len(st_keys)) + if resized_relative_position_bias: + print('resized relative_position_bias_table keys', + resized_relative_position_bias[:10], 'total', len(resized_relative_position_bias)) + if skipped_shape: + print('skipped shape-mismatched keys', skipped_shape[:10], 'total', len(skipped_shape)) + print('Check\n\n') + result = self.load_state_dict(filtered_state_dict, strict=False) + print(result) + print(f"Loaded pretrained model from {pretrained_model_path}") + + @staticmethod + def resize_relative_position_bias_table(value, target_shape): + if value.ndim != 2 or len(target_shape) != 2: + return None + source_length, source_heads = value.shape + target_length, target_heads = target_shape + if source_heads != target_heads: + return None + + source_size = int(math.sqrt(source_length)) + target_size = int(math.sqrt(target_length)) + if source_size * source_size != source_length: + return None + if target_size * target_size != target_length: + return None + + value_float = value.float().permute(1, 0).reshape(1, source_heads, source_size, source_size) + resized = F.interpolate( + value_float, + size=(target_size, target_size), + mode='bicubic', + align_corners=False, + ) + resized = resized.reshape(source_heads, target_length).permute(1, 0).contiguous() + return resized.to(dtype=value.dtype, device=value.device) + + + @property + def device(self) -> device: + """ + Returns the device of the model's parameters. + + Returns: + device: The device the model is on. + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + Returns the data type of the model's parameters. + + Returns: + torch.dtype: The data type of the model. + """ + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False) -> int: + """ + Returns the number of parameters in the model, optionally filtering only trainable parameters. + + Parameters: + only_trainable (bool, optional): Whether to count only trainable parameters. Default is False. + + Returns: + int: The number of parameters. + """ + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) + + def has_trainable_params(self): + """ + Checks if the model has any trainable parameters. + + Returns: + bool: True if the model has trainable parameters, False otherwise. + """ + return any(p.requires_grad for p in self.parameters()) diff --git a/cvlface/research/recognition/code/run_v1/models/base/configs/example.yaml b/cvlface/research/recognition/code/run_v1/models/base/configs/example.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3447d03e5ef6913edac2e178328a6d59fe2aee87 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/base/configs/example.yaml @@ -0,0 +1,6 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: '' +output_dim: 512 +start_from: '' +freeze: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/base/utils.py b/cvlface/research/recognition/code/run_v1/models/base/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cb9f6fd290ede2f2ced6959ddb57b8ba56fc9fdd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/base/utils.py @@ -0,0 +1,91 @@ +import itertools +from typing import List, Optional, Tuple, Union +import safetensors +import torch +from torch import Tensor +import os +from pathlib import Path +from omegaconf import DictConfig, OmegaConf + + +def get_parameter_device(parameter: torch.nn.Module): + try: + parameters_and_buffers = itertools.chain(parameter.parameters(), parameter.buffers()) + return next(parameters_and_buffers).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + params = tuple(parameter.parameters()) + if len(params) > 0: + return params[0].dtype + + buffers = tuple(parameter.buffers()) + if len(buffers) > 0: + return buffers[0].dtype + + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def get_parent_directory(save_path: Union[str, os.PathLike]) -> Path: + path_obj = Path(save_path) + return path_obj.parent + +def get_base_name(save_path: Union[str, os.PathLike]) -> str: + path_obj = Path(save_path) + return path_obj.name + +def load_state_dict_from_path(path: Union[str, os.PathLike]): + # Load a state dict from a path. + if 'safetensors' in path: + state_dict = safetensors.torch.load_file(path) + else: + state_dict = torch.load(path, map_location="cpu") + return state_dict + +def replace_extension(path, new_extension): + if not new_extension.startswith('.'): + new_extension = '.' + new_extension + return os.path.splitext(path)[0] + new_extension + +def make_config_path(save_path): + config_path = replace_extension(save_path, '.yaml') + return config_path + +def save_config(config, config_path): + assert isinstance(config, dict) or isinstance(config, DictConfig) + os.makedirs(get_parent_directory(config_path), exist_ok=True) + if isinstance(config, dict): + config = OmegaConf.create(config) + OmegaConf.save(config, config_path) + + +def save_state_dict_and_config(state_dict, config, save_path): + os.makedirs(get_parent_directory(save_path), exist_ok=True) + + # save config dict + config_path = make_config_path(save_path) + save_config(config, config_path) + + # Save the model + if 'safetensors' in save_path: + safetensors.torch.save_file(state_dict, save_path, metadata={"format": "pt"}) + else: + torch.save(state_dict, save_path) diff --git a/cvlface/research/recognition/code/run_v1/models/iresnet/__init__.py b/cvlface/research/recognition/code/run_v1/models/iresnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5c6b5b48fb6ef4b1726f37901d9751c8adca45 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/iresnet/__init__.py @@ -0,0 +1,60 @@ +from ..base import BaseModel +from .model import IR_101, IR_50, IR_18 +from torchvision import transforms + + +class IResNetModel(BaseModel): + + """ + A class representing a model for IResNet architectures. It supports creating + models with specific configurations such as IR_50 and IR_101. + + Attributes: + net (torch.nn.Module): The IResNet network (either IR_50 or IR_101). + config (object): The configuration object with model specifications. + """ + + + def __init__(self, net, config): + super(IResNetModel, self).__init__(config) + self.net = net + self.config = config + + + @classmethod + def from_config(cls, config): + if config.name == 'ir50': + net = IR_50(input_size=(112,112), output_dim=config.output_dim) + elif config.name == 'ir101': + net = IR_101(input_size=(112,112), output_dim=config.output_dim) + elif config.name == 'ir18': + net = IR_18(input_size=(112,112), output_dim=config.output_dim) + else: + raise NotImplementedError + + model = cls(net, config) + model.eval() + return model + + def forward(self, x): + if self.input_color_flip: + x = x.flip(1) + return self.net(x) + + def make_train_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + def make_test_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + +def load_model(model_config): + model = IResNetModel.from_config(model_config) + return model \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/iresnet/configs/v1_ir101.yaml b/cvlface/research/recognition/code/run_v1/models/iresnet/configs/v1_ir101.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a9a0bac53502c9838b746e90bf88977162953e2 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/iresnet/configs/v1_ir101.yaml @@ -0,0 +1,6 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'ir101' +output_dim: 512 +start_from: '' +freeze: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/iresnet/configs/v1_ir18.yaml b/cvlface/research/recognition/code/run_v1/models/iresnet/configs/v1_ir18.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28a772a344f7837c9dfc0519756c5440137ac6dc --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/iresnet/configs/v1_ir18.yaml @@ -0,0 +1,6 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'ir18' +output_dim: 512 +start_from: '' +freeze: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/iresnet/configs/v1_ir50.yaml b/cvlface/research/recognition/code/run_v1/models/iresnet/configs/v1_ir50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..67399c1a6f9173a160451c812a38878807b35fe2 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/iresnet/configs/v1_ir50.yaml @@ -0,0 +1,6 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'ir50' +output_dim: 512 +start_from: '' +freeze: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/iresnet/model.py b/cvlface/research/recognition/code/run_v1/models/iresnet/model.py new file mode 100644 index 0000000000000000000000000000000000000000..62d65d1076d794abba5259735b651ef067efd17d --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/iresnet/model.py @@ -0,0 +1,340 @@ +from collections import namedtuple +from torch.nn import Dropout +from torch.nn import MaxPool2d +from torch.nn import Sequential +import torch +import torch.nn as nn +from torch.nn import Conv2d, Linear +from torch.nn import BatchNorm1d, BatchNorm2d +from torch.nn import ReLU, Sigmoid +from torch.nn import Module +from torch.nn import PReLU +from fvcore.nn import flop_count +import numpy as np + + +def initialize_weights(modules): + for m in modules: + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, + mode='fan_out', + nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_normal_(m.weight, + mode='fan_out', + nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + + +class Flatten(Module): + def forward(self, input): + return input.view(input.size(0), -1) + + +class LinearBlock(Module): + def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1): + super(LinearBlock, self).__init__() + self.conv = Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False) + self.bn = BatchNorm2d(out_c) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + +class SEModule(Module): + def __init__(self, channels, reduction): + super(SEModule, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc1 = Conv2d(channels, channels // reduction, + kernel_size=1, padding=0, bias=False) + + nn.init.xavier_uniform_(self.fc1.weight.data) + + self.relu = ReLU(inplace=True) + self.fc2 = Conv2d(channels // reduction, channels, + kernel_size=1, padding=0, bias=False) + + self.sigmoid = Sigmoid() + + def forward(self, x): + module_input = x + x = self.avg_pool(x) + x = self.fc1(x) + x = self.relu(x) + x = self.fc2(x) + x = self.sigmoid(x) + + return module_input * x + + + +class BasicBlockIR(Module): + def __init__(self, in_channel, depth, stride): + super(BasicBlockIR, self).__init__() + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), + BatchNorm2d(depth), + PReLU(depth), + Conv2d(depth, depth, (3, 3), stride, 1, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + + return res + shortcut + + +class BottleneckIR(Module): + def __init__(self, in_channel, depth, stride): + super(BottleneckIR, self).__init__() + reduction_channel = depth // 4 + if in_channel == depth: + self.shortcut_layer = MaxPool2d(1, stride) + else: + self.shortcut_layer = Sequential( + Conv2d(in_channel, depth, (1, 1), stride, bias=False), + BatchNorm2d(depth)) + self.res_layer = Sequential( + BatchNorm2d(in_channel), + Conv2d(in_channel, reduction_channel, (1, 1), (1, 1), 0, bias=False), + BatchNorm2d(reduction_channel), + PReLU(reduction_channel), + Conv2d(reduction_channel, reduction_channel, (3, 3), (1, 1), 1, bias=False), + BatchNorm2d(reduction_channel), + PReLU(reduction_channel), + Conv2d(reduction_channel, depth, (1, 1), stride, 0, bias=False), + BatchNorm2d(depth)) + + def forward(self, x): + shortcut = self.shortcut_layer(x) + res = self.res_layer(x) + + return res + shortcut + + +class BasicBlockIRSE(BasicBlockIR): + def __init__(self, in_channel, depth, stride): + super(BasicBlockIRSE, self).__init__(in_channel, depth, stride) + self.res_layer.add_module("se_block", SEModule(depth, 16)) + + +class BottleneckIRSE(BottleneckIR): + def __init__(self, in_channel, depth, stride): + super(BottleneckIRSE, self).__init__(in_channel, depth, stride) + self.res_layer.add_module("se_block", SEModule(depth, 16)) + + +class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): + pass + + +def get_block(in_channel, depth, num_units, stride=2): + + return [Bottleneck(in_channel, depth, stride)] + \ + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] + + +def get_blocks(num_layers): + if num_layers == 18: + blocks = [ + get_block(in_channel=64, depth=64, num_units=2), + get_block(in_channel=64, depth=128, num_units=2), + get_block(in_channel=128, depth=256, num_units=2), + get_block(in_channel=256, depth=512, num_units=2) + ] + elif num_layers == 34: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=6), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 50: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=4), + get_block(in_channel=128, depth=256, num_units=14), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 100: + blocks = [ + get_block(in_channel=64, depth=64, num_units=3), + get_block(in_channel=64, depth=128, num_units=13), + get_block(in_channel=128, depth=256, num_units=30), + get_block(in_channel=256, depth=512, num_units=3) + ] + elif num_layers == 152: + blocks = [ + get_block(in_channel=64, depth=256, num_units=3), + get_block(in_channel=256, depth=512, num_units=8), + get_block(in_channel=512, depth=1024, num_units=36), + get_block(in_channel=1024, depth=2048, num_units=3) + ] + elif num_layers == 200: + blocks = [ + get_block(in_channel=64, depth=256, num_units=3), + get_block(in_channel=256, depth=512, num_units=24), + get_block(in_channel=512, depth=1024, num_units=36), + get_block(in_channel=1024, depth=2048, num_units=3) + ] + + return blocks + + +class Backbone(Module): + + def __init__(self, input_size, num_layers, mode='ir', flip=False, output_dim=512): + super(Backbone, self).__init__() + assert input_size[0] in [112, 224], \ + "input_size should be [112, 112] or [224, 224]" + assert num_layers in [18, 34, 50, 100, 152, 200], \ + "num_layers should be 18, 34, 50, 100 or 152" + assert mode in ['ir', 'ir_se'], \ + "mode should be ir or ir_se" + self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), + BatchNorm2d(64), PReLU(64)) + blocks = get_blocks(num_layers) + if num_layers <= 100: + if mode == 'ir': + unit_module = BasicBlockIR + elif mode == 'ir_se': + unit_module = BasicBlockIRSE + output_channel = 512 + else: + if mode == 'ir': + unit_module = BottleneckIR + elif mode == 'ir_se': + unit_module = BottleneckIRSE + output_channel = 2048 + + if input_size[0] == 112: + self.output_layer = Sequential(BatchNorm2d(output_channel), + Dropout(0.4), Flatten(), + Linear(output_channel * 7 * 7, output_dim), + BatchNorm1d(output_dim, affine=False)) + else: + self.output_layer = Sequential( + BatchNorm2d(output_channel), Dropout(0.4), Flatten(), + Linear(output_channel * 14 * 14, output_dim), + BatchNorm1d(output_dim, affine=False)) + + modules = [] + for block in blocks: + for bottleneck in block: + modules.append( + unit_module(bottleneck.in_channel, bottleneck.depth, + bottleneck.stride)) + self.body = Sequential(*modules) + + initialize_weights(self.modules()) + + self.flip = flip + + + def forward(self, x): + + if self.flip: + x = x.flip(1) # color channel flip + + x = self.input_layer(x) + for idx, module in enumerate(self.body): + x = module(x) + + x = self.output_layer(x) + return x + + + +def IR_18(input_size, output_dim=512): + model = Backbone(input_size, 18, 'ir', output_dim=output_dim) + + return model + + +def IR_34(input_size, output_dim=512): + model = Backbone(input_size, 34, 'ir', output_dim=output_dim) + + return model + + +def IR_50(input_size, output_dim=512): + model = Backbone(input_size, 50, 'ir', output_dim=output_dim) + + return model + + +def IR_101(input_size, output_dim=512): + model = Backbone(input_size, 100, 'ir', output_dim=output_dim) + + return model + + +def IR_101_FLIP(input_size, output_dim=512): + model = Backbone(input_size, 100, 'ir', flip=True, output_dim=output_dim) + + return model + + + +def IR_152(input_size, output_dim=512): + model = Backbone(input_size, 152, 'ir', output_dim=output_dim) + + return model + + +def IR_200(input_size, output_dim=512): + model = Backbone(input_size, 200, 'ir', output_dim=output_dim) + + return model + + +def IR_SE_50(input_size, output_dim=512): + model = Backbone(input_size, 50, 'ir_se', output_dim=output_dim) + + return model + + +def IR_SE_101(input_size, output_dim=512): + model = Backbone(input_size, 100, 'ir_se', output_dim=output_dim) + + return model + + +def IR_SE_152(input_size, output_dim=512): + model = Backbone(input_size, 152, 'ir_se', output_dim=output_dim) + + return model + + +def IR_SE_200(input_size, output_dim=512): + model = Backbone(input_size, 200, 'ir_se', output_dim=output_dim) + + return model + + +if __name__ == '__main__': + + inputs_shape = (1, 3, 112, 112) + model = IR_50(input_size=(112,112)) + model.eval() + res = flop_count(model, inputs=torch.randn(inputs_shape), supported_ops={}) + fvcore_flop = np.array(list(res[0].values())).sum() + print('FLOPs: ', fvcore_flop / 1e9, 'G') + print('Num Params: ', sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 'M') + + diff --git a/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/__init__.py b/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd67a5170dfdfabe233ba4ceff0cf53fefe9ad59 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/__init__.py @@ -0,0 +1,60 @@ +from ..base import BaseModel +from .model import iresnet100, iresnet50, iresnet18 +from torchvision import transforms + + +class IResNetModel(BaseModel): + + """ + A class representing a model for IResNet architectures. It supports creating + models with specific configurations such as IR_50 and IR_101. + + Attributes: + net (torch.nn.Module): The IResNet network (either IR_50 or IR_101). + config (object): The configuration object with model specifications. + """ + + + def __init__(self, net, config): + super(IResNetModel, self).__init__(config) + self.net = net + self.config = config + + + @classmethod + def from_config(cls, config): + if config.name == 'ir50': + net = iresnet50(input_size=(112,112), output_dim=config.output_dim) + elif config.name == 'ir101': + net = iresnet100() + elif config.name == 'ir18': + net = iresnet18(input_size=(112,112), output_dim=config.output_dim) + else: + raise NotImplementedError + + model = cls(net, config) + model.eval() + return model + + def forward(self, x): + if self.input_color_flip: + x = x.flip(1) + return self.net(x) + + def make_train_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + def make_test_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + +def load_model(model_config): + model = IResNetModel.from_config(model_config) + return model \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/configs/v1_ir101.yaml b/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/configs/v1_ir101.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3a9a0bac53502c9838b746e90bf88977162953e2 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/configs/v1_ir101.yaml @@ -0,0 +1,6 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'ir101' +output_dim: 512 +start_from: '' +freeze: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/configs/v1_ir18.yaml b/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/configs/v1_ir18.yaml new file mode 100644 index 0000000000000000000000000000000000000000..28a772a344f7837c9dfc0519756c5440137ac6dc --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/configs/v1_ir18.yaml @@ -0,0 +1,6 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'ir18' +output_dim: 512 +start_from: '' +freeze: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/configs/v1_ir50.yaml b/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/configs/v1_ir50.yaml new file mode 100644 index 0000000000000000000000000000000000000000..67399c1a6f9173a160451c812a38878807b35fe2 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/configs/v1_ir50.yaml @@ -0,0 +1,6 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'ir50' +output_dim: 512 +start_from: '' +freeze: False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/model.py b/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/model.py new file mode 100644 index 0000000000000000000000000000000000000000..2f8aff6d6e3a75cb8ed839cd35489f2d3db88023 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/iresnet_insightface/model.py @@ -0,0 +1,184 @@ +import torch +from torch import nn + +__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200'] + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, + out_planes, + kernel_size=1, + stride=stride, + bias=False) + + +class IBasicBlock(nn.Module): + expansion = 1 + def __init__(self, inplanes, planes, stride=1, downsample=None, + groups=1, base_width=64, dilation=1): + super(IBasicBlock, self).__init__() + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,) + self.conv1 = conv3x3(inplanes, planes) + self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,) + self.prelu = nn.PReLU(planes) + self.conv2 = conv3x3(planes, planes, stride) + self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu(out) + out = self.conv2(out) + out = self.bn3(out) + if self.downsample is not None: + identity = self.downsample(x) + out += identity + return out + + +class IResNet(nn.Module): + fc_scale = 7 * 7 + def __init__(self, + block, layers, dropout=0, num_features=512, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None): + super(IResNet, self).__init__() + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05) + self.prelu = nn.PReLU(self.inplanes) + self.layer1 = self._make_layer(block, 64, layers[0], stride=2) + self.layer2 = self._make_layer(block, + 128, + layers[1], + stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, + 256, + layers[2], + stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, + 512, + layers[3], + stride=2, + dilate=replace_stride_with_dilation[2]) + self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,) + self.dropout = nn.Dropout(p=dropout, inplace=True) + self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features) + self.features = nn.BatchNorm1d(num_features, eps=1e-05) + nn.init.constant_(self.features.weight, 1.0) + self.features.weight.requires_grad = False + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.normal_(m.weight, 0, 0.1) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + if zero_init_residual: + for m in self.modules(): + if isinstance(m, IBasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ), + ) + layers = [] + layers.append( + block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block(self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.bn2(x) + x = torch.flatten(x, 1) + x = self.dropout(x) + x = self.fc(x) + x = self.features(x) + return x + + +def _iresnet(arch, block, layers, pretrained, progress, **kwargs): + model = IResNet(block, layers, **kwargs) + if pretrained: + raise ValueError() + return model + + +def iresnet18(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, + progress, **kwargs) + + +def iresnet34(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, + progress, **kwargs) + + +def iresnet50(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, + progress, **kwargs) + + +def iresnet100(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, + progress, **kwargs) + + +def iresnet200(pretrained=False, progress=True, **kwargs): + return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained, + progress, **kwargs) diff --git a/cvlface/research/recognition/code/run_v1/models/part_fvit/__init__.py b/cvlface/research/recognition/code/run_v1/models/part_fvit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22ce87d7e79863c317a4f7c2c8d2d495cdedeac7 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/part_fvit/__init__.py @@ -0,0 +1,60 @@ +from ..base import BaseModel +from .vit import VisionTransformer +from torchvision import transforms +from .part_fvit import PartFVIT + +class PartFViTModel(BaseModel): + + """ + A PartFViT Model integrating a Vision Transformer (ViT) with additional functionality for part-based feature vision transformer (PartFVIT). + Sun, Zhonglin, and Georgios Tzimiropoulos. "Part-based face recognition with vision transformers." [arXiv preprint arXiv:2212.00057 (2022)](https://arxiv.org/abs/2212.00057). + """ + + def __init__(self, net, config): + super(PartFViTModel, self).__init__(config) + self.net = net + + @classmethod + def from_config(cls, config): + + if config.name == 'small': + net = VisionTransformer(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=12, + mlp_ratio=5, num_heads=8, drop_path_rate=0.1, norm_layer="ln", + mask_ratio=config.mask_ratio) + fvit = PartFVIT(net, num_patch=196, patch_size=8) + elif config.name == 'base': + net = VisionTransformer(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=24, + mlp_ratio=3, num_heads=16, drop_path_rate=0.1, norm_layer="ln", + mask_ratio=config.mask_ratio) + fvit = PartFVIT(net, num_patch=196, patch_size=8) + else: + raise NotImplementedError + + model = cls(fvit, config) + model.eval() + return model + + def forward(self, x): + + if self.input_color_flip: + x = x.flip(1) + return self.net(x) + + def make_train_transform(self): + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + def make_test_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + +def load_model(model_config): + model = PartFViTModel.from_config(model_config) + return model \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/part_fvit/configs/v1_base.yaml b/cvlface/research/recognition/code/run_v1/models/part_fvit/configs/v1_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f604e8833cdbab6a9cb85fae8fee26f8ac7f8be4 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/part_fvit/configs/v1_base.yaml @@ -0,0 +1,7 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'base' +output_dim: 512 +start_from: '' +freeze: False +mask_ratio: 0.0 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/part_fvit/configs/v1_small.yaml b/cvlface/research/recognition/code/run_v1/models/part_fvit/configs/v1_small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7d024cd34519f4ddc38429ba65e9080509f1a9b7 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/part_fvit/configs/v1_small.yaml @@ -0,0 +1,7 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'small' +output_dim: 512 +start_from: '' +freeze: False +mask_ratio: 0.0 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/part_fvit/part_fvit.py b/cvlface/research/recognition/code/run_v1/models/part_fvit/part_fvit.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1a626c706236b8f976626e0771e57fd6edef8d --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/part_fvit/part_fvit.py @@ -0,0 +1,65 @@ +from torchvision.models import mobilenet_v3_small +import torch +from torch import nn as nn +import torch.nn.functional as F + + + +class PartFVIT(nn.Module): + + def __init__(self, vit, num_patch, patch_size=8): + super(PartFVIT, self).__init__() + self.num_patch = num_patch + self.patch_size = patch_size + self.mobilenet = mobilenet_v3_small(weights=None, num_classes=num_patch * 2) + self.mobilenet.classifier = nn.Identity() + self.mobilenet.avgpool = nn.Identity() + out = self.mobilenet(torch.randn(1, 3, 112, 112)) + c_in_features = out.shape[1] + fc_loc = nn.Linear(c_in_features, num_patch * 2) + + num_step = int(self.num_patch ** 0.5) + linspace = torch.linspace(-1, 1, num_step) + grid_x, grid_y = torch.meshgrid(linspace, linspace, indexing='ij') + grid_x = grid_x.reshape(-1) + grid_y = grid_y.reshape(-1) + bias = torch.stack([grid_x, grid_y], dim=1).view(-1) + fc_loc.weight.data.zero_() + fc_loc.bias.data.copy_(bias) + + self.mobilenet.classifier = fc_loc + self.mobilenet(torch.randn(1, 3, 112, 112)) + + self.patch_emb = nn.Linear(patch_size * patch_size * 3, vit.embed_dim) + self.vit = vit + + def forward_patch_stn(self, x): + coord_pred = self.mobilenet(x) + coord_pred = coord_pred.view(-1, self.num_patch, 2) + + image_side = x.shape[-1] + scale = (self.patch_size - 1) / (image_side - 1) + batch_size = coord_pred.shape[0] + all_patches = [] + for center in coord_pred.transpose(0, 1): + theta = torch.tensor([[[scale, 0, 0], [0, scale, 0]]], + requires_grad=True, dtype=coord_pred.dtype, device=coord_pred.device) + theta = theta.repeat(batch_size, 1, 1) + theta[:, :, -1] = center + grid = F.affine_grid(theta, [batch_size, 3, self.patch_size, self.patch_size], align_corners=True) + patches = F.grid_sample(x, grid, align_corners=True) + all_patches.append(patches) + all_patches = torch.stack(all_patches, 1).view(batch_size, self.num_patch, -1) + return all_patches + + def forward(self, x): + all_patches = self.forward_patch_stn(x) + all_patches = self.patch_emb(all_patches) + return self.vit(all_patches) + + +if __name__ == '__main__': + model = PartFVIT() + x = torch.randn(3, 3, 112, 112, requires_grad=True) + out = model(x) + diff --git a/cvlface/research/recognition/code/run_v1/models/part_fvit/vit.py b/cvlface/research/recognition/code/run_v1/models/part_fvit/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..434bbee0c7a6a3330f175dd9c2257d00e6a71b07 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/part_fvit/vit.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from typing import Optional, Callable +from fvcore.nn import flop_count +import numpy as np + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class VITBatchNorm(nn.Module): + def __init__(self, num_features): + super().__init__() + self.num_features = num_features + self.bn = nn.BatchNorm1d(num_features=num_features) + + def forward(self, x): + return self.bn(x) + + +class Attention(nn.Module): + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + attn_drop: float = 0., + proj_drop: float = 0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + + batch_size, num_token, embed_dim = x.shape + #qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads] + qkv = self.qkv(x).reshape( + batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(batch_size, num_token, embed_dim) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, + dim: int, + num_heads: int, + num_patches: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.ReLU6, + norm_layer: str = "ln", + patch_n: int = 144): + super().__init__() + + if norm_layer == "bn": + self.norm1 = VITBatchNorm(num_features=num_patches) + self.norm2 = VITBatchNorm(num_features=num_patches) + elif norm_layer == "ln": + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=act_layer, drop=drop) + self.extra_gflops = (num_heads * patch_n * (dim//num_heads)*patch_n * 2) / (1000**3) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * \ + (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d(in_channels, embed_dim, + kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + batch_size, channels, height, width = x.shape + assert height == self.img_size[0] and width == self.img_size[1], \ + f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + img_size: int = 112, + patch_size: int = 16, + in_channels: int = 3, + num_classes: int = 1000, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_patches: Optional[int] = None, + norm_layer: str = "ln", + mask_ratio = 0.1, + using_checkpoint = False, + ): + super().__init__() + self.num_classes = num_classes + # num_features for consistency with other models + self.num_features = self.embed_dim = embed_dim + + self.patch_embed = nn.Identity() + num_patches = (img_size // patch_size) ** 2 + self.mask_ratio = mask_ratio + self.using_checkpoint = using_checkpoint + + self.num_patches = num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + patch_n = (img_size//patch_size)**2 + self.blocks = nn.ModuleList( + [ + Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + num_patches=num_patches, patch_n=patch_n) + for i in range(depth)] + ) + self.extra_gflops = 0.0 + for _block in self.blocks: + self.extra_gflops += _block.extra_gflops + + if norm_layer == "ln": + self.norm = nn.LayerNorm(embed_dim) + elif norm_layer == "bn": + self.norm = VITBatchNorm(self.num_patches) + + # features head + self.feature = nn.Sequential( + nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False), + nn.BatchNorm1d(num_features=embed_dim, eps=2e-5), + nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False), + nn.BatchNorm1d(num_features=num_classes, eps=2e-5) + ) + + if self.mask_ratio == 0: + pass + else: + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + torch.nn.init.normal_(self.mask_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + # trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def random_masking(self, x, mask_ratio=0.1): + N, L, D = x.size() # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + # ascend: small is keep, large is remove + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + index = ids_keep.unsqueeze(-1).repeat(1, 1, D) + x_masked = torch.gather(x, dim=1, index=index) + + return x_masked, index, ids_restore + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + if self.training and self.mask_ratio > 0: + x, _, ids_restore = self.random_masking(x) + + for func in self.blocks: + if self.using_checkpoint and self.training: + from torch.utils.checkpoint import checkpoint + x = checkpoint(func, x) + else: + x = func(x) + x = self.norm(x.float()) + + if self.training and self.mask_ratio > 0: + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = x_ + return torch.reshape(x, (B, self.num_patches * self.embed_dim)) + + def forward(self, x): + x = self.forward_features(x) + x = self.feature(x) + return x diff --git a/cvlface/research/recognition/code/run_v1/models/swin/__init__.py b/cvlface/research/recognition/code/run_v1/models/swin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07d8acc5da71645c09ea45d27066ef6a92a6bd65 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin/__init__.py @@ -0,0 +1,94 @@ +from ..base import BaseModel +from torchvision import transforms +from .swin.names import swin_s, swin_v2_b, swin_v2_s + +class SWINModel(BaseModel): + + + """ + A modified version of the Swin Transformer, tailored for facial recognition with an input dimension of 112x112 pixels. + + This model inherits from the BaseModel class and utilizes the smaller variants of the Swin Transformer architecture, + such as `swin_v2_b` and `swin_v2_s`. + + The Swin Transformer uses shifted windows to bring greater efficiency and flexibility to the transformer architecture, + allowing for attention mechanisms that adapt to the hierarchical nature of visual data. + + References: + - Swin Transformer paper: https://arxiv.org/abs/2103.14030 + ``` + @inproceedings{liu2021swin, + title={Swin transformer: Hierarchical vision transformer using shifted windows}, + author={Liu, Ze and Lin, Yutong and Cao, Yue and Hu, Han and Wei, Yixuan and Zhang, Zheng and Lin, Stephen and Guo, Baining}, + booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, + pages={10012--10022}, + year={2021} + } + ``` + """ + + def __init__(self, net, config): + super(SWINModel, self).__init__(config) + self.net = net + + + @classmethod + def from_config(cls, config): + + if config.name == 'swin_s': + net = swin_s() + elif config.name == 'small': + net = swin_v2_s() + elif config.name == 'base': + net = swin_v2_b() + else: + raise NotImplementedError + + model = cls(net, config) + model.eval() + return model + + def forward(self, x): + if self.input_color_flip: + x = x.flip(1) + return self.net(x) + + def make_train_transform(self): + image_size = tuple(self.config.input_size[1:3]) + transform = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + def make_test_transform(self): + image_size = tuple(self.config.input_size[1:3]) + transform = transforms.Compose([ + transforms.Resize(image_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + +class TimmSWINModel(SWINModel): + @classmethod + def from_config(cls, config): + import timm + net = timm.create_model( + config.timm_name, + pretrained=False, + img_size=config.input_size[1], + num_classes=config.output_dim, + ) + model = cls(net, config) + model.eval() + return model + +def load_model(model_config): + if model_config.name == 'timm_swin_s': + model = TimmSWINModel.from_config(model_config) + else: + model = SWINModel.from_config(model_config) + return model diff --git a/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_base.yaml b/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f604e8833cdbab6a9cb85fae8fee26f8ac7f8be4 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_base.yaml @@ -0,0 +1,7 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'base' +output_dim: 512 +start_from: '' +freeze: False +mask_ratio: 0.0 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_small.yaml b/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7d024cd34519f4ddc38429ba65e9080509f1a9b7 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_small.yaml @@ -0,0 +1,7 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'small' +output_dim: 512 +start_from: '' +freeze: False +mask_ratio: 0.0 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_swin_s_pretrained.yaml b/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_swin_s_pretrained.yaml new file mode 100644 index 0000000000000000000000000000000000000000..030665de3ade067411e52e63bbd37e109702a734 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_swin_s_pretrained.yaml @@ -0,0 +1,8 @@ +input_size: [3, 224, 224] +color_space: 'RGB' +name: 'timm_swin_s' +timm_name: 'swin_small_patch4_window7_224.ms_in22k_ft_in1k' +output_dim: 512 +start_from: ${oc.env:SWIN_S_PRETRAINED,/root/Adaface/cvlface/cvlface/pretrained_models/model.safetensors} +freeze: False +mask_ratio: 0.0 diff --git a/cvlface/research/recognition/code/run_v1/models/swin/swin/__init__.py b/cvlface/research/recognition/code/run_v1/models/swin/swin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3043d7e1259c5c96d15000af812e2579510d081b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin/swin/__init__.py @@ -0,0 +1,38 @@ +from .names import swin_v2_b, swin_v2_s +import torch + + +if __name__ == '__main__': + + model = swin_v2_b() + model.eval() + inputs_shape = (1, 3, 112, 112) + x = torch.randn(*inputs_shape) + y = model(x) + print(x.shape) + print(y.shape) + + model.eval() + from fvcore.nn import flop_count + import numpy as np + res = flop_count(model, inputs=torch.randn(inputs_shape), supported_ops={}) + fvcore_flop = np.array(list(res[0].values())).sum() + print(f'FLOPs: {fvcore_flop:.2f}', 'G') + print('Num Params: ', sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 'M') + + + model = swin_v2_s() + model.eval() + inputs_shape = (1, 3, 112, 112) + x = torch.randn(*inputs_shape) + y = model(x) + print(x.shape) + print(y.shape) + + model.eval() + from fvcore.nn import flop_count + import numpy as np + res = flop_count(model, inputs=torch.randn(inputs_shape), supported_ops={}) + fvcore_flop = np.array(list(res[0].values())).sum() + print(f'FLOPs: {fvcore_flop:.2f}', 'G') + print('Num Params: ', sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 'M') \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/swin/swin/model.py b/cvlface/research/recognition/code/run_v1/models/swin/swin/model.py new file mode 100644 index 0000000000000000000000000000000000000000..13d135b093dd19cefddb16614f8fb983bc137bcd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin/swin/model.py @@ -0,0 +1,179 @@ +from functools import partial +from typing import Any, Callable, List, Optional +from torch import nn, Tensor +from torchvision.ops.misc import MLP, Permute +from torchvision.utils import _log_api_usage_once +from .modules_v1 import PatchMerging, ShiftedWindowAttention, SwinTransformerBlock +from .modules_v2 import PatchMergingV2, ShiftedWindowAttentionV2, SwinTransformerBlockV2 +import torch + +class SwinTransformer(nn.Module): + """ + Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using + Shifted Windows" `_ paper. + Args: + patch_size (List[int]): Patch size. + embed_dim (int): Patch embedding dimension. + depths (List(int)): Depth of each Swin Transformer layer. + num_heads (List(int)): Number of attention heads in different layers. + window_size (List[int]): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. + num_classes (int): Number of classes for classification head. Default: 1000. + block (nn.Module, optional): SwinTransformer Block. Default: None. + norm_layer (nn.Module, optional): Normalization layer. Default: None. + downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging. + """ + + def __init__( + self, + img_size: List[int], + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.1, + num_classes: int = 1000, + norm_layer: Optional[Callable[..., nn.Module]] = None, + block: Optional[Callable[..., nn.Module]] = None, + downsample_layer: Callable[..., nn.Module] = PatchMerging, + ): + super().__init__() + _log_api_usage_once(self) + self.num_classes = num_classes + + if block is None: + block = SwinTransformerBlock + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-5) + + layers: List[nn.Module] = [] + # split image into non-overlapping patches + layers.append( + nn.Sequential( + nn.Conv2d( + 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) + ), + Permute([0, 2, 3, 1]), + norm_layer(embed_dim), + ) + ) + + total_stage_blocks = sum(depths) + stage_block_id = 0 + # build SwinTransformer blocks + for i_stage in range(len(depths)): + dim = embed_dim * 2**i_stage + for i_layer in range(depths[i_stage]): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) + layers.append( + block( + dim, + num_heads[i_stage], + window_size=window_size, + shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + ) + ) + stage_block_id += 1 + # add patch merging layer + if i_stage < (len(depths) - 1): + layers.append(downsample_layer(dim, norm_layer)) + self.features = nn.ModuleList(layers) + + num_features = embed_dim * 2 ** (len(depths) - 1) + self.norm = norm_layer(num_features) + self.flatten = nn.Flatten(1) + + # features head + num_patches = self.forward_features(torch.rand(1, 3, img_size[0], img_size[1])).shape[1:3] + num_patches = num_patches[0] * num_patches[1] + self.embed_dim = embed_dim # basis feature dim for building blocks + self.num_features = num_features # C for final intermediate feature + self.num_patches = num_patches # N for output feature + self.num_classes = num_classes # C for return feature after flattening and mapping + self.feature = nn.Sequential( + nn.Linear(in_features=num_features * num_patches, out_features=num_features, bias=False), + nn.BatchNorm1d(num_features=num_features, eps=2e-5), + nn.Linear(in_features=num_features, out_features=num_classes, bias=False), + nn.BatchNorm1d(num_features=num_classes, eps=2e-5) + ) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward_features(self, x): + for module in self.features: + x = module(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.norm(x) + x = self.flatten(x) + x = self.feature(x) + return x + + +def _swin_transformer( + img_size: List[int], + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + stochastic_depth_prob: float, + num_classes=512, + **kwargs: Any +) -> SwinTransformer: + + model = SwinTransformer( + img_size=img_size, + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + stochastic_depth_prob=stochastic_depth_prob, + num_classes=num_classes, + **kwargs + ) + + # x = torch.randn(1, 3, 112, 112) + # x = model.forward_features(x) + # + # # debug + # from models.vit.vit import VisionTransformer + # vit = VisionTransformer(img_size=112, patch_size=8, num_classes=512, embed_dim=512, depth=12, + # mlp_ratio=5, num_heads=8, drop_path_rate=0.1, norm_layer="ln", + # mask_ratio=0) + # + # x = torch.randn(1, 3, 112, 112) + # x = vit.patch_embed(x) + # x = x + vit.pos_embed + # x = vit.pos_drop(x) + # for func in vit.blocks: + # if vit.using_checkpoint and vit.training: + # from torch.utils.checkpoint import checkpoint + # x = checkpoint(func, x) + # else: + # x = func(x) + # x = vit.norm(x.float()) + # x.shape + + return model + diff --git a/cvlface/research/recognition/code/run_v1/models/swin/swin/modules_v1.py b/cvlface/research/recognition/code/run_v1/models/swin/swin/modules_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..dcde55cf646da5ec058ca3ed9565d20e43a1405e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin/swin/modules_v1.py @@ -0,0 +1,178 @@ +from typing import Any, Callable, List, Optional + +import torch +from torchvision.ops.misc import MLP, Permute +from torchvision.ops.stochastic_depth import StochasticDepth + +from torch import nn, Tensor +from torchvision.utils import _log_api_usage_once +from .ops import _patch_merging_pad, shifted_window_attention, _get_relative_position_bias + + +class PatchMerging(nn.Module): + """Patch Merging Layer. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + _log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) + x = self.norm(x) + x = self.reduction(x) # ... H/2 W/2 2*C + return x + + + + + +class ShiftedWindowAttention(nn.Module): + """ + See :func:`shifted_window_attention`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ): + super().__init__() + if len(window_size) != 2 or len(shift_size) != 2: + raise ValueError("window_size and shift_size must be of length 2") + self.window_size = window_size + self.shift_size = shift_size + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + self.define_relative_position_bias_table() + self.define_relative_position_index() + + def define_relative_position_bias_table(self): + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def define_relative_position_index(self): + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + def get_relative_position_bias(self) -> torch.Tensor: + return _get_relative_position_bias( + self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + relative_position_bias = self.get_relative_position_bias() + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + training=self.training, + ) + + + +class SwinTransformerBlock(nn.Module): + """ + Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, + ): + super().__init__() + _log_api_usage_once(self) + + self.norm1 = norm_layer(dim) + self.attn = attn_layer( + dim, + window_size, + shift_size, + num_heads, + attention_dropout=attention_dropout, + dropout=dropout, + ) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.norm2 = norm_layer(dim) + self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) + + for m in self.mlp.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x: Tensor): + x = x + self.stochastic_depth(self.attn(self.norm1(x))) + x = x + self.stochastic_depth(self.mlp(self.norm2(x))) + return x diff --git a/cvlface/research/recognition/code/run_v1/models/swin/swin/modules_v2.py b/cvlface/research/recognition/code/run_v1/models/swin/swin/modules_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..4c3a8016d0a99a333557181f10c5adb1ab0c2f4f --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin/swin/modules_v2.py @@ -0,0 +1,173 @@ +from typing import Any, Callable, List, Optional +import torch +from torch import nn, Tensor + + +from torchvision.utils import _log_api_usage_once +from .ops import _patch_merging_pad, shifted_window_attention, _get_relative_position_bias +from .modules_v1 import ShiftedWindowAttention, SwinTransformerBlock + + +class PatchMergingV2(nn.Module): + """Patch Merging Layer for Swin Transformer V2. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + _log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) # difference + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) + x = self.reduction(x) # ... H/2 W/2 2*C + x = self.norm(x) + return x + + + +class ShiftedWindowAttentionV2(ShiftedWindowAttention): + """ + See :func:`shifted_window_attention_v2`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ): + super().__init__( + dim, + window_size, + shift_size, + num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attention_dropout=attention_dropout, + dropout=dropout, + ) + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) + ) + if qkv_bias: + length = self.qkv.bias.numel() // 3 + self.qkv.bias[length : 2 * length].data.zero_() + + def define_relative_position_bias_table(self): + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = ( + torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0 + ) + self.register_buffer("relative_coords_table", relative_coords_table) + + def get_relative_position_bias(self) -> torch.Tensor: + relative_position_bias = _get_relative_position_bias( + self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads), + self.relative_position_index, # type: ignore[arg-type] + self.window_size, + ) + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + return relative_position_bias + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + relative_position_bias = self.get_relative_position_bias() + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + logit_scale=self.logit_scale, + training=self.training, + ) + + + +class SwinTransformerBlockV2(SwinTransformerBlock): + """ + Swin Transformer V2 Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2. + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2, + ): + super().__init__( + dim, + num_heads, + window_size, + shift_size, + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=stochastic_depth_prob, + norm_layer=norm_layer, + attn_layer=attn_layer, + ) + + def forward(self, x: Tensor): + # Here is the difference, we apply norm after the attention in V2. + # In V1 we applied norm before the attention. + x = x + self.stochastic_depth(self.norm1(self.attn(x))) + x = x + self.stochastic_depth(self.norm2(self.mlp(x))) + return x diff --git a/cvlface/research/recognition/code/run_v1/models/swin/swin/names.py b/cvlface/research/recognition/code/run_v1/models/swin/swin/names.py new file mode 100644 index 0000000000000000000000000000000000000000..c7fb265fe48c4a72059b4cdf4da4e371ac06c42e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin/swin/names.py @@ -0,0 +1,34 @@ +from .model import _swin_transformer +from .modules_v2 import SwinTransformerBlockV2, PatchMergingV2 + +def swin_t(): + return _swin_transformer(img_size=[112,112], patch_size=[3, 3], embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=[5, 5], stochastic_depth_prob=0.2, num_classes=512, ) + + +def swin_s(): + return _swin_transformer(img_size=[112,112], patch_size=[3, 3], embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], + window_size=[5, 5], stochastic_depth_prob=0.3, num_classes=512, ) + + +def swin_b(): + return _swin_transformer(img_size=[112,112], patch_size=[3, 3], embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], + window_size=[5, 5], stochastic_depth_prob=0.5, num_classes=512, ) + + +def swin_v2_t(): + return _swin_transformer(img_size=[112,112], patch_size=[3, 3], embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=[5, 5], stochastic_depth_prob=0.2, num_classes=512, + block=SwinTransformerBlockV2, downsample_layer=PatchMergingV2, ) + + +def swin_v2_s(): + return _swin_transformer(img_size=[112,112], patch_size=[4, 4], embed_dim=128, depths=[4, 6, 8], num_heads=[4, 8, 16], + window_size=[7, 7], stochastic_depth_prob=0.5, num_classes=512, + block=SwinTransformerBlockV2, downsample_layer=PatchMergingV2, ) + + +def swin_v2_b(): + return _swin_transformer(img_size=[112,112], patch_size=[4, 4], embed_dim=256, depths=[4, 6, 8], num_heads=[4, 8, 16], + window_size=[7, 7], stochastic_depth_prob=0.5, num_classes=512, + block=SwinTransformerBlockV2, downsample_layer=PatchMergingV2, ) diff --git a/cvlface/research/recognition/code/run_v1/models/swin/swin/ops.py b/cvlface/research/recognition/code/run_v1/models/swin/swin/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5caf294ceef3459c7c81310f17cf1363c81c8c --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin/swin/ops.py @@ -0,0 +1,151 @@ +import math +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor: + H, W, _ = x.shape[-3:] + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C + return x +torch.fx.wrap("_patch_merging_pad") + + +def _get_relative_position_bias( + relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int] +) -> torch.Tensor: + N = window_size[0] * window_size[1] + relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index] + relative_position_bias = relative_position_bias.view(N, N, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + return relative_position_bias + + +torch.fx.wrap("_get_relative_position_bias") + + + +def shifted_window_attention( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: List[int], + num_heads: int, + shift_size: List[int], + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, + logit_scale: Optional[torch.Tensor] = None, + training: bool = True, +) -> Tensor: + """ + Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. + qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. + proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. + relative_position_bias (Tensor): The learned relative position bias added to attention. + window_size (List[int]): Window size. + num_heads (int): Number of attention heads. + shift_size (List[int]): Shift size for shifted window attention. + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. + qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. + proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. + training (bool, optional): Training flag used by the dropout parameters. Default: True. + Returns: + Tensor[N, H, W, C]: The output tensor after shifted window attention. + """ + B, H, W, C = input.shape + # pad feature maps to multiples of window size + pad_r = (window_size[1] - W % window_size[1]) % window_size[1] + pad_b = (window_size[0] - H % window_size[0]) % window_size[0] + x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_H, pad_W, _ = x.shape + + shift_size = shift_size.copy() + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 + + # cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + + # partition windows + num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1]) + x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C + + # multi-head attention + if logit_scale is not None and qkv_bias is not None: + qkv_bias = qkv_bias.clone() + length = qkv_bias.numel() // 3 + qkv_bias[length : 2 * length].zero_() + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + if logit_scale is not None: + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp() + attn = attn * logit_scale + else: + q = q * (C // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) + # add relative position bias + attn = attn + relative_position_bias + + if sum(shift_size) > 0: + # generate attention mask + attn_mask = x.new_zeros((pad_H, pad_W)) + h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None)) + w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None)) + count = 0 + for h in h_slices: + for w in w_slices: + attn_mask[h[0] : h[1], w[0] : w[1]] = count + count += 1 + attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1]) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout, training=training) + + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout, training=training) + + # reverse windows + x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) + + # reverse cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + + # unpad features + x = x[:, :H, :W, :].contiguous() + return x + + +torch.fx.wrap("shifted_window_attention") + diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/KPRPE/dist.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/KPRPE/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..d747a2de4c6c5e3da4ff574fa11893d0c02fb6c5 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/KPRPE/dist.py @@ -0,0 +1,163 @@ +import torch +import os +import math + + +@torch.no_grad() +def piecewise_index(relative_position, alpha, beta, gamma, dtype): + """piecewise index function defined in Eq. (18) in our paper. + + Parameters + ---------- + relative_position: torch.Tensor, dtype: long or float + The shape of `relative_position` is (L, L). + alpha, beta, gamma: float + The coefficients of piecewise index function. + + Returns + ------- + idx: torch.Tensor, dtype: long + A tensor indexing relative distances to corresponding encodings. + `idx` is a long tensor, whose shape is (L, L) and each element is in [-beta, beta]. + """ + rp_abs = relative_position.abs() + mask = rp_abs <= alpha + not_mask = ~mask + rp_out = relative_position[not_mask] + rp_abs_out = rp_abs[not_mask] + y_out = (torch.sign(rp_out) * (alpha + + torch.log(rp_abs_out / alpha) / + math.log(gamma / alpha) * + (beta - alpha)).round().clip(max=beta)).to(dtype) + + idx = relative_position.clone() + if idx.dtype in [torch.float32, torch.float64]: + # round(x) when |x| <= alpha + idx = idx.round().to(dtype) + + # assign the value when |x| > alpha + idx[not_mask] = y_out + return idx + + +@torch.no_grad() +def _rp_2d_euclidean(diff, **kwargs): + """2D RPE with Euclidean method. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + dis = diff.square().sum(2).float().sqrt().round() + return piecewise_index(dis, **kwargs) + + +@torch.no_grad() +def _rp_2d_quant(diff, **kwargs): + """2D RPE with Quantization method. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + + dis = diff.square().sum(2) + return piecewise_index(dis, **kwargs) + + +@torch.no_grad() +def _rp_2d_product(diff, **kwargs): + """2D RPE with Product method. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + # convert beta to an integer since beta is a float number. + beta_int = int(kwargs['beta']) + S = 2 * beta_int + 1 + # the output of piecewise index function is in [-beta_int, beta_int] + r = piecewise_index(diff[:, :, 0], **kwargs) + \ + beta_int # [0, 2 * beta_int] + c = piecewise_index(diff[:, :, 1], **kwargs) + \ + beta_int # [0, 2 * beta_int] + + pid = r * S + c + + return pid + + +@torch.no_grad() +def _rp_2d_cross_rows(diff, **kwargs): + """2D RPE with Cross for rows. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + dis = diff[:, :, 0] + return piecewise_index(dis, **kwargs) + + +@torch.no_grad() +def _rp_2d_cross_cols(diff, **kwargs): + """2D RPE with Cross for columns. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + + dis = diff[:, :, 1] + return piecewise_index(dis, **kwargs) + diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/KPRPE/kprpe_shared.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/KPRPE/kprpe_shared.py new file mode 100644 index 0000000000000000000000000000000000000000..a29eac5f03945d853c7a426d4f3806b9d6143f16 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/KPRPE/kprpe_shared.py @@ -0,0 +1,750 @@ +from easydict import EasyDict as edict +import math +import torch +import torch.nn as nn +from .dist import _rp_2d_cross_cols, _rp_2d_cross_rows, _rp_2d_euclidean, _rp_2d_product, _rp_2d_quant + +try: + from ..rpe_ops.rpe_index import RPEIndexFunction +except Exception as e: + print('Failed to import cuda/cpp RPEIndexFunction') + RPEIndexFunction = None + + + +def get_absolute_positions(height, width, dtype, device): + '''Get absolute positions + + Take height = 3, width = 3 as an example: + rows: cols: + 1 1 1 1 2 3 + 2 2 2 1 2 3 + 3 3 3 1 2 3 + + return stack([rows, cols], 2) + + Parameters + ---------- + height, width: int + The height and width of feature map + dtype: torch.dtype + the data type of returned value + device: torch.device + the device of returned value + + Return + ------ + 2D absolute positions: torch.Tensor + The shape is (height, width, 2), + where 2 represents a 2D position (row, col). + ''' + rows = torch.arange(height, dtype=dtype, device=device).view( + height, 1).repeat(1, width) + cols = torch.arange(width, dtype=dtype, device=device).view( + 1, width).repeat(height, 1) + return torch.stack([rows, cols], 2) + + +class METHOD: + """define iRPE method IDs + We divide the implementation of CROSS into CROSS_ROWS and CROSS_COLS. + + """ + EUCLIDEAN = 0 + QUANT = 1 + PRODUCT = 3 + CROSS = 4 + CROSS_ROWS = 41 + CROSS_COLS = 42 + + +# Define a mapping from METHOD_ID to Python function +_METHOD_FUNC = { + METHOD.EUCLIDEAN: _rp_2d_euclidean, + METHOD.QUANT: _rp_2d_quant, + METHOD.PRODUCT: _rp_2d_product, + METHOD.CROSS_ROWS: _rp_2d_cross_rows, + METHOD.CROSS_COLS: _rp_2d_cross_cols, +} + + +def get_num_buckets(method, alpha, beta, gamma): + """ Get number of buckets storing relative position encoding. + The buckets does not contain `skip` token. + + Parameters + ---------- + method: METHOD + The method ID of image relative position encoding. + alpha, beta, gamma: float + The coefficients of piecewise index function. + + Returns + ------- + num_buckets: int + The number of buckets storing relative position encoding. + """ + beta_int = int(beta) + if method == METHOD.PRODUCT: + # IDs in [0, (2 * beta_int + 1)^2) for Product method + num_buckets = (2 * beta_int + 1) ** 2 + else: + # IDs in [-beta_int, beta_int] except of Product method + num_buckets = 2 * beta_int + 1 + return num_buckets + + +# (method, alpha, beta, gamma) -> (bucket_ids, num_buckets, height, width) +BUCKET_IDS_BUF = dict() + + +@torch.no_grad() +def get_bucket_ids_2d_without_skip(method, height, width, + alpha, beta, gamma, + dtype=torch.long, device=torch.device('cpu')): + """Get bucket IDs for image relative position encodings without skip token + + Parameters + ---------- + method: METHOD + The method ID of image relative position encoding. + height, width: int + The height and width of the feature map. + The sequence length is equal to `height * width`. + alpha, beta, gamma: float + The coefficients of piecewise index function. + dtype: torch.dtype + the data type of returned `bucket_ids` + device: torch.device + the device of returned `bucket_ids` + + Returns + ------- + bucket_ids: torch.Tensor, dtype: long + The bucket IDs which index to corresponding encodings. + The shape of `bucket_ids` is (skip + L, skip + L), + where `L = height * wdith`. + num_buckets: int + The number of buckets including `skip` token. + L: int + The sequence length + """ + + key = (method, alpha, beta, gamma, dtype, device) + value = BUCKET_IDS_BUF.get(key, None) + if value is None or value[-2] < height or value[-1] < width: + if value is None: + max_height, max_width = height, width + else: + max_height = max(value[-2], height) + max_width = max(value[-1], width) + # relative position encoding mapping function + func = _METHOD_FUNC.get(method, None) + if func is None: + raise NotImplementedError( + f"[Error] The method ID {method} does not exist.") + pos = get_absolute_positions(max_height, max_width, dtype, device) + + # compute the offset of a pair of 2D relative positions + max_L = max_height * max_width + pos1 = pos.view((max_L, 1, 2)) + pos2 = pos.view((1, max_L, 2)) + # diff: shape of (L, L, 2) + diff = pos1 - pos2 + + # bucket_ids: shape of (L, L) + bucket_ids = func(diff, alpha=alpha, beta=beta, + gamma=gamma, dtype=dtype) + beta_int = int(beta) + if method != METHOD.PRODUCT: + bucket_ids += beta_int + bucket_ids = bucket_ids.view( + max_height, max_width, max_height, max_width) + + num_buckets = get_num_buckets(method, alpha, beta, gamma) + value = (bucket_ids, num_buckets, height, width) + BUCKET_IDS_BUF[key] = value + L = height * width + bucket_ids = value[0][:height, :width, :height, :width].reshape(L, L) + num_buckets = value[1] + + return bucket_ids, num_buckets, L + + +@torch.no_grad() +def get_bucket_ids_2d(method, height, width, + skip, alpha, beta, gamma, + dtype=torch.long, device=torch.device('cpu')): + """Get bucket IDs for image relative position encodings + + Parameters + ---------- + method: METHOD + The method ID of image relative position encoding. + height, width: int + The height and width of the feature map. + The sequence length is equal to `height * width`. + skip: int + The number of skip token before spatial tokens. + When skip is 0, no classification token. + When skip is 1, there is a classification token before spatial tokens. + When skip > 1, there are `skip` extra tokens before spatial tokens. + alpha, beta, gamma: float + The coefficients of piecewise index function. + dtype: torch.dtype + the data type of returned `bucket_ids` + device: torch.device + the device of returned `bucket_ids` + + Returns + ------- + bucket_ids: torch.Tensor, dtype: long + The bucket IDs which index to corresponding encodings. + The shape of `bucket_ids` is (skip + L, skip + L), + where `L = height * wdith`. + num_buckets: int + The number of buckets including `skip` token. + """ + bucket_ids, num_buckets, L = get_bucket_ids_2d_without_skip(method, height, width, + alpha, beta, gamma, + dtype, device) + + # add an extra encoding (id = num_buckets) for the classification token + if skip > 0: + new_bids = bucket_ids.new_empty(size=(skip + L, skip + L)) + + # if extra token exists, we add extra bucket as its encoding. + extra_bucket_id = num_buckets + num_buckets += 1 + + new_bids[:skip] = extra_bucket_id + new_bids[:, :skip] = extra_bucket_id + new_bids[skip:, skip:] = bucket_ids + + bucket_ids = new_bids + bucket_ids = bucket_ids.contiguous() + return bucket_ids, num_buckets + + +class KPRPE(nn.Module): + """The implementation of image relative position encoding (excluding Cross method). + + Parameters + ---------- + head_dim: int + The dimension for each head. + num_heads: int + The number of parallel attention heads. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + method: METHOD + The method ID of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + transposed: bool + Whether to transpose the input feature. + For KPRPE on queries or keys, transposed should be `True`. + For KPRPE on values, transposed should be `False`. + num_buckets: int + The number of buckets, which store encodings. + initializer: None or an inplace function + [Optional] The initializer to `lookup_table`. + Initalize `lookup_table` as zero by default. + rpe_config: RPEConfig + The config generated by the function `get_single_rpe_config`. + """ + # a buffer to store bucket index + # (key, rp_bucket, _ctx_rp_bucket_flatten) + _rp_bucket_buf = (None, None, None) + + def __init__(self, head_dim, num_heads=8, + mode=None, method=None, + transposed=True, num_buckets=None, + initializer=None, rpe_config=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + # relative position + assert mode in [None, 'bias', 'contextual'] + self.mode = mode + + assert method is not None, 'method should be a METHOD ID rather than None' + self.method = method + + self.transposed = transposed + self.num_buckets = num_buckets + + if initializer is None: + def initializer(x): return None + self.initializer = initializer + + self.reset_parameters() + + self.rpe_config = rpe_config + + @torch.no_grad() + def reset_parameters(self): + # initialize the parameters of KPRPE + if self.transposed: + if self.mode == 'bias': + self.lookup_table_bias = nn.Parameter( + torch.zeros(self.num_heads, self.num_buckets)) + self.initializer(self.lookup_table_bias) + elif self.mode == 'contextual': + # shared and initialized from vit + pass + else: + if self.mode == 'bias': + raise NotImplementedError( + "[Error] Bias non-transposed RPE does not exist.") + elif self.mode == 'contextual': + raise ValueError('may not work, check') + + def forward(self, x, height=None, width=None): + """forward function for KPRPE. + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + + Returns + ------- + rpe_encoding: torch.Tensor + image Relative Position Encoding, + whose shape is (B, H, L, L) + """ + rp_bucket, self._ctx_rp_bucket_flatten = \ + self._get_rp_bucket(x, height=height, width=width) + + if self.transposed: + return self.forward_rpe_transpose(x, rp_bucket) + return self.forward_rpe_no_transpose(x, rp_bucket) + + def _get_rp_bucket(self, x, height=None, width=None): + """Get relative position encoding buckets IDs corresponding the input shape + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + height: int or None + [Optional] The height of the input + If not defined, height = floor(sqrt(L)) + width: int or None + [Optional] The width of the input + If not defined, width = floor(sqrt(L)) + + Returns + ------- + rp_bucket: torch.Tensor + relative position encoding buckets IDs + The shape is (L, L) + _ctx_rp_bucket_flatten: torch.Tensor or None + It is a private tensor for efficient computation. + """ + B, H, L, D = x.shape + device = x.device + if height is None: + E = int(math.sqrt(L)) + height = width = E + key = (height, width, device) + # use buffer if the spatial shape and device is not changable. + + if self._rp_bucket_buf[0] == key: + return self._rp_bucket_buf[1:3] + + skip = L - height * width + config = self.rpe_config + if RPEIndexFunction is not None and self.mode == 'contextual' and self.transposed: + # RPEIndexFunction uses int32 index. + dtype = torch.int32 + else: + dtype = torch.long + rp_bucket, num_buckets = get_bucket_ids_2d(method=self.method, + height=height, width=width, + skip=skip, alpha=config.alpha, + beta=config.beta, gamma=config.gamma, + dtype=dtype, device=device) + assert num_buckets == self.num_buckets + + # transposed contextual + _ctx_rp_bucket_flatten = None + if self.mode == 'contextual' and self.transposed: + if RPEIndexFunction is None: + offset = torch.arange(0, L * self.num_buckets, self.num_buckets, + dtype=rp_bucket.dtype, device=rp_bucket.device).view(-1, 1) + _ctx_rp_bucket_flatten = (rp_bucket + offset).flatten() + self._rp_bucket_buf = (key, rp_bucket, _ctx_rp_bucket_flatten) + return rp_bucket, _ctx_rp_bucket_flatten + + def forward_rpe_transpose(self, x, rp_bucket): + """Forward function for KPRPE (transposed version) + This version is utilized by RPE on Query or Key + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + rp_bucket: torch.Tensor + relative position encoding buckets IDs + The shape is (L, L) + + Weights + ------- + lookup_table_bias: torch.Tensor + The shape is (H or 1, num_buckets) + + or + + lookup_table_weight: torch.Tensor + The shape is (H or 1, head_dim, num_buckets) + + Returns + ------- + output: torch.Tensor + Relative position encoding on queries or keys. + The shape is (B or 1, H, L, L), + where D is the output dimension for each head. + """ + + B = len(x) # batch_size + L_query, L_key = rp_bucket.shape + if self.mode == 'bias': + return self.lookup_table_bias[:, rp_bucket.flatten()]. \ + view(1, self.num_heads, L_query, L_key) + + elif self.mode == 'contextual': + """ + ret[b, h, i, j] = lookup_table_weight[b, h, i, rp_bucket[i, j]] + + ret[b, h, i * L_key + j] = \ + lookup_table[b, h, i * num_buckets + rp_buckets[i, j]] + + computational cost + ------------------ + matmul: B * H * L_query * head_dim * num_buckets + index: L_query + L_query * L_key + B * H * L_query * L_key + total: O(B * H * L_query * (head_dim * num_buckets + L_key)) + """ + if RPEIndexFunction is not None: + return RPEIndexFunction.apply(x, rp_bucket) + else: + return x.flatten(2)[:, :, self._ctx_rp_bucket_flatten]. \ + view(B, -1, L_query, L_key) + + def forward_rpe_no_transpose(self, x, rp_bucket): + """Forward function for KPRPE (non-transposed version) + This version is utilized by RPE on Value. + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + rp_bucket: torch.Tensor + relative position encoding buckets IDs + The shape is (L, L) + + Weights + ------- + lookup_table_weight: torch.Tensor + The shape is (H or 1, num_buckets, head_dim) + + Returns + ------- + output: torch.Tensor + Relative position encoding on values. + The shape is (B, H, L, D), + where D is the output dimension for each head. + """ + + B = len(x) # batch_size + L_query, L_key = rp_bucket.shape + assert self.mode == 'contextual', "Only support contextual \ +version in non-transposed version" + weight = self.lookup_table_weight[:, rp_bucket.flatten()]. \ + view(self.num_heads, L_query, L_key, self.head_dim) + # (H, L_query, B, L_key) @ (H, L_query, L_key, D) = (H, L_query, B, D) + # -> (B, H, L_query, D) + return torch.matmul(x.permute(1, 2, 0, 3), weight).permute(2, 0, 1, 3) + + def __repr__(self): + return 'KPRPE(head_dim={rpe.head_dim}, num_heads={rpe.num_heads}, \ +mode="{rpe.mode}", method={rpe.method}, transposed={rpe.transposed}, \ +num_buckets={rpe.num_buckets}, initializer={rpe.initializer}, \ +rpe_config={rpe.rpe_config})'.format(rpe=self) + + +class KPRPE_Cross(nn.Module): + """The implementation of image relative position encoding (specific for Cross method). + + Parameters + ---------- + head_dim: int + The dimension for each head. + num_heads: int + The number of parallel attention heads. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + method: METHOD + The method ID of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + transposed: bool + Whether to transpose the input feature. + For KPRPE on queries or keys, transposed should be `True`. + For KPRPE on values, transposed should be `False`. + num_buckets: int + The number of buckets, which store encodings. + initializer: None or an inplace function + [Optional] The initializer to `lookup_table`. + Initalize `lookup_table` as zero by default. + rpe_config: RPEConfig + The config generated by the function `get_single_rpe_config`. + """ + + def __init__(self, method, **kwargs): + super().__init__() + assert method == METHOD.CROSS + self.rp_rows = KPRPE(**kwargs, method=METHOD.CROSS_ROWS) + self.rp_cols = KPRPE(**kwargs, method=METHOD.CROSS_COLS) + + def forward(self, x, height=None, width=None): + """forward function for KPRPE. + Compute encoding on horizontal and vertical directions separately, + then summarize them. + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + height: int or None + [Optional] The height of the input + If not defined, height = floor(sqrt(L)) + width: int or None + [Optional] The width of the input + If not defined, width = floor(sqrt(L)) + + Returns + ------- + rpe_encoding: torch.Tensor + Image Relative Position Encoding, + whose shape is (B, H, L, L) + """ + + rows = self.rp_rows(x, height=height, width=width) + cols = self.rp_cols(x, height=height, width=width) + return rows + cols + + def __repr__(self): + return 'KPRPE_Cross(head_dim={rpe.head_dim}, \ +num_heads={rpe.num_heads}, mode="{rpe.mode}", method={rpe.method}, \ +transposed={rpe.transposed}, num_buckets={rpe.num_buckets}, \ +initializer={rpe.initializer}, \ +rpe_config={rpe.rpe_config})'.format(rpe=self.rp_rows) + + +def get_single_rpe_config(ratio=1.9, + method=METHOD.PRODUCT, + mode='contextual', + shared_head=True, + skip=0): + """Get the config of single relative position encoding + + Parameters + ---------- + ratio: float + The ratio to control the number of buckets. + method: METHOD + The method ID of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + shared_head: bool + Whether to share weight among different heads. + skip: int + The number of skip token before spatial tokens. + When skip is 0, no classification token. + When skip is 1, there is a classification token before spatial tokens. + When skip > 1, there are `skip` extra tokens before spatial tokens. + + Returns + ------- + config: RPEConfig + The config of single relative position encoding. + """ + config = edict() + # whether to share encodings across different heads + config.shared_head = shared_head + # mode: None, bias, contextual + config.mode = mode + # method: None, Bias, Quant, Cross, Product + config.method = method + # the coefficients of piecewise index function + config.alpha = 1 * ratio + config.beta = 2 * ratio + config.gamma = 8 * ratio + + # set the number of buckets + config.num_buckets = get_num_buckets(method, + config.alpha, + config.beta, + config.gamma) + # add extra bucket for `skip` token (e.g. class token) + if skip > 0: + config.num_buckets += 1 + return config + + +def get_rpe_config(ratio=1.9, + method=METHOD.PRODUCT, + mode='contextual', + shared_head=True, + skip=0, + rpe_on='k'): + """Get the config of relative position encoding on queries, keys and values + + Parameters + ---------- + ratio: float + The ratio to control the number of buckets. + method: METHOD or str + The method ID (or name) of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + shared_head: bool + Whether to share weight among different heads. + skip: int + The number of skip token before spatial tokens. + When skip is 0, no classification token. + When skip is 1, there is a classification token before spatial tokens. + When skip > 1, there are `skip` extra tokens before spatial tokens. + rpe_on: str + Where RPE attaches. + "q": RPE on queries + "k": RPE on keys + "v": RPE on values + "qk": RPE on queries and keys + "qkv": RPE on queries, keys and values + + Returns + ------- + config: RPEConfigs + config.rpe_q: the config of relative position encoding on queries + config.rpe_k: the config of relative position encoding on keys + config.rpe_v: the config of relative position encoding on values + """ + + # alias + if isinstance(method, str): + method_mapping = dict( + euc=METHOD.EUCLIDEAN, + quant=METHOD.QUANT, + cross=METHOD.CROSS, + product=METHOD.PRODUCT, + ) + method = method_mapping[method.lower()] + if mode == 'ctx': + mode = 'contextual' + config = edict() + # relative position encoding on queries, keys and values + kwargs = dict( + ratio=ratio, + method=method, + mode=mode, + shared_head=shared_head, + skip=skip, + ) + config.rpe_q = get_single_rpe_config(**kwargs) if 'q' in rpe_on else None + config.rpe_k = get_single_rpe_config(**kwargs) if 'k' in rpe_on else None + config.rpe_v = get_single_rpe_config(**kwargs) if 'v' in rpe_on else None + return config + + +def build_rpe(config, head_dim, num_heads): + """Build KPRPE modules on queries, keys and values. + + Parameters + ---------- + config: RPEConfigs + config.rpe_q: the config of relative position encoding on queries + config.rpe_k: the config of relative position encoding on keys + config.rpe_v: the config of relative position encoding on values + None when RPE is not used. + head_dim: int + The dimension for each head. + num_heads: int + The number of parallel attention heads. + + Returns + ------- + modules: a list of nn.Module + The KPRPE Modules on [queries, keys, values]. + None when RPE is not used. + """ + if config is None: + return None, None, None + rpes = [config.rpe_q, config.rpe_k, config.rpe_v] + transposeds = [True, True, False] + + def _build_single_rpe(rpe, transposed): + if rpe is None: + return None + + rpe_cls = KPRPE if rpe.method != METHOD.CROSS else KPRPE_Cross + return rpe_cls( + head_dim=head_dim, + num_heads=1 if rpe.shared_head else num_heads, + mode=rpe.mode, + method=rpe.method, + transposed=transposed, + num_buckets=rpe.num_buckets, + rpe_config=rpe, + ) + return [_build_single_rpe(rpe, transposed) + for rpe, transposed in zip(rpes, transposeds)] + + + +if __name__ == '__main__': + + import lovely_tensors as lt + from models.vit import Attention + lt.monkey_patch() + rpe_config = get_rpe_config(skip=0) + rpe_config.name = 'iRPE' + attn = Attention(dim=512, num_heads=8, qkv_bias=False, qk_scale=None, + attn_drop=0, proj_drop=0, rpe_config=rpe_config) + + x = torch.randn(1, 196, 512) + y = attn(x) + diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/KPRPE/relative_keypoints.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/KPRPE/relative_keypoints.py new file mode 100644 index 0000000000000000000000000000000000000000..ddfd4f5e42a43299a2a267199d88f89b9738adad --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/KPRPE/relative_keypoints.py @@ -0,0 +1,39 @@ +import torch +import math + +@torch.no_grad() +def make_rel_keypoints(keypoints, query): + seq_length = query.shape[1] + side = int(math.sqrt(seq_length)) + assert side == math.sqrt(seq_length) + + # make a grid of points from 0 to 1 + coord = torch.linspace(0, 1, side+1, device=query.device, dtype=query.dtype) + coord = (coord[:-1] + coord[1:]) / 2 # get center of patches + + x, y = torch.meshgrid(coord, coord, indexing='ij') + grid = torch.stack([y, x], dim=-1).reshape(-1, 2).unsqueeze(0).unsqueeze(-2) # BxNx1x2 + _keypoints = keypoints.unsqueeze(-3) # Bx1x5x2 + diff = (grid - _keypoints) # BxNx5x2 + diff = diff.flatten(2) # BxNx10 + return diff + + +def make_grid_0_1(side, device, dtype): + if isinstance(side, tuple): + one_side = side[0] + assert side[0] == side[1] + else: + one_side = side + # make a grid of points from 0 to 1 + coord = torch.linspace(0, 1, one_side+1, device=device, dtype=dtype) + coord = (coord[:-1] + coord[1:]) / 2 # get center of patches + x, y = torch.meshgrid(coord, coord, indexing='ij') + grid = torch.stack([y, x], dim=-1).reshape(-1, 2).unsqueeze(0).unsqueeze(-2) # BxNx1x2 + return grid + +def calc_rel_keypoints(keypoints, grid): + _keypoints = keypoints.unsqueeze(-3) # Bx1x5x2 + diff = (grid - _keypoints) # BxNx5x2 + diff = diff.flatten(2) # BxNx10 + return diff \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/__init__.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37e22a3f8275581c2eaa43ac2575aa8718785bcd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/__init__.py @@ -0,0 +1,47 @@ +from .KPRPE import kprpe_shared +import subprocess +import sys +import os +import warnings + +try: + from .rpe_ops.rpe_index import RPEIndexFunction +except ImportError: + try: + # Attempt to install the module from the setup.py script + dirname = os.path.dirname(os.path.abspath(__file__)) + cwd = os.getcwd() + os.chdir(os.path.join(dirname, 'rpe_ops')) + subprocess.check_call([sys.executable, 'setup.py', 'install', '--user']) + GREEN_STR = "\033[92m{}\033[00m" + print(GREEN_STR.format("\n[INFO] Successfully installed `rpe_ops`. Restart Application"),) + sys.exit() + except subprocess.CalledProcessError as install_error: + RED_STR = "\033[91m{}\033[00m" + warnings.warn(RED_STR.format("\n[WARNING] Failed to install `rpe_ops`. " + "Please check the installation script."),) + except ImportError as import_error: + RED_STR = "\033[91m{}\033[00m" + warnings.warn(RED_STR.format("\n[WARNING] The module `rpe_ops` is not built. " + "For better training performance, please build `rpe_ops`."),) + + +def build_rpe(rpe_config, head_dim, num_heads): + if rpe_config is None: + return None + else: + name = rpe_config.name + if name == 'KPRPE_shared': + rpe_config = kprpe_shared.get_rpe_config( + ratio=rpe_config.ratio, + method=rpe_config.method, + mode=rpe_config.mode, + shared_head=rpe_config.shared_head, + skip=0, + rpe_on=rpe_config.rpe_on, + ) + return kprpe_shared.build_rpe(rpe_config, head_dim=head_dim, num_heads=num_heads) + + else: + raise NotImplementedError(f"Unknow RPE: {name}") + diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/README.md b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d3140f928917a37f464c0e96e62cc5383e242d8b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/README.md @@ -0,0 +1,6 @@ + +# Installation +``` +cd models/vit_kprpe/RPE/rpe_ops +python setup.py install --user +``` \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/rpe_index.cpp b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/rpe_index.cpp new file mode 100644 index 0000000000000000000000000000000000000000..766142bd9ac31a327204a7d827477ef82100ec7a --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/rpe_index.cpp @@ -0,0 +1,142 @@ +#include + +#include +#include + +using index_t = int; + +at::Tensor rpe_index_forward_cpu(torch::Tensor input, torch::Tensor index) { + /* + - Inputs + input: float32 (B, H, L_query, num_buckets) + index: index_t (L_query, L_key) + - Outputs + Y: float32 (B, H, L_query, L_key) + */ + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(index.device().is_cpu(), "index must be a CPU tensor"); + AT_ASSERTM(input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + const index_t B = input.size(0); + const index_t H = input.size(1); + const index_t num_buckets = input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + at::Tensor Y = at::empty({B, H, L_query, L_key}, input.options()); + auto input_ = input.contiguous(); + auto index_ = index.contiguous(); + const index_t grain_size = 3000; + const index_t numel = Y.numel(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "rpe_index_forward_cpu", [&] { + const scalar_t *p_input = input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + scalar_t *p_Y = Y.data_ptr(); + at::parallel_for(0, numel, grain_size, [&](index_t begin, index_t end) { + /* + // we optimize the following function to + // reduce the number of operators, namely divide and multiply. + for (index_t i = begin; i < end; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + */ + + index_t aligned_begin = (begin + L_qk - 1) / L_qk * L_qk; + if (aligned_begin > end) aligned_begin = end; + index_t aligned_end = end / L_qk * L_qk; + for (index_t i = begin; i < aligned_begin; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + + // [aligned_begin, aligned_end) + // where aligned_begin % L_qk == 0, aligned_end % L_qk == 0 + index_t base = aligned_begin / L_key * num_buckets; + const index_t base_end = aligned_end / L_key * num_buckets; + index_t i = aligned_begin; + while (base < base_end) { + for (index_t q = 0, j = 0; q < L_query; ++q) { + for (index_t k = 0; k < L_key; ++k) { + p_Y[i++] = p_input[base + p_index[j++]]; + } + base += num_buckets; + } + } + + for (index_t i = aligned_end; i < end; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + }); + }); + return Y; +} + +template +inline scalar_t cpuAtomicAdd(scalar_t *address, const scalar_t val) { +#pragma omp critical + *address += val; + return *address; +} + +void rpe_index_backward_cpu(torch::Tensor grad_input, torch::Tensor grad_output, + torch::Tensor index) { + /* + - Inputs + grad_output: float32 (B, H, L_query, L_key) + index: index_t (L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + */ + AT_ASSERTM(grad_input.device().is_cpu(), "grad_input must be a CPU tensor"); + AT_ASSERTM(grad_output.device().is_cpu(), "grad_output must be a CPU tensor"); + AT_ASSERTM(index.device().is_cpu(), "grad_index must be a CPU tensor"); + AT_ASSERTM(grad_input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(grad_output.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + + const index_t num_buckets = grad_input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + + auto grad_input_ = grad_input.contiguous(); + auto grad_output_ = grad_output.contiguous(); + auto index_ = index.contiguous(); + + const index_t grain_size = 3000; + const index_t numel = grad_output.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_input.scalar_type(), "rpe_index_backward_atomic_cpu", [&] { + scalar_t *p_grad_input = grad_input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + const scalar_t *p_grad_output = grad_output_.data_ptr(); + at::parallel_for(0, numel, grain_size, [&](index_t begin, index_t end) { + for (index_t i = begin; i < end; ++i) { + const index_t input_i = i / L_key * num_buckets + p_index[i % L_qk]; + const scalar_t v = p_grad_output[i]; + cpuAtomicAdd(p_grad_input + input_i, v); + } + }); + }); +} + +std::string version() { + return "1.2.0"; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("version", &version, "The version of the package `rpe_index_cpp`"); + m.def("forward_cpu", &rpe_index_forward_cpu, "2D RPE Index Forward (CPU)"); + m.def("backward_cpu", &rpe_index_backward_cpu, "2D RPE Index Backward (CPU)"); + +#if defined(WITH_CUDA) + at::Tensor rpe_index_forward_gpu(torch::Tensor input, torch::Tensor index); + void rpe_index_backward_gpu(torch::Tensor grad_input, + torch::Tensor grad_output, torch::Tensor index); + m.def("forward_gpu", &rpe_index_forward_gpu, "2D RPE Index Forward (GPU)"); + m.def("backward_gpu", &rpe_index_backward_gpu, "2D RPE Index Backward (GPU)"); +#endif +} diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/rpe_index.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/rpe_index.py new file mode 100644 index 0000000000000000000000000000000000000000..1d915e1a56bdde6ed634ac4a8e632c6f833adc63 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/rpe_index.py @@ -0,0 +1,100 @@ +import torch +import rpe_index_cpp + + +EXPECTED_VERSION = "1.2.0" +assert rpe_index_cpp.version() == EXPECTED_VERSION, \ + f"""Unmatched `rpe_index_cpp` version: {rpe_index_cpp.version()}, expected version: {EXPECTED_VERSION} +Please re-build the package `rpe_ops`.""" + + +class RPEIndexFunction(torch.autograd.Function): + '''Y[b, h, i, j] = input[b, h, i, index[i, j]]''' + @staticmethod + def forward(ctx, input, index): + ''' + Y[b, h, i, j] = input[b, h, i, index[i, j]] + + Parameters + ---------- + input: torch.Tensor, float32 + The shape is (B, H, L_query, num_buckets) + index: torch.Tensor, int32 + The shape is (L_query, L_key) + + where B is the batch size, and H is the number of attention heads. + + Returns + ------- + Y: torch.Tensor, float32 + The shape is (B, H, L_query, L_key) + ''' + + num_buckets = input.size(-1) + ctx.save_for_backward(index) + ctx.input_shape = input.shape + forward_fn = rpe_index_cpp.forward_cpu if \ + input.device.type == 'cpu' else rpe_index_cpp.forward_gpu + output = forward_fn(input, index) + return output + + @staticmethod + def backward(ctx, grad_output): + ''' + - Inputs + grad_output: float32 (B, H, L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + ''' + index = ctx.saved_tensors[0] + if ctx.needs_input_grad[0]: + grad_input = grad_output.new_zeros(ctx.input_shape) + backward_fn = rpe_index_cpp.backward_cpu if \ + grad_output.device.type == 'cpu' else rpe_index_cpp.backward_gpu + backward_fn(grad_input, grad_output, index) + return grad_input, None + return None, None + + +if __name__ == '__main__': + import numpy as np + import time + B = 128 + H = 32 + L_query = 50 + L_key = L_query + num_buckets = 50 + + x = torch.randn(B, H, L_query, num_buckets) + + index = torch.randint(low=0, high=num_buckets, size=(L_query, L_key)) + index = index.to(torch.int) + offset = torch.arange(0, L_query * num_buckets, num_buckets).view(-1, 1) + + def test(x, index, offset): + tic = time.time() + x1 = x.clone() + x1.requires_grad = True + x2 = x.clone() + x2.requires_grad = True + + y = RPEIndexFunction.apply(x1, index) + gt_y = x2.flatten(2)[:, :, (index + offset).flatten() + ].view(B, H, L_query, L_key) + + np.testing.assert_almost_equal( + gt_y.detach().cpu().numpy(), y.detach().cpu().numpy()) + + mask = torch.randn(gt_y.shape, device=x.device) + (gt_y * mask).sum().backward() + (y * mask).sum().backward() + + print("X1:", x1.grad.cpu().numpy().flatten().sum()) + print("X2:", x2.grad.cpu().numpy().flatten().sum()) + np.testing.assert_almost_equal( + x1.grad.cpu().numpy(), x2.grad.cpu().numpy(), decimal=5) + print("Test over", x.device) + print("Cost:", time.time() - tic) + test(x, index, offset) + if torch.cuda.is_available(): + test(x.cuda(), index.cuda(), offset.cuda()) diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/rpe_index_cuda.cu b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/rpe_index_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..3ddb53158d8b232248785abe3347215dc21423b7 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/rpe_index_cuda.cu @@ -0,0 +1,140 @@ +#include +#include + +#include +#include + +using index_t = int; + +const int HIP_MAX_GRID_NUM = 65535; +const int HIP_MAX_NUM_THREADS = 512; + +inline int HIP_GET_NUM_THREADS(const int n) { + return std::min(HIP_MAX_NUM_THREADS, ((n + 31) / 32) * 32); +} + +inline int HIP_GET_BLOCKS(const int n, const int num_threads) { + return std::min(HIP_MAX_GRID_NUM, n + num_threads - 1) / num_threads; +} + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void rpe_index_forward_gpu_kernel( + index_t n, scalar_t *p_Y, const scalar_t *__restrict__ p_input, + const index_t *__restrict__ p_index, index_t num_buckets, index_t H, + index_t L_query, index_t L_key, index_t L_qk, index_t s0, index_t s1, + index_t s2, index_t s3) { + CUDA_KERNEL_LOOP(i, n) { + index_t gi = i / L_key; + const index_t qi = gi % L_query; + gi /= L_query; + const index_t hi = gi % H; + gi /= H; + const index_t bi = gi; + const index_t ind = bi * s0 + hi * s1 + qi * s2 + p_index[i % L_qk] * s3; + p_Y[i] = __ldg(&p_input[ind]); + } +} + +template +__global__ void rpe_index_backward_gpu_kernel( + index_t n, scalar_t *p_grad_input, const index_t *__restrict__ p_index, + const scalar_t *__restrict__ p_grad_output, index_t num_buckets, + index_t L_key, index_t L_qk) { + CUDA_KERNEL_LOOP(i, n) { + const index_t input_i = i / L_key * num_buckets + p_index[i % L_qk]; + const scalar_t v = p_grad_output[i]; + gpuAtomicAdd(p_grad_input + input_i, v); + } +} + +at::Tensor rpe_index_forward_gpu(torch::Tensor input, torch::Tensor index) { + /* + - Inputs + input: float32 (B, H, L_query, num_buckets) + index: index_t (L_query, L_key) + - Outputs + Y: float32 (B, H, L_query, L_key) + */ + AT_ASSERTM(input.device().is_cuda(), "input must be a GPU tensor"); + AT_ASSERTM(index.device().is_cuda(), "index must be a GPU tensor"); + AT_ASSERTM(input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + AT_ASSERTM(index.is_contiguous(), "index should be contiguous"); + const index_t B = input.size(0); + const index_t H = input.size(1); + const index_t num_buckets = input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + at::Tensor Y = at::empty({B, H, L_query, L_key}, input.options()); + const index_t numel = Y.numel(); + const at::IntArrayRef strides = input.strides(); + + const int threadsPerBlock = HIP_GET_NUM_THREADS(numel); + const int blocks = HIP_GET_BLOCKS(numel, threadsPerBlock); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "rpe_index_forward_gpu", [&] { + const scalar_t *p_input = input.data_ptr(); + const index_t *p_index = index.data_ptr(); + scalar_t *p_Y = Y.data_ptr(); + rpe_index_forward_gpu_kernel<<>>( + numel, p_Y, p_input, p_index, num_buckets, H, L_query, L_key, L_qk, + strides[0], strides[1], strides[2], strides[3]); + }); + return Y; +} + +void rpe_index_backward_gpu(torch::Tensor grad_input, torch::Tensor grad_output, + torch::Tensor index) { + /* + - Inputs + grad_output: float32 (B, H, L_query, L_key) + index: index_t (L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + */ + AT_ASSERTM(grad_input.device().is_cuda(), "grad_input must be a GPU tensor"); + AT_ASSERTM(grad_output.device().is_cuda(), + "grad_output must be a GPU tensor"); + AT_ASSERTM(index.device().is_cuda(), "grad_index must be a GPU tensor"); + AT_ASSERTM(grad_input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(grad_output.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + + const index_t num_buckets = grad_input.size(3); + const index_t L_query = grad_output.size(2); + const index_t L_key = grad_output.size(3); + const index_t L_qk = L_query * L_key; + + auto grad_input_ = grad_input.contiguous(); + auto grad_output_ = grad_output.contiguous(); + auto index_ = index.contiguous(); + + const index_t numel = grad_output.numel(); + + const int threadsPerBlock = HIP_GET_NUM_THREADS(numel); + const int blocks = HIP_GET_BLOCKS(numel, threadsPerBlock); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "rpe_index_backward_gpu", [&] { + scalar_t *p_grad_input = grad_input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + const scalar_t *p_grad_output = grad_output_.data_ptr(); + rpe_index_backward_gpu_kernel<<>>( + numel, p_grad_input, p_index, p_grad_output, num_buckets, L_key, + L_qk); + }); +} diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/setup.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c145714e9595b0364f070babe3ecf0b926a01ec6 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/RPE/rpe_ops/setup.py @@ -0,0 +1,24 @@ +from setuptools import setup, Extension +import torch +from torch.utils import cpp_extension + +ext_t = cpp_extension.CppExtension +ext_fnames = ['rpe_index.cpp'] +define_macros = [] +extra_compile_args = dict(cxx=['-fopenmp', '-O3'], + nvcc=['-O3']) + +if torch.cuda.is_available(): + ext_t = cpp_extension.CUDAExtension + ext_fnames.append('rpe_index_cuda.cu') + define_macros.append(('WITH_CUDA', None)) + +setup(name='rpe_index', + version="1.2.0", + ext_modules=[ext_t( + 'rpe_index_cpp', + ext_fnames, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + )], + cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/__init__.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2be5bda980b9aa68b178e959094a9078e8bd0150 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/__init__.py @@ -0,0 +1,79 @@ +from ..base import BaseModel +from torchvision import transforms +from .swin.names import swin_v2_b, swin_v2_s +from .rpe_options import make_kprpe_shared, make_kprpe_bias +from torch import nn +from .RPE import build_rpe + + +class SWINKPRPEModelWithKPRPE(BaseModel): + + def __init__(self, net, config): + super(SWINKPRPEModelWithKPRPE, self).__init__(config) + self.net = net + + self.rpe_config = config.rpe_config + self.keypoint_linear, self.num_buckets, self.swin_config = make_kprpe_shared(self.rpe_config, self.net) + kprpes = [] + for feature_size in self.swin_config.keys(): + num_heads = self.swin_config[feature_size][0]['num_heads'] + _, rpe_k, _ = build_rpe(self.rpe_config, head_dim=None, num_heads=num_heads) + kprpes.append(rpe_k) + self.kprpes = nn.ModuleList(kprpes) + + # remove unused params in swin + for mod in self.net.features: + if hasattr(mod, 'attn'): + if hasattr(mod.attn, 'relative_position_bias_table'): + del mod.attn.relative_position_bias_table + if hasattr(mod.attn, 'cpb_mlp'): + del mod.attn.cpb_mlp + + + assert config.mask_ratio == 0 + + @classmethod + def from_config(cls, config): + + if config.name == 'small': + net = swin_v2_s() + elif config.name == 'base': + net = swin_v2_b() + else: + raise NotImplementedError + + model = cls(net, config) + model.eval() + return model + + def forward(self, x, keypoints): + if self.input_color_flip: + x = x.flip(1) + + if self.rpe_config is None: + extra_ctx = None + else: + extra_ctx = make_kprpe_bias(keypoints, x, + self.keypoint_linear, self.rpe_config, self.swin_config, self.num_buckets, + self.kprpes) + + return self.net(x, extra_ctx) + + def make_train_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + def make_test_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + +def load_model(model_config): + model = SWINKPRPEModelWithKPRPE.from_config(model_config) + return model + diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/configs/v1_base_kprpe_splithead_unshared.yaml b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/configs/v1_base_kprpe_splithead_unshared.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5bdfec14039f53d8c6fcfedfef4f4d1a3351207c --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/configs/v1_base_kprpe_splithead_unshared.yaml @@ -0,0 +1,17 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'base' +output_dim: 512 +start_from: '' +freeze: False + +mask_ratio: 0.0 +rpe_config: + name: KPRPE_shared + rpe_on: k + shared_head: True + mode: ctx + method: product + ratio: 1.9 + ctx_type: 'rel_keypoint_splithead_unshared' + num_keypoints: 5 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/configs/v1_small_kprpe_splithead_unshared.yaml b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/configs/v1_small_kprpe_splithead_unshared.yaml new file mode 100644 index 0000000000000000000000000000000000000000..946db1af29595dbf1068de92bfdeacd2d4f80e98 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/configs/v1_small_kprpe_splithead_unshared.yaml @@ -0,0 +1,17 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'small' +output_dim: 512 +start_from: '' +freeze: False + +mask_ratio: 0.0 +rpe_config: + name: KPRPE_shared + rpe_on: k + shared_head: True + mode: ctx + method: product + ratio: 1.9 + ctx_type: 'rel_keypoint_splithead_unshared' + num_keypoints: 5 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/rpe_options.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/rpe_options.py new file mode 100644 index 0000000000000000000000000000000000000000..edfc24277fca62ee884d18b2eafaceeabfb573cd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/rpe_options.py @@ -0,0 +1,152 @@ +import torch +import torch.nn as nn + +from .RPE.KPRPE.kprpe_shared import get_rpe_config +from .RPE.KPRPE import relative_keypoints +import torch.nn.functional as F + + +def make_kprpe_shared(rpe_config, net): + + assert rpe_config.rpe_on == 'k' + num_buckets = get_rpe_config( + ratio=rpe_config.ratio, + method=rpe_config.method, + mode=rpe_config.mode, + shared_head=rpe_config.shared_head, + skip=0, + rpe_on=rpe_config.rpe_on, + )['rpe_k']['num_buckets'] + + if rpe_config.ctx_type == 'rel_keypoint_splithead_unshared': + swin_config = get_swin_config(net) + + module_list = [] + for side, blocks_cfg in swin_config.items(): + total_heads = sum([block_cfg['num_heads'] for block_cfg in blocks_cfg]) + keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * total_heads) + # init zero + keypoint_linear.weight.data.zero_() + keypoint_linear.bias.data.zero_() + module_list.append(keypoint_linear) + keypoint_linear = nn.ModuleList(module_list) + else: + raise ValueError(f'Not support ctx_type: {rpe_config.ctx_type}') + + return keypoint_linear, num_buckets, swin_config + + + +def make_kprpe_bias(keypoints, x, keypoint_linear, rpe_config, swin_config, num_buckets, kprpes): + B = x.shape[0] + ctx_type = rpe_config.get('ctx_type', '') + num_kp = rpe_config.num_keypoints + if ctx_type == 'rel_keypoint_splithead_unshared': + + extra_ctx = [] + for feature_size, linear, kprpe in zip(swin_config.keys(), keypoint_linear, kprpes): + grid = relative_keypoints.make_grid_0_1(feature_size, x.device, x.dtype) + rel_keypoints = relative_keypoints.calc_rel_keypoints(keypoints, grid)[:, :, :2 * num_kp] + rel_keypoints = linear(rel_keypoints) # B H N D + blocks_cfg = swin_config[feature_size] + heads = [block_cfg['num_heads'] for block_cfg in blocks_cfg] + rel_keypoints = rel_keypoints.view(B, -1, sum(heads), num_buckets).transpose(1, 2) + rel_keypoints = torch.split(rel_keypoints, heads, dim=1) + for rel_kp, block_cfg in zip(rel_keypoints, blocks_cfg): + # make table + windowed_rel_kp = split_window(rel_kp, feature_size[0], feature_size[1], + block_cfg['window_size'], block_cfg['shift_size']) + + per_window_rel_kp = torch.split(windowed_rel_kp, 1, dim=2) + + window_rpe_biases = [] + for window_kp in per_window_rel_kp: + per_window_rpe_biases = kprpe(window_kp.squeeze(2)) + window_rpe_biases.append(per_window_rpe_biases) + window_rpe_biases = torch.stack(window_rpe_biases, dim=1) + B, n_window, nhead, attn_sz, attn_sz2 = window_rpe_biases.shape + window_rpe_biases = window_rpe_biases.view(B * n_window, nhead, attn_sz, attn_sz2).contiguous() + + row = {'num_heads': block_cfg['num_heads'], + 'shift_size': block_cfg['shift_size'], + 'feature_size': block_cfg['feature_size'], + 'window_size': block_cfg['window_size'], + 'rel_keypoints': window_rpe_biases} + extra_ctx.append(row) + else: + raise ValueError(f'Not support ctx_type: {ctx_type}') + + return extra_ctx + + + + +def split_window(ctx, x_H, x_W, window_size, shift_size): + # ctx: # B H N D + B = ctx.shape[0] + num_head = ctx.shape[1] + D = ctx.shape[3] + ctx = ctx.reshape(B, num_head, x_H, x_W, D) + ctx = ctx.reshape(B * num_head, x_H, x_W, D) + + BH, H, W, C = ctx.shape + # pad feature maps to multiples of window size + pad_r = (window_size[1] - W % window_size[1]) % window_size[1] + pad_b = (window_size[0] - H % window_size[0]) % window_size[0] + ctx = F.pad(ctx, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_H, pad_W, _ = ctx.shape + + shift_size = shift_size.copy() + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 + + # cyclic shift + if sum(shift_size) > 0: + ctx = torch.roll(ctx, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + + # partition windows + grid_size = ((pad_H // window_size[0]), (pad_W // window_size[1])) + num_windows = grid_size[0] * grid_size[1] + ctx = ctx.view(BH, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C) + ctx = ctx.permute(0, 1, 3, 2, 4, 5) + ctx = ctx.reshape(B, num_head, grid_size[0], grid_size[1], window_size[0], window_size[1], C) + ctx = ctx.reshape(B, num_head, num_windows, window_size[0] * window_size[1], C) # B, H, num_grid, num_el_window, C + # ex : [32, 4, 64, 25, 49] + return ctx.contiguous() + + +def get_swin_config(net): + assert hasattr(net, 'features') + features = net.features + + input = torch.rand(1, 3, 112, 112) + config = {} + for module in features: + if hasattr(module, 'attn'): + attn = module.attn + num_heads = attn.num_heads + shift_size = attn.shift_size + feature_size = tuple(input.shape[1:3]) + window_size = attn.window_size + if feature_size not in config: + config[feature_size] = [] + config[feature_size].append({'num_heads': num_heads, 'shift_size': shift_size, + 'feature_size': feature_size, 'window_size': window_size, + }) + input = module(input) + # {(37, 37): [{'num_heads': 4, 'shift_size': [0, 0], 'feature_size': (37, 37)}, + # {'num_heads': 4, 'shift_size': [2, 2], 'feature_size': (37, 37)}], + # (19, 19): [{'num_heads': 8, 'shift_size': [0, 0], 'feature_size': (19, 19)}, + # {'num_heads': 8, 'shift_size': [2, 2], 'feature_size': (19, 19)}], + # (10, 10): [{'num_heads': 16, 'shift_size': [0, 0], 'feature_size': (10, 10)}, + # {'num_heads': 16, 'shift_size': [2, 2], 'feature_size': (10, 10)}, + # {'num_heads': 16, 'shift_size': [0, 0], 'feature_size': (10, 10)}, + # ... + # {'num_heads': 16, 'shift_size': [0, 0], 'feature_size': (10, 10)}, + # {'num_heads': 16, 'shift_size': [2, 2], 'feature_size': (10, 10)}], + # (5, 5): [{'num_heads': 32, 'shift_size': [0, 0], 'feature_size': (5, 5)}, + # {'num_heads': 32, 'shift_size': [2, 2], 'feature_size': (5, 5)}]} + return config \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/__init__.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3043d7e1259c5c96d15000af812e2579510d081b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/__init__.py @@ -0,0 +1,38 @@ +from .names import swin_v2_b, swin_v2_s +import torch + + +if __name__ == '__main__': + + model = swin_v2_b() + model.eval() + inputs_shape = (1, 3, 112, 112) + x = torch.randn(*inputs_shape) + y = model(x) + print(x.shape) + print(y.shape) + + model.eval() + from fvcore.nn import flop_count + import numpy as np + res = flop_count(model, inputs=torch.randn(inputs_shape), supported_ops={}) + fvcore_flop = np.array(list(res[0].values())).sum() + print(f'FLOPs: {fvcore_flop:.2f}', 'G') + print('Num Params: ', sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 'M') + + + model = swin_v2_s() + model.eval() + inputs_shape = (1, 3, 112, 112) + x = torch.randn(*inputs_shape) + y = model(x) + print(x.shape) + print(y.shape) + + model.eval() + from fvcore.nn import flop_count + import numpy as np + res = flop_count(model, inputs=torch.randn(inputs_shape), supported_ops={}) + fvcore_flop = np.array(list(res[0].values())).sum() + print(f'FLOPs: {fvcore_flop:.2f}', 'G') + print('Num Params: ', sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6, 'M') \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/model.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/model.py new file mode 100644 index 0000000000000000000000000000000000000000..61067968fad77f28c72f22c45abd9df2f6deb679 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/model.py @@ -0,0 +1,160 @@ +from functools import partial +from typing import Any, Callable, List, Optional +from torch import nn, Tensor +from torchvision.ops.misc import MLP, Permute +from torchvision.utils import _log_api_usage_once +from .modules_v1 import PatchMerging, ShiftedWindowAttention, SwinTransformerBlock +from .modules_v2 import PatchMergingV2, ShiftedWindowAttentionV2, SwinTransformerBlockV2 +import torch + +class SwinTransformer(nn.Module): + """ + Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using + Shifted Windows" `_ paper. + Args: + patch_size (List[int]): Patch size. + embed_dim (int): Patch embedding dimension. + depths (List(int)): Depth of each Swin Transformer layer. + num_heads (List(int)): Number of attention heads in different layers. + window_size (List[int]): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob (float): Stochastic depth rate. Default: 0.1. + num_classes (int): Number of classes for classification head. Default: 1000. + block (nn.Module, optional): SwinTransformer Block. Default: None. + norm_layer (nn.Module, optional): Normalization layer. Default: None. + downsample_layer (nn.Module): Downsample layer (patch merging). Default: PatchMerging. + """ + + def __init__( + self, + img_size: List[int], + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.1, + num_classes: int = 1000, + norm_layer: Optional[Callable[..., nn.Module]] = None, + block: Optional[Callable[..., nn.Module]] = SwinTransformerBlockV2, + downsample_layer: Callable[..., nn.Module] = PatchMergingV2, + ): + super().__init__() + _log_api_usage_once(self) + self.num_classes = num_classes + + if block is None: + block = SwinTransformerBlock + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-5) + + layers: List[nn.Module] = [] + # split image into non-overlapping patches + layers.append( + nn.Sequential( + nn.Conv2d( + 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) + ), + Permute([0, 2, 3, 1]), + norm_layer(embed_dim), + ) + ) + + total_stage_blocks = sum(depths) + stage_block_id = 0 + # build SwinTransformer blocks + for i_stage in range(len(depths)): + dim = embed_dim * 2**i_stage + for i_layer in range(depths[i_stage]): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) + layers.append( + block( + dim, + num_heads[i_stage], + window_size=window_size, + shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + ) + ) + stage_block_id += 1 + # add patch merging layer + if i_stage < (len(depths) - 1): + layers.append(downsample_layer(dim, norm_layer)) + self.features = nn.ModuleList(layers) + + num_features = embed_dim * 2 ** (len(depths) - 1) + self.norm = norm_layer(num_features) + self.flatten = nn.Flatten(1) + + # features head + if img_size[0] == 112 and patch_size[0] == 3: + final_patch_size = [5, 5] + else: + # just calculate + raise NotImplementedError("Not implemented for this img_size and patch_size") + num_patches = final_patch_size[0] * final_patch_size[1] + self.embed_dim = embed_dim # basis feature dim for building blocks + self.num_features = num_features # C for final intermediate feature + self.num_patches = num_patches # N for output feature + self.num_classes = num_classes # C for return feature after flattening and mapping + self.feature = nn.Sequential( + nn.Linear(in_features=num_features * num_patches, out_features=num_features, bias=False), + nn.BatchNorm1d(num_features=num_features, eps=2e-5), + nn.Linear(in_features=num_features, out_features=num_classes, bias=False), + nn.BatchNorm1d(num_features=num_classes, eps=2e-5) + ) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x, extra_ctx=None): + for module in self.features: + if isinstance(module, SwinTransformerBlock) or isinstance(module, SwinTransformerBlockV2): + x = module(x, extra_ctx) + else: + x = module(x) + x = self.norm(x) + x = self.flatten(x) + x = self.feature(x) + return x + + +def _swin_transformer( + img_size: List[int], + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + stochastic_depth_prob: float, + num_classes=512, + **kwargs: Any +) -> SwinTransformer: + + model = SwinTransformer( + img_size=img_size, + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + stochastic_depth_prob=stochastic_depth_prob, + num_classes=num_classes, + **kwargs + ) + + return model + diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/modules_v1.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/modules_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..629574416ef146ef84bd36e314b6297937dce8b5 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/modules_v1.py @@ -0,0 +1,178 @@ +from typing import Any, Callable, List, Optional + +import torch +from torchvision.ops.misc import MLP, Permute +from torchvision.ops.stochastic_depth import StochasticDepth + +from torch import nn, Tensor +from torchvision.utils import _log_api_usage_once +from .ops import _patch_merging_pad, shifted_window_attention, _get_relative_position_bias + + +class PatchMerging(nn.Module): + """Patch Merging Layer. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + _log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) + x = self.norm(x) + x = self.reduction(x) # ... H/2 W/2 2*C + return x + + + + + +class ShiftedWindowAttention(nn.Module): + """ + See :func:`shifted_window_attention`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ): + super().__init__() + if len(window_size) != 2 or len(shift_size) != 2: + raise ValueError("window_size and shift_size must be of length 2") + self.window_size = window_size + self.shift_size = shift_size + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + self.define_relative_position_bias_table() + self.define_relative_position_index() + + def define_relative_position_bias_table(self): + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), self.num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def define_relative_position_index(self): + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1).flatten() # Wh*Ww*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + def get_relative_position_bias(self) -> torch.Tensor: + return _get_relative_position_bias( + self.relative_position_bias_table, self.relative_position_index, self.window_size # type: ignore[arg-type] + ) + + def forward(self, x: Tensor, extra_ctx=None) -> Tensor: + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + relative_position_bias = self.get_relative_position_bias() + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + training=self.training, + ) + + + +class SwinTransformerBlock(nn.Module): + """ + Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, + ): + super().__init__() + _log_api_usage_once(self) + + self.norm1 = norm_layer(dim) + self.attn = attn_layer( + dim, + window_size, + shift_size, + num_heads, + attention_dropout=attention_dropout, + dropout=dropout, + ) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.norm2 = norm_layer(dim) + self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) + + for m in self.mlp.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x: Tensor, extra_ctx=None): + x = x + self.stochastic_depth(self.attn(self.norm1(x), extra_ctx)) + x = x + self.stochastic_depth(self.mlp(self.norm2(x))) + return x diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/modules_v2.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/modules_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f62d50207c15af83a007dc77c1808b56ec31c986 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/modules_v2.py @@ -0,0 +1,191 @@ +from typing import Any, Callable, List, Optional +import torch +from torch import nn, Tensor + + +from torchvision.utils import _log_api_usage_once +from .ops import _patch_merging_pad, shifted_window_attention, _get_relative_position_bias +from .modules_v1 import ShiftedWindowAttention, SwinTransformerBlock + + +class PatchMergingV2(nn.Module): + """Patch Merging Layer for Swin Transformer V2. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + _log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) # difference + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) + x = self.reduction(x) # ... H/2 W/2 2*C + x = self.norm(x) + return x + + + +class ShiftedWindowAttentionV2(ShiftedWindowAttention): + """ + See :func:`shifted_window_attention_v2`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ): + super().__init__( + dim, + window_size, + shift_size, + num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attention_dropout=attention_dropout, + dropout=dropout, + ) + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), nn.ReLU(inplace=True), nn.Linear(512, num_heads, bias=False) + ) + if qkv_bias: + length = self.qkv.bias.numel() // 3 + self.qkv.bias[length : 2 * length].data.zero_() + + def define_relative_position_bias_table(self): + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w], indexing="ij")) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = ( + torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / 3.0 + ) + self.register_buffer("relative_coords_table", relative_coords_table) + + def get_relative_position_bias(self, ctx) -> torch.Tensor: + if ctx is not None: + return ctx + else: + relative_position_bias = _get_relative_position_bias( + self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads), + self.relative_position_index, # type: ignore[arg-type] + self.window_size, + ) + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + return relative_position_bias + + def forward(self, x: Tensor, extra_ctx=None): + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + # kprpe + if extra_ctx is not None: + ctx = extra_ctx.pop(0)['rel_keypoints'] + else: + ctx = None + self.relative_coords_table + self.relative_position_index # 625 because it is 25x25 attention map from 5x5 feature + # and each element is between 0 and 80 + # + # rp_bucket = self.relative_position_index.view(25, 25) + # lookup_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads).transpose(0,1) + # lookup_table[:, rp_bucket.flatten()].view(1, self.num_heads, L_query, L_key) + # _get_relative_position_bias() + + + relative_position_bias = self.get_relative_position_bias(ctx) + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + logit_scale=self.logit_scale, + training=self.training, + ) + + + +class SwinTransformerBlockV2(SwinTransformerBlock): + """ + Swin Transformer V2 Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttentionV2. + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttentionV2, + ): + super().__init__( + dim, + num_heads, + window_size, + shift_size, + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=stochastic_depth_prob, + norm_layer=norm_layer, + attn_layer=attn_layer, + ) + + def forward(self, x: Tensor, extra_ctx=None): + # Here is the difference, we apply norm after the attention in V2. + # In V1 we applied norm before the attention. + x = x + self.stochastic_depth(self.norm1(self.attn(x, extra_ctx))) + x = x + self.stochastic_depth(self.norm2(self.mlp(x))) + return x diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/names.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/names.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f06c7b2365efb82a48a271d06e61008df4ed23 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/names.py @@ -0,0 +1,34 @@ +from .model import _swin_transformer +from .modules_v2 import SwinTransformerBlockV2, PatchMergingV2 + +def swin_t(): + return _swin_transformer(img_size=[112,112], patch_size=[3, 3], embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=[5, 5], stochastic_depth_prob=0.2, num_classes=512, ) + + +def swin_s(): + return _swin_transformer(img_size=[112,112], patch_size=[3, 3], embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], + window_size=[5, 5], stochastic_depth_prob=0.3, num_classes=512, ) + + +def swin_b(): + return _swin_transformer(img_size=[112,112], patch_size=[3, 3], embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], + window_size=[5, 5], stochastic_depth_prob=0.5, num_classes=512, ) + + +def swin_v2_t(): + return _swin_transformer(img_size=[112,112], patch_size=[3, 3], embed_dim=96, depths=[2, 2, 6, 2], num_heads=[3, 6, 12, 24], + window_size=[5, 5], stochastic_depth_prob=0.2, num_classes=512, + block=SwinTransformerBlockV2, downsample_layer=PatchMergingV2, ) + + +def swin_v2_s(): + return _swin_transformer(img_size=[112,112], patch_size=[3, 3], embed_dim=128, depths=[2, 2, 12, 2], num_heads=[4, 8, 16, 32], + window_size=[5, 5], stochastic_depth_prob=0.3, num_classes=512, + block=SwinTransformerBlockV2, downsample_layer=PatchMergingV2, ) + + +def swin_v2_b(): + return _swin_transformer(img_size=[112,112], patch_size=[3, 3], embed_dim=128, depths=[2, 2, 18, 2], num_heads=[4, 8, 16, 32], + window_size=[5, 5], stochastic_depth_prob=0.5, num_classes=512, + block=SwinTransformerBlockV2, downsample_layer=PatchMergingV2, ) diff --git a/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/ops.py b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..4732cb4530e4143db48e92d6aecce1c4d2a402c0 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/swin_kprpe/swin/ops.py @@ -0,0 +1,151 @@ +import math +from typing import List, Optional + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + + +def _patch_merging_pad(x: torch.Tensor) -> torch.Tensor: + H, W, _ = x.shape[-3:] + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C + return x +torch.fx.wrap("_patch_merging_pad") + + +def _get_relative_position_bias( + relative_position_bias_table: torch.Tensor, relative_position_index: torch.Tensor, window_size: List[int], +) -> torch.Tensor: + N = window_size[0] * window_size[1] + relative_position_bias = relative_position_bias_table[relative_position_index] # type: ignore[index] + relative_position_bias = relative_position_bias.view(N, N, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + return relative_position_bias + + +torch.fx.wrap("_get_relative_position_bias") + + + +def shifted_window_attention( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: List[int], + num_heads: int, + shift_size: List[int], + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, + logit_scale: Optional[torch.Tensor] = None, + training: bool = True, +) -> Tensor: + """ + Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. + qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. + proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. + relative_position_bias (Tensor): The learned relative position bias added to attention. + window_size (List[int]): Window size. + num_heads (int): Number of attention heads. + shift_size (List[int]): Shift size for shifted window attention. + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. + qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. + proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + logit_scale (Tensor[out_dim], optional): Logit scale of cosine attention for Swin Transformer V2. Default: None. + training (bool, optional): Training flag used by the dropout parameters. Default: True. + Returns: + Tensor[N, H, W, C]: The output tensor after shifted window attention. + """ + B, H, W, C = input.shape + # pad feature maps to multiples of window size + pad_r = (window_size[1] - W % window_size[1]) % window_size[1] + pad_b = (window_size[0] - H % window_size[0]) % window_size[0] + x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_H, pad_W, _ = x.shape + + shift_size = shift_size.copy() + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 + + # cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + + # partition windows + num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1]) + x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C + + # multi-head attention + if logit_scale is not None and qkv_bias is not None: + qkv_bias = qkv_bias.clone() + length = qkv_bias.numel() // 3 + qkv_bias[length : 2 * length].zero_() + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + if logit_scale is not None: + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp(logit_scale, max=math.log(100.0)).exp() + attn = attn * logit_scale + else: + q = q * (C // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) + # add relative position bias + attn = attn + relative_position_bias + + if sum(shift_size) > 0: + # generate attention mask + attn_mask = x.new_zeros((pad_H, pad_W)) + h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None)) + w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None)) + count = 0 + for h in h_slices: + for w in w_slices: + attn_mask[h[0] : h[1], w[0] : w[1]] = count + count += 1 + attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1]) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout, training=training) + + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout, training=training) + + # reverse windows + x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) + + # reverse cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + + # unpad features + x = x[:, :H, :W, :].contiguous() + return x + + +torch.fx.wrap("shifted_window_attention") + diff --git a/cvlface/research/recognition/code/run_v1/models/vit/__init__.py b/cvlface/research/recognition/code/run_v1/models/vit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe5faf297eef5e74e17a92bc2908862fb7ab8a7 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit/__init__.py @@ -0,0 +1,68 @@ +from ..base import BaseModel +from .vit import VisionTransformer +from torchvision import transforms + + +class ViTModel(BaseModel): + + """ + A class representing a Vision Transformer (ViT) model that inherits from the BaseModel class. + + This model applies the transformer architecture to image analysis, utilizing patches of images as input sequences, + allowing for attention-based processing of visual elements. + https://arxiv.org/abs/2010.11929 + ``` + @article{dosovitskiy2020image, + title={An image is worth 16x16 words: Transformers for image recognition at scale}, + author={Dosovitskiy, Alexey and Beyer, Lucas and Kolesnikov, Alexander and Weissenborn, Dirk and Zhai, Xiaohua and Unterthiner, Thomas and Dehghani, Mostafa and Minderer, Matthias and Heigold, Georg and Gelly, Sylvain and others}, + journal={arXiv preprint arXiv:2010.11929}, + year={2020} + } + ``` + """ + + def __init__(self, net, config): + super(ViTModel, self).__init__(config) + self.net = net + + + @classmethod + def from_config(cls, config): + + if config.name == 'small': + net = VisionTransformer(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=12, + mlp_ratio=5, num_heads=8, drop_path_rate=0.1, norm_layer="ln", + mask_ratio=config.mask_ratio) + elif config.name == 'base': + net = VisionTransformer(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=24, + mlp_ratio=3, num_heads=16, drop_path_rate=0.1, norm_layer="ln", + mask_ratio=config.mask_ratio) + else: + raise NotImplementedError + + model = cls(net, config) + model.eval() + return model + + def forward(self, x): + if self.input_color_flip: + x = x.flip(1) + return self.net(x) + + def make_train_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + def make_test_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + +def load_model(model_config): + model = ViTModel.from_config(model_config) + return model \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit/configs/v1_base.yaml b/cvlface/research/recognition/code/run_v1/models/vit/configs/v1_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f604e8833cdbab6a9cb85fae8fee26f8ac7f8be4 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit/configs/v1_base.yaml @@ -0,0 +1,7 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'base' +output_dim: 512 +start_from: '' +freeze: False +mask_ratio: 0.0 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit/configs/v1_small.yaml b/cvlface/research/recognition/code/run_v1/models/vit/configs/v1_small.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7d024cd34519f4ddc38429ba65e9080509f1a9b7 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit/configs/v1_small.yaml @@ -0,0 +1,7 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'small' +output_dim: 512 +start_from: '' +freeze: False +mask_ratio: 0.0 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit/vit.py b/cvlface/research/recognition/code/run_v1/models/vit/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5e55ee692e07d068667e71fd2e1ec85104aec6 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit/vit.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from typing import Optional, Callable +from fvcore.nn import flop_count +import numpy as np + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class VITBatchNorm(nn.Module): + def __init__(self, num_features): + super().__init__() + self.num_features = num_features + self.bn = nn.BatchNorm1d(num_features=num_features) + + def forward(self, x): + return self.bn(x) + + +class Attention(nn.Module): + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + attn_drop: float = 0., + proj_drop: float = 0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + + batch_size, num_token, embed_dim = x.shape + #qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads] + qkv = self.qkv(x).reshape( + batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(batch_size, num_token, embed_dim) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, + dim: int, + num_heads: int, + num_patches: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.ReLU6, + norm_layer: str = "ln", + patch_n: int = 144): + super().__init__() + + if norm_layer == "bn": + self.norm1 = VITBatchNorm(num_features=num_patches) + self.norm2 = VITBatchNorm(num_features=num_patches) + elif norm_layer == "ln": + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.extra_gflops = (num_heads * patch_n * (dim//num_heads)*patch_n * 2) / (1000**3) + + def forward(self, x): + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d(in_channels, embed_dim, + kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + batch_size, channels, height, width = x.shape + assert height == self.img_size[0] and width == self.img_size[1], \ + f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + img_size: int = 112, + patch_size: int = 16, + in_channels: int = 3, + num_classes: int = 1000, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_patches: Optional[int] = None, + norm_layer: str = "ln", + mask_ratio = 0.1, + using_checkpoint = False, + ): + super().__init__() + self.num_classes = num_classes + # num_features for consistency with other models + self.num_features = self.embed_dim = embed_dim + + if num_patches is not None: + self.patch_embed = nn.Identity() + else: + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + self.mask_ratio = mask_ratio + self.using_checkpoint = using_checkpoint + + self.num_patches = num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + patch_n = (img_size//patch_size)**2 + self.blocks = nn.ModuleList( + [ + Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + num_patches=num_patches, patch_n=patch_n) + for i in range(depth)] + ) + self.extra_gflops = 0.0 + for _block in self.blocks: + self.extra_gflops += _block.extra_gflops + + if norm_layer == "ln": + self.norm = nn.LayerNorm(embed_dim) + elif norm_layer == "bn": + self.norm = VITBatchNorm(self.num_patches) + + # features head + self.feature = nn.Sequential( + nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False), + nn.BatchNorm1d(num_features=embed_dim, eps=2e-5), + nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False), + nn.BatchNorm1d(num_features=num_classes, eps=2e-5) + ) + + if self.mask_ratio == 0: + pass + else: + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + torch.nn.init.normal_(self.mask_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + # trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def random_masking(self, x, mask_ratio=0.1): + N, L, D = x.size() # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + # ascend: small is keep, large is remove + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + index = ids_keep.unsqueeze(-1).repeat(1, 1, D) + x_masked = torch.gather(x, dim=1, index=index) + + return x_masked, index, ids_restore + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + if self.training and self.mask_ratio > 0: + x, _, ids_restore = self.random_masking(x) + + for func in self.blocks: + if self.using_checkpoint and self.training: + from torch.utils.checkpoint import checkpoint + x = checkpoint(func, x) + else: + x = func(x) + x = self.norm(x.float()) + + if self.training and self.mask_ratio > 0: + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = x_ + return torch.reshape(x, (B, self.num_patches * self.embed_dim)) + + def forward(self, x): + x = self.forward_features(x) + x = self.feature(x) + return x diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/__init__.py b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7eeb241797f7faba0352bd67e56f3ad753f8e4c4 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/__init__.py @@ -0,0 +1,65 @@ +from .iRPE import irpe +import torch +import warnings +import subprocess +import sys +import os + +try: + from .rpe_ops.rpe_index import RPEIndexFunction +except ImportError: + try: + # Attempt to install the module from the setup.py script + dirname = os.path.dirname(os.path.abspath(__file__)) + cwd = os.getcwd() + os.chdir(os.path.join(dirname, 'rpe_ops')) + subprocess.check_call([sys.executable, 'setup.py', 'install', '--user']) + GREEN_STR = "\033[92m{}\033[00m" + print(GREEN_STR.format("\n[INFO] Successfully installed `rpe_ops`. Restart Application"),) + sys.exit() + except subprocess.CalledProcessError as install_error: + RED_STR = "\033[91m{}\033[00m" + warnings.warn(RED_STR.format("\n[WARNING] Failed to install `rpe_ops`. " + "Please check the installation script."),) + except ImportError as import_error: + RED_STR = "\033[91m{}\033[00m" + warnings.warn(RED_STR.format("\n[WARNING] The module `rpe_ops` is not built. " + "For better training performance, please build `rpe_ops`."),) + + +def build_rpe(rpe_config, head_dim, num_heads): + if rpe_config is None: + return None + else: + name = rpe_config.name + if name == "iRPE": + rpe_config = irpe.get_rpe_config( + ratio=rpe_config.ratio, + method=rpe_config.method, + mode=rpe_config.mode, + shared_head=rpe_config.shared_head, + skip=0, + rpe_on=rpe_config.rpe_on, + ) + return irpe.build_rpe(rpe_config, head_dim=head_dim, num_heads=num_heads) + else: + raise NotImplementedError(f"Unknow RPE: {name}") + + +if __name__ == '__main__': + import easydict + rpe_config = easydict.EasyDict( + {'name': 'iRPE', + 'ratio': 1.9, + 'method': 'product', + 'mode': 'ctx', + 'shared_head': True, + 'skip': 0, + 'rpe_on': 'k', + 'ctx_type': 'rel_keypoint', + }) + head_dim = 64 + num_heads = 8 + rpe_q, rpe_k, rpe_v = build_rpe(rpe_config, head_dim=head_dim, num_heads=num_heads) + ctx = torch.rand(1, 1, 9, 49) + rpe_k(ctx) diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/iRPE/dist.py b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/iRPE/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..986e5400d3aba516db5d69fbbcdbd35f572e212e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/iRPE/dist.py @@ -0,0 +1,162 @@ +import torch +import math + + +@torch.no_grad() +def piecewise_index(relative_position, alpha, beta, gamma, dtype): + """piecewise index function defined in Eq. (18) in our paper. + + Parameters + ---------- + relative_position: torch.Tensor, dtype: long or float + The shape of `relative_position` is (L, L). + alpha, beta, gamma: float + The coefficients of piecewise index function. + + Returns + ------- + idx: torch.Tensor, dtype: long + A tensor indexing relative distances to corresponding encodings. + `idx` is a long tensor, whose shape is (L, L) and each element is in [-beta, beta]. + """ + rp_abs = relative_position.abs() + mask = rp_abs <= alpha + not_mask = ~mask + rp_out = relative_position[not_mask] + rp_abs_out = rp_abs[not_mask] + y_out = (torch.sign(rp_out) * (alpha + + torch.log(rp_abs_out / alpha) / + math.log(gamma / alpha) * + (beta - alpha)).round().clip(max=beta)).to(dtype) + + idx = relative_position.clone() + if idx.dtype in [torch.float32, torch.float64]: + # round(x) when |x| <= alpha + idx = idx.round().to(dtype) + + # assign the value when |x| > alpha + idx[not_mask] = y_out + return idx + + +@torch.no_grad() +def _rp_2d_euclidean(diff, **kwargs): + """2D RPE with Euclidean method. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + dis = diff.square().sum(2).float().sqrt().round() + return piecewise_index(dis, **kwargs) + + +@torch.no_grad() +def _rp_2d_quant(diff, **kwargs): + """2D RPE with Quantization method. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + + dis = diff.square().sum(2) + return piecewise_index(dis, **kwargs) + + +@torch.no_grad() +def _rp_2d_product(diff, **kwargs): + """2D RPE with Product method. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + # convert beta to an integer since beta is a float number. + beta_int = int(kwargs['beta']) + S = 2 * beta_int + 1 + # the output of piecewise index function is in [-beta_int, beta_int] + r = piecewise_index(diff[:, :, 0], **kwargs) + \ + beta_int # [0, 2 * beta_int] + c = piecewise_index(diff[:, :, 1], **kwargs) + \ + beta_int # [0, 2 * beta_int] + + pid = r * S + c + + return pid + + +@torch.no_grad() +def _rp_2d_cross_rows(diff, **kwargs): + """2D RPE with Cross for rows. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + dis = diff[:, :, 0] + return piecewise_index(dis, **kwargs) + + +@torch.no_grad() +def _rp_2d_cross_cols(diff, **kwargs): + """2D RPE with Cross for columns. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + + dis = diff[:, :, 1] + return piecewise_index(dis, **kwargs) + diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/iRPE/irpe.py b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/iRPE/irpe.py new file mode 100644 index 0000000000000000000000000000000000000000..131a3f4db3fa87ea200b444a84ed89312f84e73e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/iRPE/irpe.py @@ -0,0 +1,745 @@ +from easydict import EasyDict as edict +import math +import torch +import torch.nn as nn +from .dist import _rp_2d_cross_cols, _rp_2d_cross_rows, _rp_2d_euclidean, _rp_2d_product, _rp_2d_quant + +try: + from ..rpe_ops.rpe_index import RPEIndexFunction +except Exception as e: + print('Failed to import cuda/cpp RPEIndexFunction') + RPEIndexFunction = None + + + +def get_absolute_positions(height, width, dtype, device): + '''Get absolute positions + + Take height = 3, width = 3 as an example: + rows: cols: + 1 1 1 1 2 3 + 2 2 2 1 2 3 + 3 3 3 1 2 3 + + return stack([rows, cols], 2) + + Parameters + ---------- + height, width: int + The height and width of feature map + dtype: torch.dtype + the data type of returned value + device: torch.device + the device of returned value + + Return + ------ + 2D absolute positions: torch.Tensor + The shape is (height, width, 2), + where 2 represents a 2D position (row, col). + ''' + rows = torch.arange(height, dtype=dtype, device=device).view( + height, 1).repeat(1, width) + cols = torch.arange(width, dtype=dtype, device=device).view( + 1, width).repeat(height, 1) + return torch.stack([rows, cols], 2) + + +class METHOD: + """define iRPE method IDs + We divide the implementation of CROSS into CROSS_ROWS and CROSS_COLS. + + """ + EUCLIDEAN = 0 + QUANT = 1 + PRODUCT = 3 + CROSS = 4 + CROSS_ROWS = 41 + CROSS_COLS = 42 + + +# Define a mapping from METHOD_ID to Python function +_METHOD_FUNC = { + METHOD.EUCLIDEAN: _rp_2d_euclidean, + METHOD.QUANT: _rp_2d_quant, + METHOD.PRODUCT: _rp_2d_product, + METHOD.CROSS_ROWS: _rp_2d_cross_rows, + METHOD.CROSS_COLS: _rp_2d_cross_cols, +} + + +def get_num_buckets(method, alpha, beta, gamma): + """ Get number of buckets storing relative position encoding. + The buckets does not contain `skip` token. + + Parameters + ---------- + method: METHOD + The method ID of image relative position encoding. + alpha, beta, gamma: float + The coefficients of piecewise index function. + + Returns + ------- + num_buckets: int + The number of buckets storing relative position encoding. + """ + beta_int = int(beta) + if method == METHOD.PRODUCT: + # IDs in [0, (2 * beta_int + 1)^2) for Product method + num_buckets = (2 * beta_int + 1) ** 2 + else: + # IDs in [-beta_int, beta_int] except of Product method + num_buckets = 2 * beta_int + 1 + return num_buckets + + +# (method, alpha, beta, gamma) -> (bucket_ids, num_buckets, height, width) +BUCKET_IDS_BUF = dict() + + +@torch.no_grad() +def get_bucket_ids_2d_without_skip(method, height, width, + alpha, beta, gamma, + dtype=torch.long, device=torch.device('cpu')): + """Get bucket IDs for image relative position encodings without skip token + + Parameters + ---------- + method: METHOD + The method ID of image relative position encoding. + height, width: int + The height and width of the feature map. + The sequence length is equal to `height * width`. + alpha, beta, gamma: float + The coefficients of piecewise index function. + dtype: torch.dtype + the data type of returned `bucket_ids` + device: torch.device + the device of returned `bucket_ids` + + Returns + ------- + bucket_ids: torch.Tensor, dtype: long + The bucket IDs which index to corresponding encodings. + The shape of `bucket_ids` is (skip + L, skip + L), + where `L = height * wdith`. + num_buckets: int + The number of buckets including `skip` token. + L: int + The sequence length + """ + + key = (method, alpha, beta, gamma, dtype, device) + value = BUCKET_IDS_BUF.get(key, None) + if value is None or value[-2] < height or value[-1] < width: + if value is None: + max_height, max_width = height, width + else: + max_height = max(value[-2], height) + max_width = max(value[-1], width) + # relative position encoding mapping function + func = _METHOD_FUNC.get(method, None) + if func is None: + raise NotImplementedError( + f"[Error] The method ID {method} does not exist.") + pos = get_absolute_positions(max_height, max_width, dtype, device) + + # compute the offset of a pair of 2D relative positions + max_L = max_height * max_width + pos1 = pos.view((max_L, 1, 2)) + pos2 = pos.view((1, max_L, 2)) + # diff: shape of (L, L, 2) + diff = pos1 - pos2 + + # bucket_ids: shape of (L, L) + bucket_ids = func(diff, alpha=alpha, beta=beta, + gamma=gamma, dtype=dtype) + beta_int = int(beta) + if method != METHOD.PRODUCT: + bucket_ids += beta_int + bucket_ids = bucket_ids.view( + max_height, max_width, max_height, max_width) + + num_buckets = get_num_buckets(method, alpha, beta, gamma) + value = (bucket_ids, num_buckets, height, width) + BUCKET_IDS_BUF[key] = value + L = height * width + bucket_ids = value[0][:height, :width, :height, :width].reshape(L, L) + num_buckets = value[1] + + return bucket_ids, num_buckets, L + + +@torch.no_grad() +def get_bucket_ids_2d(method, height, width, + skip, alpha, beta, gamma, + dtype=torch.long, device=torch.device('cpu')): + """Get bucket IDs for image relative position encodings + + Parameters + ---------- + method: METHOD + The method ID of image relative position encoding. + height, width: int + The height and width of the feature map. + The sequence length is equal to `height * width`. + skip: int + The number of skip token before spatial tokens. + When skip is 0, no classification token. + When skip is 1, there is a classification token before spatial tokens. + When skip > 1, there are `skip` extra tokens before spatial tokens. + alpha, beta, gamma: float + The coefficients of piecewise index function. + dtype: torch.dtype + the data type of returned `bucket_ids` + device: torch.device + the device of returned `bucket_ids` + + Returns + ------- + bucket_ids: torch.Tensor, dtype: long + The bucket IDs which index to corresponding encodings. + The shape of `bucket_ids` is (skip + L, skip + L), + where `L = height * wdith`. + num_buckets: int + The number of buckets including `skip` token. + """ + bucket_ids, num_buckets, L = get_bucket_ids_2d_without_skip(method, height, width, + alpha, beta, gamma, + dtype, device) + + # add an extra encoding (id = num_buckets) for the classification token + if skip > 0: + new_bids = bucket_ids.new_empty(size=(skip + L, skip + L)) + + # if extra token exists, we add extra bucket as its encoding. + extra_bucket_id = num_buckets + num_buckets += 1 + + new_bids[:skip] = extra_bucket_id + new_bids[:, :skip] = extra_bucket_id + new_bids[skip:, skip:] = bucket_ids + + bucket_ids = new_bids + bucket_ids = bucket_ids.contiguous() + return bucket_ids, num_buckets + + +class iRPE(nn.Module): + """The implementation of image relative position encoding (excluding Cross method). + + Parameters + ---------- + head_dim: int + The dimension for each head. + num_heads: int + The number of parallel attention heads. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + method: METHOD + The method ID of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + transposed: bool + Whether to transpose the input feature. + For iRPE on queries or keys, transposed should be `True`. + For iRPE on values, transposed should be `False`. + num_buckets: int + The number of buckets, which store encodings. + initializer: None or an inplace function + [Optional] The initializer to `lookup_table`. + Initalize `lookup_table` as zero by default. + rpe_config: RPEConfig + The config generated by the function `get_single_rpe_config`. + """ + # a buffer to store bucket index + # (key, rp_bucket, _ctx_rp_bucket_flatten) + _rp_bucket_buf = (None, None, None) + + def __init__(self, head_dim, num_heads=8, + mode=None, method=None, + transposed=True, num_buckets=None, + initializer=None, rpe_config=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + # relative position + assert mode in [None, 'bias', 'contextual'] + self.mode = mode + + assert method is not None, 'method should be a METHOD ID rather than None' + self.method = method + + self.transposed = transposed + self.num_buckets = num_buckets + + if initializer is None: + def initializer(x): return None + self.initializer = initializer + + self.reset_parameters() + + self.rpe_config = rpe_config + + @torch.no_grad() + def reset_parameters(self): + # initialize the parameters of iRPE + if self.transposed: + if self.mode == 'bias': + self.lookup_table_bias = nn.Parameter( + torch.zeros(self.num_heads, self.num_buckets)) + self.initializer(self.lookup_table_bias) + elif self.mode == 'contextual': + self.lookup_table_weight = nn.Parameter( + torch.zeros(self.num_heads, + self.head_dim, self.num_buckets)) + self.initializer(self.lookup_table_weight) + else: + if self.mode == 'bias': + raise NotImplementedError( + "[Error] Bias non-transposed RPE does not exist.") + elif self.mode == 'contextual': + self.lookup_table_weight = nn.Parameter( + torch.zeros(self.num_heads, + self.num_buckets, self.head_dim)) + self.initializer(self.lookup_table_weight) + + def forward(self, x, height=None, width=None): + """forward function for iRPE. + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + + Returns + ------- + rpe_encoding: torch.Tensor + image Relative Position Encoding, + whose shape is (B, H, L, L) + """ + rp_bucket, self._ctx_rp_bucket_flatten = \ + self._get_rp_bucket(x, height=height, width=width) + + if self.transposed: + return self.forward_rpe_transpose(x, rp_bucket) + return self.forward_rpe_no_transpose(x, rp_bucket) + + def _get_rp_bucket(self, x, height=None, width=None): + """Get relative position encoding buckets IDs corresponding the input shape + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + height: int or None + [Optional] The height of the input + If not defined, height = floor(sqrt(L)) + width: int or None + [Optional] The width of the input + If not defined, width = floor(sqrt(L)) + + Returns + ------- + rp_bucket: torch.Tensor + relative position encoding buckets IDs + The shape is (L, L) + _ctx_rp_bucket_flatten: torch.Tensor or None + It is a private tensor for efficient computation. + """ + B, H, L, D = x.shape + device = x.device + if height is None: + E = int(math.sqrt(L)) + height = width = E + key = (height, width, device) + # use buffer if the spatial shape and device is not changable. + + if self._rp_bucket_buf[0] == key: + return self._rp_bucket_buf[1:3] + + skip = L - height * width + config = self.rpe_config + if RPEIndexFunction is not None and self.mode == 'contextual' and self.transposed: + # RPEIndexFunction uses int32 index. + dtype = torch.int32 + else: + dtype = torch.long + rp_bucket, num_buckets = get_bucket_ids_2d(method=self.method, + height=height, width=width, + skip=skip, alpha=config.alpha, + beta=config.beta, gamma=config.gamma, + dtype=dtype, device=device) + assert num_buckets == self.num_buckets + + # transposed contextual + _ctx_rp_bucket_flatten = None + if self.mode == 'contextual' and self.transposed: + if RPEIndexFunction is None: + offset = torch.arange(0, L * self.num_buckets, self.num_buckets, + dtype=rp_bucket.dtype, device=rp_bucket.device).view(-1, 1) + _ctx_rp_bucket_flatten = (rp_bucket + offset).flatten() + self._rp_bucket_buf = (key, rp_bucket, _ctx_rp_bucket_flatten) + return rp_bucket, _ctx_rp_bucket_flatten + + def forward_rpe_transpose(self, x, rp_bucket): + """Forward function for iRPE (transposed version) + This version is utilized by RPE on Query or Key + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + rp_bucket: torch.Tensor + relative position encoding buckets IDs + The shape is (L, L) + + Weights + ------- + lookup_table_bias: torch.Tensor + The shape is (H or 1, num_buckets) + + or + + lookup_table_weight: torch.Tensor + The shape is (H or 1, head_dim, num_buckets) + + Returns + ------- + output: torch.Tensor + Relative position encoding on queries or keys. + The shape is (B or 1, H, L, L), + where D is the output dimension for each head. + """ + + B = len(x) # batch_size + L_query, L_key = rp_bucket.shape + if self.mode == 'bias': + return self.lookup_table_bias[:, rp_bucket.flatten()]. \ + view(1, self.num_heads, L_query, L_key) + + elif self.mode == 'contextual': + """ + ret[b, h, i, j] = lookup_table_weight[b, h, i, rp_bucket[i, j]] + + ret[b, h, i * L_key + j] = \ + lookup_table[b, h, i * num_buckets + rp_buckets[i, j]] + + computational cost + ------------------ + matmul: B * H * L_query * head_dim * num_buckets + index: L_query + L_query * L_key + B * H * L_query * L_key + total: O(B * H * L_query * (head_dim * num_buckets + L_key)) + """ + lookup_table = torch.matmul( + x.transpose(0, 1).reshape(-1, B * L_query, self.head_dim), + self.lookup_table_weight). \ + view(-1, B, L_query, self.num_buckets).transpose(0, 1) + + if RPEIndexFunction is not None: + return RPEIndexFunction.apply(lookup_table, rp_bucket) + else: + return lookup_table.flatten(2)[:, :, self._ctx_rp_bucket_flatten]. \ + view(B, -1, L_query, L_key) + + def forward_rpe_no_transpose(self, x, rp_bucket): + """Forward function for iRPE (non-transposed version) + This version is utilized by RPE on Value. + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + rp_bucket: torch.Tensor + relative position encoding buckets IDs + The shape is (L, L) + + Weights + ------- + lookup_table_weight: torch.Tensor + The shape is (H or 1, num_buckets, head_dim) + + Returns + ------- + output: torch.Tensor + Relative position encoding on values. + The shape is (B, H, L, D), + where D is the output dimension for each head. + """ + + B = len(x) # batch_size + L_query, L_key = rp_bucket.shape + assert self.mode == 'contextual', "Only support contextual \ +version in non-transposed version" + weight = self.lookup_table_weight[:, rp_bucket.flatten()]. \ + view(self.num_heads, L_query, L_key, self.head_dim) + # (H, L_query, B, L_key) @ (H, L_query, L_key, D) = (H, L_query, B, D) + # -> (B, H, L_query, D) + return torch.matmul(x.permute(1, 2, 0, 3), weight).permute(2, 0, 1, 3) + + def __repr__(self): + return 'iRPE(head_dim={rpe.head_dim}, num_heads={rpe.num_heads}, \ +mode="{rpe.mode}", method={rpe.method}, transposed={rpe.transposed}, \ +num_buckets={rpe.num_buckets}, initializer={rpe.initializer}, \ +rpe_config={rpe.rpe_config})'.format(rpe=self) + + +class iRPE_Cross(nn.Module): + """The implementation of image relative position encoding (specific for Cross method). + + Parameters + ---------- + head_dim: int + The dimension for each head. + num_heads: int + The number of parallel attention heads. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + method: METHOD + The method ID of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + transposed: bool + Whether to transpose the input feature. + For iRPE on queries or keys, transposed should be `True`. + For iRPE on values, transposed should be `False`. + num_buckets: int + The number of buckets, which store encodings. + initializer: None or an inplace function + [Optional] The initializer to `lookup_table`. + Initalize `lookup_table` as zero by default. + rpe_config: RPEConfig + The config generated by the function `get_single_rpe_config`. + """ + + def __init__(self, method, **kwargs): + super().__init__() + assert method == METHOD.CROSS + self.rp_rows = iRPE(**kwargs, method=METHOD.CROSS_ROWS) + self.rp_cols = iRPE(**kwargs, method=METHOD.CROSS_COLS) + + def forward(self, x, height=None, width=None): + """forward function for iRPE. + Compute encoding on horizontal and vertical directions separately, + then summarize them. + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + height: int or None + [Optional] The height of the input + If not defined, height = floor(sqrt(L)) + width: int or None + [Optional] The width of the input + If not defined, width = floor(sqrt(L)) + + Returns + ------- + rpe_encoding: torch.Tensor + Image Relative Position Encoding, + whose shape is (B, H, L, L) + """ + + rows = self.rp_rows(x, height=height, width=width) + cols = self.rp_cols(x, height=height, width=width) + return rows + cols + + def __repr__(self): + return 'iRPE_Cross(head_dim={rpe.head_dim}, \ +num_heads={rpe.num_heads}, mode="{rpe.mode}", method={rpe.method}, \ +transposed={rpe.transposed}, num_buckets={rpe.num_buckets}, \ +initializer={rpe.initializer}, \ +rpe_config={rpe.rpe_config})'.format(rpe=self.rp_rows) + + +def get_single_rpe_config(ratio=1.9, + method=METHOD.PRODUCT, + mode='contextual', + shared_head=True, + skip=0): + """Get the config of single relative position encoding + + Parameters + ---------- + ratio: float + The ratio to control the number of buckets. + method: METHOD + The method ID of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + shared_head: bool + Whether to share weight among different heads. + skip: int + The number of skip token before spatial tokens. + When skip is 0, no classification token. + When skip is 1, there is a classification token before spatial tokens. + When skip > 1, there are `skip` extra tokens before spatial tokens. + + Returns + ------- + config: RPEConfig + The config of single relative position encoding. + """ + config = edict() + # whether to share encodings across different heads + config.shared_head = shared_head + # mode: None, bias, contextual + config.mode = mode + # method: None, Bias, Quant, Cross, Product + config.method = method + # the coefficients of piecewise index function + config.alpha = 1 * ratio + config.beta = 2 * ratio + config.gamma = 8 * ratio + + # set the number of buckets + config.num_buckets = get_num_buckets(method, + config.alpha, + config.beta, + config.gamma) + # add extra bucket for `skip` token (e.g. class token) + if skip > 0: + config.num_buckets += 1 + return config + + +def get_rpe_config(ratio=1.9, + method=METHOD.PRODUCT, + mode='contextual', + shared_head=True, + skip=0, + rpe_on='k'): + """Get the config of relative position encoding on queries, keys and values + + Parameters + ---------- + ratio: float + The ratio to control the number of buckets. + method: METHOD or str + The method ID (or name) of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + shared_head: bool + Whether to share weight among different heads. + skip: int + The number of skip token before spatial tokens. + When skip is 0, no classification token. + When skip is 1, there is a classification token before spatial tokens. + When skip > 1, there are `skip` extra tokens before spatial tokens. + rpe_on: str + Where RPE attaches. + "q": RPE on queries + "k": RPE on keys + "v": RPE on values + "qk": RPE on queries and keys + "qkv": RPE on queries, keys and values + + Returns + ------- + config: RPEConfigs + config.rpe_q: the config of relative position encoding on queries + config.rpe_k: the config of relative position encoding on keys + config.rpe_v: the config of relative position encoding on values + """ + + # alias + if isinstance(method, str): + method_mapping = dict( + euc=METHOD.EUCLIDEAN, + quant=METHOD.QUANT, + cross=METHOD.CROSS, + product=METHOD.PRODUCT, + ) + method = method_mapping[method.lower()] + if mode == 'ctx': + mode = 'contextual' + config = edict() + # relative position encoding on queries, keys and values + kwargs = dict( + ratio=ratio, + method=method, + mode=mode, + shared_head=shared_head, + skip=skip, + ) + config.rpe_q = get_single_rpe_config(**kwargs) if 'q' in rpe_on else None + config.rpe_k = get_single_rpe_config(**kwargs) if 'k' in rpe_on else None + config.rpe_v = get_single_rpe_config(**kwargs) if 'v' in rpe_on else None + return config + + +def build_rpe(config, head_dim, num_heads): + """Build iRPE modules on queries, keys and values. + + Parameters + ---------- + config: RPEConfigs + config.rpe_q: the config of relative position encoding on queries + config.rpe_k: the config of relative position encoding on keys + config.rpe_v: the config of relative position encoding on values + None when RPE is not used. + head_dim: int + The dimension for each head. + num_heads: int + The number of parallel attention heads. + + Returns + ------- + modules: a list of nn.Module + The iRPE Modules on [queries, keys, values]. + None when RPE is not used. + """ + if config is None: + return None, None, None + rpes = [config.rpe_q, config.rpe_k, config.rpe_v] + transposeds = [True, True, False] + + def _build_single_rpe(rpe, transposed): + if rpe is None: + return None + + rpe_cls = iRPE if rpe.method != METHOD.CROSS else iRPE_Cross + return rpe_cls( + head_dim=head_dim, + num_heads=1 if rpe.shared_head else num_heads, + mode=rpe.mode, + method=rpe.method, + transposed=transposed, + num_buckets=rpe.num_buckets, + rpe_config=rpe, + ) + return [_build_single_rpe(rpe, transposed) + for rpe, transposed in zip(rpes, transposeds)] + diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/README.md b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8e86b14d36bfc9b9d5f2a98cd1bf3702875189f3 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/README.md @@ -0,0 +1,6 @@ + +# Installation +``` +cd models/vit_irpe/RPE/rpe_ops +python setup.py install --user +``` \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/rpe_index.cpp b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/rpe_index.cpp new file mode 100644 index 0000000000000000000000000000000000000000..766142bd9ac31a327204a7d827477ef82100ec7a --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/rpe_index.cpp @@ -0,0 +1,142 @@ +#include + +#include +#include + +using index_t = int; + +at::Tensor rpe_index_forward_cpu(torch::Tensor input, torch::Tensor index) { + /* + - Inputs + input: float32 (B, H, L_query, num_buckets) + index: index_t (L_query, L_key) + - Outputs + Y: float32 (B, H, L_query, L_key) + */ + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(index.device().is_cpu(), "index must be a CPU tensor"); + AT_ASSERTM(input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + const index_t B = input.size(0); + const index_t H = input.size(1); + const index_t num_buckets = input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + at::Tensor Y = at::empty({B, H, L_query, L_key}, input.options()); + auto input_ = input.contiguous(); + auto index_ = index.contiguous(); + const index_t grain_size = 3000; + const index_t numel = Y.numel(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "rpe_index_forward_cpu", [&] { + const scalar_t *p_input = input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + scalar_t *p_Y = Y.data_ptr(); + at::parallel_for(0, numel, grain_size, [&](index_t begin, index_t end) { + /* + // we optimize the following function to + // reduce the number of operators, namely divide and multiply. + for (index_t i = begin; i < end; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + */ + + index_t aligned_begin = (begin + L_qk - 1) / L_qk * L_qk; + if (aligned_begin > end) aligned_begin = end; + index_t aligned_end = end / L_qk * L_qk; + for (index_t i = begin; i < aligned_begin; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + + // [aligned_begin, aligned_end) + // where aligned_begin % L_qk == 0, aligned_end % L_qk == 0 + index_t base = aligned_begin / L_key * num_buckets; + const index_t base_end = aligned_end / L_key * num_buckets; + index_t i = aligned_begin; + while (base < base_end) { + for (index_t q = 0, j = 0; q < L_query; ++q) { + for (index_t k = 0; k < L_key; ++k) { + p_Y[i++] = p_input[base + p_index[j++]]; + } + base += num_buckets; + } + } + + for (index_t i = aligned_end; i < end; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + }); + }); + return Y; +} + +template +inline scalar_t cpuAtomicAdd(scalar_t *address, const scalar_t val) { +#pragma omp critical + *address += val; + return *address; +} + +void rpe_index_backward_cpu(torch::Tensor grad_input, torch::Tensor grad_output, + torch::Tensor index) { + /* + - Inputs + grad_output: float32 (B, H, L_query, L_key) + index: index_t (L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + */ + AT_ASSERTM(grad_input.device().is_cpu(), "grad_input must be a CPU tensor"); + AT_ASSERTM(grad_output.device().is_cpu(), "grad_output must be a CPU tensor"); + AT_ASSERTM(index.device().is_cpu(), "grad_index must be a CPU tensor"); + AT_ASSERTM(grad_input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(grad_output.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + + const index_t num_buckets = grad_input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + + auto grad_input_ = grad_input.contiguous(); + auto grad_output_ = grad_output.contiguous(); + auto index_ = index.contiguous(); + + const index_t grain_size = 3000; + const index_t numel = grad_output.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_input.scalar_type(), "rpe_index_backward_atomic_cpu", [&] { + scalar_t *p_grad_input = grad_input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + const scalar_t *p_grad_output = grad_output_.data_ptr(); + at::parallel_for(0, numel, grain_size, [&](index_t begin, index_t end) { + for (index_t i = begin; i < end; ++i) { + const index_t input_i = i / L_key * num_buckets + p_index[i % L_qk]; + const scalar_t v = p_grad_output[i]; + cpuAtomicAdd(p_grad_input + input_i, v); + } + }); + }); +} + +std::string version() { + return "1.2.0"; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("version", &version, "The version of the package `rpe_index_cpp`"); + m.def("forward_cpu", &rpe_index_forward_cpu, "2D RPE Index Forward (CPU)"); + m.def("backward_cpu", &rpe_index_backward_cpu, "2D RPE Index Backward (CPU)"); + +#if defined(WITH_CUDA) + at::Tensor rpe_index_forward_gpu(torch::Tensor input, torch::Tensor index); + void rpe_index_backward_gpu(torch::Tensor grad_input, + torch::Tensor grad_output, torch::Tensor index); + m.def("forward_gpu", &rpe_index_forward_gpu, "2D RPE Index Forward (GPU)"); + m.def("backward_gpu", &rpe_index_backward_gpu, "2D RPE Index Backward (GPU)"); +#endif +} diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/rpe_index.py b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/rpe_index.py new file mode 100644 index 0000000000000000000000000000000000000000..1d915e1a56bdde6ed634ac4a8e632c6f833adc63 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/rpe_index.py @@ -0,0 +1,100 @@ +import torch +import rpe_index_cpp + + +EXPECTED_VERSION = "1.2.0" +assert rpe_index_cpp.version() == EXPECTED_VERSION, \ + f"""Unmatched `rpe_index_cpp` version: {rpe_index_cpp.version()}, expected version: {EXPECTED_VERSION} +Please re-build the package `rpe_ops`.""" + + +class RPEIndexFunction(torch.autograd.Function): + '''Y[b, h, i, j] = input[b, h, i, index[i, j]]''' + @staticmethod + def forward(ctx, input, index): + ''' + Y[b, h, i, j] = input[b, h, i, index[i, j]] + + Parameters + ---------- + input: torch.Tensor, float32 + The shape is (B, H, L_query, num_buckets) + index: torch.Tensor, int32 + The shape is (L_query, L_key) + + where B is the batch size, and H is the number of attention heads. + + Returns + ------- + Y: torch.Tensor, float32 + The shape is (B, H, L_query, L_key) + ''' + + num_buckets = input.size(-1) + ctx.save_for_backward(index) + ctx.input_shape = input.shape + forward_fn = rpe_index_cpp.forward_cpu if \ + input.device.type == 'cpu' else rpe_index_cpp.forward_gpu + output = forward_fn(input, index) + return output + + @staticmethod + def backward(ctx, grad_output): + ''' + - Inputs + grad_output: float32 (B, H, L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + ''' + index = ctx.saved_tensors[0] + if ctx.needs_input_grad[0]: + grad_input = grad_output.new_zeros(ctx.input_shape) + backward_fn = rpe_index_cpp.backward_cpu if \ + grad_output.device.type == 'cpu' else rpe_index_cpp.backward_gpu + backward_fn(grad_input, grad_output, index) + return grad_input, None + return None, None + + +if __name__ == '__main__': + import numpy as np + import time + B = 128 + H = 32 + L_query = 50 + L_key = L_query + num_buckets = 50 + + x = torch.randn(B, H, L_query, num_buckets) + + index = torch.randint(low=0, high=num_buckets, size=(L_query, L_key)) + index = index.to(torch.int) + offset = torch.arange(0, L_query * num_buckets, num_buckets).view(-1, 1) + + def test(x, index, offset): + tic = time.time() + x1 = x.clone() + x1.requires_grad = True + x2 = x.clone() + x2.requires_grad = True + + y = RPEIndexFunction.apply(x1, index) + gt_y = x2.flatten(2)[:, :, (index + offset).flatten() + ].view(B, H, L_query, L_key) + + np.testing.assert_almost_equal( + gt_y.detach().cpu().numpy(), y.detach().cpu().numpy()) + + mask = torch.randn(gt_y.shape, device=x.device) + (gt_y * mask).sum().backward() + (y * mask).sum().backward() + + print("X1:", x1.grad.cpu().numpy().flatten().sum()) + print("X2:", x2.grad.cpu().numpy().flatten().sum()) + np.testing.assert_almost_equal( + x1.grad.cpu().numpy(), x2.grad.cpu().numpy(), decimal=5) + print("Test over", x.device) + print("Cost:", time.time() - tic) + test(x, index, offset) + if torch.cuda.is_available(): + test(x.cuda(), index.cuda(), offset.cuda()) diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/rpe_index_cuda.cu b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/rpe_index_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..3ddb53158d8b232248785abe3347215dc21423b7 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/rpe_index_cuda.cu @@ -0,0 +1,140 @@ +#include +#include + +#include +#include + +using index_t = int; + +const int HIP_MAX_GRID_NUM = 65535; +const int HIP_MAX_NUM_THREADS = 512; + +inline int HIP_GET_NUM_THREADS(const int n) { + return std::min(HIP_MAX_NUM_THREADS, ((n + 31) / 32) * 32); +} + +inline int HIP_GET_BLOCKS(const int n, const int num_threads) { + return std::min(HIP_MAX_GRID_NUM, n + num_threads - 1) / num_threads; +} + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void rpe_index_forward_gpu_kernel( + index_t n, scalar_t *p_Y, const scalar_t *__restrict__ p_input, + const index_t *__restrict__ p_index, index_t num_buckets, index_t H, + index_t L_query, index_t L_key, index_t L_qk, index_t s0, index_t s1, + index_t s2, index_t s3) { + CUDA_KERNEL_LOOP(i, n) { + index_t gi = i / L_key; + const index_t qi = gi % L_query; + gi /= L_query; + const index_t hi = gi % H; + gi /= H; + const index_t bi = gi; + const index_t ind = bi * s0 + hi * s1 + qi * s2 + p_index[i % L_qk] * s3; + p_Y[i] = __ldg(&p_input[ind]); + } +} + +template +__global__ void rpe_index_backward_gpu_kernel( + index_t n, scalar_t *p_grad_input, const index_t *__restrict__ p_index, + const scalar_t *__restrict__ p_grad_output, index_t num_buckets, + index_t L_key, index_t L_qk) { + CUDA_KERNEL_LOOP(i, n) { + const index_t input_i = i / L_key * num_buckets + p_index[i % L_qk]; + const scalar_t v = p_grad_output[i]; + gpuAtomicAdd(p_grad_input + input_i, v); + } +} + +at::Tensor rpe_index_forward_gpu(torch::Tensor input, torch::Tensor index) { + /* + - Inputs + input: float32 (B, H, L_query, num_buckets) + index: index_t (L_query, L_key) + - Outputs + Y: float32 (B, H, L_query, L_key) + */ + AT_ASSERTM(input.device().is_cuda(), "input must be a GPU tensor"); + AT_ASSERTM(index.device().is_cuda(), "index must be a GPU tensor"); + AT_ASSERTM(input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + AT_ASSERTM(index.is_contiguous(), "index should be contiguous"); + const index_t B = input.size(0); + const index_t H = input.size(1); + const index_t num_buckets = input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + at::Tensor Y = at::empty({B, H, L_query, L_key}, input.options()); + const index_t numel = Y.numel(); + const at::IntArrayRef strides = input.strides(); + + const int threadsPerBlock = HIP_GET_NUM_THREADS(numel); + const int blocks = HIP_GET_BLOCKS(numel, threadsPerBlock); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "rpe_index_forward_gpu", [&] { + const scalar_t *p_input = input.data_ptr(); + const index_t *p_index = index.data_ptr(); + scalar_t *p_Y = Y.data_ptr(); + rpe_index_forward_gpu_kernel<<>>( + numel, p_Y, p_input, p_index, num_buckets, H, L_query, L_key, L_qk, + strides[0], strides[1], strides[2], strides[3]); + }); + return Y; +} + +void rpe_index_backward_gpu(torch::Tensor grad_input, torch::Tensor grad_output, + torch::Tensor index) { + /* + - Inputs + grad_output: float32 (B, H, L_query, L_key) + index: index_t (L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + */ + AT_ASSERTM(grad_input.device().is_cuda(), "grad_input must be a GPU tensor"); + AT_ASSERTM(grad_output.device().is_cuda(), + "grad_output must be a GPU tensor"); + AT_ASSERTM(index.device().is_cuda(), "grad_index must be a GPU tensor"); + AT_ASSERTM(grad_input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(grad_output.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + + const index_t num_buckets = grad_input.size(3); + const index_t L_query = grad_output.size(2); + const index_t L_key = grad_output.size(3); + const index_t L_qk = L_query * L_key; + + auto grad_input_ = grad_input.contiguous(); + auto grad_output_ = grad_output.contiguous(); + auto index_ = index.contiguous(); + + const index_t numel = grad_output.numel(); + + const int threadsPerBlock = HIP_GET_NUM_THREADS(numel); + const int blocks = HIP_GET_BLOCKS(numel, threadsPerBlock); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "rpe_index_backward_gpu", [&] { + scalar_t *p_grad_input = grad_input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + const scalar_t *p_grad_output = grad_output_.data_ptr(); + rpe_index_backward_gpu_kernel<<>>( + numel, p_grad_input, p_index, p_grad_output, num_buckets, L_key, + L_qk); + }); +} diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/setup.py b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c145714e9595b0364f070babe3ecf0b926a01ec6 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/RPE/rpe_ops/setup.py @@ -0,0 +1,24 @@ +from setuptools import setup, Extension +import torch +from torch.utils import cpp_extension + +ext_t = cpp_extension.CppExtension +ext_fnames = ['rpe_index.cpp'] +define_macros = [] +extra_compile_args = dict(cxx=['-fopenmp', '-O3'], + nvcc=['-O3']) + +if torch.cuda.is_available(): + ext_t = cpp_extension.CUDAExtension + ext_fnames.append('rpe_index_cuda.cu') + define_macros.append(('WITH_CUDA', None)) + +setup(name='rpe_index', + version="1.2.0", + ext_modules=[ext_t( + 'rpe_index_cpp', + ext_fnames, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + )], + cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/__init__.py b/cvlface/research/recognition/code/run_v1/models/vit_irpe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d314c2ae68f189a5086fb1c72bfa9319fc703db --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/__init__.py @@ -0,0 +1,67 @@ +from ..base import BaseModel +from .vit import VisionTransformerWithiRPE +from torchvision import transforms + + +class ViTiRPEModel(BaseModel): + + + """ + Vision Transformer for face recognition model with image Relative Position Encoding (ViT-iRPE) model. + + ``` + @inproceedings{wu2021rethinking, + title={Rethinking and improving relative position encoding for vision transformer}, + author={Wu, Kan and Peng, Houwen and Chen, Minghao and Fu, Jianlong and Chao, Hongyang}, + booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, + pages={10033--10041}, + year={2021} + } + ``` + """ + + def __init__(self, net, config): + super(ViTiRPEModel, self).__init__(config) + self.net = net + + + @classmethod + def from_config(cls, config): + + if config.name == 'small': + net = VisionTransformerWithiRPE(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=12, + mlp_ratio=5, num_heads=8, drop_path_rate=0.1, norm_layer="ln", + mask_ratio=config.mask_ratio, rpe_config=config.rpe_config) + elif config.name == 'base': + net = VisionTransformerWithiRPE(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=24, + mlp_ratio=3, num_heads=16, drop_path_rate=0.1, norm_layer="ln", + mask_ratio=config.mask_ratio, rpe_config=config.rpe_config) + else: + raise NotImplementedError + + model = cls(net, config) + model.eval() + return model + + def forward(self, x, *args, **kwargs): + if self.input_color_flip: + x = x.flip(1) + return self.net(x, *args, **kwargs) + + def make_train_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + def make_test_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + +def load_model(model_config): + model = ViTiRPEModel.from_config(model_config) + return model \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/configs/v1_base_irpe.yaml b/cvlface/research/recognition/code/run_v1/models/vit_irpe/configs/v1_base_irpe.yaml new file mode 100644 index 0000000000000000000000000000000000000000..414a9b4f652d3c606676b4f3cd5f994bb03f24cd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/configs/v1_base_irpe.yaml @@ -0,0 +1,17 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'base' +output_dim: 512 +start_from: '' +freeze: False + +mask_ratio: 0.0 +rpe_config: + name: iRPE + rpe_on: qkv + shared_head: True + mode: ctx + method: product + ratio: 1.9 + ctx_type: '' + num_keypoints: 5 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit_irpe/vit.py b/cvlface/research/recognition/code/run_v1/models/vit_irpe/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..8d0246a7f0f6d57224e576be13d36b4ca150c457 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_irpe/vit.py @@ -0,0 +1,300 @@ +import torch +import torch.nn as nn +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from typing import Optional, Callable +from .RPE import build_rpe + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class VITBatchNorm(nn.Module): + def __init__(self, num_features): + super().__init__() + self.num_features = num_features + self.bn = nn.BatchNorm1d(num_features=num_features) + + def forward(self, x): + return self.bn(x) + + +class Attention(nn.Module): + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + attn_drop: float = 0., + proj_drop: float = 0., + rpe_config=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + # image relative position encoding + self.rpe_config = rpe_config + self.rpe_q, self.rpe_k, self.rpe_v = build_rpe(rpe_config, head_dim=head_dim, num_heads=num_heads) + + + def forward(self, x): + + batch_size, num_token, embed_dim = x.shape + #qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads] + qkv = self.qkv(x).reshape( + batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q *= self.scale + attn = (q @ k.transpose(-2, -1)) + + # image relative position on keys + if self.rpe_k is not None: + attn += self.rpe_k(q) + + # image relative position on queries + if self.rpe_q is not None: + attn += self.rpe_q(k * self.scale).transpose(2, 3) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + out = attn @ v + + # image relative position on values + if self.rpe_v is not None: + out += self.rpe_v(attn) + + x = out.transpose(1, 2).reshape(batch_size, num_token, embed_dim) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, + dim: int, + num_heads: int, + num_patches: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.ReLU6, + norm_layer: str = "ln", + patch_n: int = 144, + rpe_config=None): + super().__init__() + + if norm_layer == "bn": + self.norm1 = VITBatchNorm(num_features=num_patches) + self.norm2 = VITBatchNorm(num_features=num_patches) + elif norm_layer == "ln": + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + rpe_config=rpe_config) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.extra_gflops = (num_heads * patch_n * (dim//num_heads)*patch_n * 2) / (1000**3) + + def forward(self, x): + norm_x = self.norm1(x) + attn_out = self.attn(norm_x) + x = x + self.drop_path(attn_out) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d(in_channels, embed_dim, + kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + batch_size, channels, height, width = x.shape + assert height == self.img_size[0] and width == self.img_size[1], \ + f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformerWithiRPE(nn.Module): + + def __init__(self, + img_size: int = 112, + patch_size: int = 16, + in_channels: int = 3, + num_classes: int = 1000, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_patches: Optional[int] = None, + norm_layer: str = "ln", + mask_ratio = 0.1, + using_checkpoint = False, + rpe_config=None, + ): + super().__init__() + self.num_classes = num_classes + # num_features for consistency with other models + self.num_features = self.embed_dim = embed_dim + + if num_patches is not None: + self.patch_embed = nn.Identity() + else: + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + self.mask_ratio = mask_ratio + self.using_checkpoint = using_checkpoint + + self.num_patches = num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + patch_n = (img_size//patch_size)**2 + self.blocks = nn.ModuleList( + [ + Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + num_patches=num_patches, patch_n=patch_n, rpe_config=rpe_config) + for i in range(depth)] + ) + self.extra_gflops = 0.0 + for _block in self.blocks: + self.extra_gflops += _block.extra_gflops + + if norm_layer == "ln": + self.norm = nn.LayerNorm(embed_dim) + elif norm_layer == "bn": + self.norm = VITBatchNorm(self.num_patches) + + # features head + self.feature = nn.Sequential( + nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False), + nn.BatchNorm1d(num_features=embed_dim, eps=2e-5), + nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False), + nn.BatchNorm1d(num_features=num_classes, eps=2e-5) + ) + + if self.mask_ratio == 0: + pass + else: + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + torch.nn.init.normal_(self.mask_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + # trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + self.num_heads = num_heads + self.depth = depth + + self.rpe_config = rpe_config + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def random_masking(self, x, mask_ratio=0.1): + N, L, D = x.size() # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + # ascend: small is keep, large is remove + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + index = ids_keep.unsqueeze(-1).repeat(1, 1, D) + x_masked = torch.gather(x, dim=1, index=index) + + return x_masked, index, ids_restore + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + if self.training and self.mask_ratio > 0: + x, _, ids_restore = self.random_masking(x) + + for block_idx, func in enumerate(self.blocks): + if self.using_checkpoint and self.training: + from torch.utils.checkpoint import checkpoint + x = checkpoint(func, x) + else: + x = func(x) + x = self.norm(x.float()) + if self.training and self.mask_ratio > 0: + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = x_ + return torch.reshape(x, (B, self.num_patches * self.embed_dim)) + + def forward(self, x): + x = self.forward_features(x) + x = self.feature(x) + return x + diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/KPRPE/dist.py b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/KPRPE/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..986e5400d3aba516db5d69fbbcdbd35f572e212e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/KPRPE/dist.py @@ -0,0 +1,162 @@ +import torch +import math + + +@torch.no_grad() +def piecewise_index(relative_position, alpha, beta, gamma, dtype): + """piecewise index function defined in Eq. (18) in our paper. + + Parameters + ---------- + relative_position: torch.Tensor, dtype: long or float + The shape of `relative_position` is (L, L). + alpha, beta, gamma: float + The coefficients of piecewise index function. + + Returns + ------- + idx: torch.Tensor, dtype: long + A tensor indexing relative distances to corresponding encodings. + `idx` is a long tensor, whose shape is (L, L) and each element is in [-beta, beta]. + """ + rp_abs = relative_position.abs() + mask = rp_abs <= alpha + not_mask = ~mask + rp_out = relative_position[not_mask] + rp_abs_out = rp_abs[not_mask] + y_out = (torch.sign(rp_out) * (alpha + + torch.log(rp_abs_out / alpha) / + math.log(gamma / alpha) * + (beta - alpha)).round().clip(max=beta)).to(dtype) + + idx = relative_position.clone() + if idx.dtype in [torch.float32, torch.float64]: + # round(x) when |x| <= alpha + idx = idx.round().to(dtype) + + # assign the value when |x| > alpha + idx[not_mask] = y_out + return idx + + +@torch.no_grad() +def _rp_2d_euclidean(diff, **kwargs): + """2D RPE with Euclidean method. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + dis = diff.square().sum(2).float().sqrt().round() + return piecewise_index(dis, **kwargs) + + +@torch.no_grad() +def _rp_2d_quant(diff, **kwargs): + """2D RPE with Quantization method. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + + dis = diff.square().sum(2) + return piecewise_index(dis, **kwargs) + + +@torch.no_grad() +def _rp_2d_product(diff, **kwargs): + """2D RPE with Product method. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + # convert beta to an integer since beta is a float number. + beta_int = int(kwargs['beta']) + S = 2 * beta_int + 1 + # the output of piecewise index function is in [-beta_int, beta_int] + r = piecewise_index(diff[:, :, 0], **kwargs) + \ + beta_int # [0, 2 * beta_int] + c = piecewise_index(diff[:, :, 1], **kwargs) + \ + beta_int # [0, 2 * beta_int] + + pid = r * S + c + + return pid + + +@torch.no_grad() +def _rp_2d_cross_rows(diff, **kwargs): + """2D RPE with Cross for rows. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + dis = diff[:, :, 0] + return piecewise_index(dis, **kwargs) + + +@torch.no_grad() +def _rp_2d_cross_cols(diff, **kwargs): + """2D RPE with Cross for columns. + + Parameters + ---------- + diff: torch.Tensor + The shape of `diff` is (L, L, 2), + where L is the sequence length, + and 2 represents a 2D offset (row_offset, col_offset). + + Returns + ------- + index: torch.Tensor, dtype: long + index to corresponding encodings. + The shape of `index` is (L, L), + where L is the sequence length. + """ + + dis = diff[:, :, 1] + return piecewise_index(dis, **kwargs) + diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/KPRPE/kprpe_shared.py b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/KPRPE/kprpe_shared.py new file mode 100644 index 0000000000000000000000000000000000000000..db29bdcc11135733e0aae90ebc8975cd01e10638 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/KPRPE/kprpe_shared.py @@ -0,0 +1,735 @@ +from easydict import EasyDict as edict +import math +import torch +import torch.nn as nn +from .dist import _rp_2d_cross_cols, _rp_2d_cross_rows, _rp_2d_euclidean, _rp_2d_product, _rp_2d_quant + +try: + from ..rpe_ops.rpe_index import RPEIndexFunction +except Exception as e: + print('Failed to import cuda/cpp RPEIndexFunction') + RPEIndexFunction = None + + + +def get_absolute_positions(height, width, dtype, device): + '''Get absolute positions + + Take height = 3, width = 3 as an example: + rows: cols: + 1 1 1 1 2 3 + 2 2 2 1 2 3 + 3 3 3 1 2 3 + + return stack([rows, cols], 2) + + Parameters + ---------- + height, width: int + The height and width of feature map + dtype: torch.dtype + the data type of returned value + device: torch.device + the device of returned value + + Return + ------ + 2D absolute positions: torch.Tensor + The shape is (height, width, 2), + where 2 represents a 2D position (row, col). + ''' + rows = torch.arange(height, dtype=dtype, device=device).view( + height, 1).repeat(1, width) + cols = torch.arange(width, dtype=dtype, device=device).view( + 1, width).repeat(height, 1) + return torch.stack([rows, cols], 2) + + +class METHOD: + """define iRPE method IDs + We divide the implementation of CROSS into CROSS_ROWS and CROSS_COLS. + + """ + EUCLIDEAN = 0 + QUANT = 1 + PRODUCT = 3 + CROSS = 4 + CROSS_ROWS = 41 + CROSS_COLS = 42 + + +# Define a mapping from METHOD_ID to Python function +_METHOD_FUNC = { + METHOD.EUCLIDEAN: _rp_2d_euclidean, + METHOD.QUANT: _rp_2d_quant, + METHOD.PRODUCT: _rp_2d_product, + METHOD.CROSS_ROWS: _rp_2d_cross_rows, + METHOD.CROSS_COLS: _rp_2d_cross_cols, +} + + +def get_num_buckets(method, alpha, beta, gamma): + """ Get number of buckets storing relative position encoding. + The buckets does not contain `skip` token. + + Parameters + ---------- + method: METHOD + The method ID of image relative position encoding. + alpha, beta, gamma: float + The coefficients of piecewise index function. + + Returns + ------- + num_buckets: int + The number of buckets storing relative position encoding. + """ + beta_int = int(beta) + if method == METHOD.PRODUCT: + # IDs in [0, (2 * beta_int + 1)^2) for Product method + num_buckets = (2 * beta_int + 1) ** 2 + else: + # IDs in [-beta_int, beta_int] except of Product method + num_buckets = 2 * beta_int + 1 + return num_buckets + + +# (method, alpha, beta, gamma) -> (bucket_ids, num_buckets, height, width) +BUCKET_IDS_BUF = dict() + + +@torch.no_grad() +def get_bucket_ids_2d_without_skip(method, height, width, + alpha, beta, gamma, + dtype=torch.long, device=torch.device('cpu')): + """Get bucket IDs for image relative position encodings without skip token + + Parameters + ---------- + method: METHOD + The method ID of image relative position encoding. + height, width: int + The height and width of the feature map. + The sequence length is equal to `height * width`. + alpha, beta, gamma: float + The coefficients of piecewise index function. + dtype: torch.dtype + the data type of returned `bucket_ids` + device: torch.device + the device of returned `bucket_ids` + + Returns + ------- + bucket_ids: torch.Tensor, dtype: long + The bucket IDs which index to corresponding encodings. + The shape of `bucket_ids` is (skip + L, skip + L), + where `L = height * wdith`. + num_buckets: int + The number of buckets including `skip` token. + L: int + The sequence length + """ + + key = (method, alpha, beta, gamma, dtype, device) + value = BUCKET_IDS_BUF.get(key, None) + if value is None or value[-2] < height or value[-1] < width: + if value is None: + max_height, max_width = height, width + else: + max_height = max(value[-2], height) + max_width = max(value[-1], width) + # relative position encoding mapping function + func = _METHOD_FUNC.get(method, None) + if func is None: + raise NotImplementedError( + f"[Error] The method ID {method} does not exist.") + pos = get_absolute_positions(max_height, max_width, dtype, device) + + # compute the offset of a pair of 2D relative positions + max_L = max_height * max_width + pos1 = pos.view((max_L, 1, 2)) + pos2 = pos.view((1, max_L, 2)) + # diff: shape of (L, L, 2) + diff = pos1 - pos2 + + # bucket_ids: shape of (L, L) + bucket_ids = func(diff, alpha=alpha, beta=beta, + gamma=gamma, dtype=dtype) + beta_int = int(beta) + if method != METHOD.PRODUCT: + bucket_ids += beta_int + bucket_ids = bucket_ids.view( + max_height, max_width, max_height, max_width) + + num_buckets = get_num_buckets(method, alpha, beta, gamma) + value = (bucket_ids, num_buckets, height, width) + BUCKET_IDS_BUF[key] = value + L = height * width + bucket_ids = value[0][:height, :width, :height, :width].reshape(L, L) + num_buckets = value[1] + + return bucket_ids, num_buckets, L + + +@torch.no_grad() +def get_bucket_ids_2d(method, height, width, + skip, alpha, beta, gamma, + dtype=torch.long, device=torch.device('cpu')): + """Get bucket IDs for image relative position encodings + + Parameters + ---------- + method: METHOD + The method ID of image relative position encoding. + height, width: int + The height and width of the feature map. + The sequence length is equal to `height * width`. + skip: int + The number of skip token before spatial tokens. + When skip is 0, no classification token. + When skip is 1, there is a classification token before spatial tokens. + When skip > 1, there are `skip` extra tokens before spatial tokens. + alpha, beta, gamma: float + The coefficients of piecewise index function. + dtype: torch.dtype + the data type of returned `bucket_ids` + device: torch.device + the device of returned `bucket_ids` + + Returns + ------- + bucket_ids: torch.Tensor, dtype: long + The bucket IDs which index to corresponding encodings. + The shape of `bucket_ids` is (skip + L, skip + L), + where `L = height * wdith`. + num_buckets: int + The number of buckets including `skip` token. + """ + bucket_ids, num_buckets, L = get_bucket_ids_2d_without_skip(method, height, width, + alpha, beta, gamma, + dtype, device) + + # add an extra encoding (id = num_buckets) for the classification token + if skip > 0: + new_bids = bucket_ids.new_empty(size=(skip + L, skip + L)) + + # if extra token exists, we add extra bucket as its encoding. + extra_bucket_id = num_buckets + num_buckets += 1 + + new_bids[:skip] = extra_bucket_id + new_bids[:, :skip] = extra_bucket_id + new_bids[skip:, skip:] = bucket_ids + + bucket_ids = new_bids + bucket_ids = bucket_ids.contiguous() + return bucket_ids, num_buckets + + +class iRPE(nn.Module): + """The implementation of image relative position encoding (excluding Cross method). + + Parameters + ---------- + head_dim: int + The dimension for each head. + num_heads: int + The number of parallel attention heads. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + method: METHOD + The method ID of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + transposed: bool + Whether to transpose the input feature. + For iRPE on queries or keys, transposed should be `True`. + For iRPE on values, transposed should be `False`. + num_buckets: int + The number of buckets, which store encodings. + initializer: None or an inplace function + [Optional] The initializer to `lookup_table`. + Initalize `lookup_table` as zero by default. + rpe_config: RPEConfig + The config generated by the function `get_single_rpe_config`. + """ + # a buffer to store bucket index + # (key, rp_bucket, _ctx_rp_bucket_flatten) + _rp_bucket_buf = (None, None, None) + + def __init__(self, head_dim, num_heads=8, + mode=None, method=None, + transposed=True, num_buckets=None, + initializer=None, rpe_config=None): + super().__init__() + self.num_heads = num_heads + self.head_dim = head_dim + + # relative position + assert mode in [None, 'bias', 'contextual'] + self.mode = mode + + assert method is not None, 'method should be a METHOD ID rather than None' + self.method = method + + self.transposed = transposed + self.num_buckets = num_buckets + + if initializer is None: + def initializer(x): return None + self.initializer = initializer + + self.reset_parameters() + + self.rpe_config = rpe_config + + @torch.no_grad() + def reset_parameters(self): + # initialize the parameters of iRPE + if self.transposed: + if self.mode == 'bias': + self.lookup_table_bias = nn.Parameter( + torch.zeros(self.num_heads, self.num_buckets)) + self.initializer(self.lookup_table_bias) + elif self.mode == 'contextual': + # shared and initialized from vit + pass + else: + if self.mode == 'bias': + raise NotImplementedError( + "[Error] Bias non-transposed RPE does not exist.") + elif self.mode == 'contextual': + raise ValueError('may not work, check') + + def forward(self, x, height=None, width=None): + """forward function for iRPE. + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + + Returns + ------- + rpe_encoding: torch.Tensor + image Relative Position Encoding, + whose shape is (B, H, L, L) + """ + rp_bucket, self._ctx_rp_bucket_flatten = \ + self._get_rp_bucket(x, height=height, width=width) + + if self.transposed: + return self.forward_rpe_transpose(x, rp_bucket) + return self.forward_rpe_no_transpose(x, rp_bucket) + + def _get_rp_bucket(self, x, height=None, width=None): + """Get relative position encoding buckets IDs corresponding the input shape + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + height: int or None + [Optional] The height of the input + If not defined, height = floor(sqrt(L)) + width: int or None + [Optional] The width of the input + If not defined, width = floor(sqrt(L)) + + Returns + ------- + rp_bucket: torch.Tensor + relative position encoding buckets IDs + The shape is (L, L) + _ctx_rp_bucket_flatten: torch.Tensor or None + It is a private tensor for efficient computation. + """ + B, H, L, D = x.shape + device = x.device + if height is None: + E = int(math.sqrt(L)) + height = width = E + key = (height, width, device) + # use buffer if the spatial shape and device is not changable. + + if self._rp_bucket_buf[0] == key: + return self._rp_bucket_buf[1:3] + + skip = L - height * width + config = self.rpe_config + if RPEIndexFunction is not None and self.mode == 'contextual' and self.transposed: + # RPEIndexFunction uses int32 index. + dtype = torch.int32 + else: + dtype = torch.long + rp_bucket, num_buckets = get_bucket_ids_2d(method=self.method, + height=height, width=width, + skip=skip, alpha=config.alpha, + beta=config.beta, gamma=config.gamma, + dtype=dtype, device=device) + assert num_buckets == self.num_buckets + + # transposed contextual + _ctx_rp_bucket_flatten = None + if self.mode == 'contextual' and self.transposed: + if RPEIndexFunction is None: + offset = torch.arange(0, L * self.num_buckets, self.num_buckets, + dtype=rp_bucket.dtype, device=rp_bucket.device).view(-1, 1) + _ctx_rp_bucket_flatten = (rp_bucket + offset).flatten() + self._rp_bucket_buf = (key, rp_bucket, _ctx_rp_bucket_flatten) + return rp_bucket, _ctx_rp_bucket_flatten + + def forward_rpe_transpose(self, x, rp_bucket): + """Forward function for iRPE (transposed version) + This version is utilized by RPE on Query or Key + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + rp_bucket: torch.Tensor + relative position encoding buckets IDs + The shape is (L, L) + + Weights + ------- + lookup_table_bias: torch.Tensor + The shape is (H or 1, num_buckets) + + or + + lookup_table_weight: torch.Tensor + The shape is (H or 1, head_dim, num_buckets) + + Returns + ------- + output: torch.Tensor + Relative position encoding on queries or keys. + The shape is (B or 1, H, L, L), + where D is the output dimension for each head. + """ + + B = len(x) # batch_size + L_query, L_key = rp_bucket.shape + if self.mode == 'bias': + return self.lookup_table_bias[:, rp_bucket.flatten()]. \ + view(1, self.num_heads, L_query, L_key) + + elif self.mode == 'contextual': + """ + ret[b, h, i, j] = lookup_table_weight[b, h, i, rp_bucket[i, j]] + + ret[b, h, i * L_key + j] = \ + lookup_table[b, h, i * num_buckets + rp_buckets[i, j]] + + computational cost + ------------------ + matmul: B * H * L_query * head_dim * num_buckets + index: L_query + L_query * L_key + B * H * L_query * L_key + total: O(B * H * L_query * (head_dim * num_buckets + L_key)) + """ + if RPEIndexFunction is not None: + return RPEIndexFunction.apply(x, rp_bucket) + else: + return x.flatten(2)[:, :, self._ctx_rp_bucket_flatten]. \ + view(B, -1, L_query, L_key) + + def forward_rpe_no_transpose(self, x, rp_bucket): + """Forward function for iRPE (non-transposed version) + This version is utilized by RPE on Value. + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + rp_bucket: torch.Tensor + relative position encoding buckets IDs + The shape is (L, L) + + Weights + ------- + lookup_table_weight: torch.Tensor + The shape is (H or 1, num_buckets, head_dim) + + Returns + ------- + output: torch.Tensor + Relative position encoding on values. + The shape is (B, H, L, D), + where D is the output dimension for each head. + """ + + B = len(x) # batch_size + L_query, L_key = rp_bucket.shape + assert self.mode == 'contextual', "Only support contextual \ +version in non-transposed version" + weight = self.lookup_table_weight[:, rp_bucket.flatten()]. \ + view(self.num_heads, L_query, L_key, self.head_dim) + # (H, L_query, B, L_key) @ (H, L_query, L_key, D) = (H, L_query, B, D) + # -> (B, H, L_query, D) + return torch.matmul(x.permute(1, 2, 0, 3), weight).permute(2, 0, 1, 3) + + def __repr__(self): + return 'iRPE(head_dim={rpe.head_dim}, num_heads={rpe.num_heads}, \ +mode="{rpe.mode}", method={rpe.method}, transposed={rpe.transposed}, \ +num_buckets={rpe.num_buckets}, initializer={rpe.initializer}, \ +rpe_config={rpe.rpe_config})'.format(rpe=self) + + +class iRPE_Cross(nn.Module): + """The implementation of image relative position encoding (specific for Cross method). + + Parameters + ---------- + head_dim: int + The dimension for each head. + num_heads: int + The number of parallel attention heads. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + method: METHOD + The method ID of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + transposed: bool + Whether to transpose the input feature. + For iRPE on queries or keys, transposed should be `True`. + For iRPE on values, transposed should be `False`. + num_buckets: int + The number of buckets, which store encodings. + initializer: None or an inplace function + [Optional] The initializer to `lookup_table`. + Initalize `lookup_table` as zero by default. + rpe_config: RPEConfig + The config generated by the function `get_single_rpe_config`. + """ + + def __init__(self, method, **kwargs): + super().__init__() + assert method == METHOD.CROSS + self.rp_rows = iRPE(**kwargs, method=METHOD.CROSS_ROWS) + self.rp_cols = iRPE(**kwargs, method=METHOD.CROSS_COLS) + + def forward(self, x, height=None, width=None): + """forward function for iRPE. + Compute encoding on horizontal and vertical directions separately, + then summarize them. + + Parameters + ---------- + x: torch.Tensor + Input Tensor whose shape is (B, H, L, head_dim), + where B is batch size, + H is the number of heads, + L is the sequence length, + equal to height * width (+1 if class token exists) + head_dim is the dimension of each head + height: int or None + [Optional] The height of the input + If not defined, height = floor(sqrt(L)) + width: int or None + [Optional] The width of the input + If not defined, width = floor(sqrt(L)) + + Returns + ------- + rpe_encoding: torch.Tensor + Image Relative Position Encoding, + whose shape is (B, H, L, L) + """ + + rows = self.rp_rows(x, height=height, width=width) + cols = self.rp_cols(x, height=height, width=width) + return rows + cols + + def __repr__(self): + return 'iRPE_Cross(head_dim={rpe.head_dim}, \ +num_heads={rpe.num_heads}, mode="{rpe.mode}", method={rpe.method}, \ +transposed={rpe.transposed}, num_buckets={rpe.num_buckets}, \ +initializer={rpe.initializer}, \ +rpe_config={rpe.rpe_config})'.format(rpe=self.rp_rows) + + +def get_single_rpe_config(ratio=1.9, + method=METHOD.PRODUCT, + mode='contextual', + shared_head=True, + skip=0): + """Get the config of single relative position encoding + + Parameters + ---------- + ratio: float + The ratio to control the number of buckets. + method: METHOD + The method ID of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + shared_head: bool + Whether to share weight among different heads. + skip: int + The number of skip token before spatial tokens. + When skip is 0, no classification token. + When skip is 1, there is a classification token before spatial tokens. + When skip > 1, there are `skip` extra tokens before spatial tokens. + + Returns + ------- + config: RPEConfig + The config of single relative position encoding. + """ + config = edict() + # whether to share encodings across different heads + config.shared_head = shared_head + # mode: None, bias, contextual + config.mode = mode + # method: None, Bias, Quant, Cross, Product + config.method = method + # the coefficients of piecewise index function + config.alpha = 1 * ratio + config.beta = 2 * ratio + config.gamma = 8 * ratio + + # set the number of buckets + config.num_buckets = get_num_buckets(method, + config.alpha, + config.beta, + config.gamma) + # add extra bucket for `skip` token (e.g. class token) + if skip > 0: + config.num_buckets += 1 + return config + + +def get_rpe_config(ratio=1.9, + method=METHOD.PRODUCT, + mode='contextual', + shared_head=True, + skip=0, + rpe_on='k'): + """Get the config of relative position encoding on queries, keys and values + + Parameters + ---------- + ratio: float + The ratio to control the number of buckets. + method: METHOD or str + The method ID (or name) of image relative position encoding. + The `METHOD` class is defined in `irpe.py`. + mode: str or None + The mode of image relative position encoding. + Choices: [None, 'bias', 'contextual'] + shared_head: bool + Whether to share weight among different heads. + skip: int + The number of skip token before spatial tokens. + When skip is 0, no classification token. + When skip is 1, there is a classification token before spatial tokens. + When skip > 1, there are `skip` extra tokens before spatial tokens. + rpe_on: str + Where RPE attaches. + "q": RPE on queries + "k": RPE on keys + "v": RPE on values + "qk": RPE on queries and keys + "qkv": RPE on queries, keys and values + + Returns + ------- + config: RPEConfigs + config.rpe_q: the config of relative position encoding on queries + config.rpe_k: the config of relative position encoding on keys + config.rpe_v: the config of relative position encoding on values + """ + + # alias + if isinstance(method, str): + method_mapping = dict( + euc=METHOD.EUCLIDEAN, + quant=METHOD.QUANT, + cross=METHOD.CROSS, + product=METHOD.PRODUCT, + ) + method = method_mapping[method.lower()] + if mode == 'ctx': + mode = 'contextual' + config = edict() + # relative position encoding on queries, keys and values + kwargs = dict( + ratio=ratio, + method=method, + mode=mode, + shared_head=shared_head, + skip=skip, + ) + config.rpe_q = get_single_rpe_config(**kwargs) if 'q' in rpe_on else None + config.rpe_k = get_single_rpe_config(**kwargs) if 'k' in rpe_on else None + config.rpe_v = get_single_rpe_config(**kwargs) if 'v' in rpe_on else None + return config + + +def build_rpe(config, head_dim, num_heads): + """Build iRPE modules on queries, keys and values. + + Parameters + ---------- + config: RPEConfigs + config.rpe_q: the config of relative position encoding on queries + config.rpe_k: the config of relative position encoding on keys + config.rpe_v: the config of relative position encoding on values + None when RPE is not used. + head_dim: int + The dimension for each head. + num_heads: int + The number of parallel attention heads. + + Returns + ------- + modules: a list of nn.Module + The iRPE Modules on [queries, keys, values]. + None when RPE is not used. + """ + if config is None: + return None, None, None + rpes = [config.rpe_q, config.rpe_k, config.rpe_v] + transposeds = [True, True, False] + + def _build_single_rpe(rpe, transposed): + if rpe is None: + return None + + rpe_cls = iRPE if rpe.method != METHOD.CROSS else iRPE_Cross + return rpe_cls( + head_dim=head_dim, + num_heads=1 if rpe.shared_head else num_heads, + mode=rpe.mode, + method=rpe.method, + transposed=transposed, + num_buckets=rpe.num_buckets, + rpe_config=rpe, + ) + return [_build_single_rpe(rpe, transposed) + for rpe, transposed in zip(rpes, transposeds)] + diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/KPRPE/relative_keypoints.py b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/KPRPE/relative_keypoints.py new file mode 100644 index 0000000000000000000000000000000000000000..80509b0091c878066a33e71bbe7e25bfc3835409 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/KPRPE/relative_keypoints.py @@ -0,0 +1,19 @@ +import torch +import math + +@torch.no_grad() +def make_rel_keypoints(keyponints, query): + seq_length = query.shape[1] + side = int(math.sqrt(seq_length)) + assert side == math.sqrt(seq_length) + + # make a grid of points from 0 to 1 + coord = torch.linspace(0, 1, side+1, device=query.device, dtype=query.dtype) + coord = (coord[:-1] + coord[1:]) / 2 # get center of patches + + x, y = torch.meshgrid(coord, coord, indexing='ij') + grid = torch.stack([y, x], dim=-1).reshape(-1, 2).unsqueeze(0).unsqueeze(-2) # BxNx1x2 + _keyponints = keyponints.unsqueeze(-3) # Bx1x5x2 + diff = (grid - _keyponints) # BxNx5x2 + diff = diff.flatten(2) # BxNx10 + return diff \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/__init__.py b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7f1ae81e8c4efd8c431e0deae12a0093115055e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/__init__.py @@ -0,0 +1,48 @@ +from .KPRPE import kprpe_shared +import torch +import warnings +import subprocess +import sys +import os + +try: + from .rpe_ops.rpe_index import RPEIndexFunction +except ImportError: + try: + # Attempt to install the module from the setup.py script + dirname = os.path.dirname(os.path.abspath(__file__)) + cwd = os.getcwd() + os.chdir(os.path.join(dirname, 'rpe_ops')) + subprocess.check_call([sys.executable, 'setup.py', 'install', '--user']) + GREEN_STR = "\033[92m{}\033[00m" + print(GREEN_STR.format("\n[INFO] Successfully installed `rpe_ops`. Restart Application"),) + sys.exit() + except subprocess.CalledProcessError as install_error: + RED_STR = "\033[91m{}\033[00m" + warnings.warn(RED_STR.format("\n[WARNING] Failed to install `rpe_ops`. " + "Please check the installation script."),) + except ImportError as import_error: + RED_STR = "\033[91m{}\033[00m" + warnings.warn(RED_STR.format("\n[WARNING] The module `rpe_ops` is not built. " + "For better training performance, please build `rpe_ops`."),) + + +def build_rpe(rpe_config, head_dim, num_heads): + if rpe_config is None: + return None + else: + name = rpe_config.name + if name == 'KPRPE_shared': + rpe_config = kprpe_shared.get_rpe_config( + ratio=rpe_config.ratio, + method=rpe_config.method, + mode=rpe_config.mode, + shared_head=rpe_config.shared_head, + skip=0, + rpe_on=rpe_config.rpe_on, + ) + return kprpe_shared.build_rpe(rpe_config, head_dim=head_dim, num_heads=num_heads) + + else: + raise NotImplementedError(f"Unknow RPE: {name}") + diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/README.md b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d3140f928917a37f464c0e96e62cc5383e242d8b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/README.md @@ -0,0 +1,6 @@ + +# Installation +``` +cd models/vit_kprpe/RPE/rpe_ops +python setup.py install --user +``` \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/rpe_index.cpp b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/rpe_index.cpp new file mode 100644 index 0000000000000000000000000000000000000000..766142bd9ac31a327204a7d827477ef82100ec7a --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/rpe_index.cpp @@ -0,0 +1,142 @@ +#include + +#include +#include + +using index_t = int; + +at::Tensor rpe_index_forward_cpu(torch::Tensor input, torch::Tensor index) { + /* + - Inputs + input: float32 (B, H, L_query, num_buckets) + index: index_t (L_query, L_key) + - Outputs + Y: float32 (B, H, L_query, L_key) + */ + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(index.device().is_cpu(), "index must be a CPU tensor"); + AT_ASSERTM(input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + const index_t B = input.size(0); + const index_t H = input.size(1); + const index_t num_buckets = input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + at::Tensor Y = at::empty({B, H, L_query, L_key}, input.options()); + auto input_ = input.contiguous(); + auto index_ = index.contiguous(); + const index_t grain_size = 3000; + const index_t numel = Y.numel(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "rpe_index_forward_cpu", [&] { + const scalar_t *p_input = input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + scalar_t *p_Y = Y.data_ptr(); + at::parallel_for(0, numel, grain_size, [&](index_t begin, index_t end) { + /* + // we optimize the following function to + // reduce the number of operators, namely divide and multiply. + for (index_t i = begin; i < end; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + */ + + index_t aligned_begin = (begin + L_qk - 1) / L_qk * L_qk; + if (aligned_begin > end) aligned_begin = end; + index_t aligned_end = end / L_qk * L_qk; + for (index_t i = begin; i < aligned_begin; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + + // [aligned_begin, aligned_end) + // where aligned_begin % L_qk == 0, aligned_end % L_qk == 0 + index_t base = aligned_begin / L_key * num_buckets; + const index_t base_end = aligned_end / L_key * num_buckets; + index_t i = aligned_begin; + while (base < base_end) { + for (index_t q = 0, j = 0; q < L_query; ++q) { + for (index_t k = 0; k < L_key; ++k) { + p_Y[i++] = p_input[base + p_index[j++]]; + } + base += num_buckets; + } + } + + for (index_t i = aligned_end; i < end; ++i) { + p_Y[i] = p_input[i / L_key * num_buckets + p_index[i % L_qk]]; + } + }); + }); + return Y; +} + +template +inline scalar_t cpuAtomicAdd(scalar_t *address, const scalar_t val) { +#pragma omp critical + *address += val; + return *address; +} + +void rpe_index_backward_cpu(torch::Tensor grad_input, torch::Tensor grad_output, + torch::Tensor index) { + /* + - Inputs + grad_output: float32 (B, H, L_query, L_key) + index: index_t (L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + */ + AT_ASSERTM(grad_input.device().is_cpu(), "grad_input must be a CPU tensor"); + AT_ASSERTM(grad_output.device().is_cpu(), "grad_output must be a CPU tensor"); + AT_ASSERTM(index.device().is_cpu(), "grad_index must be a CPU tensor"); + AT_ASSERTM(grad_input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(grad_output.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + + const index_t num_buckets = grad_input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + + auto grad_input_ = grad_input.contiguous(); + auto grad_output_ = grad_output.contiguous(); + auto index_ = index.contiguous(); + + const index_t grain_size = 3000; + const index_t numel = grad_output.numel(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_input.scalar_type(), "rpe_index_backward_atomic_cpu", [&] { + scalar_t *p_grad_input = grad_input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + const scalar_t *p_grad_output = grad_output_.data_ptr(); + at::parallel_for(0, numel, grain_size, [&](index_t begin, index_t end) { + for (index_t i = begin; i < end; ++i) { + const index_t input_i = i / L_key * num_buckets + p_index[i % L_qk]; + const scalar_t v = p_grad_output[i]; + cpuAtomicAdd(p_grad_input + input_i, v); + } + }); + }); +} + +std::string version() { + return "1.2.0"; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("version", &version, "The version of the package `rpe_index_cpp`"); + m.def("forward_cpu", &rpe_index_forward_cpu, "2D RPE Index Forward (CPU)"); + m.def("backward_cpu", &rpe_index_backward_cpu, "2D RPE Index Backward (CPU)"); + +#if defined(WITH_CUDA) + at::Tensor rpe_index_forward_gpu(torch::Tensor input, torch::Tensor index); + void rpe_index_backward_gpu(torch::Tensor grad_input, + torch::Tensor grad_output, torch::Tensor index); + m.def("forward_gpu", &rpe_index_forward_gpu, "2D RPE Index Forward (GPU)"); + m.def("backward_gpu", &rpe_index_backward_gpu, "2D RPE Index Backward (GPU)"); +#endif +} diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/rpe_index.py b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/rpe_index.py new file mode 100644 index 0000000000000000000000000000000000000000..1d915e1a56bdde6ed634ac4a8e632c6f833adc63 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/rpe_index.py @@ -0,0 +1,100 @@ +import torch +import rpe_index_cpp + + +EXPECTED_VERSION = "1.2.0" +assert rpe_index_cpp.version() == EXPECTED_VERSION, \ + f"""Unmatched `rpe_index_cpp` version: {rpe_index_cpp.version()}, expected version: {EXPECTED_VERSION} +Please re-build the package `rpe_ops`.""" + + +class RPEIndexFunction(torch.autograd.Function): + '''Y[b, h, i, j] = input[b, h, i, index[i, j]]''' + @staticmethod + def forward(ctx, input, index): + ''' + Y[b, h, i, j] = input[b, h, i, index[i, j]] + + Parameters + ---------- + input: torch.Tensor, float32 + The shape is (B, H, L_query, num_buckets) + index: torch.Tensor, int32 + The shape is (L_query, L_key) + + where B is the batch size, and H is the number of attention heads. + + Returns + ------- + Y: torch.Tensor, float32 + The shape is (B, H, L_query, L_key) + ''' + + num_buckets = input.size(-1) + ctx.save_for_backward(index) + ctx.input_shape = input.shape + forward_fn = rpe_index_cpp.forward_cpu if \ + input.device.type == 'cpu' else rpe_index_cpp.forward_gpu + output = forward_fn(input, index) + return output + + @staticmethod + def backward(ctx, grad_output): + ''' + - Inputs + grad_output: float32 (B, H, L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + ''' + index = ctx.saved_tensors[0] + if ctx.needs_input_grad[0]: + grad_input = grad_output.new_zeros(ctx.input_shape) + backward_fn = rpe_index_cpp.backward_cpu if \ + grad_output.device.type == 'cpu' else rpe_index_cpp.backward_gpu + backward_fn(grad_input, grad_output, index) + return grad_input, None + return None, None + + +if __name__ == '__main__': + import numpy as np + import time + B = 128 + H = 32 + L_query = 50 + L_key = L_query + num_buckets = 50 + + x = torch.randn(B, H, L_query, num_buckets) + + index = torch.randint(low=0, high=num_buckets, size=(L_query, L_key)) + index = index.to(torch.int) + offset = torch.arange(0, L_query * num_buckets, num_buckets).view(-1, 1) + + def test(x, index, offset): + tic = time.time() + x1 = x.clone() + x1.requires_grad = True + x2 = x.clone() + x2.requires_grad = True + + y = RPEIndexFunction.apply(x1, index) + gt_y = x2.flatten(2)[:, :, (index + offset).flatten() + ].view(B, H, L_query, L_key) + + np.testing.assert_almost_equal( + gt_y.detach().cpu().numpy(), y.detach().cpu().numpy()) + + mask = torch.randn(gt_y.shape, device=x.device) + (gt_y * mask).sum().backward() + (y * mask).sum().backward() + + print("X1:", x1.grad.cpu().numpy().flatten().sum()) + print("X2:", x2.grad.cpu().numpy().flatten().sum()) + np.testing.assert_almost_equal( + x1.grad.cpu().numpy(), x2.grad.cpu().numpy(), decimal=5) + print("Test over", x.device) + print("Cost:", time.time() - tic) + test(x, index, offset) + if torch.cuda.is_available(): + test(x.cuda(), index.cuda(), offset.cuda()) diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/rpe_index_cuda.cu b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/rpe_index_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..3ddb53158d8b232248785abe3347215dc21423b7 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/rpe_index_cuda.cu @@ -0,0 +1,140 @@ +#include +#include + +#include +#include + +using index_t = int; + +const int HIP_MAX_GRID_NUM = 65535; +const int HIP_MAX_NUM_THREADS = 512; + +inline int HIP_GET_NUM_THREADS(const int n) { + return std::min(HIP_MAX_NUM_THREADS, ((n + 31) / 32) * 32); +} + +inline int HIP_GET_BLOCKS(const int n, const int num_threads) { + return std::min(HIP_MAX_GRID_NUM, n + num_threads - 1) / num_threads; +} + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void rpe_index_forward_gpu_kernel( + index_t n, scalar_t *p_Y, const scalar_t *__restrict__ p_input, + const index_t *__restrict__ p_index, index_t num_buckets, index_t H, + index_t L_query, index_t L_key, index_t L_qk, index_t s0, index_t s1, + index_t s2, index_t s3) { + CUDA_KERNEL_LOOP(i, n) { + index_t gi = i / L_key; + const index_t qi = gi % L_query; + gi /= L_query; + const index_t hi = gi % H; + gi /= H; + const index_t bi = gi; + const index_t ind = bi * s0 + hi * s1 + qi * s2 + p_index[i % L_qk] * s3; + p_Y[i] = __ldg(&p_input[ind]); + } +} + +template +__global__ void rpe_index_backward_gpu_kernel( + index_t n, scalar_t *p_grad_input, const index_t *__restrict__ p_index, + const scalar_t *__restrict__ p_grad_output, index_t num_buckets, + index_t L_key, index_t L_qk) { + CUDA_KERNEL_LOOP(i, n) { + const index_t input_i = i / L_key * num_buckets + p_index[i % L_qk]; + const scalar_t v = p_grad_output[i]; + gpuAtomicAdd(p_grad_input + input_i, v); + } +} + +at::Tensor rpe_index_forward_gpu(torch::Tensor input, torch::Tensor index) { + /* + - Inputs + input: float32 (B, H, L_query, num_buckets) + index: index_t (L_query, L_key) + - Outputs + Y: float32 (B, H, L_query, L_key) + */ + AT_ASSERTM(input.device().is_cuda(), "input must be a GPU tensor"); + AT_ASSERTM(index.device().is_cuda(), "index must be a GPU tensor"); + AT_ASSERTM(input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + AT_ASSERTM(index.is_contiguous(), "index should be contiguous"); + const index_t B = input.size(0); + const index_t H = input.size(1); + const index_t num_buckets = input.size(3); + const index_t L_query = index.size(0); + const index_t L_key = index.size(1); + const index_t L_qk = L_query * L_key; + at::Tensor Y = at::empty({B, H, L_query, L_key}, input.options()); + const index_t numel = Y.numel(); + const at::IntArrayRef strides = input.strides(); + + const int threadsPerBlock = HIP_GET_NUM_THREADS(numel); + const int blocks = HIP_GET_BLOCKS(numel, threadsPerBlock); + + at::cuda::CUDAGuard device_guard(input.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "rpe_index_forward_gpu", [&] { + const scalar_t *p_input = input.data_ptr(); + const index_t *p_index = index.data_ptr(); + scalar_t *p_Y = Y.data_ptr(); + rpe_index_forward_gpu_kernel<<>>( + numel, p_Y, p_input, p_index, num_buckets, H, L_query, L_key, L_qk, + strides[0], strides[1], strides[2], strides[3]); + }); + return Y; +} + +void rpe_index_backward_gpu(torch::Tensor grad_input, torch::Tensor grad_output, + torch::Tensor index) { + /* + - Inputs + grad_output: float32 (B, H, L_query, L_key) + index: index_t (L_query, L_key) + - Outputs + grad_input: float32 (B, H, L_query, num_buckets) + */ + AT_ASSERTM(grad_input.device().is_cuda(), "grad_input must be a GPU tensor"); + AT_ASSERTM(grad_output.device().is_cuda(), + "grad_output must be a GPU tensor"); + AT_ASSERTM(index.device().is_cuda(), "grad_index must be a GPU tensor"); + AT_ASSERTM(grad_input.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(grad_output.ndimension() == 4, "input must be a 4D tensor"); + AT_ASSERTM(index.ndimension() == 2, "index must be a 2D tensor"); + AT_ASSERTM(index.scalar_type() == at::kInt, "index must be Int type"); + + const index_t num_buckets = grad_input.size(3); + const index_t L_query = grad_output.size(2); + const index_t L_key = grad_output.size(3); + const index_t L_qk = L_query * L_key; + + auto grad_input_ = grad_input.contiguous(); + auto grad_output_ = grad_output.contiguous(); + auto index_ = index.contiguous(); + + const index_t numel = grad_output.numel(); + + const int threadsPerBlock = HIP_GET_NUM_THREADS(numel); + const int blocks = HIP_GET_BLOCKS(numel, threadsPerBlock); + + at::cuda::CUDAGuard device_guard(grad_output.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad_output.scalar_type(), "rpe_index_backward_gpu", [&] { + scalar_t *p_grad_input = grad_input_.data_ptr(); + const index_t *p_index = index_.data_ptr(); + const scalar_t *p_grad_output = grad_output_.data_ptr(); + rpe_index_backward_gpu_kernel<<>>( + numel, p_grad_input, p_index, p_grad_output, num_buckets, L_key, + L_qk); + }); +} diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/setup.py b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..c145714e9595b0364f070babe3ecf0b926a01ec6 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/RPE/rpe_ops/setup.py @@ -0,0 +1,24 @@ +from setuptools import setup, Extension +import torch +from torch.utils import cpp_extension + +ext_t = cpp_extension.CppExtension +ext_fnames = ['rpe_index.cpp'] +define_macros = [] +extra_compile_args = dict(cxx=['-fopenmp', '-O3'], + nvcc=['-O3']) + +if torch.cuda.is_available(): + ext_t = cpp_extension.CUDAExtension + ext_fnames.append('rpe_index_cuda.cu') + define_macros.append(('WITH_CUDA', None)) + +setup(name='rpe_index', + version="1.2.0", + ext_modules=[ext_t( + 'rpe_index_cpp', + ext_fnames, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + )], + cmdclass={'build_ext': cpp_extension.BuildExtension}) diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/__init__.py b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..254e7608193807b6fbafb6417f2fc0212d842ed8 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/__init__.py @@ -0,0 +1,65 @@ +from ..base import BaseModel +from .vit import VisionTransformerWithKPRPE +from torchvision import transforms + + +class ViTKPRPEModel(BaseModel): + + + """ + Vision Transformer for face recognition model with KeyPoint Relative Position Encoding (KP-RPE). + + ``` + @article{kim2024keypoint, + title={KeyPoint Relative Position Encoding for Face Recognition}, + author={Kim, Minchul and Su, Yiyang and Liu, Feng and Jain, Anil and Liu, Xiaoming}, + journal={CVPR}, + year={2024} + } + ``` + """ + def __init__(self, net, config): + super(ViTKPRPEModel, self).__init__(config) + self.net = net + + + @classmethod + def from_config(cls, config): + + if config.name == 'small': + net = VisionTransformerWithKPRPE(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=12, + mlp_ratio=5, num_heads=8, drop_path_rate=0.1, norm_layer="ln", + mask_ratio=config.mask_ratio, rpe_config=config.rpe_config) + elif config.name == 'base': + net = VisionTransformerWithKPRPE(img_size=112, patch_size=8, num_classes=config.output_dim, embed_dim=512, depth=24, + mlp_ratio=3, num_heads=16, drop_path_rate=0.1, norm_layer="ln", + mask_ratio=config.mask_ratio, rpe_config=config.rpe_config) + else: + raise NotImplementedError + + model = cls(net, config) + model.eval() + return model + + def forward(self, x, *args, **kwargs): + if self.input_color_flip: + x = x.flip(1) + return self.net(x, *args, **kwargs) + + def make_train_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + + def make_test_transform(self): + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + +def load_model(model_config): + model = ViTKPRPEModel.from_config(model_config) + return model \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/configs/v1_base_kprpe_splithead_unshared.yaml b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/configs/v1_base_kprpe_splithead_unshared.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5bdfec14039f53d8c6fcfedfef4f4d1a3351207c --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/configs/v1_base_kprpe_splithead_unshared.yaml @@ -0,0 +1,17 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'base' +output_dim: 512 +start_from: '' +freeze: False + +mask_ratio: 0.0 +rpe_config: + name: KPRPE_shared + rpe_on: k + shared_head: True + mode: ctx + method: product + ratio: 1.9 + ctx_type: 'rel_keypoint_splithead_unshared' + num_keypoints: 5 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/configs/v1_small_kprpe_splithead_unshared.yaml b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/configs/v1_small_kprpe_splithead_unshared.yaml new file mode 100644 index 0000000000000000000000000000000000000000..946db1af29595dbf1068de92bfdeacd2d4f80e98 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/configs/v1_small_kprpe_splithead_unshared.yaml @@ -0,0 +1,17 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'small' +output_dim: 512 +start_from: '' +freeze: False + +mask_ratio: 0.0 +rpe_config: + name: KPRPE_shared + rpe_on: k + shared_head: True + mode: ctx + method: product + ratio: 1.9 + ctx_type: 'rel_keypoint_splithead_unshared' + num_keypoints: 5 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/rpe_options.py b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/rpe_options.py new file mode 100644 index 0000000000000000000000000000000000000000..7e5d2054015c1c393b4e8938264559fa61da2a09 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/rpe_options.py @@ -0,0 +1,128 @@ +import torch +import torch.nn as nn +from .RPE.KPRPE.kprpe_shared import get_rpe_config +from .RPE.KPRPE import relative_keypoints + + +def make_kprpe_shared(rpe_config, depth, num_heads): + + assert rpe_config.rpe_on == 'k' + num_buckets = get_rpe_config( + ratio=rpe_config.ratio, + method=rpe_config.method, + mode=rpe_config.mode, + shared_head=rpe_config.shared_head, + skip=0, + rpe_on=rpe_config.rpe_on, + )['rpe_k']['num_buckets'] + if rpe_config.ctx_type == 'rel_keypoint': + keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets) + # init zero + keypoint_linear.weight.data.zero_() + keypoint_linear.bias.data.zero_() + + elif rpe_config.ctx_type == 'rel_keypoint_unshared': + keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * depth) + # init zero + keypoint_linear.weight.data.zero_() + keypoint_linear.bias.data.zero_() + + elif rpe_config.ctx_type == 'rel_keypoint_unshared_v2': + keypoint_linear = nn.Sequential( + nn.Linear(2 * rpe_config.num_keypoints, 256), + nn.ReLU(inplace=True), + nn.LayerNorm(256), + nn.Linear(256, num_buckets * depth), + ) + # init zero + keypoint_linear[-1].weight.data.zero_() + keypoint_linear[-1].bias.data.zero_() + + elif rpe_config.ctx_type == 'rel_keypoint_splithead': + keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * num_heads) + # init zero + keypoint_linear.weight.data.zero_() + keypoint_linear.bias.data.zero_() + + elif rpe_config.ctx_type == 'rel_keypoint_splithead_unshared': + keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets * num_heads * depth) + # init zero + keypoint_linear.weight.data.zero_() + keypoint_linear.bias.data.zero_() + + elif rpe_config.ctx_type == 'rel_keypoint_v2': + keypoint_linear = nn.Sequential( + nn.Linear(2 * rpe_config.num_keypoints, 256), + nn.ReLU(inplace=True), + nn.LayerNorm(256), + nn.Linear(256, num_buckets), + ) + # init zero + keypoint_linear[-1].weight.data.zero_() + keypoint_linear[-1].bias.data.zero_() + elif rpe_config.ctx_type == 'keypoint': + keypoint_linear = nn.Linear(2 * rpe_config.num_keypoints, num_buckets) + # init zero + keypoint_linear.weight.data.zero_() + keypoint_linear.bias.data.zero_() + else: + raise ValueError(f'Not support ctx_type: {rpe_config.ctx_type}') + + return keypoint_linear, num_buckets + + + +def make_kprpe_input(keypoints, x, keypoint_linear, rpe_config, mask_ratio, depth, num_heads, num_buckets): + B = x.shape[0] + ctx_type = rpe_config.get('ctx_type', '') + num_kp = rpe_config.num_keypoints + if ctx_type == 'rel_keypoint': + assert mask_ratio == 0 + rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] + rel_keypoints = keypoint_linear(rel_keypoints).unsqueeze(1) # B H N D + extra_ctx = {'rel_keypoints': rel_keypoints} + + elif ctx_type == 'rel_keypoint_unshared': + assert mask_ratio == 0 + rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] + rel_keypoints = keypoint_linear(rel_keypoints) # B H N D + rel_keypoints = rel_keypoints.view(B, -1, depth, num_buckets).transpose(1, 2) + rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1) + extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints] + + elif ctx_type == 'rel_keypoint_unshared_v2': + assert mask_ratio == 0 + rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] + rel_keypoints = keypoint_linear(rel_keypoints) # B H N D + rel_keypoints = rel_keypoints.view(B, -1, depth, num_buckets).transpose(1, 2) + rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1) + extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints] + + elif ctx_type == 'rel_keypoint_splithead': + assert mask_ratio == 0 + rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] + rel_keypoints = keypoint_linear(rel_keypoints) # B H N D + rel_keypoints = rel_keypoints.view(B, -1, num_heads, num_buckets).transpose(1, 2) + extra_ctx = {'rel_keypoints': rel_keypoints} + elif ctx_type == 'rel_keypoint_splithead_unshared': + assert mask_ratio == 0 + rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] + rel_keypoints = keypoint_linear(rel_keypoints) # B H N D + rel_keypoints = rel_keypoints.view(B, -1, num_heads * depth, num_buckets).transpose(1, 2) + rel_keypoints = torch.chunk(rel_keypoints, depth, dim=1) + extra_ctx = [{'rel_keypoints': rel_keypoint} for rel_keypoint in rel_keypoints] + + elif ctx_type == 'rel_keypoint_v2': + assert mask_ratio == 0 + rel_keypoints = relative_keypoints.make_rel_keypoints(keypoints, x)[:, :, :2 * num_kp] + rel_keypoints = keypoint_linear(rel_keypoints).unsqueeze(1) # B H N D + extra_ctx = {'rel_keypoints': rel_keypoints} + + elif ctx_type == 'keypoint': + keypoints = keypoints.flatten(1).unsqueeze(1) + keypoints = keypoint_linear(keypoints).unsqueeze(1) + extra_ctx = {'rel_keypoints': keypoints} + else: + raise ValueError(f'Not support ctx_type: {ctx_type}') + + return extra_ctx \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/models/vit_kprpe/vit.py b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..4a178dd67af93d423055c529b3d3eac0424d43f5 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/models/vit_kprpe/vit.py @@ -0,0 +1,313 @@ +import torch +import torch.nn as nn +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from typing import Optional, Callable +from .rpe_options import make_kprpe_shared, make_kprpe_input +from .RPE import build_rpe + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU6, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class VITBatchNorm(nn.Module): + def __init__(self, num_features): + super().__init__() + self.num_features = num_features + self.bn = nn.BatchNorm1d(num_features=num_features) + + def forward(self, x): + return self.bn(x) + + +class Attention(nn.Module): + def __init__(self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + attn_drop: float = 0., + proj_drop: float = 0., + rpe_config=None): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + # image relative position encoding + self.rpe_config = rpe_config + self.rpe_q, self.rpe_k, self.rpe_v = build_rpe(rpe_config, head_dim=head_dim, num_heads=num_heads) + + + def forward(self, x, extra_ctx=None): + + batch_size, num_token, embed_dim = x.shape + #qkv is [3,batch_size,num_heads,num_token, embed_dim//num_heads] + qkv = self.qkv(x).reshape( + batch_size, num_token, 3, self.num_heads, embed_dim // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + q *= self.scale + attn = (q @ k.transpose(-2, -1)) + + # image relative position on keys + if self.rpe_k is not None: + ctx = extra_ctx['rel_keypoints'] + attn += self.rpe_k(ctx) + + # image relative position on queries + if self.rpe_q is not None: + attn += self.rpe_q(k * self.scale).transpose(2, 3) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + out = attn @ v + + # image relative position on values + if self.rpe_v is not None: + out += self.rpe_v(attn) + + x = out.transpose(1, 2).reshape(batch_size, num_token, embed_dim) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + + def __init__(self, + dim: int, + num_heads: int, + num_patches: int, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + drop: float = 0., + attn_drop: float = 0., + drop_path: float = 0., + act_layer: Callable = nn.ReLU6, + norm_layer: str = "ln", + patch_n: int = 144, + rpe_config=None): + super().__init__() + + if norm_layer == "bn": + self.norm1 = VITBatchNorm(num_features=num_patches) + self.norm2 = VITBatchNorm(num_features=num_patches) + elif norm_layer == "ln": + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.attn = Attention( + dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + rpe_config=rpe_config) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.extra_gflops = (num_heads * patch_n * (dim//num_heads)*patch_n * 2) / (1000**3) + + def forward(self, x, extra_ctx=None): + norm_x = self.norm1(x) + attn_out = self.attn(norm_x, extra_ctx=extra_ctx) + x = x + self.drop_path(attn_out) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class PatchEmbed(nn.Module): + def __init__(self, img_size=108, patch_size=9, in_channels=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d(in_channels, embed_dim, + kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + batch_size, channels, height, width = x.shape + assert height == self.img_size[0] and width == self.img_size[1], \ + f"Input image size ({height}*{width}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class VisionTransformerWithKPRPE(nn.Module): + """ + Vision Transformer with auxiliary keypoint inputs for KP-RPE + """ + + def __init__(self, + img_size: int = 112, + patch_size: int = 16, + in_channels: int = 3, + num_classes: int = 1000, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: Optional[None] = None, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0., + num_patches: Optional[int] = None, + norm_layer: str = "ln", + mask_ratio = 0.1, + using_checkpoint = False, + rpe_config=None, + ): + super().__init__() + self.num_classes = num_classes + # num_features for consistency with other models + self.num_features = self.embed_dim = embed_dim + + if num_patches is not None: + self.patch_embed = nn.Identity() + else: + self.patch_embed = PatchEmbed(img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + self.mask_ratio = mask_ratio + self.using_checkpoint = using_checkpoint + + self.num_patches = num_patches + + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + patch_n = (img_size//patch_size)**2 + self.blocks = nn.ModuleList( + [ + Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + num_patches=num_patches, patch_n=patch_n, rpe_config=rpe_config) + for i in range(depth)] + ) + self.extra_gflops = 0.0 + for _block in self.blocks: + self.extra_gflops += _block.extra_gflops + + if norm_layer == "ln": + self.norm = nn.LayerNorm(embed_dim) + elif norm_layer == "bn": + self.norm = VITBatchNorm(self.num_patches) + + # features head + self.feature = nn.Sequential( + nn.Linear(in_features=embed_dim * num_patches, out_features=embed_dim, bias=False), + nn.BatchNorm1d(num_features=embed_dim, eps=2e-5), + nn.Linear(in_features=embed_dim, out_features=num_classes, bias=False), + nn.BatchNorm1d(num_features=num_classes, eps=2e-5) + ) + + if self.mask_ratio == 0: + pass + else: + self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + torch.nn.init.normal_(self.mask_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + # trunc_normal_(self.cls_token, std=.02) + self.apply(self._init_weights) + + self.num_heads = num_heads + self.depth = depth + + self.rpe_config = rpe_config + self.keypoint_linear, self.num_buckets = make_kprpe_shared(rpe_config, depth, num_heads) + + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def random_masking(self, x, mask_ratio=0.1): + N, L, D = x.size() # batch, length, dim + len_keep = int(L * (1 - mask_ratio)) + + noise = torch.rand(N, L, device=x.device) # noise in [0, 1] + + # sort noise for each sample + # ascend: small is keep, large is remove + ids_shuffle = torch.argsort(noise, dim=1) + ids_restore = torch.argsort(ids_shuffle, dim=1) + + # keep the first subset + ids_keep = ids_shuffle[:, :len_keep] + index = ids_keep.unsqueeze(-1).repeat(1, 1, D) + x_masked = torch.gather(x, dim=1, index=index) + + return x_masked, index, ids_restore + + def forward_features(self, x, keypoints=None): + B = x.shape[0] + x = self.patch_embed(x) + x = x + self.pos_embed + x = self.pos_drop(x) + + if self.training and self.mask_ratio > 0: + x, _, ids_restore = self.random_masking(x) + + extra_ctx = make_kprpe_input(keypoints, x, self.keypoint_linear, self.rpe_config, self.mask_ratio, + self.depth, self.num_heads, self.num_buckets) + + for block_idx, func in enumerate(self.blocks): + if isinstance(extra_ctx, list): + extra_ctx_ = extra_ctx[block_idx] + else: + extra_ctx_ = extra_ctx + if self.using_checkpoint and self.training: + from torch.utils.checkpoint import checkpoint + x = checkpoint(func, x, extra_ctx_) + else: + x = func(x, extra_ctx=extra_ctx_) + x = self.norm(x.float()) + if self.training and self.mask_ratio > 0: + mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] - x.shape[1], 1) + x_ = torch.cat([x[:, :, :], mask_tokens], dim=1) # no cls token + x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle + x = x_ + return torch.reshape(x, (B, self.num_patches * self.embed_dim)) + + def forward(self, x, keypoints=None): + x = self.forward_features(x, keypoints=keypoints) + x = self.feature(x) + return x + diff --git a/cvlface/research/recognition/code/run_v1/optims/configs/cosine.yaml b/cvlface/research/recognition/code/run_v1/optims/configs/cosine.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7d3282fe226d8bb4abb2eb15f5f8f7ac776551ed --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/optims/configs/cosine.yaml @@ -0,0 +1,13 @@ + +num_epoch: 34 +optimizer: 'adamw' +lr: 0.0001 +momentum: 0.9 +weight_decay: 0.3 +scheduler: 'cosine' +filter_bias_and_bn: true +warmup_epoch: 3 +max_grad_norm: 5.0 +lr_milestones: [] + + diff --git a/cvlface/research/recognition/code/run_v1/optims/configs/cosine_long.yaml b/cvlface/research/recognition/code/run_v1/optims/configs/cosine_long.yaml new file mode 100644 index 0000000000000000000000000000000000000000..734727ec9d176a43f787aee93278213ba4012a5e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/optims/configs/cosine_long.yaml @@ -0,0 +1,12 @@ + +num_epoch: 36 +optimizer: 'adamw' +lr: 0.0001 +momentum: 0.9 +weight_decay: 0.3 +scheduler: 'cosine' +filter_bias_and_bn: true +warmup_epoch: 3 +max_grad_norm: 5.0 +lr_milestones: [] + diff --git a/cvlface/research/recognition/code/run_v1/optims/configs/cosine_wd5.yaml b/cvlface/research/recognition/code/run_v1/optims/configs/cosine_wd5.yaml new file mode 100644 index 0000000000000000000000000000000000000000..aada7387845630921fd468f4f23dc5c997ef46f5 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/optims/configs/cosine_wd5.yaml @@ -0,0 +1,12 @@ + +num_epoch: 36 +optimizer: 'adamw' +lr: 0.0001 +momentum: 0.9 +weight_decay: 0.05 +scheduler: 'cosine' +filter_bias_and_bn: true +warmup_epoch: 3 +max_grad_norm: 5.0 +lr_milestones: [] + diff --git a/cvlface/research/recognition/code/run_v1/optims/configs/cosine_wd5_extra_long.yaml b/cvlface/research/recognition/code/run_v1/optims/configs/cosine_wd5_extra_long.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c19992ded561ca4ea7f076e29f4fad81251e649d --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/optims/configs/cosine_wd5_extra_long.yaml @@ -0,0 +1,12 @@ + +num_epoch: 50 +optimizer: 'adamw' +lr: 0.0001 +momentum: 0.9 +weight_decay: 0.05 +scheduler: 'cosine' +filter_bias_and_bn: true +warmup_epoch: 3 +max_grad_norm: 5.0 +lr_milestones: [] + diff --git a/cvlface/research/recognition/code/run_v1/optims/configs/cosine_wd5_long.yaml b/cvlface/research/recognition/code/run_v1/optims/configs/cosine_wd5_long.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5bb6169e3c0f2d8929d7be3b900fc02076de5979 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/optims/configs/cosine_wd5_long.yaml @@ -0,0 +1,12 @@ + +num_epoch: 40 +optimizer: 'adamw' +lr: 0.0001 +momentum: 0.9 +weight_decay: 0.05 +scheduler: 'cosine' +filter_bias_and_bn: true +warmup_epoch: 3 +max_grad_norm: 5.0 +lr_milestones: [] + diff --git a/cvlface/research/recognition/code/run_v1/optims/configs/plateau_adamw_agedb.yaml b/cvlface/research/recognition/code/run_v1/optims/configs/plateau_adamw_agedb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f153c08739ce9f89630b7304e81a42d2a2b6ac3b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/optims/configs/plateau_adamw_agedb.yaml @@ -0,0 +1,15 @@ +num_epoch: 50 +optimizer: 'adamw' +lr: 0.001 +momentum: 0.9 +weight_decay: 0.05 +scheduler: 'plateau' +filter_bias_and_bn: true +warmup_epoch: 0 +max_grad_norm: 5.0 +lr_milestones: [] +plateau_factor: 0.5 +plateau_patience: 2 +plateau_threshold: 0.01 +plateau_cooldown: 0 +min_lr: 0.000001 diff --git a/cvlface/research/recognition/code/run_v1/optims/configs/poly.yaml b/cvlface/research/recognition/code/run_v1/optims/configs/poly.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fa5351a9f38251fcd2e8c9908ac79d61e813efa1 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/optims/configs/poly.yaml @@ -0,0 +1,12 @@ + +num_epoch: 34 +optimizer: 'adamw' +lr: 0.001 +momentum: 0.9 +weight_decay: 0.3 +scheduler: 'poly_2' +filter_bias_and_bn: true +warmup_epoch: 3 +max_grad_norm: 5.0 +lr_milestones: [] + diff --git a/cvlface/research/recognition/code/run_v1/optims/configs/step_sgd.yaml b/cvlface/research/recognition/code/run_v1/optims/configs/step_sgd.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f28c13011bb8c81769854b80d3a9adae7052624 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/optims/configs/step_sgd.yaml @@ -0,0 +1,12 @@ + +num_epoch: 26 +optimizer: 'sgd' +lr: 0.1 +momentum: 0.9 +weight_decay: 0.0005 +scheduler: 'step' +filter_bias_and_bn: true +warmup_epoch: 1 +lr_milestones: [12, 20, 24] +lr_lambda: 0.1 +max_grad_norm: 5.0 diff --git a/cvlface/research/recognition/code/run_v1/optims/lr_scheduler.py b/cvlface/research/recognition/code/run_v1/optims/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..f9eb4a871178c2e5c4419f1094893522ad26a264 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/optims/lr_scheduler.py @@ -0,0 +1,213 @@ +from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau +import torch +from timm.scheduler.cosine_lr import CosineLRScheduler +import numpy as np +from torch import nn + +def param_groups_weight_decay( + named_parameters, + weight_decay=1e-5, + no_weight_decay_list=(), + no_weight_decay_value=0.0, +): + no_weight_decay_list = set(no_weight_decay_list) + decay = [] + no_decay = [] + for name, param in named_parameters: + if not param.requires_grad: + continue + + if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list: + no_decay.append(param) + else: + decay.append(param) + + return [ + {'params': no_decay, 'weight_decay': no_weight_decay_value}, + {'params': decay, 'weight_decay': weight_decay}] + + +def make_scheduler(cfg, opt): + + if cfg.optims.scheduler == 'poly_2': + print('poly_2 scheduler') + lr_scheduler = PolyScheduler( + optimizer=opt, + base_lr=cfg.optims.lr, + max_steps=cfg.trainers.total_step, + warmup_steps=cfg.trainers.warmup_step, + last_epoch=-1, + power=2 + ) + elif cfg.optims.scheduler == 'poly_0': + print('poly_0 scheduler') + lr_scheduler = PolyScheduler( + optimizer=opt, + base_lr=cfg.optims.lr, + max_steps=cfg.trainers.total_step, + warmup_steps=cfg.trainers.warmup_step, + last_epoch=-1, + power=0 + ) + elif cfg.optims.scheduler == 'cosine': + print('cosine scheduler') + lr_scheduler = CosineLRScheduler(opt, + t_initial=cfg.trainers.total_step - cfg.trainers.warmup_step, + warmup_t=cfg.trainers.warmup_step, + warmup_lr_init=0, + warmup_prefix=True) + elif cfg.optims.scheduler == 'step': + print('step scheduler') + steps_per_epoch = cfg.trainers.total_step // cfg.optims.num_epoch + step_milestones = [mile * steps_per_epoch for mile in cfg.optims.lr_milestones] + lr_scheduler = StepScheduler(optimizer=opt, + base_lr=cfg.optims.lr, + max_steps=cfg.trainers.total_step, + warmup_steps=cfg.trainers.warmup_step, + lr_milestones=step_milestones, + lr_lambda=cfg.optims.lr_lambda) + elif cfg.optims.scheduler == 'plateau': + print('plateau scheduler') + lr_scheduler = ReduceLROnPlateau( + opt, + mode='max', + factor=cfg.optims.plateau_factor, + patience=cfg.optims.plateau_patience, + threshold=cfg.optims.plateau_threshold, + cooldown=cfg.optims.plateau_cooldown, + min_lr=cfg.optims.min_lr, + ) + + else: + raise ValueError('') + + return lr_scheduler + + +def scheduler_step(scheduler, global_step): + if isinstance(scheduler, ReduceLROnPlateau): + return + if isinstance(scheduler, _LRScheduler): + scheduler.step() + else: + scheduler.step(global_step) + +def scheduler_step_on_metric(scheduler, metric): + if isinstance(scheduler, ReduceLROnPlateau): + scheduler.step(metric) + +def get_last_lr(optimizer): + lrs = [group['lr'] for group in optimizer.param_groups] + return float(np.mean(lrs)) + + +class PolyScheduler(_LRScheduler): + def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1, power=2): + self.base_lr = base_lr + self.warmup_lr_init = 0.0001 + self.max_steps: int = max_steps + self.warmup_steps: int = warmup_steps + self.power = power + super(PolyScheduler, self).__init__(optimizer, -1, False) + self.last_epoch = last_epoch + + def get_warmup_lr(self): + alpha = float(self.last_epoch) / float(self.warmup_steps) + return [self.base_lr * alpha for _ in self.optimizer.param_groups] + + def get_lr(self): + if self.last_epoch == -1: + return [self.warmup_lr_init for _ in self.optimizer.param_groups] + if self.last_epoch < self.warmup_steps: + return self.get_warmup_lr() + else: + alpha = pow( + 1 + - float(self.last_epoch - self.warmup_steps) + / float(self.max_steps - self.warmup_steps), + self.power, + ) + return [self.base_lr * alpha for _ in self.optimizer.param_groups] + + + +class StepScheduler(_LRScheduler): + def __init__(self, optimizer, base_lr, max_steps, warmup_steps, warmup_lr_init=0.0001, lr_milestones=[], lr_lambda=0.1, last_epoch=-1): + self.base_lr = base_lr + self.warmup_lr_init = warmup_lr_init + self.max_steps: int = max_steps + self.warmup_steps: int = warmup_steps + super(StepScheduler, self).__init__(optimizer, -1, False) + self.last_epoch = last_epoch + self.lr_milestones = lr_milestones + self.lr_lambda = lr_lambda + + def get_warmup_lr(self): + alpha = float(self.last_epoch) / float(self.warmup_steps) + return [self.base_lr * alpha for _ in self.optimizer.param_groups] + + def get_lr(self): + if self.last_epoch == -1: + return [self.warmup_lr_init for _ in self.optimizer.param_groups] + if self.last_epoch < self.warmup_steps: + return self.get_warmup_lr() + else: + alpha = 1.0 + for milestone in self.lr_milestones: + if self.last_epoch > milestone: + alpha = alpha * self.lr_lambda + return [self.base_lr * alpha for _ in self.optimizer.param_groups] + + + + +if __name__ == '__main__': + print('') + model = torch.nn.Linear(5,5) + optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1) + base_lr = 0.001 + max_steps = 26 + warmup_steps = 10 + + lr_milestones = [12, 20, 24] + lr_lambda = 0.1 + scheduler = StepScheduler(optimizer, base_lr, max_steps, warmup_steps=warmup_steps, warmup_lr_init=0.0, lr_milestones=lr_milestones, lr_lambda=lr_lambda) + lrs = [] + for step in range(max_steps): + optimizer.step() + scheduler.step() + lr = optimizer.param_groups[0]['lr'] + lrs.append(lr) + + from matplotlib import pyplot as plt + plt.plot(lrs) + plt.show() + + scheduler = PolyScheduler(optimizer, base_lr, max_steps, warmup_steps, power=0) + lrs = [] + for step in range(max_steps): + optimizer.step() + scheduler.step() + lr = optimizer.param_groups[0]['lr'] + lrs.append(lr) + + from matplotlib import pyplot as plt + plt.plot(lrs) + plt.show() + + scheduler = CosineLRScheduler(optimizer, + t_initial=max_steps - warmup_steps, + warmup_t=warmup_steps, + warmup_lr_init=0, + warmup_prefix=True) + + lrs = [] + for step in range(max_steps): + optimizer.step() + scheduler_step(scheduler, step) + lr = get_last_lr(optimizer) + lrs.append(lr) + + from matplotlib import pyplot as plt + plt.plot(lrs) + plt.show() diff --git a/cvlface/research/recognition/code/run_v1/optims/optims.py b/cvlface/research/recognition/code/run_v1/optims/optims.py new file mode 100644 index 0000000000000000000000000000000000000000..29bedabf86fe5f1bd446644734cd9ae786c76cb9 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/optims/optims.py @@ -0,0 +1,70 @@ +import torch +from .lr_scheduler import param_groups_weight_decay + + +def make_optimizer(cfg, model, classifier, aligner): + + params = [] + num_total_params = 0 + num_trainable_params = 0 + + if cfg.optims.filter_bias_and_bn: + no_weight_decay_value = 0.0 + else: + no_weight_decay_value = cfg.optims.weight_decay + + # get backbone param groups + if model.has_trainable_params(): + no_weight_decay_list = [] + for name, param in model.named_parameters(): + if ('emb' in name and 'patch_embed' not in name) or ('token' in name): + no_weight_decay_list.append(name) + num_total_params += sum([p.numel() for p in model.parameters()]) + num_trainable_params += sum([p.numel() for p in model.parameters() if p.requires_grad]) + model_param_groups = param_groups_weight_decay(model.named_parameters(), + weight_decay=cfg.optims.weight_decay, + no_weight_decay_value=no_weight_decay_value, + no_weight_decay_list=no_weight_decay_list) + params = params + model_param_groups + + # get classifier param groups + if classifier is not None and classifier.has_trainable_params(): + num_total_params += sum([p.numel() for p in classifier.parameters()]) + num_trainable_params += sum([p.numel() for p in classifier.parameters() if p.requires_grad]) + cls_param_groups = [{"params": [p for p in classifier.parameters() if p.requires_grad], + 'weight_decay': cfg.optims.weight_decay}] + params = params + cls_param_groups + + # get aligner param groups + if aligner.has_trainable_params(): + num_total_params += sum([p.numel() for p in aligner.parameters()]) + num_trainable_params += sum([p.numel() for p in aligner.parameters() if p.requires_grad]) + aligner_param_groups = [{"params": [p for p in aligner.parameters() if p.requires_grad], + 'weight_decay': cfg.optims.weight_decay}] + params = params + aligner_param_groups + + # print number of params + print(f"Total params: {num_total_params}") + print(f"Trainable params: {num_trainable_params}") + print(f"Percentage of trainable params: {num_trainable_params / num_total_params * 100:.2f}%") + + if cfg.optims.optimizer == "sgd": + # TODO the params of partial fc must be last in the params list + opt = torch.optim.SGD(params=params, lr=cfg.optims.lr, momentum=cfg.optims.momentum, weight_decay=cfg.optims.weight_decay) + elif cfg.optims.optimizer == "adamw": + opt = torch.optim.AdamW(params=params, lr=cfg.optims.lr, weight_decay=cfg.optims.weight_decay) + else: + raise + + return opt + + +def split_backbone_aligner(backbone): + backbone_named_params = [] + aligner_named_params = [] + for name, param in backbone.named_parameters(): + if 'mobilenet' in name or 'fc_loc' in name: + aligner_named_params.append((name, param)) + else: + backbone_named_params.append((name, param)) + return backbone_named_params, aligner_named_params \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/pefts/__init__.py b/cvlface/research/recognition/code/run_v1/pefts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb68267dcceaca86d0f43216bf83ecce6d8e43bd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pefts/__init__.py @@ -0,0 +1,175 @@ +import torch +import os + +def apply_peft(peft_config, model, classifier, data_cfg, label_mapping=None): + + if peft_config.name == 'none': + return model, classifier + + print('Apply peft') + + if peft_config.model_ckpt_dir: + print('load model from', peft_config.model_ckpt_dir) + model.load_state_dict_from_path(os.path.join(peft_config.model_ckpt_dir, 'model.pt')) + if peft_config.classifier_ckpt_dir and classifier is not None: + print('load classifier from', peft_config.classifier_ckpt_dir) + classifier.load_state_dict_from_path(os.path.join(peft_config.classifier_ckpt_dir, 'classifier.pt')) + + peft_model = apply_peft_to_model(peft_config, model) + classifier = load_center(classifier, peft_config, data_cfg, label_mapping) + return peft_model, classifier + + +def apply_peft_to_model(peft_config, model): + + + if peft_config.name == 'lora': + from peft import LoraConfig, LoraModel + target_modules_mapping = { + 'att_qkv': ['qkv'], + 'att_qkv_keypoint_linear': ['qkv', 'keypoint_linear'], + 'att_qkv_feature': ['qkv', 'feature.0', 'feature.2'], + 'att_qkv_feature_keypoint_linear': ['qkv', 'feature.0', 'feature.2', 'keypoint_linear'], + } + target_modules = target_modules_mapping[peft_config.target_modules] + peft_config = LoraConfig(r=peft_config.lora_rank, lora_alpha=peft_config.lora_rank, + target_modules=target_modules, + lora_dropout=0.1, bias="none") + perf_model = LoraModel(model, peft_config, adapter_name='default') + + + elif peft_config.name == 'part_freeze': + target_modules_mapping = {} + for k in range(24): + target_modules_mapping[f'blocks.{k}'] = [f'blocks.{i}' for i in range(k, 24)] + target_modules_mapping[f'blocks.{k}_feature'] = [f'blocks.{i}' for i in range(k, 24)] + ['feature'] + target_modules_mapping[f'blocks.{k}_keypoint_linear'] = [f'blocks.{i}' for i in range(k, 24)] + ['keypoint_linear'] + target_modules_mapping[f'blocks.{k}_feature_keypoint_linear'] = [f'blocks.{i}' for i in range(k, 24)] + ['feature'] + ['keypoint_linear'] + for k in range(49): + target_modules_mapping[f'body.{k}'] = [f'body.{i}' for i in range(k, 49)] + ['output_layer'] + target_modules = target_modules_mapping[peft_config.target_modules] + + for key, param in model.named_parameters(): + is_train = False + for target in target_modules: + if target in key: + is_train = True + break + if is_train: + param.requires_grad = True + else: + param.requires_grad = False + perf_model = model + elif peft_config.name == 'full': + # full train + perf_model = model + + elif peft_config.name == 'freeze': + for key, param in model.named_parameters(): + param.requires_grad = False + perf_model = model + else: + raise ValueError(f"peft_config.name: {peft_config.name}") + + trainables, untrainables = print_trainable_parameters(perf_model) + return perf_model + +def load_center(classifier, peft_config, data_cfg, label_mapping=None): + + # center is a pre-computed center for each class for custom dataset + + if not peft_config.center_paths or classifier is None: + print('skip_center loading') + return classifier + + print('loading center') + center = None + + if data_cfg.rec: + if data_cfg.rec == 'webface260m/WebFace4M': + main_center_path = 'WebFace4M' + elif data_cfg.rec == 'webface260m/WebFace12M': + main_center_path = 'WebFace12M' + elif data_cfg.rec == 'webface260m/WebFace42M': + main_center_path = 'WebFace42M' + else: + raise NotImplementedError(f"data_cfg.rec: {data_cfg.rec}") + + main_center_path_full = os.path.join(peft_config.model_ckpt_dir, 'centers', main_center_path, 'center.pth') + if os.path.isfile(main_center_path_full): + main_center_st = torch.load(main_center_path_full) + else: + main_center_path = main_center_path.lower() + main_center_path_full = os.path.join(peft_config.model_ckpt_dir, 'centers', main_center_path, 'center.pth') + main_center_st = torch.load(main_center_path_full) + + center = main_center_st['center'] + + if label_mapping is not None: + if len(label_mapping) < len(center): + print('load label_mapping for center') + if list(label_mapping.values()) == list(range(len(label_mapping))): + # then we can just use key as index + center = center[list(label_mapping.keys()), :] + else: + raise NotImplementedError('label_mapping is not a simple index mapping') + + print('peft_config.center_paths', peft_config.center_paths) + for center_path in peft_config.center_paths: + _center = torch.load(os.path.join(peft_config.model_ckpt_dir, 'centers', center_path, 'center.pth'))['center'] + if center is None: + center = _center + else: + center = torch.cat([center, _center], dim=0) + + if classifier.world_size == 1: + assert len(center) == classifier.partial_fc.num_local + + class_start = classifier.partial_fc.class_start + num_sample = classifier.partial_fc.num_local + sub_center = center[class_start:class_start + num_sample, :] + if classifier.partial_fc.weight.shape[0] != sub_center.shape[0]: + print('Rank', classifier.rank, + 'classifier.partial_fc.weight.shape[0] != sub_center.shape[0]' + 'classifier.partial_fc.weight.shape[0]', classifier.partial_fc.weight.shape[0], + 'sub_center.shape[0]', sub_center.shape[0], + 'center.shape[0]', center.shape[0], + 'class_start', class_start, + 'num_sample', num_sample, + 'class_start+num_sample', class_start+num_sample) + extra = classifier.partial_fc.weight.shape[0] - sub_center.shape[0] + extra_center = sub_center[-extra:, :] + sub_center = torch.cat([sub_center, extra_center], dim=0) + classifier.partial_fc.weight.data.copy_(sub_center) + + if hasattr(classifier.partial_fc, 'batch_mean'): + print('load batch_mean and batch_std') + try: + print(main_center_st['batch_mean'], main_center_st['batch_std']) + batch_mean = main_center_st['batch_mean'] + batch_std = main_center_st['batch_std'] + except: + print('no batch_mean and batch_std, using default from webface12m') + batch_mean = torch.tensor([22.099]) + batch_std = torch.tensor([3.86]) + + classifier.partial_fc.batch_mean.data.copy_(batch_mean) + classifier.partial_fc.batch_std.data.copy_(batch_std) + return classifier + +def print_trainable_parameters(model): + trainable_params = 0 + all_param = 0 + trainable_names = [] + untrainable_names = [] + for name, param in model.named_parameters(): + all_param += param.numel() + if param.requires_grad: + trainable_params += param.numel() + trainable_names.append(name) + else: + untrainable_names.append(name) + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}" + ) + return trainable_names, untrainable_names diff --git a/cvlface/research/recognition/code/run_v1/pefts/configs/freeze.yaml b/cvlface/research/recognition/code/run_v1/pefts/configs/freeze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3fa9595ec45965426f6f2298922852222734ae20 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pefts/configs/freeze.yaml @@ -0,0 +1,5 @@ +name: freeze +model_ckpt_dir: '' +classifier_ckpt_dir: '' +center_paths: [] +target_modules: '' \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/pefts/configs/full.yaml b/cvlface/research/recognition/code/run_v1/pefts/configs/full.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e9bc1449384b5f90ce89470cafad218ee56fd3a --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pefts/configs/full.yaml @@ -0,0 +1,4 @@ +name: full +model_ckpt_dir: '' +classifier_ckpt_dir: '' +center_paths: [] \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/pefts/configs/lora.yaml b/cvlface/research/recognition/code/run_v1/pefts/configs/lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d32979820e2c59d77138fcbae6a47a22ed4d49da --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pefts/configs/lora.yaml @@ -0,0 +1,6 @@ +name: lora +model_ckpt_dir: '' +classifier_ckpt_dir: '' +center_paths: [] +target_modules: att_qkv +lora_rank: 16 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/pefts/configs/none.yaml b/cvlface/research/recognition/code/run_v1/pefts/configs/none.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fed39f05635adb0a1fe4490d4829779ac161b770 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pefts/configs/none.yaml @@ -0,0 +1 @@ +name: none \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/pefts/configs/part_freeze.yaml b/cvlface/research/recognition/code/run_v1/pefts/configs/part_freeze.yaml new file mode 100644 index 0000000000000000000000000000000000000000..16e15fa92b4a18e8a1f099e656ef4757fce952e0 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pefts/configs/part_freeze.yaml @@ -0,0 +1,5 @@ +name: part_freeze +model_ckpt_dir: '' +classifier_ckpt_dir: '' +center_paths: [] +target_modules: blocks.20 \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/pipelines/__init__.py b/cvlface/research/recognition/code/run_v1/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f535906ecf722f0c06239d1ce619e12ea64e53 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pipelines/__init__.py @@ -0,0 +1,50 @@ +from omegaconf import DictConfig +from typing import Union, Any +from .train_model_cls_pipeline import TrainModelClsPipeline +from .train_keypoint_model_cls_pipeline import TrainKeypointModelClsPipeline +from .infer_model_pipeline import InferModelPipeline +from .infer_aligner_model_pipeline import InferAlignerModelPipeline +from .infer_aligner_keypoint_model_pipeline import InferAlignerKeypointModelPipeline +from .infer_aligner_keypoint_model_nmescore_pipeline import InferAlignerKeypointModelNMEScorePipeline +def pipeline_from_config(pipeline_config: Union[DictConfig, dict], + model: Any=None, + classifier: Any=None, + aligner: Any=None, + optimizer: Any=None, + lr_scheduler: Any=None): + + if pipeline_config.name == 'TrainModelClsPipeline': + pipeline = TrainModelClsPipeline(model, classifier, optimizer, lr_scheduler) + elif pipeline_config.name == 'TrainKeypointModelClsPipeline': + pipeline = TrainKeypointModelClsPipeline(model, classifier, optimizer, lr_scheduler) + else: + raise NotImplementedError(f"pipeline {pipeline_config.name} not implemented") + + if pipeline_config.resume: + epoch, step, n_images_seen = pipeline.resume_from_dir(pipeline_config.resume) + start_epoch = epoch + 1 + + else: + start_epoch, step, n_images_seen = 0, 0, 0 + pipeline.start_epoch = start_epoch + pipeline.step = step + pipeline.n_images_seen = n_images_seen + + return pipeline + +def pipeline_from_name(name: str, model: Any=None, aligner: Any=None): + if name == 'infer_model_pipeline': + pipeline = InferModelPipeline(model) + elif name == 'infer_aligner_model_pipeline': + assert aligner.has_params() + pipeline = InferAlignerModelPipeline(aligner=aligner, model=model) + elif name == 'infer_aligner_keypoint_model_pipeline': + assert aligner.has_params() + pipeline = InferAlignerKeypointModelPipeline(aligner=aligner, model=model) + elif name == 'infer_aligner_keypoint_model_nmescore_pipeline': + assert aligner.has_params() + pipeline = InferAlignerKeypointModelNMEScorePipeline(aligner=aligner, model=model) + else: + raise NotImplementedError(f"pipeline {name} not implemented") + + return pipeline \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/pipelines/base.py b/cvlface/research/recognition/code/run_v1/pipelines/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e87cd99517a28af5ac15813e839de7f8719052d2 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pipelines/base.py @@ -0,0 +1,100 @@ +import os +import torch +import omegaconf +import shutil + + +class BasePipeline(): + + def __init__(self): + self.start_epoch = 0 + self.last_save_name = None + self.color_space = None + + @property + def module_names_list(self): + raise NotImplementedError + + def __call__(self, *args, **kwargs): + raise NotImplementedError + + def train(self): + raise NotImplementedError + + def eval(self): + raise NotImplementedError + + def integrity_check(self, dataset_color_space): + raise NotImplementedError + + def resume_from_dir(self, resume_dir): + pipepline_path = os.path.join(resume_dir, 'pipeline.pt') + pipeline_st = torch.load(pipepline_path, map_location='cpu') + self.start_epoch = pipeline_st['epoch'] + self.last_save_name = pipeline_st['last_save_name'] + self.color_space = pipeline_st['color_space'] + assert self.module_names_list == pipeline_st['module_names_list'] + + for name in self.module_names_list: + mod = getattr(self, name) + if mod is None: + continue + + if hasattr(mod, 'load_state_dict_from_path'): + mod.load_state_dict_from_path(pipepline_path.replace('pipeline.pt', f'{name}.pt')) + else: + mod_st = torch.load(pipepline_path.replace('pipeline.pt', f'{name}.pt'), map_location='cpu') + mod.load_state_dict(mod_st) + + epoch = pipeline_st['epoch'] + step = pipeline_st['step'] + n_images_seen = pipeline_st['n_images_seen'] + return epoch, step, n_images_seen + + def save(self, fabric, pipeline, cfg, epoch, step, n_images_seen, is_best=False): + + # save model (it could happen in more than rank 0) + save_dir = os.path.join(cfg.trainers.output_dir, 'checkpoints', f'epoch:{epoch}_step:{step}') + self.save_pipelines_and_configs(save_dir, fabric, pipeline, cfg, epoch, step, n_images_seen) + + fabric.barrier() + + if fabric.local_rank == 0: + if is_best: + best_save_dir = os.path.join(cfg.trainers.output_dir, 'checkpoints', f'best') + if os.path.exists(best_save_dir): + shutil.rmtree(best_save_dir) + shutil.copytree(save_dir, best_save_dir, dirs_exist_ok=True) + + # remove old checkpoints + if self.last_save_name is not None: + if os.path.exists(self.last_save_name): + os.system(f'rm -rf {self.last_save_name}') + self.last_save_name = save_dir + + fabric.barrier() + + + @staticmethod + def save_pipelines_and_configs(save_dir, fabric, pipeline, cfg, epoch, step, n_images_seen): + os.makedirs(save_dir, exist_ok=True) + for name in pipeline.module_names_list: + mod = getattr(pipeline, name) + if hasattr(mod, 'save_pretrained'): + # model, classifier, aligner, etc + mod.save_pretrained(save_dir=save_dir, name=f'{name}.pt', rank=fabric.local_rank) + elif mod is None: + pass + else: + # optimizer and lr_scheduler + if fabric.local_rank == 0: + torch.save(mod.state_dict(), os.path.join(save_dir, f'{name}.pt')) + + if fabric.local_rank == 0: + # save omega config to yaml + omegaconf.OmegaConf.save(cfg, os.path.join(save_dir, 'config.yaml')) + torch.save({'epoch': epoch, 'step': step, 'n_images_seen': n_images_seen, + 'cfg': cfg, 'module_names_list': pipeline.module_names_list, + 'last_save_name': pipeline.last_save_name, + 'color_space': pipeline.color_space, }, + os.path.join(save_dir, f'pipeline.pt')) diff --git a/cvlface/research/recognition/code/run_v1/pipelines/configs/train_keypoint_model_cls.yaml b/cvlface/research/recognition/code/run_v1/pipelines/configs/train_keypoint_model_cls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..09ece7f84911eb7d2c18512cfab039cca42afc2b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pipelines/configs/train_keypoint_model_cls.yaml @@ -0,0 +1,3 @@ +name: 'TrainKeypointModelClsPipeline' +resume: ${trainers.resume} +eval_pipeline_name: 'infer_aligner_keypoint_model_pipeline' \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/pipelines/configs/train_model_cls.yaml b/cvlface/research/recognition/code/run_v1/pipelines/configs/train_model_cls.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e19ef1dccd5c96f0471891e84777fc4b34b2f16f --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pipelines/configs/train_model_cls.yaml @@ -0,0 +1,3 @@ +name: 'TrainModelClsPipeline' +resume: ${trainers.resume} +eval_pipeline_name: 'infer_model_pipeline' \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/pipelines/infer_aligner_keypoint_model_nmescore_pipeline.py b/cvlface/research/recognition/code/run_v1/pipelines/infer_aligner_keypoint_model_nmescore_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..432b0cbb89d7dd0a224dc33bb07504a89f7b8d4b --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pipelines/infer_aligner_keypoint_model_nmescore_pipeline.py @@ -0,0 +1,98 @@ +import numpy as np + +from .base import BasePipeline +from models.base import BaseModel +from aligners.base import BaseAligner +from PIL import Image +import torch +from math import sqrt + + +class InferAlignerKeypointModelNMEScorePipeline(BasePipeline): + + def __init__(self, + aligner:BaseAligner, + model:BaseModel, + ): + super(InferAlignerKeypointModelNMEScorePipeline, self).__init__() + + self.aligner = aligner + self.model = model + self.eval() + + self._reference_landmark = torch.from_numpy(np.array([[38.29459953, 51.69630051], + [73.53179932, 51.50139999], + [56.02519989, 71.73660278], + [41.54930115, 92.3655014], + [70.72990036, 92.20410156]]) / 112) + self.reference_landmark = None + + @property + def module_names_list(self): + return ['aligner', 'model', ] + + def integrity_check(self, dataset_color_space): + # color space check + assert dataset_color_space == self.model.config.color_space + self.color_space = dataset_color_space + assert dataset_color_space == self.aligner.config.color_space + self.make_test_transform() + + def make_test_transform(self): + # check that aligner and model have the same transform + aligner_transform = self.aligner.make_test_transform() + model_transform = self.model.make_test_transform() + x = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) + x = Image.fromarray(x) + assert (aligner_transform(x) == model_transform(x)).all() + return model_transform + + def __call__(self, batch): + inputs = batch + + # if inputs.size(2) == 112: + # # we can pad it to be 160 + # padding_ratio_override = 0.215 + # else: + # padding_ratio_override = None + padding_ratio_override = None + + aligner_result = self.aligner(inputs.to(self.aligner.device), padding_ratio_override=padding_ratio_override) + aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox = aligner_result + assert inputs.size(2) == inputs.size(3) # we can only use orig_pred_ldmks if the image is not altered + feats = self.model(inputs, orig_pred_ldmks) + + if self.reference_landmark is None: + self.reference_landmark = self._reference_landmark.to(orig_pred_ldmks.device).unsqueeze(0) + + nme = self.calc_nme_batched(self.reference_landmark, orig_pred_ldmks, bbox=[0, 0, 1, 1]) + ldmk_score = torch.clip((0.2 - torch.clip(nme, 0, 0.2)) / 0.2 + 1e-6, 0, 1) + norm_norms = ldmk_score * score + norm_norms = torch.clip(norm_norms, 1e-6, 1) + norms = torch.norm(feats, p=2, dim=1, keepdim=True) + feats = feats / norms * norm_norms + + return feats + + + def train(self): + raise NotImplementedError('InferAlignerKeypointModelPipeline does not support train mode') + + + def eval(self): + self.aligner.eval() + self.model.eval() + + def calc_nme_batched(self, ldmk_gt, ldmk_pred, bbox): + minx, miny, maxx, maxy = bbox + llength = sqrt((maxx - minx) * (maxy - miny)) + + assert ldmk_gt.ndim == 3 + assert ldmk_pred.ndim == 3 + + # nme + dis = (ldmk_pred - ldmk_gt) ** 2 + dis = torch.sqrt(torch.sum(dis, 2)) + dis = torch.mean(dis, 1, keepdim=True) + nme = dis / llength + return nme diff --git a/cvlface/research/recognition/code/run_v1/pipelines/infer_aligner_keypoint_model_pipeline.py b/cvlface/research/recognition/code/run_v1/pipelines/infer_aligner_keypoint_model_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..2f566f408c49463093769f8939f1e81a4b1144b4 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pipelines/infer_aligner_keypoint_model_pipeline.py @@ -0,0 +1,65 @@ +import numpy as np + +from .base import BasePipeline +from models.base import BaseModel +from aligners.base import BaseAligner +from PIL import Image + +class InferAlignerKeypointModelPipeline(BasePipeline): + + def __init__(self, + aligner:BaseAligner, + model:BaseModel, + ): + super(InferAlignerKeypointModelPipeline, self).__init__() + + self.aligner = aligner + self.model = model + self.eval() + + @property + def module_names_list(self): + return ['aligner', 'model', ] + + def integrity_check(self, dataset_color_space): + # color space check + assert dataset_color_space == self.model.config.color_space + self.color_space = dataset_color_space + assert dataset_color_space == self.aligner.config.color_space + self.make_test_transform() + + def make_test_transform(self): + # check that aligner and model have the same transform + aligner_transform = self.aligner.make_test_transform() + model_transform = self.model.make_test_transform() + x = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) + x = Image.fromarray(x) + assert (aligner_transform(x) == model_transform(x)).all() + return model_transform + + def __call__(self, batch): + inputs = batch + + # if inputs.size(2) == 112: + # # we can pad it to be 160 + # padding_ratio_override = 0.215 + # else: + # padding_ratio_override = None + padding_ratio_override = None + + aligner_result = self.aligner(inputs.to(self.aligner.device), padding_ratio_override=padding_ratio_override) + aligned_x, orig_pred_ldmks, aligned_ldmks, score, thetas, normalized_bbox = aligner_result + assert inputs.size(2) == inputs.size(3) # we can only use orig_pred_ldmks if the image is not altered + feats = self.model(inputs, orig_pred_ldmks) + return feats + + + def train(self): + raise NotImplementedError('InferAlignerKeypointModelPipeline does not support train mode') + + + def eval(self): + self.aligner.eval() + self.model.eval() + + diff --git a/cvlface/research/recognition/code/run_v1/pipelines/infer_aligner_model_pipeline.py b/cvlface/research/recognition/code/run_v1/pipelines/infer_aligner_model_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..94d57ccd0df45b2906b7dac86f28aae56d01b0c1 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pipelines/infer_aligner_model_pipeline.py @@ -0,0 +1,55 @@ +import numpy as np + +from .base import BasePipeline +from models.base import BaseModel +from aligners.base import BaseAligner +from PIL import Image + +class InferAlignerModelPipeline(BasePipeline): + + def __init__(self, + aligner:BaseAligner, + model:BaseModel, + ): + super(InferAlignerModelPipeline, self).__init__() + + self.aligner = aligner + self.model = model + self.eval() + + @property + def module_names_list(self): + return ['aligner', 'model', ] + + def integrity_check(self, dataset_color_space): + # color space check + assert dataset_color_space == self.model.config.color_space + self.color_space = dataset_color_space + assert dataset_color_space == self.aligner.config.color_space + self.make_test_transform() + + def make_test_transform(self): + # check that aligner and model have the same transform + aligner_transform = self.aligner.make_test_transform() + model_transform = self.model.make_test_transform() + x = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8) + x = Image.fromarray(x) + assert (aligner_transform(x) == model_transform(x)).all() + return model_transform + + def __call__(self, batch): + inputs = batch + alinged_inputs = self.aligner(inputs.to(self.aligner.device))[0] + feats = self.model(alinged_inputs) + return feats + + + def train(self): + raise NotImplementedError('InferAlignerModelPipeline does not support train mode') + + + def eval(self): + self.aligner.eval() + self.model.eval() + + diff --git a/cvlface/research/recognition/code/run_v1/pipelines/infer_model_pipeline.py b/cvlface/research/recognition/code/run_v1/pipelines/infer_model_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..ce0682750c75ef424a6a17c39b64df42971a0e20 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pipelines/infer_model_pipeline.py @@ -0,0 +1,42 @@ + + +from .base import BasePipeline +from models.base import BaseModel + +class InferModelPipeline(BasePipeline): + + def __init__(self, + model:BaseModel, + ): + super(InferModelPipeline, self).__init__() + + self.model = model + self.eval() + + @property + def module_names_list(self): + return ['model', ] + + def integrity_check(self, dataset_color_space): + # color space check + assert dataset_color_space == self.model.config.color_space + self.color_space = dataset_color_space + self.make_test_transform() + + def make_test_transform(self): + return self.model.make_test_transform() + + def __call__(self, batch): + inputs = batch + feats = self.model(inputs) + return feats + + + def train(self): + raise NotImplementedError('InferModelPipeline does not support train mode') + + + def eval(self): + self.model.eval() + + diff --git a/cvlface/research/recognition/code/run_v1/pipelines/train_keypoint_model_cls_pipeline.py b/cvlface/research/recognition/code/run_v1/pipelines/train_keypoint_model_cls_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..5ea4258ff06714ac96805858cf80c50b68ba8942 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pipelines/train_keypoint_model_cls_pipeline.py @@ -0,0 +1,62 @@ +from .base import BasePipeline +from models.base import BaseModel +import torch + +class TrainKeypointModelClsPipeline(BasePipeline): + + def __init__(self, + model:BaseModel, + classifier:BaseModel, + optimizer, + lr_scheduler): + super(TrainKeypointModelClsPipeline, self).__init__() + + self.model = model + self.classifier = classifier + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + @property + def module_names_list(self): + return ['model', 'classifier', 'optimizer', 'lr_scheduler'] + + def integrity_check(self, dataset): + # color space check + dataset_color_space = dataset.color_space + assert dataset_color_space == self.model.config.color_space + self.color_space = dataset_color_space + self.make_train_transform() + + def make_train_transform(self): + return self.model.make_train_transform() + + def __call__(self, batch): + if len(batch) == 7: + sample1, targets, ldmk1, theta1, sample2, ldmk2, theta2 = batch + else: + raise ValueError('not supported batch format') + if sample2.ndim == 1: + inputs = sample1 + ldmks = ldmk1 + else: + inputs = torch.cat([sample1, sample2], dim=0) + ldmks = torch.cat([ldmk1, ldmk2], dim=0) + targets = torch.cat([targets, targets], dim=0) + + feats = self.model(inputs, ldmks) + loss = self.classifier(feats, targets.to(self.classifier.device)) + return loss + + + def train(self): + if not self.model.config.freeze: + self.model.train() + if not self.classifier.config.freeze: + self.classifier.train() + + + def eval(self): + self.model.eval() + self.classifier.eval() + + diff --git a/cvlface/research/recognition/code/run_v1/pipelines/train_model_cls_pipeline.py b/cvlface/research/recognition/code/run_v1/pipelines/train_model_cls_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..589e39f505bb47c4f2a86e2d58d7e88780c9cf54 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/pipelines/train_model_cls_pipeline.py @@ -0,0 +1,62 @@ +from .base import BasePipeline +from models.base import BaseModel +import torch + +class TrainModelClsPipeline(BasePipeline): + + def __init__(self, + model:BaseModel, + classifier:BaseModel, + optimizer, + lr_scheduler): + super(TrainModelClsPipeline, self).__init__() + + self.model = model + self.classifier = classifier + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + + @property + def module_names_list(self): + return ['model', 'classifier', 'optimizer', 'lr_scheduler'] + + def integrity_check(self, dataset): + # color space check + dataset_color_space = dataset.color_space + assert dataset_color_space == self.model.config.color_space + self.color_space = dataset_color_space + self.make_train_transform() + + def make_train_transform(self): + return self.model.make_train_transform() + + def __call__(self, batch): + if len(batch) == 2: + inputs, targets = batch + elif len(batch) == 4: + inputs, placeholder, targets, thetas = batch + elif len(batch) == 7: + inputs, targets, ldmk1, theta1, sample2, ldmk2, theta2 = batch + if sample2.ndim != 1: + inputs = torch.cat([inputs, sample2], dim=0) + targets = torch.cat([targets, targets], dim=0) + + else: + raise ValueError('not supported batch format') + feats = self.model(inputs) + loss = self.classifier(feats, targets.to(self.classifier.device)) + return loss + + + def train(self): + if not self.model.config.freeze: + self.model.train() + if not self.classifier.config.freeze: + self.classifier.train() + + + def eval(self): + self.model.eval() + self.classifier.eval() + + diff --git a/cvlface/research/recognition/code/run_v1/scripts/configs/agedb567_lora_r4_fgnet_age30.yaml b/cvlface/research/recognition/code/run_v1/scripts/configs/agedb567_lora_r4_fgnet_age30.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c2e8a1014843f27d6935c73f379e8c733f06fa0e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/configs/agedb567_lora_r4_fgnet_age30.yaml @@ -0,0 +1,26 @@ +train_root: /root/Lab1/data/processed/agedb_harmonliu05_224_by_identity +output_dir: /root/Lab1/experiments/agedb567_lora_r4_yaml +ckpt: /root/Lab1/project/cvlface/pretrained_models/model.safetensors + +num_classes: 567 +image_size: 224 +batch_size: 128 +num_workers: 8 +epochs: 20 + +amp: true +freeze_backbone: true +use_lora: true +lora_rank: 4 +lora_alpha: 4 +lora_dropout: 0.1 +lora_target_modules: qkv +lora_lr: 0.0003 +head_lr: 0.001 +save_lora_only: true + +save_every: 200 +eval_every_epochs: 5 +eval_batch_size: 128 +eval_pairs_csv: + - /root/Lab1/data/processed/fgnet_age30_protocol/facerec_val/fgnet_age30/pairs.csv diff --git a/cvlface/research/recognition/code/run_v1/scripts/debug/debug.sh b/cvlface/research/recognition/code/run_v1/scripts/debug/debug.sh new file mode 100644 index 0000000000000000000000000000000000000000..0a4473c8d733f94f1310088e53a41164b0408193 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/debug/debug.sh @@ -0,0 +1,10 @@ +trainers=configs/debug +models=vit/configs/v1_base +data_augs=configs/gridsample_v1 +dataset=configs/casia +pipelines=configs/train_model_cls +aligners=configs/none +classifiers=configs/partial_fc +evaluations=configs/quick +trainers.batch_size=8 +trainers.limit_num_batch=128 diff --git a/cvlface/research/recognition/code/run_v1/scripts/debug/inspect_swin_load.py b/cvlface/research/recognition/code/run_v1/scripts/debug/inspect_swin_load.py new file mode 100644 index 0000000000000000000000000000000000000000..36bc9180bf46b652c4a9d19c85d00fa417a06b3e --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/debug/inspect_swin_load.py @@ -0,0 +1,260 @@ +import argparse +import math +from pathlib import Path + +import torch +import torch.nn.functional as F + + +CRITICAL_PATTERNS = [ + "patch_embed", + "downsample", + "attn.qkv", + "attn.proj", + "mlp.fc1", + "mlp.fc2", + "norm", +] + + +def load_checkpoint(path, key="state_dict"): + path = str(path) + if path.endswith(".safetensors"): + from safetensors.torch import load_file + + return load_file(path) + + ckpt = torch.load(path, map_location="cpu") + if isinstance(ckpt, dict) and key in ckpt: + return ckpt[key] + if isinstance(ckpt, dict): + return ckpt + raise TypeError(f"Unsupported checkpoint type: {type(ckpt)}") + + +def strip_prefix(state_dict, prefix): + if not prefix: + return state_dict + out = {} + for key, value in state_dict.items(): + if key.startswith(prefix): + key = key[len(prefix) :] + out[key] = value + return out + + +def maybe_add_model_prefix(state_dict, model_state, prefix): + if not prefix: + return state_dict, False + if any(key in model_state for key in state_dict): + return state_dict, False + if not any(key.startswith(prefix) for key in model_state): + return state_dict, False + return {prefix + key: value for key, value in state_dict.items()}, True + + +def shape_of(value): + return tuple(value.shape) if hasattr(value, "shape") else None + + +def is_critical(key): + return any(pattern in key for pattern in CRITICAL_PATTERNS) + + +def resize_relative_position_bias_table(value, target_shape): + if value.ndim != 2 or len(target_shape) != 2: + return None + source_length, source_heads = value.shape + target_length, target_heads = target_shape + if source_heads != target_heads: + return None + + source_size = int(math.sqrt(source_length)) + target_size = int(math.sqrt(target_length)) + if source_size * source_size != source_length: + return None + if target_size * target_size != target_length: + return None + + value_float = value.float().permute(1, 0).reshape(1, source_heads, source_size, source_size) + resized = F.interpolate( + value_float, + size=(target_size, target_size), + mode="bicubic", + align_corners=False, + ) + resized = resized.reshape(source_heads, target_length).permute(1, 0).contiguous() + return resized.to(dtype=value.dtype, device=value.device) + + +def inspect_load( + model, + ckpt_path, + key="state_dict", + model_prefix_to_strip=None, + model_prefix_to_add="net.", + resize_relative_position_bias=False, +): + state_dict = load_checkpoint(ckpt_path, key=key) + state_dict = strip_prefix(state_dict, model_prefix_to_strip) + + model_state = model.state_dict() + state_dict, added_prefix = maybe_add_model_prefix( + state_dict, model_state, model_prefix_to_add + ) + + loaded = [] + resized_relative_position_bias = [] + skipped_shape = [] + missing = [] + unexpected = [] + + for name, value in state_dict.items(): + if name not in model_state: + unexpected.append((name, shape_of(value))) + continue + if shape_of(value) != shape_of(model_state[name]): + if resize_relative_position_bias and name.endswith("relative_position_bias_table"): + resized_value = resize_relative_position_bias_table( + value, + shape_of(model_state[name]), + ) + if resized_value is not None: + state_dict[name] = resized_value + loaded.append(name) + resized_relative_position_bias.append( + (name, shape_of(value), shape_of(resized_value)) + ) + continue + skipped_shape.append((name, shape_of(value), shape_of(model_state[name]))) + else: + loaded.append(name) + + for name, value in model_state.items(): + if name not in state_dict: + missing.append((name, shape_of(value))) + elif shape_of(state_dict[name]) != shape_of(value): + missing.append((name, shape_of(value))) + + critical_missing = [item for item in missing if is_critical(item[0])] + critical_skipped = [item for item in skipped_shape if is_critical(item[0])] + critical_unexpected = [item for item in unexpected if is_critical(item[0])] + + print("=" * 100) + print(f"checkpoint: {ckpt_path}") + print(f"added prefix: {model_prefix_to_add if added_prefix else ''}") + print(f"total model keys: {len(model_state)}") + print(f"checkpoint keys: {len(state_dict)}") + print(f"loaded: {len(loaded)}") + print(f"resized_rpb: {len(resized_relative_position_bias)}") + print(f"skipped_shape: {len(skipped_shape)}") + print(f"missing: {len(missing)}") + print(f"unexpected: {len(unexpected)}") + + print("\n[RESIZED RELATIVE_POSITION_BIAS_TABLE KEYS]") + for item in resized_relative_position_bias: + print(item) + + print("\n[SKIPPED SHAPE-MISMATCHED KEYS]") + for item in skipped_shape: + print(item) + + print("\n[MISSING KEYS]") + for item in missing: + print(item) + + print("\n[UNEXPECTED KEYS]") + for item in unexpected: + print(item) + + print("\n" + "=" * 100) + print("[CRITICAL CHECK]") + print(f"critical_missing: {len(critical_missing)}") + print(f"critical_skipped: {len(critical_skipped)}") + print(f"critical_unexpected: {len(critical_unexpected)}") + + if critical_skipped: + print("\n[CRITICAL SKIPPED]") + for item in critical_skipped: + print(item) + + if critical_missing: + print("\n[CRITICAL MISSING]") + for item in critical_missing: + print(item) + + if critical_unexpected: + print("\n[CRITICAL UNEXPECTED]") + for item in critical_unexpected: + print(item) + + grouped = { + "patch_merging": [x for x in critical_missing + critical_skipped if "downsample" in x[0]], + "qkv": [x for x in critical_missing + critical_skipped if "attn.qkv" in x[0]], + "attn_proj": [x for x in critical_missing + critical_skipped if "attn.proj" in x[0]], + "mlp": [x for x in critical_missing + critical_skipped if "mlp.fc1" in x[0] or "mlp.fc2" in x[0]], + "patch_embed": [x for x in critical_missing + critical_skipped if "patch_embed" in x[0]], + "norm": [x for x in critical_missing + critical_skipped if "norm" in x[0]], + } + + print("\n[CRITICAL GROUP COUNTS: missing + skipped_shape]") + for name, items in grouped.items(): + print(f"{name}: {len(items)}") + + return { + "loaded": loaded, + "resized_relative_position_bias": resized_relative_position_bias, + "skipped_shape": skipped_shape, + "missing": missing, + "unexpected": unexpected, + "critical_missing": critical_missing, + "critical_skipped": critical_skipped, + "critical_unexpected": critical_unexpected, + "critical_group_counts": {key: len(value) for key, value in grouped.items()}, + } + + +class ProjectLikeTimmSwin(torch.nn.Module): + def __init__(self, timm_name, img_size, output_dim): + super().__init__() + import timm + + self.net = timm.create_model( + timm_name, + pretrained=False, + img_size=img_size, + num_classes=output_dim, + ) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--ckpt", required=True, type=Path) + parser.add_argument( + "--timm-name", + default="swin_small_patch4_window7_224.ms_in22k_ft_in1k", + ) + parser.add_argument("--img-size", default=112, type=int) + parser.add_argument("--output-dim", default=512, type=int) + parser.add_argument("--key", default="state_dict") + parser.add_argument("--strip-prefix", default=None) + parser.add_argument("--add-prefix", default="net.") + parser.add_argument("--resize-relative-position-bias", action="store_true") + return parser.parse_args() + + +def main(): + args = parse_args() + model = ProjectLikeTimmSwin(args.timm_name, args.img_size, args.output_dim) + inspect_load( + model=model, + ckpt_path=args.ckpt, + key=args.key, + model_prefix_to_strip=args.strip_prefix, + model_prefix_to_add=args.add_prefix, + resize_relative_position_bias=args.resize_relative_position_bias, + ) + + +if __name__ == "__main__": + main() diff --git a/cvlface/research/recognition/code/run_v1/scripts/debug/test_relative_position_bias_resize.py b/cvlface/research/recognition/code/run_v1/scripts/debug/test_relative_position_bias_resize.py new file mode 100644 index 0000000000000000000000000000000000000000..080f97c721c7420752e93c5bea3ff5907dcf1cf3 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/debug/test_relative_position_bias_resize.py @@ -0,0 +1,26 @@ +import sys +from pathlib import Path + +import torch + +RUN_DIR = Path(__file__).resolve().parents[2] +sys.path.insert(0, str(RUN_DIR)) + +from models.base import BaseModel + + +def main(): + source = torch.arange(169 * 24, dtype=torch.float32).reshape(169, 24) + resized = BaseModel.resize_relative_position_bias_table(source, (25, 24)) + + assert resized is not None + assert tuple(resized.shape) == (25, 24) + assert resized.dtype == source.dtype + assert resized.device == source.device + + not_resizable = BaseModel.resize_relative_position_bias_table(source, (26, 24)) + assert not_resizable is None + + +if __name__ == "__main__": + main() diff --git a/cvlface/research/recognition/code/run_v1/scripts/eval/run_multi_gpu.sh b/cvlface/research/recognition/code/run_v1/scripts/eval/run_multi_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..8eee9c6d46dcda449fcfc6e615c46efea2e3ebbd --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/eval/run_multi_gpu.sh @@ -0,0 +1,97 @@ +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir18_casia +# +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir18_vgg2 +# +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir18_webface4m +# +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir50_casia +# +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir50_webface4m +# +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir50_ms1mv2 +# +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir101_ms1mv2 +# +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir101_ms1mv3 +# +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/arcface_ir101_webface4m +# +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir101_webface4m +# +#LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ +# --strategy=ddp \ +# --devices=4 \ +# --main_port=9999 \ +# --precision="32-true" \ +# eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir101_webface12m + +LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ + --strategy=ddp \ + --devices=4 \ + --main_port=9999 \ + --precision="32-true" \ + eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_vit_base_webface4m + +LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ + --strategy=ddp \ + --devices=4 \ + --main_port=9999 \ + --precision="32-true" \ + eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_vit_base_kprpe_webface4m + +LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3 lightning run model \ + --strategy=ddp \ + --devices=4 \ + --main_port=9999 \ + --precision="32-true" \ + eval.py --num_gpu 4 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_vit_base_kprpe_webface12m diff --git a/cvlface/research/recognition/code/run_v1/scripts/eval/run_single_gpu.sh b/cvlface/research/recognition/code/run_v1/scripts/eval/run_single_gpu.sh new file mode 100644 index 0000000000000000000000000000000000000000..f341436c3b33d7812a8fb3a62b8ec8fe1ae6f29d --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/eval/run_single_gpu.sh @@ -0,0 +1,14 @@ +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir18_casia +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir18_vgg2 +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir18_webface4m +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir50_casia +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir50_webface4m +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir50_ms1mv2 +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir101_ms1mv2 +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir101_ms1mv3 +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/arcface_ir101_webface4m +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir101_webface4m +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_ir101_webface12m +#python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_vit_base_webface4m +python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_vit_base_kprpe_webface4m +python eval.py --num_gpu 1 --eval_config_name full --ckpt_dir ../../../../pretrained_models/recognition/adaface_vit_base_kprpe_webface12m diff --git a/cvlface/research/recognition/code/run_v1/scripts/examples/run_adaface_ir101_webface4m.sh b/cvlface/research/recognition/code/run_v1/scripts/examples/run_adaface_ir101_webface4m.sh new file mode 100644 index 0000000000000000000000000000000000000000..c326b4bf4877763e62460cd58fff869c255bafb4 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/examples/run_adaface_ir101_webface4m.sh @@ -0,0 +1,21 @@ + +LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 lightning run model \ + --strategy=ddp \ + --devices=7 \ + --precision="32-true" \ + train.py trainers.prefix=ir101_WF4M_adaface \ + trainers.num_gpu=7 \ + trainers.batch_size=256 \ + trainers.gradient_acc=1 \ + trainers.num_workers=8 \ + trainers.precision='32-true' \ + trainers.float32_matmul_precision='high' \ + dataset=configs/webface4m.yaml \ + data_augs=configs/basic_v1.yaml \ + models=iresnet/configs/v1_ir101.yaml \ + pipelines=configs/train_model_cls.yaml \ + evaluations=configs/full.yaml \ + classifiers=configs/fc.yaml \ + optims=configs/step_sgd.yaml \ + losses=configs/adaface.yaml \ + trainers.skip_final_eval=False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/scripts/examples/run_adaface_ir50_casia.sh b/cvlface/research/recognition/code/run_v1/scripts/examples/run_adaface_ir50_casia.sh new file mode 100644 index 0000000000000000000000000000000000000000..0a5ca53c804b2a972e2a8678cc2f86ac7cd652c6 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/examples/run_adaface_ir50_casia.sh @@ -0,0 +1,21 @@ + +LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 lightning run model \ + --strategy=ddp \ + --devices=7 \ + --precision="32-true" \ + train.py trainers.prefix=ir50_casia_adaface \ + trainers.num_gpu=7 \ + trainers.batch_size=256 \ + trainers.gradient_acc=1 \ + trainers.num_workers=8 \ + trainers.precision='32-true' \ + trainers.float32_matmul_precision='high' \ + dataset=configs/casia.yaml \ + data_augs=configs/basic_v1.yaml \ + models=iresnet/configs/v1_ir50.yaml \ + pipelines=configs/train_model_cls.yaml \ + evaluations=configs/full.yaml \ + classifiers=configs/fc.yaml \ + optims=configs/step_sgd.yaml \ + losses=configs/adaface.yaml \ + trainers.skip_final_eval=False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/scripts/examples/run_swin_s_agedb.sh b/cvlface/research/recognition/code/run_v1/scripts/examples/run_swin_s_agedb.sh new file mode 100644 index 0000000000000000000000000000000000000000..20a133301924065d8796e822e63001394e3de2d8 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/examples/run_swin_s_agedb.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../../../../../../../" && pwd)" +RUN_DIR="$REPO_ROOT/cvlface/research/recognition/code/run_v1" + +export AGEDB_PROTOCOL_ROOT="$REPO_ROOT/data/agedb_protocol" +export SWIN_S_PRETRAINED="${SWIN_S_PRETRAINED:-$REPO_ROOT/cvlface/pretrained_models/model.safetensors}" + +cd "$RUN_DIR" + +python scripts/prepare_agedb_protocol.py \ + --source "$REPO_ROOT/AgeDB_aligned_224" \ + --output "$AGEDB_PROTOCOL_ROOT" \ + --train-ratio 0.8 \ + --seed 2048 \ + --max-pairs 3000 + +python train.py \ + trainers=configs/agedb_swin.yaml \ + dataset=configs/agedb_80.yaml \ + data_augs=configs/basic_v1.yaml \ + models=swin/configs/v1_swin_s_pretrained.yaml \ + pipelines=configs/train_model_cls.yaml \ + evaluations=configs/agedb_30_1to1.yaml \ + classifiers=configs/fc.yaml \ + optims=configs/plateau_adamw_agedb.yaml \ + losses=configs/adaface.yaml \ + aligners=configs/none.yaml \ + pefts=configs/none.yaml diff --git a/cvlface/research/recognition/code/run_v1/scripts/examples/run_vit_kprpe_webface4m.sh b/cvlface/research/recognition/code/run_v1/scripts/examples/run_vit_kprpe_webface4m.sh new file mode 100644 index 0000000000000000000000000000000000000000..adc2e54435a60961cda5f2bd6f9816e77f55422f --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/examples/run_vit_kprpe_webface4m.sh @@ -0,0 +1,22 @@ + +LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 lightning run model \ + --strategy=ddp \ + --devices=7 \ + --precision="32-true" \ + train.py trainers.prefix=vit_base_KPRPE_WF4M_adaface \ + trainers.num_gpu=7 \ + trainers.batch_size=256 \ + trainers.gradient_acc=1 \ + trainers.num_workers=8 \ + trainers.precision='32-true' \ + trainers.float32_matmul_precision='high' \ + dataset=configs/webface4m_ldmktheta_RA10.yaml \ + data_augs=configs/gridsample_v1.yaml \ + models=vit_kprpe/configs/v1_base_kprpe_splithead_unshared.yaml \ + pipelines=configs/train_keypoint_model_cls.yaml \ + aligners=config/dfa \ + evaluations=configs/full.yaml \ + classifiers=configs/fc.yaml \ + optims=configs/step_sgd.yaml \ + losses=configs/adaface.yaml \ + trainers.skip_final_eval=False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/scripts/examples/run_vit_webface4m.sh b/cvlface/research/recognition/code/run_v1/scripts/examples/run_vit_webface4m.sh new file mode 100644 index 0000000000000000000000000000000000000000..e90f2d7f5e95bb800366cacb0257adaf58dc1097 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/examples/run_vit_webface4m.sh @@ -0,0 +1,21 @@ + +LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 lightning run model \ + --strategy=ddp \ + --devices=7 \ + --precision="32-true" \ + train.py trainers.prefix=vit_base_WF4M_adaface \ + trainers.num_gpu=7 \ + trainers.batch_size=256 \ + trainers.gradient_acc=1 \ + trainers.num_workers=8 \ + trainers.precision='32-true' \ + trainers.float32_matmul_precision='high' \ + dataset=configs/webface4m_ldmktheta_RA10.yaml \ + data_augs=configs/gridsample_v1.yaml \ + models=vit/configs/v1_base.yaml \ + pipelines=configs/train_model_cls.yaml \ + evaluations=configs/full.yaml \ + classifiers=configs/fc.yaml \ + optims=configs/step_sgd.yaml \ + losses=configs/adaface.yaml \ + trainers.skip_final_eval=False \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/scripts/generate_plan8_report.py b/cvlface/research/recognition/code/run_v1/scripts/generate_plan8_report.py new file mode 100644 index 0000000000000000000000000000000000000000..524ad96919ff4322f58efc507f1515d3506c7d01 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/generate_plan8_report.py @@ -0,0 +1,178 @@ +import csv +import json +from pathlib import Path + +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +REPORT_DIR = Path("/root/Lab1/experiments/plan8_report") +AGEDB_METRICS = Path("/root/Lab1/experiments/agedb567_lora_r4_epochlog/logs/epoch_metrics.csv") +AGEDB_CONFIG = Path("/root/Lab1/experiments/agedb567_lora_r4_epochlog/logs/training_config.json") +WEBFACE_SUMMARY = Path("/root/Lab1/experiments/finetune_backbone_1epoch_bs128_amp/summary.json") +WEBFACE_LOG = Path("/root/Lab1/experiments/finetune_backbone_1epoch_bs128_amp/train_log.csv") + + +def read_csv(path): + with path.open(newline="", encoding="utf-8") as handle: + return list(csv.DictReader(handle)) + + +def save_line_plot(path, x, y, xlabel, ylabel, title, marker=None): + plt.figure(figsize=(8, 4.5)) + plt.plot(x, y, marker=marker, linewidth=1.8) + plt.xlabel(xlabel) + plt.ylabel(ylabel) + plt.title(title) + plt.grid(True, alpha=0.3) + plt.tight_layout() + plt.savefig(path, dpi=160) + plt.close() + + +def main(): + REPORT_DIR.mkdir(parents=True, exist_ok=True) + + agedb_rows = read_csv(AGEDB_METRICS) + agedb_config = json.loads(AGEDB_CONFIG.read_text(encoding="utf-8")) + webface_summary = json.loads(WEBFACE_SUMMARY.read_text(encoding="utf-8")) + webface_rows = read_csv(WEBFACE_LOG) + + epochs = [int(row["epoch"]) for row in agedb_rows] + losses = [float(row["train_loss"]) for row in agedb_rows] + fgnet_epochs = [int(row["epoch"]) for row in agedb_rows if row.get("fgnet_age30_acc")] + fgnet_acc = [float(row["fgnet_age30_acc"]) for row in agedb_rows if row.get("fgnet_age30_acc")] + + loss_plot = REPORT_DIR / "agedb_lora_epoch_loss.png" + acc_plot = REPORT_DIR / "agedb_lora_fgnet_acc.png" + web_loss_plot = REPORT_DIR / "webface_full_1epoch_loss.png" + save_line_plot(loss_plot, epochs, losses, "Epoch", "Train loss", "AgeDB-567 LoRA Training Loss", marker="o") + save_line_plot(acc_plot, fgnet_epochs, fgnet_acc, "Epoch", "FGNET AGE-30 accuracy (%)", "AgeDB-567 LoRA Validation Accuracy", marker="o") + save_line_plot( + web_loss_plot, + [int(row["step"]) for row in webface_rows], + [float(row["loss"]) for row in webface_rows], + "Step", + "Train loss", + "WebFace Full Fine-tune 1 Epoch Loss", + ) + + results = { + "webface_full_1epoch": { + "checkpoint": "/root/Lab1/experiments/finetune_backbone_1epoch_bs128_amp/checkpoints/best.pt", + "method": "Full backbone fine-tune, not LoRA", + "train_dataset": "WebFace-224", + "train_images": 490623, + "classes": 10572, + "epochs": 1, + "batch_size": 128, + "fgnet_age30_acc": 61.8118572292801, + "fgnet_age30_std": 4.725344134795952, + "agedb_30_acc": 77.45, + "agedb_30_std": 1.8679311193581702, + }, + "agedb567_lora_r4": { + "checkpoint_latest": "/root/Lab1/experiments/agedb567_lora_r4_epochlog/checkpoints/latest.pt", + "checkpoint_best": "/root/Lab1/experiments/agedb567_lora_r4_epochlog/checkpoints/best.pt", + "method": "LoRA qkv r=4 alpha=4 + AdaFace head", + "train_dataset": "AgeDB-567 224 by identity", + "train_images": agedb_config["train_images"], + "classes": agedb_config["train_classes"], + "epochs": agedb_config["epochs"], + "batch_size": agedb_config["batch_size"], + "fgnet_age30_best_acc": max(fgnet_acc), + "fgnet_age30_best_epoch": fgnet_epochs[fgnet_acc.index(max(fgnet_acc))], + "fgnet_age30_final_acc": fgnet_acc[-1], + "agedb_30_best_checkpoint_acc": 65.68333333333334, + "agedb_30_best_checkpoint_std": 1.4915502747886933, + "agedb_30_latest_checkpoint_acc": 68.0, + "agedb_30_latest_checkpoint_std": 1.4719601443879762, + }, + } + (REPORT_DIR / "plan8_results.json").write_text(json.dumps(results, indent=2), encoding="utf-8") + + metrics_table = "\n".join( + f"| {row['epoch']} | {row['train_loss']} | {row.get('fgnet_age30_acc') or '-'} |" + for row in agedb_rows + ) + + report = f"""# Plan8 实验报告 + +生成时间:2026-04-16 + +## 数据与协议 + +| 项目 | 路径 | 说明 | +|---|---|---| +| WebFace-224 | `/root/Lab1/data/processed/webface_224` | 490,623 张,10,572 类 | +| AgeDB-567 train | `/root/Lab1/data/processed/agedb_harmonliu05_224_by_identity` | 16,488 张,567 类 | +| FGNET AGE-30 | `/root/Lab1/data/processed/fgnet_age30_protocol/facerec_val/fgnet_age30/pairs.csv` | 576 pairs,正负各 288 | +| AgeDB-30 | `/root/Lab1/data/processed/kaggle_verification_protocol/facerec_val/agedb_30/pairs.csv` | Kaggle verification protocol | + +## 记录 1:WebFace 全量微调 1 epoch + +| 字段 | 值 | +|---|---| +| checkpoint | `/root/Lab1/experiments/finetune_backbone_1epoch_bs128_amp/checkpoints/best.pt` | +| 微调方式 | Full backbone fine-tune,不是 LoRA | +| 训练集 | WebFace-224 | +| 训练图片/类别 | 490,623 / 10,572 | +| batch size | 128 | +| 训练 epoch | 1 | +| FGNET AGE-30 acc | {results['webface_full_1epoch']['fgnet_age30_acc']:.4f} ± {results['webface_full_1epoch']['fgnet_age30_std']:.4f} | +| AgeDB-30 acc | {results['webface_full_1epoch']['agedb_30_acc']:.4f} ± {results['webface_full_1epoch']['agedb_30_std']:.4f} | + +![WebFace loss](webface_full_1epoch_loss.png) + +## 记录 2:AgeDB-567 LoRA 训练 20 epoch + +| 字段 | 值 | +|---|---| +| output dir | `/root/Lab1/experiments/agedb567_lora_r4_epochlog` | +| latest checkpoint | `/root/Lab1/experiments/agedb567_lora_r4_epochlog/checkpoints/latest.pt` | +| best checkpoint | `/root/Lab1/experiments/agedb567_lora_r4_epochlog/checkpoints/best.pt` | +| 微调方式 | LoRA qkv, r=4, alpha=4, dropout=0.1 + AdaFace head | +| 训练集 | AgeDB-567 224 by identity | +| 训练图片/类别 | {agedb_config['train_images']} / {agedb_config['train_classes']} | +| batch size | {agedb_config['batch_size']} | +| epoch | {agedb_config['epochs']} | +| eval every | {agedb_config['eval_every_epochs']} epochs | +| FGNET AGE-30 best acc | {results['agedb567_lora_r4']['fgnet_age30_best_acc']:.4f} @ epoch {results['agedb567_lora_r4']['fgnet_age30_best_epoch']} | +| FGNET AGE-30 final acc | {results['agedb567_lora_r4']['fgnet_age30_final_acc']:.4f} @ epoch {fgnet_epochs[-1]} | +| AgeDB-30 best checkpoint acc | {results['agedb567_lora_r4']['agedb_30_best_checkpoint_acc']:.4f} ± {results['agedb567_lora_r4']['agedb_30_best_checkpoint_std']:.4f} | +| AgeDB-30 latest checkpoint acc | {results['agedb567_lora_r4']['agedb_30_latest_checkpoint_acc']:.4f} ± {results['agedb567_lora_r4']['agedb_30_latest_checkpoint_std']:.4f} | + +![AgeDB LoRA loss](agedb_lora_epoch_loss.png) + +![AgeDB LoRA FGNET acc](agedb_lora_fgnet_acc.png) + +### Epoch 指标 + +| epoch | train loss | FGNET AGE-30 acc | +|---:|---:|---:| +{metrics_table} + +## 简要结论 + +1. WebFace 全量微调 1 epoch 在 FGNET AGE-30 上为 {results['webface_full_1epoch']['fgnet_age30_acc']:.2f}%,在 AgeDB-30 上为 {results['webface_full_1epoch']['agedb_30_acc']:.2f}%。 +2. AgeDB-567 LoRA 训练 loss 从 {losses[0]:.4f} 下降到 {losses[-1]:.4f},但 FGNET AGE-30 accuracy 没有随 epoch 稳定提升,最好在 epoch {results['agedb567_lora_r4']['fgnet_age30_best_epoch']},为 {results['agedb567_lora_r4']['fgnet_age30_best_acc']:.2f}%。 +3. AgeDB-567 LoRA 的 latest checkpoint 在 AgeDB-30 上为 {results['agedb567_lora_r4']['agedb_30_latest_checkpoint_acc']:.2f}%,高于 best checkpoint 的 {results['agedb567_lora_r4']['agedb_30_best_checkpoint_acc']:.2f}%。这里的 best 是按 FGNET 指标保存的,不是按 AgeDB-30 保存的。 + +## 产物 + +| 文件 | 说明 | +|---|---| +| `plan8_report.md` | 本报告 | +| `plan8_results.json` | 结构化结果 | +| `agedb_lora_epoch_loss.png` | AgeDB-LoRA epoch loss 曲线 | +| `agedb_lora_fgnet_acc.png` | AgeDB-LoRA FGNET AGE-30 acc 曲线 | +| `webface_full_1epoch_loss.png` | WebFace full fine-tune step loss 曲线 | +""" + (REPORT_DIR / "plan8_report.md").write_text(report, encoding="utf-8") + print(REPORT_DIR / "plan8_report.md") + + +if __name__ == "__main__": + main() diff --git a/cvlface/research/recognition/code/run_v1/scripts/mock/run.sh b/cvlface/research/recognition/code/run_v1/scripts/mock/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..975a3b5aae38e06133c2a8c36a050b73c2a4ffe3 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/mock/run.sh @@ -0,0 +1,40 @@ + +# 1 gpu +python train.py trainers.prefix=test_run \ + trainers.num_gpu=1 \ + trainers.batch_size=32 \ + trainers.limit_num_batch=128 \ + trainers.gradient_acc=1 \ + trainers.num_workers=8 \ + trainers.precision='32-true' \ + trainers.float32_matmul_precision='high' \ + dataset=configs/synthetic.yaml \ + data_augs=configs/basic_v1.yaml \ + models=iresnet/configs/v1_ir50.yaml \ + pipelines=configs/train_model_cls.yaml \ + evaluations=configs/base.yaml \ + classifiers=configs/fc.yaml \ + optims=configs/step_sgd.yaml \ + losses=configs/cosface.yaml + +# multi gpu +LIGHTING_TESTING=1 CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6 lightning run model \ + --strategy=ddp \ + --devices=7 \ + --precision="32-true" \ + train.py trainers.prefix=test_run \ + trainers.num_gpu=7 \ + trainers.batch_size=32 \ + trainers.limit_num_batch=128 \ + trainers.gradient_acc=1 \ + trainers.num_workers=8 \ + trainers.precision='32-true' \ + trainers.float32_matmul_precision='high' \ + dataset=configs/synthetic.yaml \ + data_augs=configs/basic_v1.yaml \ + models=iresnet/configs/v1_ir50.yaml \ + pipelines=configs/train_model_cls.yaml \ + evaluations=configs/base.yaml \ + classifiers=configs/fc.yaml \ + optims=configs/step_sgd.yaml \ + losses=configs/cosface.yaml \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/scripts/prepare_agedb_protocol.py b/cvlface/research/recognition/code/run_v1/scripts/prepare_agedb_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..c162c8aa13234cdd669d39b325da64c3344358f1 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/prepare_agedb_protocol.py @@ -0,0 +1,139 @@ +import argparse +import csv +import os +import random +import shutil +from collections import defaultdict +from pathlib import Path + + +def parse_agedb_name(path): + left, age, gender = path.stem.rsplit('_', 2) + _, identity = left.split('_', 1) + return identity, int(age), gender + + +def link_or_copy(src, dst, copy_files): + dst.parent.mkdir(parents=True, exist_ok=True) + if dst.exists() or dst.is_symlink(): + return + if copy_files: + shutil.copy2(src, dst) + else: + os.symlink(os.path.relpath(src, dst.parent), dst) + + +def build_positive_pairs(test_rows, max_pairs): + by_identity = defaultdict(list) + for row in test_rows: + by_identity[row['identity']].append(row) + + candidates = [] + for rows in by_identity.values(): + rows = sorted(rows, key=lambda x: (x['age'], x['path'].name)) + for i, left in enumerate(rows): + for right in rows[i + 1:]: + age_gap = abs(left['age'] - right['age']) + if age_gap >= 30: + candidates.append((age_gap, left, right)) + candidates.sort(key=lambda x: (-x[0], x[1]['path'].name, x[2]['path'].name)) + return [(left, right) for _, left, right in candidates[:max_pairs]] + + +def build_negative_pairs(test_rows, count, seed): + rng = random.Random(seed) + rows = sorted(test_rows, key=lambda x: x['path'].name) + by_identity = defaultdict(list) + for row in rows: + by_identity[row['identity']].append(row) + + identities = sorted(by_identity) + pairs = [] + used = set() + attempts = 0 + while len(pairs) < count and attempts < count * 100: + attempts += 1 + left_id, right_id = rng.sample(identities, 2) + left = rng.choice(by_identity[left_id]) + right = rng.choice(by_identity[right_id]) + key = tuple(sorted([left['path'].name, right['path'].name])) + if key in used: + continue + used.add(key) + pairs.append((left, right)) + if len(pairs) < count: + raise RuntimeError(f'Only built {len(pairs)} negative pairs, expected {count}') + return pairs + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--source', default='AgeDB_aligned_224') + parser.add_argument('--output', default='data/agedb_protocol') + parser.add_argument('--train-ratio', type=float, default=0.8) + parser.add_argument('--seed', type=int, default=2048) + parser.add_argument('--max-pairs', type=int, default=3000) + parser.add_argument('--copy-files', action='store_true') + args = parser.parse_args() + + source = Path(args.source).resolve() + output = Path(args.output).resolve() + train_dir = output / 'agedb_train_80' + val_dir = output / 'facerec_val' / 'agedb_30_1to1' + val_dir.mkdir(parents=True, exist_ok=True) + + by_identity = defaultdict(list) + for path in sorted(source.glob('*.jpg')): + identity, age, gender = parse_agedb_name(path) + by_identity[identity].append({'path': path, 'identity': identity, 'age': age, 'gender': gender}) + + rng = random.Random(args.seed) + train_rows = [] + test_rows = [] + for identity, rows in sorted(by_identity.items()): + rows = sorted(rows, key=lambda x: x['path'].name) + indices = list(range(len(rows))) + rng.shuffle(indices) + if len(rows) > 1: + n_test = max(1, int(round(len(rows) * (1 - args.train_ratio)))) + else: + n_test = 0 + test_indices = set(indices[:n_test]) + for idx, row in enumerate(rows): + if idx in test_indices: + test_rows.append(row) + else: + train_rows.append(row) + + for row in train_rows: + dst = train_dir / row['identity'] / row['path'].name + link_or_copy(row['path'], dst, args.copy_files) + + positive_pairs = build_positive_pairs(test_rows, args.max_pairs) + if not positive_pairs: + raise RuntimeError('No AgeDB-30 positive pairs found in the test split') + negative_pairs = build_negative_pairs(test_rows, len(positive_pairs), args.seed) + + pair_rows = [] + pair_index = 0 + for is_same, pairs in [(True, positive_pairs), (False, negative_pairs)]: + for left, right in pairs: + pair_rows.append({'path': str(left['path']), 'index': pair_index * 2, 'is_same': is_same}) + pair_rows.append({'path': str(right['path']), 'index': pair_index * 2 + 1, 'is_same': is_same}) + pair_index += 1 + + with open(val_dir / 'pairs.csv', 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=['path', 'index', 'is_same']) + writer.writeheader() + writer.writerows(pair_rows) + + print(f'identities: {len(by_identity)}') + print(f'train images: {len(train_rows)}') + print(f'test images: {len(test_rows)}') + print(f'verification pairs: {len(pair_rows) // 2} ({len(positive_pairs)} positive, {len(negative_pairs)} negative)') + print(f'train_dir: {train_dir}') + print(f'val_dir: {val_dir}') + + +if __name__ == '__main__': + main() diff --git a/cvlface/research/recognition/code/run_v1/scripts/prepare_fgnet_age30_protocol.py b/cvlface/research/recognition/code/run_v1/scripts/prepare_fgnet_age30_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..46b193a69f0f3a0e9e68ac390b82d69ece3b9b9d --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/prepare_fgnet_age30_protocol.py @@ -0,0 +1,217 @@ +import argparse +import csv +import random +import re +from dataclasses import dataclass +from itertools import combinations +from pathlib import Path + +import cv2 +import numpy as np + + +IMAGE_PATTERN = re.compile(r"^(?P\d{3})[Aa](?P\d{2,3})(?P[A-Za-z]?)$") +ARCFACE_REFERENCE_112 = np.array( + [ + [38.2946, 51.6963], + [73.5318, 51.5014], + [56.0252, 71.7366], + [41.5493, 92.3655], + [70.7299, 92.2041], + ], + dtype=np.float32, +) + + +@dataclass(frozen=True) +class FgnetImage: + subject: str + age: int + source_path: Path + aligned_path: Path + + @property + def stem(self): + return self.source_path.stem + + +@dataclass(frozen=True) +class VerificationPair: + left: FgnetImage + right: FgnetImage + is_same: bool + + +def parse_fgnet_filename(path): + match = IMAGE_PATTERN.match(path.stem) + if not match: + raise ValueError(f"Unexpected FGNET filename: {path.name}") + return FgnetImage( + subject=match.group("subject"), + age=int(match.group("age")), + source_path=path, + aligned_path=path, + ) + + +def read_pts(path): + points = [] + in_points = False + with path.open("r", encoding="utf-8", errors="replace") as handle: + for line in handle: + line = line.strip() + if line == "{": + in_points = True + continue + if line == "}": + break + if in_points and line: + x, y = line.split()[:2] + points.append([float(x), float(y)]) + if len(points) != 68: + raise ValueError(f"Expected 68 landmarks in {path}, got {len(points)}") + return np.asarray(points, dtype=np.float32) + + +def five_point_landmarks(points_68): + left_eye = points_68[36:42].mean(axis=0) + right_eye = points_68[42:48].mean(axis=0) + nose = points_68[30] + mouth_left = points_68[48] + mouth_right = points_68[54] + return np.stack([left_eye, right_eye, nose, mouth_left, mouth_right]).astype(np.float32) + + +def align_image(image_path, points_path, output_path, image_size): + image = cv2.imread(str(image_path)) + if image is None: + raise ValueError(f"Failed to read image: {image_path}") + landmarks = five_point_landmarks(read_pts(points_path)) + reference = ARCFACE_REFERENCE_112 * (image_size / 112.0) + transform, _ = cv2.estimateAffinePartial2D(landmarks, reference, method=cv2.LMEDS) + if transform is None: + raise ValueError(f"Failed to estimate affine transform for {image_path}") + aligned = cv2.warpAffine(image, transform, (image_size, image_size), borderValue=0.0) + output_path.parent.mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(output_path), aligned) + + +def collect_and_align(fgnet_root, output_dir, image_size): + images_dir = fgnet_root / "images" + points_dir = fgnet_root / "points" + records = [] + for source_path in sorted(images_dir.iterdir()): + if source_path.suffix.lower() not in {".jpg", ".jpeg", ".png", ".bmp"}: + continue + parsed = parse_fgnet_filename(source_path) + points_path = points_dir / f"{source_path.stem.lower()}.pts" + if not points_path.is_file(): + raise FileNotFoundError(points_path) + aligned_path = output_dir / "images" / source_path.name.lower() + align_image(source_path, points_path, aligned_path, image_size) + records.append( + FgnetImage( + subject=parsed.subject, + age=parsed.age, + source_path=source_path, + aligned_path=aligned_path.resolve(), + ) + ) + return records + + +def build_age30_pairs(records, seed=2048): + positives = [] + negatives = [] + for left, right in combinations(records, 2): + if abs(left.age - right.age) < 30: + continue + if left.subject == right.subject: + positives.append(VerificationPair(left, right, True)) + else: + negatives.append(VerificationPair(left, right, False)) + + if not positives: + raise ValueError("No positive AGE-30 pairs found") + if len(negatives) < len(positives): + raise ValueError(f"Not enough negative AGE-30 pairs: {len(negatives)} < {len(positives)}") + + rng = random.Random(seed) + negatives = rng.sample(negatives, len(positives)) + pairs = positives + negatives + rng.shuffle(pairs) + return pairs + + +def write_pairs_csv(pairs, output_dir): + output_dir.mkdir(parents=True, exist_ok=True) + pairs_csv = output_dir / "pairs.csv" + with pairs_csv.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=["path", "index", "is_same"]) + writer.writeheader() + for pair_index, pair in enumerate(pairs): + writer.writerow( + { + "path": str(pair.left.aligned_path.resolve()), + "index": pair_index * 2, + "is_same": pair.is_same, + } + ) + writer.writerow( + { + "path": str(pair.right.aligned_path.resolve()), + "index": pair_index * 2 + 1, + "is_same": pair.is_same, + } + ) + return pairs_csv + + +def write_manifest(records, pairs, output_dir): + manifest_csv = output_dir / "manifest.csv" + with manifest_csv.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=["path", "subject", "age", "source_path"]) + writer.writeheader() + for record in records: + writer.writerow( + { + "path": str(record.aligned_path), + "subject": record.subject, + "age": record.age, + "source_path": str(record.source_path.resolve()), + } + ) + summary_path = output_dir / "summary.csv" + same_count = sum(1 for pair in pairs if pair.is_same) + diff_count = len(pairs) - same_count + with summary_path.open("w", newline="", encoding="utf-8") as handle: + writer = csv.writer(handle) + writer.writerow(["images", len(records)]) + writer.writerow(["pairs", len(pairs)]) + writer.writerow(["same_pairs", same_count]) + writer.writerow(["different_pairs", diff_count]) + return manifest_csv, summary_path + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--fgnet-root", required=True, type=Path) + parser.add_argument("--output-root", required=True, type=Path) + parser.add_argument("--image-size", type=int, default=224) + parser.add_argument("--seed", type=int, default=2048) + args = parser.parse_args() + + output_dir = args.output_root / "facerec_val" / "fgnet_age30" + records = collect_and_align(args.fgnet_root, output_dir, args.image_size) + pairs = build_age30_pairs(records, seed=args.seed) + pairs_csv = write_pairs_csv(pairs, output_dir) + manifest_csv, summary_csv = write_manifest(records, pairs, output_dir) + print(f"images: {len(records)}") + print(f"pairs: {len(pairs)}") + print(f"pairs_csv: {pairs_csv}") + print(f"manifest_csv: {manifest_csv}") + print(f"summary_csv: {summary_csv}") + + +if __name__ == "__main__": + main() diff --git a/cvlface/research/recognition/code/run_v1/scripts/prepare_kaggle_verification_protocol.py b/cvlface/research/recognition/code/run_v1/scripts/prepare_kaggle_verification_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..3178cba54a9d3de3a7d982c72c64223171cfff8a --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/prepare_kaggle_verification_protocol.py @@ -0,0 +1,78 @@ +import argparse +import csv +from pathlib import Path + + +DATASETS = { + "agedb_30": "agedb_30_ann.txt", + "calfw": "calfw_ann.txt", + "cplfw": "cplfw_ann.txt", + "lfw": "lfw_ann.txt", + "combined": "combined_ann.txt", +} + + +def write_pairs_csv(annotation_path, raw_val_dir, output_dir): + output_dir.mkdir(parents=True, exist_ok=True) + pair_rows = [] + with annotation_path.open("r", encoding="utf-8") as handle: + for pair_index, line in enumerate(handle): + line = line.strip() + if not line: + continue + label, left_rel, right_rel = line.split() + left_path = raw_val_dir / left_rel + right_path = raw_val_dir / right_rel + if not left_path.is_file(): + raise FileNotFoundError(left_path) + if not right_path.is_file(): + raise FileNotFoundError(right_path) + is_same = label == "1" + pair_rows.append( + { + "path": str(left_path.resolve()), + "index": pair_index * 2, + "is_same": is_same, + } + ) + pair_rows.append( + { + "path": str(right_path.resolve()), + "index": pair_index * 2 + 1, + "is_same": is_same, + } + ) + + pairs_csv = output_dir / "pairs.csv" + with pairs_csv.open("w", newline="", encoding="utf-8") as handle: + writer = csv.DictWriter(handle, fieldnames=["path", "index", "is_same"]) + writer.writeheader() + writer.writerows(pair_rows) + return pairs_csv, len(pair_rows) // 2 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--raw-val-root", required=True) + parser.add_argument("--output-root", required=True) + args = parser.parse_args() + + raw_val_dir = Path(args.raw_val_root).resolve() + output_root = Path(args.output_root).resolve() + if not raw_val_dir.is_dir(): + raise NotADirectoryError(raw_val_dir) + + for name, ann_name in DATASETS.items(): + annotation_path = raw_val_dir / ann_name + if not annotation_path.is_file(): + raise FileNotFoundError(annotation_path) + pairs_csv, pair_count = write_pairs_csv( + annotation_path=annotation_path, + raw_val_dir=raw_val_dir, + output_dir=output_root / "facerec_val" / name, + ) + print(f"{name}: {pair_count} pairs -> {pairs_csv}") + + +if __name__ == "__main__": + main() diff --git a/cvlface/research/recognition/code/run_v1/scripts/resize_image_folder.py b/cvlface/research/recognition/code/run_v1/scripts/resize_image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..cface3b4b70900311f61dffb366be0ac236dc501 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/resize_image_folder.py @@ -0,0 +1,108 @@ +import argparse +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +from PIL import Image, ImageOps +from tqdm import tqdm + + +IMAGE_EXTENSIONS = {".bmp", ".jpeg", ".jpg", ".png", ".webp"} + + +def iter_image_paths(input_root, max_images=None): + count = 0 + for path in sorted(input_root.rglob("*")): + if not path.is_file() or path.suffix.lower() not in IMAGE_EXTENSIONS: + continue + yield path + count += 1 + if max_images is not None and count >= max_images: + return + + +def output_path_for(input_path, input_root, output_root, output_ext): + relative_path = input_path.relative_to(input_root) + if output_ext: + relative_path = relative_path.with_suffix(output_ext) + return output_root / relative_path + + +def resize_one(input_path, input_root, output_root, size, output_ext, quality, overwrite): + output_path = output_path_for(input_path, input_root, output_root, output_ext) + if output_path.exists() and not overwrite: + return "skipped" + + output_path.parent.mkdir(parents=True, exist_ok=True) + with Image.open(input_path) as image: + image = ImageOps.exif_transpose(image).convert("RGB") + image = image.resize((size, size), Image.Resampling.BICUBIC) + save_kwargs = {} + if output_path.suffix.lower() in {".jpg", ".jpeg"}: + save_kwargs = {"quality": quality, "optimize": True} + image.save(output_path, **save_kwargs) + return "written" + + +def resize_folder(input_root, output_root, size, workers, output_ext, quality, overwrite, max_images): + input_root = Path(input_root).resolve() + output_root = Path(output_root).resolve() + if not input_root.is_dir(): + raise NotADirectoryError(input_root) + if input_root == output_root: + raise ValueError("output_root must be different from input_root") + + image_paths = list(iter_image_paths(input_root, max_images=max_images)) + output_root.mkdir(parents=True, exist_ok=True) + counts = {"written": 0, "skipped": 0} + + def work(path): + return resize_one( + input_path=path, + input_root=input_root, + output_root=output_root, + size=size, + output_ext=output_ext, + quality=quality, + overwrite=overwrite, + ) + + with ThreadPoolExecutor(max_workers=workers) as executor: + for status in tqdm(executor.map(work, image_paths), total=len(image_paths), desc="resize"): + counts[status] += 1 + + return counts, len(image_paths) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--input-root", required=True) + parser.add_argument("--output-root", required=True) + parser.add_argument("--size", type=int, default=224) + parser.add_argument("--workers", type=int, default=8) + parser.add_argument("--output-ext", default=".jpg") + parser.add_argument("--quality", type=int, default=95) + parser.add_argument("--overwrite", action="store_true") + parser.add_argument("--max-images", type=int, default=None) + return parser.parse_args() + + +def main(): + args = parse_args() + counts, total = resize_folder( + input_root=args.input_root, + output_root=args.output_root, + size=args.size, + workers=args.workers, + output_ext=args.output_ext, + quality=args.quality, + overwrite=args.overwrite, + max_images=args.max_images, + ) + print(f"images: {total}") + print(f"written: {counts['written']}") + print(f"skipped: {counts['skipped']}") + print(f"output_root: {Path(args.output_root).resolve()}") + + +if __name__ == "__main__": + main() diff --git a/cvlface/research/recognition/code/run_v1/scripts/test_prepare_fgnet_age30_protocol.py b/cvlface/research/recognition/code/run_v1/scripts/test_prepare_fgnet_age30_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..2a0239cae3e61c92ee0751ade2cdab55f7f4bfe0 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/test_prepare_fgnet_age30_protocol.py @@ -0,0 +1,63 @@ +import importlib.util +import shutil +from pathlib import Path + +from PIL import Image + + +SCRIPT_PATH = Path(__file__).with_name("prepare_fgnet_age30_protocol.py") +SPEC = importlib.util.spec_from_file_location("prepare_fgnet_age30_protocol", SCRIPT_PATH) +fgnet_script = importlib.util.module_from_spec(SPEC) +SPEC.loader.exec_module(fgnet_script) + + +def fresh_workspace(name): + workspace = Path(__file__).with_suffix("") / name + if workspace.exists(): + shutil.rmtree(workspace) + workspace.mkdir(parents=True) + return workspace + + +def test_parse_fgnet_filename_reads_subject_age_and_suffix(): + item = fgnet_script.parse_fgnet_filename(Path("001A43a.JPG")) + + assert item.subject == "001" + assert item.age == 43 + assert item.stem == "001A43a" + + +def test_build_age30_pairs_balances_positive_and_negative_pairs(): + records = [ + fgnet_script.FgnetImage("001", 2, Path("001A02.JPG"), Path("a.jpg")), + fgnet_script.FgnetImage("001", 35, Path("001A35.JPG"), Path("b.jpg")), + fgnet_script.FgnetImage("002", 1, Path("002A01.JPG"), Path("c.jpg")), + fgnet_script.FgnetImage("002", 40, Path("002A40.JPG"), Path("d.jpg")), + ] + + pairs = fgnet_script.build_age30_pairs(records, seed=7) + + assert len([pair for pair in pairs if pair.is_same]) == 2 + assert len([pair for pair in pairs if not pair.is_same]) == 2 + assert all(abs(pair.left.age - pair.right.age) >= 30 for pair in pairs) + + +def test_write_pairs_csv_uses_absolute_paths_and_alternating_indexes(): + workspace = fresh_workspace("pairs_csv") + image = workspace / "img.jpg" + Image.new("RGB", (8, 8)).save(image) + records = [ + fgnet_script.FgnetImage("001", 2, image, image), + fgnet_script.FgnetImage("001", 35, image, image), + fgnet_script.FgnetImage("002", 1, image, image), + fgnet_script.FgnetImage("002", 40, image, image), + ] + pairs = fgnet_script.build_age30_pairs(records, seed=7) + + pairs_csv = fgnet_script.write_pairs_csv(pairs, workspace / "out") + + rows = pairs_csv.read_text(encoding="utf-8").splitlines() + assert rows[0] == "path,index,is_same" + assert rows[1].split(",")[1:] == ["0", "True"] + assert rows[2].split(",")[1:] == ["1", "True"] + assert Path(rows[1].split(",")[0]).is_absolute() diff --git a/cvlface/research/recognition/code/run_v1/scripts/test_prepare_kaggle_verification_protocol.py b/cvlface/research/recognition/code/run_v1/scripts/test_prepare_kaggle_verification_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..47460a34cfaab66f12cedf7b7e9b530df5e18bcc --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/test_prepare_kaggle_verification_protocol.py @@ -0,0 +1,89 @@ +import csv +import importlib.util +import shutil +from pathlib import Path + + +SCRIPT_PATH = Path(__file__).with_name("prepare_kaggle_verification_protocol.py") +SPEC = importlib.util.spec_from_file_location("prepare_kaggle_verification_protocol", SCRIPT_PATH) +prepare_protocol = importlib.util.module_from_spec(SPEC) +SPEC.loader.exec_module(prepare_protocol) + + +def fresh_workspace(name): + workspace = Path(__file__).with_suffix("") / name + if workspace.exists(): + shutil.rmtree(workspace) + workspace.mkdir(parents=True) + return workspace + + +def test_write_pairs_csv_uses_absolute_source_paths_without_copying_images(): + workspace = fresh_workspace("absolute_paths") + raw_val_dir = workspace / "raw_val" + output_dir = workspace / "protocol" / "facerec_val" / "lfw" + image_paths = [ + raw_val_dir / "Alice" / "Alice_0001.jpg", + raw_val_dir / "Alice" / "Alice_0002.jpg", + raw_val_dir / "Bob" / "Bob_0001.jpg", + raw_val_dir / "Carol" / "Carol_0001.jpg", + ] + for image_path in image_paths: + image_path.parent.mkdir(parents=True, exist_ok=True) + image_path.write_bytes(b"not-a-real-image") + + annotation_path = raw_val_dir / "lfw_ann.txt" + annotation_path.write_text( + "\n".join( + [ + "1 Alice/Alice_0001.jpg Alice/Alice_0002.jpg", + "0 Bob/Bob_0001.jpg Carol/Carol_0001.jpg", + ] + ), + encoding="utf-8", + ) + + pairs_csv, pair_count = prepare_protocol.write_pairs_csv( + annotation_path=annotation_path, + raw_val_dir=raw_val_dir, + output_dir=output_dir, + ) + + assert pair_count == 2 + assert pairs_csv == output_dir / "pairs.csv" + assert sorted(path.name for path in output_dir.iterdir()) == ["pairs.csv"] + + with pairs_csv.open(newline="", encoding="utf-8") as handle: + rows = list(csv.DictReader(handle)) + + assert rows == [ + {"path": str(image_paths[0].resolve()), "index": "0", "is_same": "True"}, + {"path": str(image_paths[1].resolve()), "index": "1", "is_same": "True"}, + {"path": str(image_paths[2].resolve()), "index": "2", "is_same": "False"}, + {"path": str(image_paths[3].resolve()), "index": "3", "is_same": "False"}, + ] + + +def test_write_pairs_csv_rejects_missing_source_image(): + workspace = fresh_workspace("missing_image") + raw_val_dir = workspace / "raw_val" + raw_val_dir.mkdir() + existing_image = raw_val_dir / "Alice" / "Alice_0001.jpg" + existing_image.parent.mkdir() + existing_image.write_bytes(b"not-a-real-image") + annotation_path = raw_val_dir / "lfw_ann.txt" + annotation_path.write_text( + "1 Alice/Alice_0001.jpg Alice/Alice_0002.jpg\n", + encoding="utf-8", + ) + + try: + prepare_protocol.write_pairs_csv( + annotation_path=annotation_path, + raw_val_dir=raw_val_dir, + output_dir=workspace / "protocol", + ) + except FileNotFoundError as exc: + assert str(raw_val_dir / "Alice" / "Alice_0002.jpg") in str(exc) + else: + raise AssertionError("Expected FileNotFoundError for missing pair image") diff --git a/cvlface/research/recognition/code/run_v1/scripts/test_resize_image_folder.py b/cvlface/research/recognition/code/run_v1/scripts/test_resize_image_folder.py new file mode 100644 index 0000000000000000000000000000000000000000..2b2336a126972cbaa1aac61818b187b57cefde01 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/test_resize_image_folder.py @@ -0,0 +1,75 @@ +import importlib.util +import shutil +from pathlib import Path + +from PIL import Image + + +SCRIPT_PATH = Path(__file__).with_name("resize_image_folder.py") +SPEC = importlib.util.spec_from_file_location("resize_image_folder", SCRIPT_PATH) +resize_script = importlib.util.module_from_spec(SPEC) +SPEC.loader.exec_module(resize_script) + + +def fresh_workspace(name): + workspace = Path(__file__).with_suffix("") / name + if workspace.exists(): + shutil.rmtree(workspace) + workspace.mkdir(parents=True) + return workspace + + +def test_resize_folder_preserves_relative_class_structure(): + workspace = fresh_workspace("preserves_structure") + input_root = workspace / "input" + output_root = workspace / "output" + source_image = input_root / "id_1" / "sample.png" + source_image.parent.mkdir(parents=True) + Image.new("RGB", (112, 112), color=(10, 20, 30)).save(source_image) + + counts, total = resize_script.resize_folder( + input_root=input_root, + output_root=output_root, + size=224, + workers=1, + output_ext=".jpg", + quality=90, + overwrite=False, + max_images=None, + ) + + resized_path = output_root / "id_1" / "sample.jpg" + assert total == 1 + assert counts == {"written": 1, "skipped": 0} + assert resized_path.is_file() + with Image.open(resized_path) as image: + assert image.size == (224, 224) + assert image.mode == "RGB" + + +def test_resize_folder_skips_existing_outputs_without_overwrite(): + workspace = fresh_workspace("skip_existing") + input_root = workspace / "input" + output_root = workspace / "output" + source_image = input_root / "id_1" / "sample.jpg" + output_image = output_root / "id_1" / "sample.jpg" + source_image.parent.mkdir(parents=True) + output_image.parent.mkdir(parents=True) + Image.new("RGB", (112, 112), color=(10, 20, 30)).save(source_image) + Image.new("RGB", (32, 32), color=(30, 20, 10)).save(output_image) + + counts, total = resize_script.resize_folder( + input_root=input_root, + output_root=output_root, + size=224, + workers=1, + output_ext=".jpg", + quality=90, + overwrite=False, + max_images=None, + ) + + assert total == 1 + assert counts == {"written": 0, "skipped": 1} + with Image.open(output_image) as image: + assert image.size == (32, 32) diff --git a/cvlface/research/recognition/code/run_v1/scripts/test_train_plan4_epoch_logging.py b/cvlface/research/recognition/code/run_v1/scripts/test_train_plan4_epoch_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..936dd5f735490b43935df4989c4dd32047eb0a91 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/test_train_plan4_epoch_logging.py @@ -0,0 +1,86 @@ +import argparse +import csv +import importlib.util +import json +import sys +import types +from pathlib import Path + + +SCRIPT_PATH = Path(__file__).with_name("train_plan4_adaface_swin.py") + +losses_module = types.ModuleType("losses") +adaface_module = types.ModuleType("losses.adaface") +adaface_module.AdaFaceLoss = object +models_module = types.ModuleType("models") +models_module.get_model = lambda cfg: None +sys.modules.setdefault("losses", losses_module) +sys.modules.setdefault("losses.adaface", adaface_module) +sys.modules.setdefault("models", models_module) + +SPEC = importlib.util.spec_from_file_location("train_plan4_adaface_swin", SCRIPT_PATH) +train_script = importlib.util.module_from_spec(SPEC) +SPEC.loader.exec_module(train_script) + + +def test_training_config_keeps_basic_fields_only(tmp_path): + args = argparse.Namespace( + train_root="/data/train", + output_dir="/exp/out", + ckpt=Path("/ckpt/model.safetensors"), + num_classes=567, + image_size=224, + batch_size=128, + epochs=10, + max_steps=None, + eval_every_epochs=5, + use_lora=True, + freeze_backbone=True, + lora_rank=4, + lora_alpha=4, + lora_dropout=0.1, + lora_lr=3e-4, + lora_target_modules="qkv", + head_lr=1e-3, + backbone_lr=3e-5, + eval_pairs_csv=["/data/fgnet/pairs.csv"], + amp=True, + save_lora_only=True, + seed=2048, + ) + + path = train_script.write_training_config(tmp_path, args, train_images=1002, train_classes=82, steps_per_epoch=7) + + config = json.loads(path.read_text(encoding="utf-8")) + assert config["train_images"] == 1002 + assert config["train_classes"] == 82 + assert config["batch_size"] == 128 + assert config["epochs"] == 10 + assert config["eval_every_epochs"] == 5 + assert config["lora"] == {"enabled": True, "rank": 4, "alpha": 4, "dropout": 0.1, "lr": 3e-4, "targets": ["qkv"]} + + +def test_epoch_metrics_csv_writes_loss_and_dataset_acc_columns(tmp_path): + metrics_path = tmp_path / "epoch_metrics.csv" + writer_handle, writer = train_script.open_epoch_metrics_writer(metrics_path, ["lfw", "fgnet_age30"]) + try: + train_script.write_epoch_metrics( + writer, + epoch=5, + step=640, + train_loss=12.345678, + eval_results={"lfw": {"acc": 98.1, "std": 0.2}, "fgnet_age30": {"acc": 81.2, "std": 1.1}}, + ) + finally: + writer_handle.close() + + rows = list(csv.DictReader(metrics_path.open(newline="", encoding="utf-8"))) + assert rows == [ + { + "epoch": "5", + "step": "640", + "train_loss": "12.345678", + "lfw_acc": "98.100000", + "fgnet_age30_acc": "81.200000", + } + ] diff --git a/cvlface/research/recognition/code/run_v1/scripts/test_train_plan4_lora.py b/cvlface/research/recognition/code/run_v1/scripts/test_train_plan4_lora.py new file mode 100644 index 0000000000000000000000000000000000000000..21febf57aad55ce54abdd1c504a007aab4da5f6d --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/test_train_plan4_lora.py @@ -0,0 +1,92 @@ +import argparse +import importlib.util +import sys +import types +from pathlib import Path + +import torch + + +SCRIPT_PATH = Path(__file__).with_name("train_plan4_adaface_swin.py") + +losses_module = types.ModuleType("losses") +adaface_module = types.ModuleType("losses.adaface") +adaface_module.AdaFaceLoss = object +models_module = types.ModuleType("models") +models_module.get_model = lambda cfg: None +sys.modules.setdefault("losses", losses_module) +sys.modules.setdefault("losses.adaface", adaface_module) +sys.modules.setdefault("models", models_module) + +SPEC = importlib.util.spec_from_file_location("train_plan4_adaface_swin", SCRIPT_PATH) +train_script = importlib.util.module_from_spec(SPEC) +SPEC.loader.exec_module(train_script) + + +class TinyBackbone(torch.nn.Module): + def __init__(self): + super().__init__() + self.base = torch.nn.Linear(4, 4) + self.lora_a = torch.nn.Parameter(torch.ones(4, 2)) + self.lora_b = torch.nn.Parameter(torch.ones(2, 4)) + + +def test_parse_lora_target_modules_trims_empty_values(): + assert train_script.parse_lora_target_modules(" qkv, proj, ") == ["qkv", "proj"] + + +def test_lora_optimizer_uses_lora_group_and_head_group_only(): + model = TinyBackbone() + for name, param in model.named_parameters(): + param.requires_grad = "lora_" in name + classifier = torch.nn.Linear(4, 3) + args = argparse.Namespace( + use_lora=True, + freeze_backbone=True, + backbone_lr=3e-5, + lora_lr=3e-4, + head_lr=1e-3, + weight_decay=0.05, + ) + + optimizer = train_script.build_optimizer(model, classifier, args) + + assert [group["name"] for group in optimizer.param_groups] == ["lora_backbone", "adaface_head"] + assert optimizer.param_groups[0]["lr"] == args.lora_lr + optimizer_param_ids = {id(param) for group in optimizer.param_groups for param in group["params"]} + assert id(model.base.weight) not in optimizer_param_ids + assert id(model.base.bias) not in optimizer_param_ids + assert id(model.lora_a) in optimizer_param_ids + assert id(model.lora_b) in optimizer_param_ids + + +def test_lora_state_dict_keeps_only_lora_parameters(): + model = TinyBackbone() + + state = train_script.get_lora_state_dict(model) + + assert sorted(state) == ["lora_a", "lora_b"] + + +def test_load_lora_only_checkpoint_updates_lora_params_without_base_params(tmp_path): + model = TinyBackbone() + checkpoint = tmp_path / "lora_only.pt" + torch.save( + { + "model_state_kind": "lora_only", + "model_lora": { + "lora_a": torch.full_like(model.lora_a, 3.0), + "lora_b": torch.full_like(model.lora_b, 5.0), + }, + }, + checkpoint, + ) + original_base = model.base.weight.detach().clone() + + report = train_script.load_model_weights(checkpoint, model, torch.device("cpu")) + + assert report["model_state_kind"] == "lora_only" + assert report["loaded_lora_keys"] == 2 + assert torch.equal(model.base.weight, original_base) + assert torch.equal(model.lora_a, torch.full_like(model.lora_a, 3.0)) + assert torch.equal(model.lora_b, torch.full_like(model.lora_b, 5.0)) diff --git a/cvlface/research/recognition/code/run_v1/scripts/test_train_plan4_yaml_config.py b/cvlface/research/recognition/code/run_v1/scripts/test_train_plan4_yaml_config.py new file mode 100644 index 0000000000000000000000000000000000000000..1de5410976ae6e0ca803e6bf9b3a8cdb7ea262c9 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/test_train_plan4_yaml_config.py @@ -0,0 +1,84 @@ +import importlib.util +import sys +import types +from pathlib import Path + + +SCRIPT_PATH = Path(__file__).with_name("train_plan4_adaface_swin.py") + +losses_module = types.ModuleType("losses") +adaface_module = types.ModuleType("losses.adaface") +adaface_module.AdaFaceLoss = object +models_module = types.ModuleType("models") +models_module.get_model = lambda cfg: None +sys.modules.setdefault("losses", losses_module) +sys.modules.setdefault("losses.adaface", adaface_module) +sys.modules.setdefault("models", models_module) + +SPEC = importlib.util.spec_from_file_location("train_plan4_adaface_swin", SCRIPT_PATH) +train_script = importlib.util.module_from_spec(SPEC) +SPEC.loader.exec_module(train_script) + + +def test_yaml_config_provides_required_arguments(tmp_path): + config = tmp_path / "train.yaml" + config.write_text( + """ +train_root: /data/train +output_dir: /exp/out +ckpt: /ckpt/model.safetensors +num_classes: 567 +epochs: 20 +batch_size: 128 +use_lora: true +freeze_backbone: true +eval_pairs_csv: + - /data/fgnet/pairs.csv + - /data/agedb/pairs.csv +""", + encoding="utf-8", + ) + + args = train_script.parse_args(["--config", str(config)]) + + assert args.train_root == "/data/train" + assert args.output_dir == "/exp/out" + assert args.ckpt == Path("/ckpt/model.safetensors") + assert args.num_classes == 567 + assert args.epochs == 20 + assert args.batch_size == 128 + assert args.use_lora is True + assert args.freeze_backbone is True + assert args.eval_pairs_csv == ["/data/fgnet/pairs.csv", "/data/agedb/pairs.csv"] + + +def test_cli_arguments_override_yaml_config(tmp_path): + config = tmp_path / "train.yaml" + config.write_text( + """ +train_root: /data/train +output_dir: /exp/out +ckpt: /ckpt/model.safetensors +batch_size: 128 +use_lora: false +eval_pairs_csv: + - /data/fgnet/pairs.csv +""", + encoding="utf-8", + ) + + args = train_script.parse_args( + [ + "--config", + str(config), + "--batch-size", + "64", + "--use-lora", + "--eval-pairs-csv", + "/override/pairs.csv", + ] + ) + + assert args.batch_size == 64 + assert args.use_lora is True + assert args.eval_pairs_csv == ["/override/pairs.csv"] diff --git a/cvlface/research/recognition/code/run_v1/scripts/train_plan4_adaface_swin.py b/cvlface/research/recognition/code/run_v1/scripts/train_plan4_adaface_swin.py new file mode 100644 index 0000000000000000000000000000000000000000..de32043cd0d0c52196eed3b84a933f1a33a50461 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/scripts/train_plan4_adaface_swin.py @@ -0,0 +1,673 @@ +import argparse +import csv +import json +import math +import sys +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F +from omegaconf import OmegaConf +from PIL import Image +from sklearn.model_selection import KFold +from sklearn.preprocessing import normalize +from torch.utils.data import DataLoader, Dataset +from torchvision import transforms +from torchvision.datasets import ImageFolder +from tqdm import tqdm + + +RUN_V1_ROOT = Path(__file__).resolve().parents[1] +if str(RUN_V1_ROOT) not in sys.path: + sys.path.insert(0, str(RUN_V1_ROOT)) + +from losses.adaface import AdaFaceLoss +from models import get_model + + +class AdaFaceHead(torch.nn.Module): + def __init__(self, margin_loss, embedding_size, num_classes): + super().__init__() + self.weight = torch.nn.Parameter(torch.normal(0, 0.01, (num_classes, embedding_size))) + self.margin_loss = margin_loss + self.cross_entropy = torch.nn.CrossEntropyLoss() + self.register_buffer("batch_mean", torch.ones(1) * 20) + self.register_buffer("batch_std", torch.ones(1) * 100) + + def forward(self, embeddings, labels): + norms = embeddings.norm(p=2, dim=1, keepdim=True).clamp_min(1e-8) + norm_embeddings = embeddings / norms + norm_weight = F.normalize(self.weight) + logits = F.linear(norm_embeddings, norm_weight).clamp(-1, 1) + logits, batch_mean, batch_std = self.margin_loss( + logits=logits, + labels=labels, + norms=norms, + batch_mean=self.batch_mean, + batch_std=self.batch_std, + ) + self.batch_mean.data = batch_mean.data + self.batch_std.data = batch_std.data + return self.cross_entropy(logits, labels) + + +class PairDataset(Dataset): + def __init__(self, pairs_csv, transform): + with Path(pairs_csv).open(newline="", encoding="utf-8") as handle: + self.rows = list(csv.DictReader(handle)) + self.transform = transform + + def __len__(self): + return len(self.rows) + + def __getitem__(self, index): + row = self.rows[index] + with Image.open(row["path"]) as image: + image = image.convert("RGB") + is_same = row["is_same"] + if isinstance(is_same, str): + is_same = is_same.lower() == "true" + return self.transform(image), int(row["index"]), bool(is_same) + + +def collate_pairs(batch): + images, indexes, is_sames = zip(*batch) + return ( + torch.stack(images), + torch.tensor(indexes, dtype=torch.long), + torch.tensor(is_sames, dtype=torch.bool), + ) + + +def make_transform(size): + return transforms.Compose( + [ + transforms.Resize((size, size)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ] + ) + + +def load_swin_model(ckpt_path, device): + cfg = OmegaConf.load(RUN_V1_ROOT / "models" / "swin" / "configs" / "v1_swin_s_pretrained.yaml") + cfg.yaml_path = "/swin/configs/v1_swin_s_pretrained.yaml" + cfg.start_from = str(ckpt_path) + model = get_model(cfg).to(device) + return model + + +def parse_lora_target_modules(target_modules): + modules = [module.strip() for module in target_modules.split(",") if module.strip()] + if not modules: + raise ValueError("--lora-target-modules must contain at least one module name") + return modules + + +def apply_lora(model, args): + try: + from peft import LoraConfig, LoraModel + except ImportError as exc: + raise ImportError("LoRA training requires the peft package. Install it with: pip install peft") from exc + + target_modules = parse_lora_target_modules(args.lora_target_modules) + lora_config = LoraConfig( + r=args.lora_rank, + lora_alpha=args.lora_alpha, + target_modules=target_modules, + lora_dropout=args.lora_dropout, + bias="none", + ) + model = LoraModel(model, lora_config, adapter_name="default") + if not any("lora_" in name and param.requires_grad for name, param in model.named_parameters()): + raise RuntimeError(f"LoRA was not applied to any trainable parameters. targets={target_modules}") + return model + + +def configure_backbone(model, freeze_backbone): + if freeze_backbone: + for param in model.parameters(): + param.requires_grad = False + model.eval() + else: + model.train() + + +def build_optimizer(model, classifier, args): + param_groups = [] + backbone_params = [p for p in model.parameters() if p.requires_grad] + if backbone_params: + if args.use_lora: + param_groups.append({"params": backbone_params, "lr": args.lora_lr, "name": "lora_backbone"}) + elif not args.freeze_backbone: + param_groups.append({"params": backbone_params, "lr": args.backbone_lr, "name": "backbone"}) + + head_params = [p for p in classifier.parameters() if p.requires_grad] + if not head_params: + raise ValueError("AdaFace head has no trainable parameters") + param_groups.append({"params": head_params, "lr": args.head_lr, "name": "adaface_head"}) + + return torch.optim.AdamW(param_groups, weight_decay=args.weight_decay) + + +def dataset_name_from_pairs_csv(pairs_csv): + return Path(pairs_csv).parent.name + + +def write_training_config(log_dir, args, train_images, train_classes, steps_per_epoch): + log_dir = Path(log_dir) + log_dir.mkdir(parents=True, exist_ok=True) + config = { + "train_root": str(args.train_root), + "output_dir": str(args.output_dir), + "ckpt": str(args.ckpt), + "train_images": train_images, + "train_classes": train_classes, + "num_classes": args.num_classes, + "image_size": args.image_size, + "batch_size": args.batch_size, + "epochs": args.epochs, + "max_steps": args.max_steps, + "steps_per_epoch": steps_per_epoch, + "eval_every_epochs": args.eval_every_epochs, + "eval_pairs_csv": [str(path) for path in args.eval_pairs_csv], + "freeze_backbone": args.freeze_backbone, + "amp": args.amp, + "seed": args.seed, + "backbone_lr": args.backbone_lr, + "head_lr": args.head_lr, + "save_lora_only": args.save_lora_only, + "lora": { + "enabled": args.use_lora, + "rank": args.lora_rank if args.use_lora else None, + "alpha": args.lora_alpha if args.use_lora else None, + "dropout": args.lora_dropout if args.use_lora else None, + "lr": args.lora_lr if args.use_lora else None, + "targets": parse_lora_target_modules(args.lora_target_modules) if args.use_lora else [], + }, + } + config_path = log_dir / "training_config.json" + config_path.write_text(json.dumps(config, indent=2), encoding="utf-8") + return config_path + + +def open_epoch_metrics_writer(metrics_path, dataset_names): + metrics_path = Path(metrics_path) + metrics_path.parent.mkdir(parents=True, exist_ok=True) + fieldnames = ["epoch", "step", "train_loss"] + [f"{name}_acc" for name in dataset_names] + handle = metrics_path.open("w", newline="", encoding="utf-8") + writer = csv.DictWriter(handle, fieldnames=fieldnames) + writer.writeheader() + return handle, writer + + +def write_epoch_metrics(writer, epoch, step, train_loss, eval_results): + row = { + "epoch": epoch, + "step": step, + "train_loss": f"{train_loss:.6f}", + } + for name, result in eval_results.items(): + row[f"{name}_acc"] = f"{result['acc']:.6f}" + writer.writerow(row) + + +def self_check(args): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = load_swin_model(args.ckpt, device) + if args.use_lora: + model = apply_lora(model, args) + model.train() + else: + configure_backbone(model, args.freeze_backbone) + margin = AdaFaceLoss(s=args.scale, m=args.margin, h=args.h, t_alpha=args.t_alpha) + classifier = AdaFaceHead(margin_loss=margin, embedding_size=512, num_classes=args.num_classes).to(device).train() + optimizer = build_optimizer(model, classifier, args) + + backbone_named_params = list(model.named_parameters()) + backbone_params = [param for _, param in backbone_named_params] + head_params = list(classifier.parameters()) + optimizer_param_ids = { + id(param) + for group in optimizer.param_groups + for param in group["params"] + } + report = { + "freeze_backbone": args.freeze_backbone, + "use_lora": args.use_lora, + "lora_rank": args.lora_rank if args.use_lora else None, + "lora_alpha": args.lora_alpha if args.use_lora else None, + "lora_dropout": args.lora_dropout if args.use_lora else None, + "lora_target_modules": parse_lora_target_modules(args.lora_target_modules) if args.use_lora else [], + "backbone_total_params": sum(param.numel() for param in backbone_params), + "backbone_trainable_params": sum(param.numel() for param in backbone_params if param.requires_grad), + "lora_trainable_params": sum( + param.numel() + for name, param in backbone_named_params + if param.requires_grad and "lora_" in name + ), + "non_lora_backbone_trainable_params": sum( + param.numel() + for name, param in backbone_named_params + if param.requires_grad and "lora_" not in name + ), + "head_total_params": sum(param.numel() for param in head_params), + "head_trainable_params": sum(param.numel() for param in head_params if param.requires_grad), + "backbone_params_in_optimizer": sum(1 for param in backbone_params if id(param) in optimizer_param_ids), + "head_params_in_optimizer": sum(1 for param in head_params if id(param) in optimizer_param_ids), + "optimizer_group_names": [group.get("name", "unnamed") for group in optimizer.param_groups], + "model_training": model.training, + "head_training": classifier.training, + "sample_trainable_backbone_names": [ + name for name, param in backbone_named_params if param.requires_grad + ][:20], + } + print(json.dumps(report, indent=2)) + + +def evaluate_embeddings(embeddings, is_same, folds=10): + embeddings = normalize(embeddings) + thresholds = np.arange(-1.0, 1.0, 0.01) + indices = np.arange(len(is_same)) + folds = min(folds, len(is_same)) + k_fold = KFold(n_splits=folds, shuffle=False) + + accuracies = [] + for train_set, test_set in k_fold.split(indices): + train_scores = np.sum(embeddings[0::2][train_set] * embeddings[1::2][train_set], axis=1) + test_scores = np.sum(embeddings[0::2][test_set] * embeddings[1::2][test_set], axis=1) + + train_acc = [ + np.mean((train_scores > threshold) == is_same[train_set]) + for threshold in thresholds + ] + best_threshold = thresholds[int(np.argmax(train_acc))] + accuracies.append(np.mean((test_scores > best_threshold) == is_same[test_set])) + + return float(np.mean(accuracies) * 100.0), float(np.std(accuracies) * 100.0) + + +@torch.no_grad() +def evaluate_pairs(model, pairs_csv, transform, device, batch_size, num_workers, max_pairs=None): + dataset = PairDataset(pairs_csv, transform) + if max_pairs is not None: + max_rows = min(len(dataset), max_pairs * 2) + dataset.rows = dataset.rows[:max_rows] + dataloader = DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=False, + pin_memory=True, + collate_fn=collate_pairs, + ) + model.eval() + features = [] + features_flip = [] + indexes = [] + is_sames = [] + for images, index, is_same in tqdm(dataloader, desc=f"eval {Path(pairs_csv).parent.name}"): + images = images.to(device, non_blocking=True) + features.append(model(images).float().cpu()) + features_flip.append(model(torch.flip(images, dims=[3])).float().cpu()) + indexes.append(index) + is_sames.append(is_same) + + index = torch.cat(indexes).numpy() + order = np.argsort(index) + combined = (torch.cat(features) + torch.cat(features_flip)).numpy()[order] + same = torch.cat(is_sames).numpy()[order][0::2] + return evaluate_embeddings(combined, same) + + +def get_lora_state_dict(model): + return { + name: tensor.detach().cpu() + for name, tensor in model.state_dict().items() + if "lora_" in name + } + + +def load_model_weights(checkpoint_path, model, device): + checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False) + state_kind = checkpoint.get("model_state_kind", "full") + if state_kind == "lora_only": + lora_state = checkpoint["model_lora"] + missing_keys, unexpected_keys = model.load_state_dict(lora_state, strict=False) + unexpected_lora_keys = [key for key in unexpected_keys if "lora_" in key] + if unexpected_lora_keys: + raise RuntimeError(f"Unexpected LoRA keys when loading {checkpoint_path}: {unexpected_lora_keys}") + return { + "model_state_kind": state_kind, + "loaded_lora_keys": len(lora_state), + "missing_lora_keys": [key for key in missing_keys if "lora_" in key], + "unexpected_keys": unexpected_keys, + } + if state_kind == "full": + missing_keys, unexpected_keys = model.load_state_dict(checkpoint["model"], strict=False) + return { + "model_state_kind": state_kind, + "loaded_model_keys": len(checkpoint["model"]), + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + } + raise ValueError(f"Unsupported model_state_kind in {checkpoint_path}: {state_kind}") + + +def save_checkpoint(path, model, classifier, optimizer, step, epoch, best_metric, args): + path.parent.mkdir(parents=True, exist_ok=True) + payload = { + "classifier": classifier.state_dict(), + "optimizer": optimizer.state_dict(), + "step": step, + "epoch": epoch, + "best_metric": best_metric, + "args": vars(args), + } + if args.use_lora and args.save_lora_only: + payload["model_state_kind"] = "lora_only" + payload["model_lora"] = get_lora_state_dict(model) + else: + payload["model_state_kind"] = "full" + payload["model"] = model.state_dict() + torch.save(payload, path) + + +def eval_only(args): + if not args.eval_pairs_csv: + raise ValueError("--eval-only requires at least one --eval-pairs-csv") + if not args.eval_checkpoint: + raise ValueError("--eval-only requires --eval-checkpoint") + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.set_float32_matmul_precision("high") + transform = make_transform(args.image_size) + model = load_swin_model(args.ckpt, device) + if args.use_lora: + model = apply_lora(model, args) + load_report = load_model_weights(args.eval_checkpoint, model, device) + print(json.dumps(load_report, indent=2)) + + eval_results = {} + for pairs_csv in args.eval_pairs_csv: + acc, std = evaluate_pairs( + model, + pairs_csv=pairs_csv, + transform=transform, + device=device, + batch_size=args.eval_batch_size, + num_workers=args.num_workers, + max_pairs=args.max_eval_pairs, + ) + eval_results[Path(pairs_csv).parent.name] = {"acc": acc, "std": std} + print(json.dumps({"eval": eval_results}, indent=2)) + + +def train(args): + output_dir = Path(args.output_dir) + output_dir.mkdir(parents=True, exist_ok=True) + checkpoints_dir = output_dir / "checkpoints" + log_dir = Path(args.log_dir) if args.log_dir else output_dir / "logs" + metrics_path = Path(args.metrics_csv) if args.metrics_csv else log_dir / "epoch_metrics.csv" + summary_path = output_dir / "summary.json" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.set_float32_matmul_precision("high") + torch.manual_seed(args.seed) + np.random.seed(args.seed) + + transform = make_transform(args.image_size) + train_dataset = ImageFolder(args.train_root, transform) + if len(train_dataset.classes) != args.num_classes: + raise ValueError(f"Expected {args.num_classes} classes, got {len(train_dataset.classes)}") + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + num_workers=args.num_workers, + shuffle=True, + pin_memory=True, + drop_last=True, + ) + + model = load_swin_model(args.ckpt, device) + if args.use_lora: + model = apply_lora(model, args) + model.train() + else: + configure_backbone(model, args.freeze_backbone) + margin = AdaFaceLoss(s=args.scale, m=args.margin, h=args.h, t_alpha=args.t_alpha) + classifier = AdaFaceHead(margin_loss=margin, embedding_size=512, num_classes=args.num_classes).to(device).train() + optimizer = build_optimizer(model, classifier, args) + scaler = torch.amp.GradScaler("cuda", enabled=args.amp and device.type == "cuda") + + steps_per_epoch = len(train_loader) + if args.epochs is None and args.max_steps is None: + raise ValueError("Set --epochs for epoch-based training or --max-steps for step-based training") + total_epochs = args.epochs if args.epochs is not None else math.ceil(args.max_steps / steps_per_epoch) + total_steps = args.max_steps if args.max_steps is not None else total_epochs * steps_per_epoch + eval_dataset_names = [dataset_name_from_pairs_csv(path) for path in args.eval_pairs_csv] + config_path = write_training_config(log_dir, args, len(train_dataset), len(train_dataset.classes), steps_per_epoch) + print(f"train images: {len(train_dataset)}") + print(f"classes: {len(train_dataset.classes)}") + print(f"batch_size: {args.batch_size}") + print(f"epochs: {total_epochs}") + print(f"steps_per_epoch: {steps_per_epoch}") + print(f"max_steps: {total_steps}") + print(f"device: {device}") + print(f"freeze_backbone: {args.freeze_backbone}") + print(f"use_lora: {args.use_lora}") + print(f"log_dir: {log_dir}") + print(f"training_config: {config_path}") + print(f"epoch_metrics: {metrics_path}") + if args.use_lora: + print( + "lora: " + f"rank={args.lora_rank}, alpha={args.lora_alpha}, dropout={args.lora_dropout}, " + f"targets={parse_lora_target_modules(args.lora_target_modules)}, lr={args.lora_lr}" + ) + print(f"save_lora_only: {args.save_lora_only}") + + best_metric = -math.inf + latest_eval_results = {} + metrics_handle, metrics_writer = open_epoch_metrics_writer(metrics_path, eval_dataset_names) + try: + step = 0 + last_epoch = 0 + for epoch in range(1, total_epochs + 1): + last_epoch = epoch + if args.freeze_backbone and not args.use_lora: + model.eval() + else: + model.train() + classifier.train() + epoch_losses = [] + progress = tqdm(train_loader, desc=f"epoch {epoch}/{total_epochs}", leave=False) + for images, labels in progress: + step += 1 + images = images.to(device, non_blocking=True) + labels = labels.to(device, non_blocking=True) + optimizer.zero_grad(set_to_none=True) + with torch.amp.autocast("cuda", enabled=args.amp and device.type == "cuda"): + embeddings = model(images) + loss = classifier(embeddings, labels) + scaler.scale(loss).backward() + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_( + [param for param in list(model.parameters()) + list(classifier.parameters()) if param.requires_grad], + args.grad_clip, + ) + scaler.step(optimizer) + scaler.update() + loss_value = float(loss.detach().cpu()) + epoch_losses.append(loss_value) + progress.set_postfix(loss=f"{loss_value:.4f}") + + if step % args.save_every == 0 or step == total_steps: + save_checkpoint(checkpoints_dir / "latest.pt", model, classifier, optimizer, step, epoch, best_metric, args) + + if step >= total_steps: + break + + mean_loss = float(np.mean(epoch_losses)) if epoch_losses else math.nan + should_eval = bool(args.eval_pairs_csv) and ( + epoch % args.eval_every_epochs == 0 or epoch == total_epochs or step >= total_steps + ) + epoch_eval_results = {} + if should_eval: + for pairs_csv in args.eval_pairs_csv: + acc, std = evaluate_pairs( + model, + pairs_csv=pairs_csv, + transform=transform, + device=device, + batch_size=args.eval_batch_size, + num_workers=args.num_workers, + max_pairs=args.max_eval_pairs, + ) + name = dataset_name_from_pairs_csv(pairs_csv) + epoch_eval_results[name] = {"acc": acc, "std": std} + best_metric = max(best_metric, acc) + latest_eval_results = epoch_eval_results + save_checkpoint(checkpoints_dir / "best.pt", model, classifier, optimizer, step, epoch, best_metric, args) + + write_epoch_metrics(metrics_writer, epoch, step, mean_loss, epoch_eval_results) + metrics_handle.flush() + acc_text = " ".join(f"{name}_acc={result['acc']:.3f}" for name, result in epoch_eval_results.items()) + cuda_mem = ( + f" cuda_max_memory_mb={round(torch.cuda.max_memory_allocated() / 1024 / 1024, 1)}" + if device.type == "cuda" + else "" + ) + print(f"epoch {epoch}/{total_epochs} step={step} loss={mean_loss:.6f} {acc_text}{cuda_mem}".rstrip()) + + if step >= total_steps: + break + finally: + metrics_handle.close() + + save_checkpoint(checkpoints_dir / "latest.pt", model, classifier, optimizer, step, last_epoch, best_metric, args) + if not (checkpoints_dir / "best.pt").is_file(): + save_checkpoint(checkpoints_dir / "best.pt", model, classifier, optimizer, step, last_epoch, best_metric, args) + + summary = { + "step": step, + "epochs_completed": last_epoch, + "epoch_fraction": step / steps_per_epoch, + "best_metric": best_metric, + "eval": latest_eval_results, + "training_config": str(Path(config_path).resolve()), + "epoch_metrics": str(Path(metrics_path).resolve()), + "latest": str((checkpoints_dir / "latest.pt").resolve()), + "best": str((checkpoints_dir / "best.pt").resolve()), + } + summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") + print(json.dumps(summary, indent=2)) + + +def add_training_arguments(parser): + parser.add_argument("--train-root") + parser.add_argument("--output-dir") + parser.add_argument("--ckpt", type=Path) + parser.add_argument("--num-classes", type=int, default=10572) + parser.add_argument("--image-size", type=int, default=224) + parser.add_argument("--batch-size", type=int, default=24) + parser.add_argument("--num-workers", type=int, default=4) + parser.add_argument("--epochs", type=int, default=None) + parser.add_argument("--max-steps", type=int, default=None) + parser.add_argument("--backbone-lr", type=float, default=3e-5) + parser.add_argument("--head-lr", type=float, default=1e-3) + parser.add_argument("--use-lora", action="store_true") + parser.add_argument("--lora-rank", type=int, default=16) + parser.add_argument("--lora-alpha", type=int, default=16) + parser.add_argument("--lora-dropout", type=float, default=0.1) + parser.add_argument("--lora-lr", type=float, default=3e-4) + parser.add_argument("--lora-target-modules", default="qkv") + parser.add_argument("--save-lora-only", action="store_true") + parser.add_argument("--weight-decay", type=float, default=0.05) + parser.add_argument("--grad-clip", type=float, default=5.0) + parser.add_argument("--scale", type=float, default=64.0) + parser.add_argument("--margin", type=float, default=0.4) + parser.add_argument("--h", type=float, default=0.333) + parser.add_argument("--t-alpha", type=float, default=0.01) + parser.add_argument("--amp", action="store_true") + parser.add_argument("--freeze-backbone", action="store_true") + parser.add_argument("--self-check", action="store_true") + parser.add_argument("--eval-only", action="store_true") + parser.add_argument("--eval-checkpoint", type=Path) + parser.add_argument("--eval-every-epochs", type=int, default=5) + parser.add_argument("--log-dir", type=Path) + parser.add_argument("--metrics-csv", type=Path) + parser.add_argument("--log-every", type=int, default=25) + parser.add_argument("--save-every", type=int, default=100) + parser.add_argument("--seed", type=int, default=2048) + parser.add_argument("--eval-pairs-csv", action="append", default=[]) + parser.add_argument("--eval-batch-size", type=int, default=64) + parser.add_argument("--max-eval-pairs", type=int, default=None) + + +def load_yaml_config(config_path): + if config_path is None: + return {} + data = OmegaConf.to_container(OmegaConf.load(config_path), resolve=True) + if data is None: + return {} + if not isinstance(data, dict): + raise ValueError(f"YAML config must contain a mapping at top level: {config_path}") + return {str(key).replace("-", "_"): value for key, value in data.items()} + + +def normalize_config_defaults(config): + defaults = dict(config) + path_keys = {"ckpt", "eval_checkpoint", "log_dir", "metrics_csv"} + for key in path_keys: + if key in defaults and defaults[key] is not None: + defaults[key] = Path(defaults[key]) + if "eval_pairs_csv" in defaults and defaults["eval_pairs_csv"] is None: + defaults["eval_pairs_csv"] = [] + return defaults + + +def cli_provides_option(argv, option): + if argv is None: + argv = sys.argv[1:] + return option in argv or any(item.startswith(f"{option}=") for item in argv) + + +def validate_required_args(args): + missing = [name for name in ["train_root", "output_dir", "ckpt"] if getattr(args, name) is None] + if missing: + joined = ", ".join(f"--{name.replace('_', '-')}" for name in missing) + raise SystemExit(f"Missing required arguments: {joined}. Provide them in YAML --config or on the command line.") + + +def parse_args(argv=None): + config_parser = argparse.ArgumentParser(add_help=False) + config_parser.add_argument("--config", type=Path) + config_args, remaining_argv = config_parser.parse_known_args(argv) + + parser = argparse.ArgumentParser(parents=[config_parser]) + add_training_arguments(parser) + defaults = normalize_config_defaults(load_yaml_config(config_args.config)) + if defaults.get("eval_pairs_csv") is not None and cli_provides_option(remaining_argv, "--eval-pairs-csv"): + defaults["eval_pairs_csv"] = [] + if defaults: + known_dests = {action.dest for action in parser._actions} + unknown = sorted(set(defaults) - known_dests) + if unknown: + raise SystemExit(f"Unknown YAML config keys: {', '.join(unknown)}") + parser.set_defaults(**defaults) + args = parser.parse_args(remaining_argv) + args.config = config_args.config + validate_required_args(args) + return args + + +if __name__ == "__main__": + parsed_args = parse_args() + if parsed_args.self_check: + self_check(parsed_args) + elif parsed_args.eval_only: + eval_only(parsed_args) + else: + train(parsed_args) diff --git a/cvlface/research/recognition/code/run_v1/test_model.py b/cvlface/research/recognition/code/run_v1/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..5141638d203f7c3a9c0bdaec6a99eec4ed647869 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/test_model.py @@ -0,0 +1,49 @@ +import pyrootutils +root = pyrootutils.setup_root( + search_from=__file__, + indicator=["__root__.txt"], + pythonpath=True, + dotenv=True, +) +import os, sys +sys.path.append(os.path.join(root)) +import numpy as np + +np.bool = np.bool_ # fix bug for mxnet 1.9.1 +np.object = np.object_ +np.float = np.float_ + +import torch +from general_utils.config_utils import load_config +from models import get_model + +if __name__ == '__main__': + + inputs_shape = (2, 3, 112, 112) + inputs = torch.randn(inputs_shape) + + # setting 1: input is image + for config_name in [ + 'models/iresnet/configs/v1_ir50.yaml', + 'models/vit/configs/v1_base.yaml', + 'models/swin/configs/v1_base.yaml', + 'models/vit_irpe/configs/v1_base_irpe.yaml', + 'models/part_fvit/configs/v1_base.yaml', + ]: + config = load_config(config_name) + config.yaml_path = config_name + model = get_model(config, task='run_v1') + out = model(inputs) + print(f'{config_name} has input shape {inputs_shape} and output shape {out.shape}') + + # setting 2: input is image + keypoints + for config_name in [ + 'models/vit_kprpe/configs/v1_base_kprpe_splithead_unshared.yaml', + 'models/swin_kprpe/configs/v1_base_kprpe_splithead_unshared.yaml', + ]: + config = load_config(config_name) + config.yaml_path = config_name + keypoints = torch.randn(2, 49, 2) + model = get_model(config, task='run_v1') + out = model(inputs, keypoints) + print(f'{config_name} has input shape {inputs_shape} and output shape {out.shape}') \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/train.py b/cvlface/research/recognition/code/run_v1/train.py new file mode 100644 index 0000000000000000000000000000000000000000..cf388d35bca7cfdead2e6b0a7ee8c65bb881d79a --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/train.py @@ -0,0 +1,332 @@ +import pyrootutils +root = pyrootutils.setup_root( + search_from=__file__, + indicator=["__root__.txt"], + pythonpath=True, + dotenv=True, +) +import os, sys + +sys.path.append(os.path.join(root)) +import numpy as np +np.bool = np.bool_ # fix bug for mxnet 1.9.1 +np.object = np.object_ +np.float = np.float64 + +import pandas as pd +import torch +import config +from config import Config +from models import get_model +from classifiers import get_classifier +from aligners import get_aligner +from losses import get_margin_loss +from dataset import get_train_dataset, visualize_dataset, set_epoch +from evaluations import get_evaluator_by_name +from general_utils import random_utils, os_utils +from optims.optims import make_optimizer +from lightning.fabric.loggers import CSVLogger +from lightning.pytorch.loggers import WandbLogger +from optims.lr_scheduler import make_scheduler, scheduler_step, scheduler_step_on_metric, get_last_lr +from pipelines import pipeline_from_config, pipeline_from_name +import omegaconf +import lovely_tensors as lt +lt.monkey_patch() +from tqdm import tqdm +from evaluations import IsBestTracker, summary +import time +from lightning.fabric import Fabric +from pefts import apply_peft +from general_utils.dist_utils import verify_ddp_weights_equal +from functools import partial +from fabric.fabric import setup_dataloader_from_dataset + + +if __name__ == '__main__': + + + cfg: Config = config.init(root) + torch.set_float32_matmul_precision(cfg.trainers.float32_matmul_precision) + print('matmul precision', cfg.trainers.float32_matmul_precision) + print('precision', cfg.trainers.precision) + + random_utils.setup_seed(seed=cfg.trainers.seed, cuda_deterministic=False) + + loggers = [] + csv_logger = CSVLogger(root_dir=cfg.trainers.output_dir, flush_logs_every_n_steps=1) + loggers.append(csv_logger) + if cfg.trainers.using_wandb: + wandb_logger = WandbLogger(project=cfg.trainers.task, save_dir=cfg.trainers.output_dir, + name=os.path.basename(cfg.trainers.output_dir), + log_model=False) + loggers.append(wandb_logger) + + # grad_max_norm? + fabric = Fabric(precision=cfg.trainers.precision, + loggers=loggers, + accelerator="auto", + strategy="ddp", + devices=cfg.trainers.num_gpu) + fabric.seed_everything(cfg.trainers.seed) + if cfg.trainers.num_gpu == 1: + fabric.launch() + fabric.setup_dataloader_from_dataset = partial(setup_dataloader_from_dataset, fabric=fabric, seed=cfg.trainers.seed) + + cfg.trainers.local_rank = fabric.local_rank + cfg.trainers.world_size = fabric.world_size + print = fabric.print + + # get model + model = get_model(cfg.models, cfg.trainers.task) + train_transform = model.make_train_transform() + test_transform = model.make_test_transform() + + # get dataloader + dataset, label_mapping = get_train_dataset(cfg.dataset, train_transform, cfg.data_augs, local_rank=cfg.trainers.local_rank) + dataloader = fabric.setup_dataloader_from_dataset(dataset=dataset, + is_train=True, + batch_size=cfg.trainers.batch_size, + num_workers=cfg.trainers.num_workers) + cfg.trainers.total_batch_size = cfg.trainers.batch_size * cfg.trainers.world_size + batch_length = len(dataloader.dataset) // cfg.trainers.total_batch_size + batch_length = batch_length if cfg.trainers.limit_num_batch <= 0 else cfg.trainers.limit_num_batch + cfg.trainers.warmup_step = batch_length * cfg.optims.warmup_epoch + cfg.trainers.total_step = batch_length * cfg.optims.num_epoch + if fabric.local_rank == 0: + visualize_dataset(dataloader, os.path.join(cfg.trainers.output_dir, 'train_data.png')) + + + # get classifier + margin_loss_fn = get_margin_loss(cfg.losses) + + extra_classes = 0 + classifier = get_classifier(cfg.classifiers, + margin_loss_fn=margin_loss_fn, + model_cfg=cfg.models, + num_classes=cfg.dataset.num_classes+extra_classes, + rank=fabric.local_rank, + world_size=fabric.world_size) + + # get aligner + aligner = get_aligner(cfg.aligners) + + # apply peft if needed + model, classifier = apply_peft(cfg.pefts, model=model, classifier=classifier, data_cfg=cfg.dataset, label_mapping=label_mapping) + + # get optimizer + optimizer = make_optimizer(cfg, model, classifier, aligner) + lr_scheduler = make_scheduler(cfg, optimizer) + + # prepare accelerator + if model.has_trainable_params(): + model, optimizer = fabric.setup(model, optimizer) + else: + model = model.to(fabric.device) + dummy_model = torch.nn.Linear(1, 1).to(fabric.device) + dummy_model, optimizer = fabric.setup(dummy_model, optimizer) + if classifier is not None: + if classifier.apply_ddp: + classifier = fabric.setup(classifier) + else: + classifier = classifier.to(fabric.device) # no ddp as it divides fc into multiple GPUs + if aligner.has_trainable_params(): + aligner = fabric.setup(aligner) + elif aligner is not None: + aligner = aligner.to(fabric.device) + + + verify_ddp_weights_equal(model) + if classifier is not None: + verify_ddp_weights_equal(classifier) + + # make train pipe (after accelerator setup) + train_pipeline = pipeline_from_config(cfg.pipelines, model, classifier, aligner, optimizer, lr_scheduler) + train_pipeline.integrity_check(dataloader.dataset) + + # make inference pipe (after accelerator setup) + eval_pipeline = pipeline_from_name(cfg.pipelines.eval_pipeline_name, model, aligner) + eval_pipeline.integrity_check(dataloader.dataset.color_space) + + # evaluation callbacks + evaluators = [] + for name, info in cfg.evaluations.per_epoch_evaluations.items(): + eval_data_path = os.path.join(cfg.evaluations.data_root, info.path) + eval_type = info.evaluation_type + eval_batch_size = info.batch_size * 4 + eval_num_workers = info.num_workers + evaluator = get_evaluator_by_name(eval_type=eval_type, name=name, eval_data_path=eval_data_path, + transform=eval_pipeline.make_test_transform(), + fabric=fabric, batch_size=eval_batch_size, num_workers=eval_num_workers) + evaluator.integrity_check(info.color_space, eval_pipeline.color_space) + evaluator.config = info + evaluators.append(evaluator) + + # copy project files + if fabric.local_rank == 0: + code_dir = os.path.dirname(os.path.abspath(__file__)) + os_utils.copy_project_files(code_dir, cfg.trainers.output_dir) + omegaconf.OmegaConf.save(cfg, os.path.join(cfg.trainers.output_dir, 'config.yaml')) + os.makedirs(os.path.join(cfg.trainers.output_dir, 'lightning_logs'), exist_ok=True) + + # train + step = train_pipeline.step + n_images_seen = train_pipeline.n_images_seen + n_epochs = cfg.optims.num_epoch - train_pipeline.start_epoch + print(f"start at {train_pipeline.start_epoch} and training for {n_epochs} epochs") + is_best_tracker = IsBestTracker(fabric) + tic = time.time() + epoch = train_pipeline.start_epoch + for epoch in range(train_pipeline.start_epoch, cfg.optims.num_epoch): + epoch_start_time = time.time() + freeze_backbone_epochs = cfg.trainers.get('freeze_backbone_epochs', 0) + freeze_backbone = epoch < freeze_backbone_epochs + for param in model.parameters(): + param.requires_grad = not freeze_backbone + if freeze_backbone: + print(f'Backbone frozen for epoch {epoch} / {freeze_backbone_epochs}') + elif freeze_backbone_epochs > 0 and epoch == freeze_backbone_epochs: + print(f'Backbone unfrozen at epoch {epoch}') + train_pipeline.train() + if freeze_backbone: + model.eval() + set_epoch(dataloader, epoch, cfg) + batch_length = len(dataloader) if cfg.trainers.limit_num_batch <= 0 else cfg.trainers.limit_num_batch + pbar = tqdm(total=batch_length, disable=fabric.local_rank != 0) + if cfg.trainers.local_rank == 0: + print('\nRun Name', os.path.basename(cfg.trainers.output_dir)) + for batch_idx, batch in enumerate(dataloader): + + if cfg.trainers.limit_num_batch > 0 and batch_idx >= cfg.trainers.limit_num_batch: + break + + if cfg.trainers.mock_lr_run: + loss = 0 + else: + is_accumulating = batch_idx % cfg.trainers.gradient_acc != 0 + sync_module = model if 'dummy_model' not in locals() else dummy_model + with fabric.no_backward_sync(sync_module, enabled=is_accumulating): + with fabric.autocast(): + loss = train_pipeline(batch) + fabric.backward(loss) + if not is_accumulating: + fabric.clip_gradients(model, optimizer, max_norm=cfg.optims.max_grad_norm) + optimizer.step() + optimizer.zero_grad() + + scheduler_step(lr_scheduler, step) + last_lr = get_last_lr(optimizer) + + n_images_seen += cfg.trainers.total_batch_size + step += 1 + + if batch_idx % 50 == 0: + log_dict = {} + log_dict['epoch'] = epoch + log_dict['step'] = step + log_dict['n_images_seen'] = n_images_seen + log_dict['train/loss'] = loss + log_dict['train/lr'] = last_lr + log_dict['trainer/global_step'] = step + log_dict['trainer/epoch'] = epoch + fabric.log_dict(log_dict, step=step) + + speed = cfg.trainers.batch_size / (time.time() - tic) + speed_total = speed * fabric.world_size + pbar.set_description(f"Epoch {epoch} | Step {step} | Batch {batch_idx} | Speed {speed_total:.0f} | LR {last_lr:.5f} | Loss {loss:.4f}") + pbar.update(1) + tic = time.time() + + # validation + if cfg.evaluations.eval_every_n_epochs > 0: + print('Evaluation Started') + eval_start_time = time.time() + all_result = {} + for evaluator in evaluators: + if (epoch % cfg.evaluations.eval_every_n_epochs == 0 # every n epochs + or epoch == (cfg.optims.num_epoch - 1) # last epoch + or epoch + 1 in cfg.optims.lr_milestones # lr decay + ): + print(f"Evaluating {evaluator.name}") + result = evaluator.evaluate(eval_pipeline, epoch=epoch, step=step, n_images_seen=n_images_seen) + all_result.update({evaluator.name + "/" + k: v for k, v in result.items()}) + eval_time = (time.time() - eval_start_time) / 60 + print(f'Evaluation Time: {eval_time:.2f} mins') + + if fabric.local_rank == 0: + if all_result: + os.makedirs(os.path.join(cfg.trainers.output_dir, 'result'), exist_ok=True) + save_result = pd.DataFrame(pd.Series(all_result), columns=['val']) + save_result.to_csv(os.path.join(cfg.trainers.output_dir, f'result/eval_{epoch}_{step}.csv')) + mean, summary_dict = summary(save_result, epoch, step, n_images_seen) + fabric.log_dict(summary_dict) + summary_result = pd.DataFrame(pd.Series(summary_dict), columns=['val']) + summary_result.to_csv(os.path.join(cfg.trainers.output_dir, f'result/eval_summary_{epoch}_{step}.csv')) + else: + print('Skipped evaluation. So best is not updated') + mean = is_best_tracker.prev_best_metric + else: + mean = -1.0 + is_best_tracker.set_is_best(mean) + scheduler_step_on_metric(lr_scheduler, mean) + if fabric.local_rank == 0: + fabric.log_dict({'is_best': float(is_best_tracker.is_best())}) + print(f'Epoch {epoch} | Step {step} | Best {is_best_tracker.is_best()}') + if all_result: + print(summary_result.round(2).to_markdown()) + + # save model + train_pipeline.save(fabric, train_pipeline, cfg, epoch, step, n_images_seen, + is_best=is_best_tracker.is_best()) + print('Evaluation Finished and Model Saved') + + epoch_time = (time.time() - epoch_start_time) / 60 + print(f'Epoch Time: {epoch_time:.2f} mins') + + # load best model and do final eval + is_best_path = os.path.join(cfg.trainers.output_dir, 'checkpoints', 'best') + epoch = epoch + 1 + step = step + 1 + n_images_seen = n_images_seen + 1 + if os.path.exists(is_best_path) and cfg.trainers.skip_final_eval is False: + fabric.barrier() + time.sleep(fabric.local_rank * 5) # prevent concurrent file access + eval_pipeline.model.load_state_dict_from_path(os.path.join(is_best_path, 'model.pt')) + print('Final Evaluation Started') + + # evaluation callbacks + cfg.evaluations = config.load_yaml('final', directory='evaluations') + evaluators = [] + for name, info in cfg.evaluations.per_epoch_evaluations.items(): + eval_data_path = os.path.join(cfg.evaluations.data_root, info.path) + eval_type = info.evaluation_type + eval_batch_size = info.batch_size + eval_num_workers = info.num_workers + evaluator = get_evaluator_by_name(eval_type=eval_type, name=name, eval_data_path=eval_data_path, + transform=eval_pipeline.make_test_transform(), + fabric=fabric, batch_size=eval_batch_size, num_workers=eval_num_workers) + evaluator.integrity_check(info.color_space, eval_pipeline.color_space) + evaluators.append(evaluator) + + + all_result = {} + for evaluator in evaluators: + print(f"Evaluating {evaluator.name}") + result = evaluator.evaluate(eval_pipeline, epoch=epoch, step=step, n_images_seen=n_images_seen) + all_result.update({evaluator.name + "/" + k: v for k, v in result.items()}) + + if fabric.local_rank == 0: + os.makedirs(os.path.join(cfg.trainers.output_dir, 'result'), exist_ok=True) + save_result = pd.DataFrame(pd.Series(all_result), columns=['val']) + save_result.to_csv(os.path.join(cfg.trainers.output_dir, f'result/eval_best.csv')) + mean, summary_dict = summary(save_result, epoch, step, n_images_seen) + summary_dict = {k.replace('summary/', 'final/'): v for k, v in summary_dict.items()} + # round to 2 decimal places + summary_dict = {k: np.round(v, 2) for k, v in summary_dict.items()} + fabric.log_dict(summary_dict) + pd.DataFrame(pd.Series(summary_dict), columns=['val']).to_csv( + os.path.join(cfg.trainers.output_dir, f'result/eval_summary_best.csv')) + else: + print('Skip final evaluation') + + # close + print('done') diff --git a/cvlface/research/recognition/code/run_v1/trainers/configs/agedb_swin.yaml b/cvlface/research/recognition/code/run_v1/trainers/configs/agedb_swin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7ce0b1d7b8090ef45b9c00ac45eaa2f8d2799ee --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/trainers/configs/agedb_swin.yaml @@ -0,0 +1,31 @@ +task: '' +prefix: 'agedb_swin_s' +debug: False +mock_lr_run: False +limit_num_batch: -1 +skip_final_eval: False + +batch_size: 64 +num_gpu: 1 +precision: '32-true' +float32_matmul_precision: high + +gradient_acc: 1 +seed: 2048 +num_workers: 4 +freeze_backbone_epochs: 10 + +wandb_key: '' +suffix_run_name: None +using_wandb: False +wandb_entity: '' +wandb_log_all: False + +resume: '' + +local_rank: -1 +world_size: -1 +total_batch_size: -1 +warmup_step: -1 +total_step: -1 +output_dir: '' diff --git a/cvlface/research/recognition/code/run_v1/trainers/configs/debug.yaml b/cvlface/research/recognition/code/run_v1/trainers/configs/debug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..22835f1259d23dfbfc62a0c6cea1b4bb0426f7f9 --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/trainers/configs/debug.yaml @@ -0,0 +1,35 @@ + +task: '' +prefix: 'debug' +debug: True +mock_lr_run: False +limit_num_batch: 10 +skip_final_eval: False + + +batch_size: 32 # per gpu +num_gpu: 1 +#precision: '16-mixed' # 32-true +precision: '32-true' +float32_matmul_precision: highest + + +gradient_acc: 1 +seed: 2048 +num_workers: 0 + +wandb_key: '' +suffix_run_name: None +using_wandb: False +wandb_entity: 'mckim' +wandb_log_all: False + +resume: '' + + +local_rank: -1 # placeholder +world_size: -1 # placeholder +total_batch_size: -1 # placeholder +warmup_step: -1 # placeholder +total_step: -1 # placeholder +output_dir: '' # placeholder \ No newline at end of file diff --git a/cvlface/research/recognition/code/run_v1/trainers/configs/default.yaml b/cvlface/research/recognition/code/run_v1/trainers/configs/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..de7f7b263c24b61ad872a5de8f7b87be25378e7c --- /dev/null +++ b/cvlface/research/recognition/code/run_v1/trainers/configs/default.yaml @@ -0,0 +1,32 @@ + +task: '' +prefix: 'default' +debug: False +mock_lr_run: False +limit_num_batch: -1 +skip_final_eval: False + +batch_size: 512 # per gpu +num_gpu: 1 +#precision: '16-mixed' # 32-true +precision: '32-true' +float32_matmul_precision: highest + +gradient_acc: 1 +seed: 2048 +num_workers: 3 + +wandb_key: '' +suffix_run_name: None +using_wandb: True +wandb_entity: 'mckim' +wandb_log_all: False + +resume: '' + +local_rank: -1 # placeholder +world_size: -1 # placeholder +total_batch_size: -1 # placeholder +warmup_step: -1 # placeholder +total_step: -1 # placeholder +output_dir: '' # placeholder \ No newline at end of file diff --git "a/cvlface/research/recognition/code/run_v1/\350\257\264\346\230\216.md" "b/cvlface/research/recognition/code/run_v1/\350\257\264\346\230\216.md" new file mode 100644 index 0000000000000000000000000000000000000000..15907f42ed2af3745b01207a1c5954f753bfabf4 --- /dev/null +++ "b/cvlface/research/recognition/code/run_v1/\350\257\264\346\230\216.md" @@ -0,0 +1,106 @@ +# run_v1 目录说明 + +位置: +E:\JingSai\参考timm的swim-s架构+adaface_onlyagedb\cvlface\research\recognition\code\run_v1 + +此文档概述该目录下的主要文件与子模块,便于快速上手与二次开发。 + +--- + +## 目录与文件一览(摘要) + +- README.md +- aligners/ — 对齐器(如 RetinaFace、差分可微的对齐器)实现与流水线 +- base.yaml — Hydra 的默认配置集合(指定 models/、losses/、pipelines/ 等默认 yaml) +- classifiers/ — 分类头/识别器实现(partial_fc、fc 等) +- config.py — 配置加载与输出目录准备逻辑(Hydra + OmegaConf 封装) +- data_augs/ — 数据增强工具与封装 +- dataset/ — 数据集适配器与加载器(support for repeated_dataset, subset, general_dataset) +- eval.py — 评估入口脚本 +- evaluations/ — 多种评估器(verification、ijb、tinyface 等)实现 +- fabric/ — Fabric(Lightning Fabric)相关辅助代码与 dataloader setup +- losses/ — 损失函数实现(如 AdaFace、margin loss 等) +- models/ — 模型实现(Swin、ViT、iResNet、part_fvit 等);查看 `models/swin/swin` 得到 Swin 具体实现 +- optims/ — 优化器、学习率策略(make_optimizer、lr_scheduler) +- pefts/ — PEFT(参数高效微调)相关代码 +- pipelines/ — 训练/推理流水线实现(train_model_cls_pipeline.py、infer_model_pipeline.py 等) +- scripts/ — 辅助脚本(例如生成评估协议的 prepare_agedb_protocol.py) +- test_model.py — 单张图片/批量推理与测试脚本 +- train.py — 主训练脚本,负责:配置解析、模型/optimizer 构建、训练循环、评估与保存 +- trainers/ — 训练器相关逻辑(若存在自定义训练封装) + +--- + +## 核心文件说明(更详细) + +- `train.py` + - 使用 Hydra/OmegaConf (`config.py`) 加载配置 + - 初始化 Fabric(分布式、mixed-precision、logger)并设置随机种子 + - 构建模型(`get_model`)、分类器(`get_classifier`)、aligner(`get_aligner`)、损失函数(`get_margin_loss`) + - 准备 dataloader(`setup_dataloader_from_dataset`),并计算 `total_step`/`warmup_step` + - 构建 train_pipeline 与 eval_pipeline(`pipelines/` 中定义),训练主循环包含梯度累积、混合精度、学习率 step、评估回调、模型保存 + +- `config.py` / `base.yaml` + - `base.yaml` 定义了运行时默认使用的子配置(模型选择、loss、pipeline 等) + - `config.init(root)`:用 Hydra 组合配置、为每个子配置记录 yaml_path,并创建输出目录(experiments/...) + +- `models/` + - 包括多个 backbone(swin、vit、iresnet 等)与适配器;例如 Swin 的实现分布在 `models/swin/swin/` 中(`model.py`、`modules_v1.py` 等) + - `get_model` 根据配置构建实例并提供 `make_train_transform()` / `make_test_transform()` 等方法 + +- `pipelines/` + - 将训练/推理拆成可复用的 pipeline(例如数据预处理、forward、loss 计算、后处理),方便在 train.py 中以统一接口调用 + +- `evaluations/` + - 封装了不同任务/数据集的评估器(verification、tinyface、ijb 等),提供 `evaluate()` 接口并输出指标 + +- `losses/`(如 `adaface.py`) + - 实现 AdaFace 或其它 margin loss,供 `get_margin_loss` 动态加载 + +- `scripts/`(辅助) + - 包含生成评估协议等工具脚本,例如 `prepare_agedb_protocol.py` 用于准备 agedb 验证协议 + +--- + +## 快速上手示例 + +1. 安装依赖并进入代码目录 + +```powershell +cd E:\JingSai\参考timm的swim-s架构+adaface_onlyagedb\cvlface\research\recognition\code\run_v1 +python -m venv venv +venv\Scripts\Activate +pip install -r requirements.txt +``` + +2. 运行训练(示例) + +```powershell +python train.py # 若需覆盖默认配置,使用 hydra 的 overrides,例如:python train.py models=vit/configs/v1_small optims.cosine=True +``` + +3. 运行评估 + +```powershell +python eval.py --config --checkpoint +``` + +输出目录位于 `research/recognition/experiments//...`,由 `config.prepare_output_dir` 自动生成并存放 checkpoints、results、config.yaml 等。 + +--- + +## 推荐阅读顺序(快速理解代码) +1. `config.py` 与 `base.yaml`(理解配置体系) +2. `train.py`(主流程) +3. `pipelines/`(理解 pipeline 封装) +4. `models/`(选择并阅读目标 backbone 实现) +5. `losses/` 与 `evaluations/`(训练目标与验证指标) + +--- + +如果需要,我可以: +- 将该说明追加到本目录下的 `README.md`; +- 为每个子目录自动生成更详细的 README(包括重要函数/类的摘录); +- 生成一份可打印的目录结构树(带文件大小与行数)。 + +*说明文档已生成。* \ No newline at end of file diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/config.yaml b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..805c6d9d8b45fe0fe50f8ff816a4b4a5d413ec4b --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/config.yaml @@ -0,0 +1,111 @@ +trainers: + task: run_v1 + prefix: agedb_swin_s + debug: false + mock_lr_run: false + limit_num_batch: -1 + skip_final_eval: false + batch_size: 64 + num_gpu: 1 + precision: 32-true + float32_matmul_precision: high + gradient_acc: 1 + seed: 2048 + num_workers: 4 + freeze_backbone_epochs: 10 + wandb_key: '' + suffix_run_name: None + using_wandb: false + wandb_entity: '' + wandb_log_all: false + resume: '' + local_rank: 0 + world_size: 1 + total_batch_size: 64 + warmup_step: 0 + total_step: 10300 + output_dir: /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7 + yaml_path: /configs/agedb_swin.yaml +optims: + num_epoch: 50 + optimizer: adamw + lr: 0.001 + momentum: 0.9 + weight_decay: 0.05 + scheduler: plateau + filter_bias_and_bn: true + warmup_epoch: 0 + max_grad_norm: 5.0 + lr_milestones: [] + plateau_factor: 0.5 + plateau_patience: 2 + plateau_threshold: 0.01 + plateau_cooldown: 0 + min_lr: 1.0e-06 + yaml_path: /configs/plateau_adamw_agedb.yaml +pefts: + name: none + yaml_path: /configs/none.yaml +models: + input_size: + - 3 + - 112 + - 112 + color_space: RGB + name: timm_swin_s + timm_name: swin_small_patch4_window7_224.ms_in22k_ft_in1k + output_dim: 512 + start_from: ${oc.env:SWIN_S_PRETRAINED,/root/Adaface/cvlface/cvlface/pretrained_models/model.safetensors} + freeze: false + mask_ratio: 0.0 + yaml_path: /swin/configs/v1_swin_s_pretrained.yaml +classifiers: + name: fc + sample_rate: 1.0 + start_from: '' + freeze: false + yaml_path: /configs/fc.yaml +aligners: + name: none + start_from: '' + freeze: false + yaml_path: /configs/none.yaml +dataset: + data_root: ${oc.env:AGEDB_PROTOCOL_ROOT} + rec: agedb_train_80 + color_space: RGB + num_classes: 567 + num_image: 13200 + repeated_sampling_cfg: null + semi_sampling_cfg: null + yaml_path: /configs/agedb_80.yaml +data_augs: + augmentation_version: basic + aug_params: + crop_augmentation_prob: 0.2 + photometric_augmentation_prob: 0.2 + low_res_augmentation_prob: 0.2 + yaml_path: /configs/basic_v1.yaml +losses: + margin_loss_name: adaface + interclass_filtering_threshold: 0 + m: 0.4 + h: 0.333 + t_alpha: 0.01 + yaml_path: /configs/adaface.yaml +pipelines: + name: TrainModelClsPipeline + resume: ${trainers.resume} + eval_pipeline_name: infer_model_pipeline + yaml_path: /configs/train_model_cls.yaml +evaluations: + data_root: ${oc.env:AGEDB_PROTOCOL_ROOT} + eval_every_n_epochs: 5 + per_epoch_evaluations: + agedb_30: + path: facerec_val/agedb_30_1to1 + evaluation_type: verification + color_space: RGB + batch_size: 32 + num_workers: 4 + yaml_path: /configs/agedb_30_1to1.yaml diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/lightning_logs/version_0/metrics.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/lightning_logs/version_0/metrics.csv new file mode 100644 index 0000000000000000000000000000000000000000..4cdc46ec22f7ff95a5836fa0478843e11b6ef1d3 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/lightning_logs/version_0/metrics.csv @@ -0,0 +1,325 @@ +epoch,final/agedb_30_acc,is_best,n_images_seen,step,summary/agedb_30_acc,train/loss,train/lr,trainer/epoch,trainer/global_step,val/agedb_30/acc,val/agedb_30/std +0,,,64,1,,35.794403076171875,0.001,0,1,, +0,,,3264,51,,32.532981872558594,0.001,0,51,, +0,,,6464,101,,32.071876525878906,0.001,0,101,, +0,,,9664,151,,32.09893035888672,0.001,0,151,, +0,,,12864,201,,31.980119705200195,0.001,0,201,, +0,,,13184,0,,,,,,29.394812680115272,4.318922402151985 +0,,,13184,0,29.394812680115272,,,0,206,, +,,1.0,,0,,,,,,, +1,,,13248,207,,28.538619995117188,0.001,1,207,, +1,,,16448,257,,30.23782730102539,0.001,1,257,, +1,,,19648,307,,29.089630126953125,0.001,1,307,, +1,,,22848,357,,29.241548538208008,0.001,1,357,, +1,,,26048,407,,30.040714263916016,0.001,1,407,, +,,1.0,,0,,,,,,, +2,,,26432,413,,27.22234344482422,0.001,2,413,, +2,,,29632,463,,28.176830291748047,0.001,2,463,, +2,,,32832,513,,28.409963607788086,0.001,2,513,, +2,,,36032,563,,28.01868438720703,0.001,2,563,, +2,,,39232,613,,28.416799545288086,0.001,2,613,, +,,0.0,,0,,,,,,, +3,,,39616,619,,26.2569637298584,0.001,3,619,, +3,,,42816,669,,28.069446563720703,0.001,3,669,, +3,,,46016,719,,28.62200927734375,0.001,3,719,, +3,,,49216,769,,27.78118133544922,0.001,3,769,, +3,,,52416,819,,28.470672607421875,0.001,3,819,, +,,0.0,,0,,,,,,, +4,,,52800,825,,26.465429306030273,0.0005,4,825,, +4,,,56000,875,,25.881080627441406,0.0005,4,875,, +4,,,59200,925,,26.93671417236328,0.0005,4,925,, +4,,,62400,975,,26.526100158691406,0.0005,4,975,, +4,,,65600,1025,,26.093551635742188,0.0005,4,1025,, +,,0.0,,0,,,,,,, +5,,,65984,1031,,25.727582931518555,0.0005,5,1031,, +5,,,69184,1081,,25.943117141723633,0.0005,5,1081,, +5,,,72384,1131,,26.025737762451172,0.0005,5,1131,, +5,,,75584,1181,,26.077341079711914,0.0005,5,1181,, +5,,,78784,1231,,26.131643295288086,0.0005,5,1231,, +5,,,79104,0,,,,,,29.394812680115272,4.318922402151985 +5,,,79104,0,29.394812680115272,,,5,1236,, +,,0.0,,0,,,,,,, +6,,,79168,1237,,26.172409057617188,0.0005,6,1237,, +6,,,82368,1287,,25.917251586914062,0.0005,6,1287,, +6,,,85568,1337,,25.951278686523438,0.0005,6,1337,, +6,,,88768,1387,,25.928762435913086,0.0005,6,1387,, +6,,,91968,1437,,26.48139762878418,0.0005,6,1437,, +,,0.0,,0,,,,,,, +7,,,92352,1443,,25.07032012939453,0.00025,7,1443,, +7,,,95552,1493,,25.209875106811523,0.00025,7,1493,, +7,,,98752,1543,,25.512048721313477,0.00025,7,1543,, +7,,,101952,1593,,25.850997924804688,0.00025,7,1593,, +7,,,105152,1643,,26.017515182495117,0.00025,7,1643,, +,,0.0,,0,,,,,,, +8,,,105536,1649,,24.898712158203125,0.00025,8,1649,, +8,,,108736,1699,,25.6771240234375,0.00025,8,1699,, +8,,,111936,1749,,25.276187896728516,0.00025,8,1749,, +8,,,115136,1799,,25.30967903137207,0.00025,8,1799,, +8,,,118336,1849,,25.892379760742188,0.00025,8,1849,, +,,0.0,,0,,,,,,, +9,,,118720,1855,,25.167795181274414,0.00025,9,1855,, +9,,,121920,1905,,24.824111938476562,0.00025,9,1905,, +9,,,125120,1955,,24.600887298583984,0.00025,9,1955,, +9,,,128320,2005,,24.488906860351562,0.00025,9,2005,, +9,,,131520,2055,,25.356266021728516,0.00025,9,2055,, +,,0.0,,0,,,,,,, +10,,,131904,2061,,25.911590576171875,0.000125,10,2061,, +10,,,135104,2111,,26.838092803955078,0.000125,10,2111,, +10,,,138304,2161,,26.692649841308594,0.000125,10,2161,, +10,,,141504,2211,,26.338706970214844,0.000125,10,2211,, +10,,,144704,2261,,26.159988403320312,0.000125,10,2261,, +10,,,145024,0,,,,,,57.34870317002882,6.954744958615956 +10,,,145024,0,57.34870317002882,,,10,2266,, +,,1.0,,0,,,,,,, +11,,,145088,2267,,24.127971649169922,0.000125,11,2267,, +11,,,148288,2317,,24.254913330078125,0.000125,11,2317,, +11,,,151488,2367,,22.46211814880371,0.000125,11,2367,, +11,,,154688,2417,,22.432842254638672,0.000125,11,2417,, +11,,,157888,2467,,23.89962387084961,0.000125,11,2467,, +,,0.0,,0,,,,,,, +12,,,158272,2473,,20.329334259033203,0.000125,12,2473,, +12,,,161472,2523,,20.036849975585938,0.000125,12,2523,, +12,,,164672,2573,,19.660903930664062,0.000125,12,2573,, +12,,,167872,2623,,20.772428512573242,0.000125,12,2623,, +12,,,171072,2673,,20.4401798248291,0.000125,12,2673,, +,,0.0,,0,,,,,,, +13,,,171456,2679,,17.140426635742188,0.000125,13,2679,, +13,,,174656,2729,,17.520263671875,0.000125,13,2729,, +13,,,177856,2779,,17.25941276550293,0.000125,13,2779,, +13,,,181056,2829,,16.521154403686523,0.000125,13,2829,, +13,,,184256,2879,,17.12468719482422,0.000125,13,2879,, +,,0.0,,0,,,,,,, +14,,,184640,2885,,15.420740127563477,6.25e-05,14,2885,, +14,,,187840,2935,,12.112133026123047,6.25e-05,14,2935,, +14,,,191040,2985,,12.645011901855469,6.25e-05,14,2985,, +14,,,194240,3035,,11.752267837524414,6.25e-05,14,3035,, +14,,,197440,3085,,15.279365539550781,6.25e-05,14,3085,, +,,0.0,,0,,,,,,, +15,,,197824,3091,,10.041845321655273,6.25e-05,15,3091,, +15,,,201024,3141,,9.269998550415039,6.25e-05,15,3141,, +15,,,204224,3191,,10.880020141601562,6.25e-05,15,3191,, +15,,,207424,3241,,8.773850440979004,6.25e-05,15,3241,, +15,,,210624,3291,,10.71726131439209,6.25e-05,15,3291,, +15,,,210944,0,,,,,,69.10662824207493,5.857734177049752 +15,,,210944,0,69.10662824207493,,,15,3296,, +,,1.0,,0,,,,,,, +16,,,211008,3297,,6.62135124206543,6.25e-05,16,3297,, +16,,,214208,3347,,8.529400825500488,6.25e-05,16,3347,, +16,,,217408,3397,,10.441800117492676,6.25e-05,16,3397,, +16,,,220608,3447,,8.276819229125977,6.25e-05,16,3447,, +16,,,223808,3497,,8.192054748535156,6.25e-05,16,3497,, +,,1.0,,0,,,,,,, +17,,,224192,3503,,7.108463287353516,6.25e-05,17,3503,, +17,,,227392,3553,,6.143529891967773,6.25e-05,17,3553,, +17,,,230592,3603,,7.9150590896606445,6.25e-05,17,3603,, +17,,,233792,3653,,7.240469455718994,6.25e-05,17,3653,, +17,,,236992,3703,,6.405365943908691,6.25e-05,17,3703,, +,,0.0,,0,,,,,,, +18,,,237376,3709,,5.5641350746154785,6.25e-05,18,3709,, +18,,,240576,3759,,6.510687828063965,6.25e-05,18,3759,, +18,,,243776,3809,,4.588730812072754,6.25e-05,18,3809,, +18,,,246976,3859,,4.331736087799072,6.25e-05,18,3859,, +18,,,250176,3909,,5.104048252105713,6.25e-05,18,3909,, +,,0.0,,0,,,,,,, +19,,,250560,3915,,3.5035879611968994,3.125e-05,19,3915,, +19,,,253760,3965,,2.81095290184021,3.125e-05,19,3965,, +19,,,256960,4015,,2.673839807510376,3.125e-05,19,4015,, +19,,,260160,4065,,3.734635829925537,3.125e-05,19,4065,, +19,,,263360,4115,,3.4280762672424316,3.125e-05,19,4115,, +,,0.0,,0,,,,,,, +20,,,263744,4121,,1.8836697340011597,3.125e-05,20,4121,, +20,,,266944,4171,,2.8146469593048096,3.125e-05,20,4171,, +20,,,270144,4221,,2.82645320892334,3.125e-05,20,4221,, +20,,,273344,4271,,2.600339412689209,3.125e-05,20,4271,, +20,,,276544,4321,,3.2413434982299805,3.125e-05,20,4321,, +20,,,276864,0,,,,,,76.34005763688761,8.719615357523367 +20,,,276864,0,76.34005763688761,,,20,4326,, +,,1.0,,0,,,,,,, +21,,,276928,4327,,1.9926902055740356,3.125e-05,21,4327,, +21,,,280128,4377,,1.949006199836731,3.125e-05,21,4377,, +21,,,283328,4427,,2.46518611907959,3.125e-05,21,4427,, +21,,,286528,4477,,1.970572829246521,3.125e-05,21,4477,, +21,,,289728,4527,,1.7479913234710693,3.125e-05,21,4527,, +,,0.0,,0,,,,,,, +22,,,290112,4533,,1.6653087139129639,3.125e-05,22,4533,, +22,,,293312,4583,,1.7446742057800293,3.125e-05,22,4583,, +22,,,296512,4633,,1.781663179397583,3.125e-05,22,4633,, +22,,,299712,4683,,2.518563747406006,3.125e-05,22,4683,, +22,,,302912,4733,,1.203094244003296,3.125e-05,22,4733,, +,,0.0,,0,,,,,,, +23,,,303296,4739,,1.7377445697784424,3.125e-05,23,4739,, +23,,,306496,4789,,1.3356142044067383,3.125e-05,23,4789,, +23,,,309696,4839,,1.4560878276824951,3.125e-05,23,4839,, +23,,,312896,4889,,2.476775884628296,3.125e-05,23,4889,, +23,,,316096,4939,,0.7002238035202026,3.125e-05,23,4939,, +,,0.0,,0,,,,,,, +24,,,316480,4945,,1.4986075162887573,1.5625e-05,24,4945,, +24,,,319680,4995,,1.5489041805267334,1.5625e-05,24,4995,, +24,,,322880,5045,,0.8734735250473022,1.5625e-05,24,5045,, +24,,,326080,5095,,0.9835149049758911,1.5625e-05,24,5095,, +24,,,329280,5145,,1.5519623756408691,1.5625e-05,24,5145,, +,,0.0,,0,,,,,,, +25,,,329664,5151,,1.0576550960540771,1.5625e-05,25,5151,, +25,,,332864,5201,,1.3905539512634277,1.5625e-05,25,5201,, +25,,,336064,5251,,0.8133316040039062,1.5625e-05,25,5251,, +25,,,339264,5301,,0.5323072075843811,1.5625e-05,25,5301,, +25,,,342464,5351,,1.0930503606796265,1.5625e-05,25,5351,, +25,,,342784,0,,,,,,75.21613832853026,8.110228979521779 +25,,,342784,0,75.21613832853026,,,25,5356,, +,,0.0,,0,,,,,,, +26,,,342848,5357,,0.5212942361831665,1.5625e-05,26,5357,, +26,,,346048,5407,,0.8142703771591187,1.5625e-05,26,5407,, +26,,,349248,5457,,0.7132399678230286,1.5625e-05,26,5457,, +26,,,352448,5507,,0.9409947991371155,1.5625e-05,26,5507,, +26,,,355648,5557,,0.5178017616271973,1.5625e-05,26,5557,, +,,0.0,,0,,,,,,, +27,,,356032,5563,,0.41073673963546753,7.8125e-06,27,5563,, +27,,,359232,5613,,1.059584140777588,7.8125e-06,27,5613,, +27,,,362432,5663,,0.6552743911743164,7.8125e-06,27,5663,, +27,,,365632,5713,,0.6959383487701416,7.8125e-06,27,5713,, +27,,,368832,5763,,1.0343782901763916,7.8125e-06,27,5763,, +,,0.0,,0,,,,,,, +28,,,369216,5769,,0.7762857675552368,7.8125e-06,28,5769,, +28,,,372416,5819,,0.6086372137069702,7.8125e-06,28,5819,, +28,,,375616,5869,,1.2398935556411743,7.8125e-06,28,5869,, +28,,,378816,5919,,0.5947744846343994,7.8125e-06,28,5919,, +28,,,382016,5969,,1.36199951171875,7.8125e-06,28,5969,, +,,0.0,,0,,,,,,, +29,,,382400,5975,,0.6881289482116699,7.8125e-06,29,5975,, +29,,,385600,6025,,0.8477485179901123,7.8125e-06,29,6025,, +29,,,388800,6075,,0.964682400226593,7.8125e-06,29,6075,, +29,,,392000,6125,,0.6017898321151733,7.8125e-06,29,6125,, +29,,,395200,6175,,0.4798910915851593,7.8125e-06,29,6175,, +,,0.0,,0,,,,,,, +30,,,395584,6181,,0.8289844989776611,3.90625e-06,30,6181,, +30,,,398784,6231,,0.7183796763420105,3.90625e-06,30,6231,, +30,,,401984,6281,,0.44069749116897583,3.90625e-06,30,6281,, +30,,,405184,6331,,0.9298021793365479,3.90625e-06,30,6331,, +30,,,408384,6381,,0.4179686903953552,3.90625e-06,30,6381,, +30,,,408704,0,,,,,,75.67723342939483,7.622682672965119 +30,,,408704,0,75.67723342939483,,,30,6386,, +,,0.0,,0,,,,,,, +31,,,408768,6387,,0.475381076335907,3.90625e-06,31,6387,, +31,,,411968,6437,,0.7154990434646606,3.90625e-06,31,6437,, +31,,,415168,6487,,0.4589882493019104,3.90625e-06,31,6487,, +31,,,418368,6537,,0.8743399381637573,3.90625e-06,31,6537,, +31,,,421568,6587,,0.9172885417938232,3.90625e-06,31,6587,, +,,0.0,,0,,,,,,, +32,,,421952,6593,,0.2661115527153015,3.90625e-06,32,6593,, +32,,,425152,6643,,0.7672258615493774,3.90625e-06,32,6643,, +32,,,428352,6693,,0.5697412490844727,3.90625e-06,32,6693,, +32,,,431552,6743,,0.6254387497901917,3.90625e-06,32,6743,, +32,,,434752,6793,,0.9145208597183228,3.90625e-06,32,6793,, +,,0.0,,0,,,,,,, +33,,,435136,6799,,0.7847840189933777,1.953125e-06,33,6799,, +33,,,438336,6849,,0.38142886757850647,1.953125e-06,33,6849,, +33,,,441536,6899,,0.5422579050064087,1.953125e-06,33,6899,, +33,,,444736,6949,,0.6979982852935791,1.953125e-06,33,6949,, +33,,,447936,6999,,0.48248857259750366,1.953125e-06,33,6999,, +,,0.0,,0,,,,,,, +34,,,448320,7005,,0.7065858840942383,1.953125e-06,34,7005,, +34,,,451520,7055,,0.5315800905227661,1.953125e-06,34,7055,, +34,,,454720,7105,,0.7515548467636108,1.953125e-06,34,7105,, +34,,,457920,7155,,0.3155518174171448,1.953125e-06,34,7155,, +34,,,461120,7205,,0.3042921721935272,1.953125e-06,34,7205,, +,,0.0,,0,,,,,,, +35,,,461504,7211,,0.231420636177063,1.953125e-06,35,7211,, +35,,,464704,7261,,0.26518887281417847,1.953125e-06,35,7261,, +35,,,467904,7311,,0.43423977494239807,1.953125e-06,35,7311,, +35,,,471104,7361,,0.524854838848114,1.953125e-06,35,7361,, +35,,,474304,7411,,0.2501969337463379,1.953125e-06,35,7411,, +35,,,474624,0,,,,,,76.10951008645533,8.067568800688592 +35,,,474624,0,76.10951008645533,,,35,7416,, +,,0.0,,0,,,,,,, +36,,,474688,7417,,0.24859872460365295,1e-06,36,7417,, +36,,,477888,7467,,0.5140712857246399,1e-06,36,7467,, +36,,,481088,7517,,0.34933316707611084,1e-06,36,7517,, +36,,,484288,7567,,0.2907460331916809,1e-06,36,7567,, +36,,,487488,7617,,0.7722290754318237,1e-06,36,7617,, +,,0.0,,0,,,,,,, +37,,,487872,7623,,0.617621123790741,1e-06,37,7623,, +37,,,491072,7673,,0.37979137897491455,1e-06,37,7673,, +37,,,494272,7723,,0.6015068292617798,1e-06,37,7723,, +37,,,497472,7773,,0.539482593536377,1e-06,37,7773,, +37,,,500672,7823,,0.3536941409111023,1e-06,37,7823,, +,,0.0,,0,,,,,,, +38,,,501056,7829,,0.42461299896240234,1e-06,38,7829,, +38,,,504256,7879,,0.5581293702125549,1e-06,38,7879,, +38,,,507456,7929,,0.6669777631759644,1e-06,38,7929,, +38,,,510656,7979,,0.715545654296875,1e-06,38,7979,, +38,,,513856,8029,,0.40226268768310547,1e-06,38,8029,, +,,0.0,,0,,,,,,, +39,,,514240,8035,,0.7434515953063965,1e-06,39,8035,, +39,,,517440,8085,,0.30899831652641296,1e-06,39,8085,, +39,,,520640,8135,,0.459061861038208,1e-06,39,8135,, +39,,,523840,8185,,0.5822697281837463,1e-06,39,8185,, +39,,,527040,8235,,0.4388253688812256,1e-06,39,8235,, +,,0.0,,0,,,,,,, +40,,,527424,8241,,0.4390608072280884,1e-06,40,8241,, +40,,,530624,8291,,0.34162667393684387,1e-06,40,8291,, +40,,,533824,8341,,0.4816899001598358,1e-06,40,8341,, +40,,,537024,8391,,0.2661440670490265,1e-06,40,8391,, +40,,,540224,8441,,0.6104631423950195,1e-06,40,8441,, +40,,,540544,0,,,,,,75.61959654178676,7.391467366699655 +40,,,540544,0,75.61959654178676,,,40,8446,, +,,0.0,,0,,,,,,, +41,,,540608,8447,,0.5636454820632935,1e-06,41,8447,, +41,,,543808,8497,,0.7672128677368164,1e-06,41,8497,, +41,,,547008,8547,,0.7117125988006592,1e-06,41,8547,, +41,,,550208,8597,,0.7243976593017578,1e-06,41,8597,, +41,,,553408,8647,,0.24143250286579132,1e-06,41,8647,, +,,0.0,,0,,,,,,, +42,,,553792,8653,,1.0013597011566162,1e-06,42,8653,, +42,,,556992,8703,,0.5134041905403137,1e-06,42,8703,, +42,,,560192,8753,,0.36960017681121826,1e-06,42,8753,, +42,,,563392,8803,,1.0246589183807373,1e-06,42,8803,, +42,,,566592,8853,,0.7379381060600281,1e-06,42,8853,, +,,0.0,,0,,,,,,, +43,,,566976,8859,,0.783515989780426,1e-06,43,8859,, +43,,,570176,8909,,0.2705845236778259,1e-06,43,8909,, +43,,,573376,8959,,0.5207534432411194,1e-06,43,8959,, +43,,,576576,9009,,0.5415868759155273,1e-06,43,9009,, +43,,,579776,9059,,0.535515308380127,1e-06,43,9059,, +,,0.0,,0,,,,,,, +44,,,580160,9065,,0.26169973611831665,1e-06,44,9065,, +44,,,583360,9115,,0.3804200291633606,1e-06,44,9115,, +44,,,586560,9165,,0.5284651517868042,1e-06,44,9165,, +44,,,589760,9215,,0.6615191698074341,1e-06,44,9215,, +44,,,592960,9265,,0.610895037651062,1e-06,44,9265,, +,,0.0,,0,,,,,,, +45,,,593344,9271,,0.8143655061721802,1e-06,45,9271,, +45,,,596544,9321,,0.4570792317390442,1e-06,45,9321,, +45,,,599744,9371,,0.5556531548500061,1e-06,45,9371,, +45,,,602944,9421,,0.3338557481765747,1e-06,45,9421,, +45,,,606144,9471,,0.5180177092552185,1e-06,45,9471,, +45,,,606464,0,,,,,,76.05187319884726,8.615172435302341 +45,,,606464,0,76.05187319884726,,,45,9476,, +,,0.0,,0,,,,,,, +46,,,606528,9477,,0.602348804473877,1e-06,46,9477,, +46,,,609728,9527,,0.215061753988266,1e-06,46,9527,, +46,,,612928,9577,,0.471391499042511,1e-06,46,9577,, +46,,,616128,9627,,0.26439300179481506,1e-06,46,9627,, +46,,,619328,9677,,0.6358875632286072,1e-06,46,9677,, +,,0.0,,0,,,,,,, +47,,,619712,9683,,0.7661458253860474,1e-06,47,9683,, +47,,,622912,9733,,0.6411811709403992,1e-06,47,9733,, +47,,,626112,9783,,0.46273455023765564,1e-06,47,9783,, +47,,,629312,9833,,0.5521481037139893,1e-06,47,9833,, +47,,,632512,9883,,0.4139159321784973,1e-06,47,9883,, +,,0.0,,0,,,,,,, +48,,,632896,9889,,0.7341700792312622,1e-06,48,9889,, +48,,,636096,9939,,0.3882235288619995,1e-06,48,9939,, +48,,,639296,9989,,0.5051562786102295,1e-06,48,9989,, +48,,,642496,10039,,0.3080351948738098,1e-06,48,10039,, +48,,,645696,10089,,0.5477005243301392,1e-06,48,10089,, +,,0.0,,0,,,,,,, +49,,,646080,10095,,0.8554441332817078,1e-06,49,10095,, +49,,,649280,10145,,0.9449437856674194,1e-06,49,10145,, +49,,,652480,10195,,0.4805559813976288,1e-06,49,10195,, +49,,,655680,10245,,0.7383802533149719,1e-06,49,10245,, +49,,,658880,10295,,0.26509514451026917,1e-06,49,10295,, +49,,,659200,0,,,,,,76.10951008645533,7.208586799222195 +49,,,659200,0,76.10951008645533,,,49,10300,, +,,0.0,,0,,,,,,, +50,,,659201,0,,,,,,76.34005763688761,8.719615357523367 +50,76.34,,659201,0,,,,50,10301,, diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/README.md b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2d37ed8a28805de41e98871385cd5ff24cc418dc --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/README.md @@ -0,0 +1,58 @@ +# Reproducibility Notes + +This directory records the information needed to reproduce the AgeDB Swin-S run `agedb_swin_s_04-15_7`. + +## Run Summary + +- Repository root: `/root/Adaface/cvlface` +- Git commit at run time: `308142aa50adf2e187711354f7524635d3414f1e` +- Branch: `main` +- Training code snapshot: parent experiment directory, for example `train.py`, `models/`, `dataset/`, `evaluations/`, `optims/`, and related modules were copied into the experiment directory by the trainer. +- Merged Hydra config: `../config.yaml` +- Training log: `/root/Adaface/cvlface/logs/agedb_swin_s_train.log` +- Main result CSVs: `../result/` +- Best checkpoint: `../checkpoints/best/model.pt` +- Last checkpoint: `../checkpoints/epoch:49_step:10300/model.pt` + +## Inputs + +- Source aligned images: `/root/Adaface/cvlface/AgeDB_aligned_224` +- Prepared protocol root: `/root/Adaface/cvlface/data/agedb_protocol` +- Train split: `agedb_train_80` +- Verification split: `facerec_val/agedb_30_1to1/pairs.csv` +- Swin-S pretrained weights: `/root/Adaface/cvlface/cvlface/pretrained_models/model.safetensors` +- Pretrained source: `timm/swin_small_patch4_window7_224.ms_in22k_ft_in1k` + +## Key Settings + +- Backbone: `timm_swin_s` +- Input size: `3x112x112` +- Optimizer: `adamw` +- Initial learning rate: `0.001` +- Scheduler: `ReduceLROnPlateau` +- Epochs: `50` +- Freeze backbone epochs: `10` +- Batch size: `64` +- Seed: `2048` +- Precision: `32-true` +- Eval interval: every `5` epochs + +## Commands + +- Data preparation command: `prepare_data_command.sh` +- Training command: `run_command.sh` + +## Captured Files + +- `environment.txt`: Python, OS, PyTorch, CUDA, GPU, timm, and Lightning versions. +- `pip_freeze.txt`: Installed Python packages in the environment after the run. +- `git_status.txt`: Worktree status when this reproducibility bundle was created. +- `git_diff.patch`: Uncommitted source/config changes needed beyond the base git commit. +- `source_files/`: Copies of newly added untracked YAML files and launcher/preparation scripts, kept under their original relative paths. +- `hashes.txt`: SHA256 hashes for important inputs, outputs, and generated manifest files. +- `dataset_manifest.json`: Dataset/protocol counts and file-list hashes. +- `results_summary.csv`: Per-evaluation accuracy summary. + +## Caveat + +The config fixes the seed, and this record fixes code, data split, weights, and dependency versions. Exact bit-for-bit replay on GPU is still not guaranteed unless the same driver, CUDA, PyTorch build, GPU, and deterministic kernel settings are also preserved. Metric-level reproduction should be possible from these records. diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/dataset_manifest.json b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/dataset_manifest.json new file mode 100644 index 0000000000000000000000000000000000000000..f2f13f82f3dacf2f518330f5be1da3744df9f820 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/dataset_manifest.json @@ -0,0 +1,19 @@ +{ + "max_pairs_argument": 3000, + "pairs_csv": "/root/Adaface/cvlface/data/agedb_protocol/facerec_val/agedb_30_1to1/pairs.csv", + "pairs_csv_sha256": "e648a7af3534c41acd5504d80a5e99fc380d403bf21f2f317164cc1325e710ac", + "prepare_command": "reproducibility/prepare_data_command.sh", + "protocol_root": "/root/Adaface/cvlface/data/agedb_protocol", + "seed": 2048, + "source_dir": "/root/Adaface/cvlface/AgeDB_aligned_224", + "source_file_list_sha256": "3e3caf3383374015366d22cfd9734d55606febf328f91b4e8d83e55f739e27cd", + "source_jpg_count": 16488, + "train_dir": "/root/Adaface/cvlface/data/agedb_protocol/agedb_train_80", + "train_identity_count": 567, + "train_image_count": 13200, + "train_ratio": 0.8, + "train_symlink_target_list_sha256": "9cc26f47459efd8b3300416bf29e0de5d5b28f70b205ff3ab1a0a1a26d182638", + "verification_negative_pairs": 1735, + "verification_pair_count": 3470, + "verification_positive_pairs": 1735 +} diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/environment.txt b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/environment.txt new file mode 100644 index 0000000000000000000000000000000000000000..164b50ec920337243f8041d4be914c0c23e764f3 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/environment.txt @@ -0,0 +1,41 @@ +python: 3.12.13 | packaged by conda-forge | (main, Mar 5 2026, 16:50:00) [GCC 14.3.0] +executable: /venv/main/bin/python +platform: Linux-6.8.0-83-generic-x86_64-with-glibc2.39 + +nvidia-smi: +Wed Apr 15 13:47:09 2026 ++-----------------------------------------------------------------------------------------+ +| NVIDIA-SMI 580.82.09 Driver Version: 580.82.09 CUDA Version: 13.0 | ++-----------------------------------------+------------------------+----------------------+ +| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | +| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | +| | | MIG M. | +|=========================================+========================+======================| +| 0 NVIDIA GeForce RTX 5070 On | 00000000:C1:00.0 Off | N/A | +| 0% 30C P8 2W / 250W | 2MiB / 12227MiB | 0% Default | +| | | N/A | ++-----------------------------------------+------------------------+----------------------+ + ++-----------------------------------------------------------------------------------------+ +| Processes: | +| GPU GI CI PID Type Process name GPU Memory | +| ID ID Usage | +|=========================================================================================| +| No running processes found | ++-----------------------------------------------------------------------------------------+ + +torch: 2.11.0+cu130 +torch_cuda: 13.0 +cuda_available: True +cuda_device_count: 1 +cuda_device_0: NVIDIA GeForce RTX 5070 +torchvision: 0.26.0+cu130 +torchaudio: 2.11.0+cu130 +timm: 1.0.26 +lightning: 2.6.1 +hydra: 1.3.2 +numpy: 2.4.3 +pandas: 3.0.2 +sklearn: 1.8.0 +cv2: 4.13.0 +wandb: 0.26.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/git_diff.patch b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/git_diff.patch new file mode 100644 index 0000000000000000000000000000000000000000..8b4e849828502305b63063994f17f6a875294be3 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/git_diff.patch @@ -0,0 +1,527 @@ +diff --git a/cvlface/research/recognition/code/run_v1/data_augs/__init__.py b/cvlface/research/recognition/code/run_v1/data_augs/__init__.py +index 444925e..8ce6ad0 100644 +--- a/cvlface/research/recognition/code/run_v1/data_augs/__init__.py ++++ b/cvlface/research/recognition/code/run_v1/data_augs/__init__.py +@@ -1,5 +1,4 @@ + from .basic_augmenter import BasicAugmenter +-from .gridsample_augmenter import GridSampleAugmenter + + def make_augmenter(augmentation_version, aug_params): + if augmentation_version == 'basic': +@@ -8,7 +7,8 @@ def make_augmenter(augmentation_version, aug_params): + low_res_augmentation_prob=aug_params.low_res_augmentation_prob, + ) + elif augmentation_version == 'gridsample': ++ from .gridsample_augmenter import GridSampleAugmenter + augmenter = GridSampleAugmenter(aug_params, input_size=112) + else: + raise ValueError('not correct augmentation version') +- return augmenter +\ No newline at end of file ++ return augmenter +diff --git a/cvlface/research/recognition/code/run_v1/dataset/__init__.py b/cvlface/research/recognition/code/run_v1/dataset/__init__.py +index 50524c1..423183e 100644 +--- a/cvlface/research/recognition/code/run_v1/dataset/__init__.py ++++ b/cvlface/research/recognition/code/run_v1/dataset/__init__.py +@@ -7,12 +7,6 @@ import cv2 + from PIL import Image + import random + +-from .base_dataset import SyntheticDataset, MXFaceDataset +-from .augment_dataset import AugmentMXDataset +-from .repeated_dataset import RepeatedSamplingMXDataset +-from .repeated_dataset_with_ldmk_theta import RepeatedWithLdmkThetaMXDataset +-from .subset_dataset import SubsetDataset +- + def get_train_dataset(dataset_cfg, train_transform, aug_cfg, local_rank=0): + + # batch_size = cfg.trainers.batch_size +@@ -26,29 +20,34 @@ def get_train_dataset(dataset_cfg, train_transform, aug_cfg, local_rank=0): + + # Synthetic + if dataset_cfg.rec == "synthetic": ++ from .base_dataset import SyntheticDataset + train_set = SyntheticDataset(dataset_cfg.num_classes, dataset_cfg.num_image) + label_mapping = None + + # Mxnet RecordIO + elif os.path.exists(rec) and os.path.exists(idx): + if aug_cfg.augmentation_version == 'none': ++ from .base_dataset import MXFaceDataset + assert dataset_cfg.repeated_sampling_cfg is None + train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank) + else: + if dataset_cfg.repeated_sampling_cfg is not None: + if dataset_cfg.repeated_sampling_cfg.ldmk_path: ++ from .repeated_dataset_with_ldmk_theta import RepeatedWithLdmkThetaMXDataset + # repeated sampling + augmentation + ldmk + train_set = RepeatedWithLdmkThetaMXDataset(root_dir=root_dir, local_rank=local_rank, + augmentation_version=aug_cfg.augmentation_version, + aug_params=aug_cfg.aug_params, + repeated_sampling_cfg=dataset_cfg.repeated_sampling_cfg) + else: ++ from .repeated_dataset import RepeatedSamplingMXDataset + # repeated sampling + augmentation + train_set = RepeatedSamplingMXDataset(root_dir=root_dir, local_rank=local_rank, + augmentation_version=aug_cfg.augmentation_version, + aug_params=aug_cfg.aug_params, + repeated_sampling_cfg=dataset_cfg.repeated_sampling_cfg) + else: ++ from .augment_dataset import AugmentMXDataset + # augmentation + train_set = AugmentMXDataset(root_dir=root_dir, local_rank=local_rank, + augmentation_version=aug_cfg.augmentation_version, +@@ -58,6 +57,7 @@ def get_train_dataset(dataset_cfg, train_transform, aug_cfg, local_rank=0): + + # resample dataset if needed + if hasattr(dataset_cfg, 'resample_dataset') and dataset_cfg.resample_dataset: ++ from .subset_dataset import SubsetDataset + + if dataset_cfg.resample_dataset == 'one_half': + removing_index = list(set(range(0, len(train_set))) - set(range(0, len(train_set), 2))) +diff --git a/cvlface/research/recognition/code/run_v1/dataset/base_dataset.py b/cvlface/research/recognition/code/run_v1/dataset/base_dataset.py +index 218fe93..8e8d74d 100644 +--- a/cvlface/research/recognition/code/run_v1/dataset/base_dataset.py ++++ b/cvlface/research/recognition/code/run_v1/dataset/base_dataset.py +@@ -1,6 +1,5 @@ + import numbers + import os +-import mxnet as mx + import numpy as np + import torch + from torch.utils.data import Dataset +@@ -10,7 +9,18 @@ import atexit + import pandas as pd + from tqdm import tqdm + ++mx = None ++ ++ ++def get_mxnet(): ++ global mx ++ if mx is None: ++ import mxnet as mxnet ++ mx = mxnet ++ return mx ++ + def iterate_record(imgidx, record): ++ mx = get_mxnet() + # make one yourself + record_info = [] + for idx in tqdm(imgidx, total=len(imgidx), desc='Iterating Dataset for extracting info (done only once)'): +@@ -29,6 +39,7 @@ def iterate_record(imgidx, record): + class MXFaceDataset(Dataset): + def __init__(self, root_dir, local_rank): + super(MXFaceDataset, self).__init__() ++ mx = get_mxnet() + self.to_PIL = transforms.ToPILImage() + self.root_dir = root_dir + self.local_rank = local_rank +@@ -103,4 +114,3 @@ class SyntheticDataset(Dataset): + + def __len__(self): + return self.num_sample +- +diff --git a/cvlface/research/recognition/code/run_v1/evaluations/__init__.py b/cvlface/research/recognition/code/run_v1/evaluations/__init__.py +index 390a5e0..220dcfc 100644 +--- a/cvlface/research/recognition/code/run_v1/evaluations/__init__.py ++++ b/cvlface/research/recognition/code/run_v1/evaluations/__init__.py +@@ -2,8 +2,6 @@ import os + import torch + + from .verification_evaluator import VerificationEvaluator +-from .ijbbc_evaluator import IJBBCEvaluator +-from .tinyface_evaluator import TinyFaceEvaluator + + def get_evaluator_by_name(eval_type, name, eval_data_path, transform, fabric, batch_size, num_workers): + +@@ -13,8 +11,10 @@ def get_evaluator_by_name(eval_type, name, eval_data_path, transform, fabric, ba + if eval_type == 'verification': + return VerificationEvaluator(name, eval_data_path, transform, fabric, batch_size, num_workers) + elif eval_type == 'ijbbc': ++ from .ijbbc_evaluator import IJBBCEvaluator + return IJBBCEvaluator(name, eval_data_path, transform, fabric, batch_size, num_workers) + elif eval_type == 'tinyface': ++ from .tinyface_evaluator import TinyFaceEvaluator + return TinyFaceEvaluator(name, eval_data_path, transform, fabric, batch_size, num_workers) + else: + raise ValueError('Unknown evaluation type: %s' % eval_type) +@@ -74,4 +74,4 @@ class IsBestTracker(): + + + def is_best(self): +- return self._is_best +\ No newline at end of file ++ return self._is_best +diff --git a/cvlface/research/recognition/code/run_v1/evaluations/verification_evaluator.py b/cvlface/research/recognition/code/run_v1/evaluations/verification_evaluator.py +index 99c39ca..bd3485c 100644 +--- a/cvlface/research/recognition/code/run_v1/evaluations/verification_evaluator.py ++++ b/cvlface/research/recognition/code/run_v1/evaluations/verification_evaluator.py +@@ -1,4 +1,3 @@ +-from datasets import Dataset + import torch + from functools import partial + from .base_evaluator import BaseEvaluator +@@ -6,6 +5,10 @@ from .verifications.verification import evaluate + import sklearn + from tqdm import tqdm + import numpy as np ++import os ++import pandas as pd ++from PIL import Image ++from torch.utils.data import Dataset as TorchDataset + + def preprocess_transform(examples, image_transforms): + images = [image.convert("RGB") for image in examples['image']] +@@ -33,9 +36,14 @@ class VerificationEvaluator(BaseEvaluator): + super().__init__(name, fabric, batch_size) + self.name = name + self.data_path = data_path +- dataset = Dataset.load_from_disk(data_path) +- preprocess = partial(preprocess_transform, image_transforms=transform) +- dataset = dataset.with_transform(preprocess) ++ pairs_csv = os.path.join(data_path, 'pairs.csv') ++ if os.path.isfile(pairs_csv): ++ dataset = LocalVerificationPairDataset(pairs_csv, transform) ++ else: ++ from datasets import Dataset ++ dataset = Dataset.load_from_disk(data_path) ++ preprocess = partial(preprocess_transform, image_transforms=transform) ++ dataset = dataset.with_transform(preprocess) + self.dataloader = fabric.setup_dataloader_from_dataset(dataset, + is_train=False, + batch_size=batch_size, +@@ -108,3 +116,23 @@ class VerificationEvaluator(BaseEvaluator): + result = {'acc': acc, 'std': std} + return result + ++ ++class LocalVerificationPairDataset(TorchDataset): ++ def __init__(self, pairs_csv, transform): ++ self.rows = pd.read_csv(pairs_csv) ++ self.transform = transform ++ ++ def __len__(self): ++ return len(self.rows) ++ ++ def __getitem__(self, index): ++ row = self.rows.iloc[index] ++ image = Image.open(row['path']).convert('RGB') ++ is_same = row['is_same'] ++ if isinstance(is_same, str): ++ is_same = is_same.lower() == 'true' ++ return { ++ 'pixel_values': self.transform(image), ++ 'index': int(row['index']), ++ 'is_same': bool(is_same), ++ } +diff --git a/cvlface/research/recognition/code/run_v1/evaluations/verifications/verification.py b/cvlface/research/recognition/code/run_v1/evaluations/verifications/verification.py +index 1548984..ad1d640 100644 +--- a/cvlface/research/recognition/code/run_v1/evaluations/verifications/verification.py ++++ b/cvlface/research/recognition/code/run_v1/evaluations/verifications/verification.py +@@ -1,16 +1,27 @@ + import datetime + import pickle + +-import mxnet as mx + import numpy as np + import sklearn + import torch +-from mxnet import ndarray as nd + from scipy import interpolate + from sklearn.decomposition import PCA + from sklearn.model_selection import KFold + import matplotlib.pyplot as plt + ++mx = None ++nd = None ++ ++ ++def get_mxnet(): ++ global mx, nd ++ if mx is None: ++ import mxnet as mxnet ++ from mxnet import ndarray as ndarray ++ mx = mxnet ++ nd = ndarray ++ return mx, nd ++ + class LFold: + def __init__(self, n_splits=2, shuffle=False): + self.n_splits = n_splits +@@ -197,16 +208,21 @@ def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0): + nrof_folds=nrof_folds, + pca=pca) + thresholds = np.arange(0, 4, 0.001) +- val, val_std, far = calculate_val(thresholds, +- embeddings1, +- embeddings2, +- np.asarray(actual_issame), +- 1e-3, +- nrof_folds=nrof_folds) ++ try: ++ val, val_std, far = calculate_val(thresholds, ++ embeddings1, ++ embeddings2, ++ np.asarray(actual_issame), ++ 1e-3, ++ nrof_folds=nrof_folds) ++ except ValueError as exc: ++ print(f'calculate_val failed: {exc}') ++ val, val_std, far = 0.0, 0.0, 0.0 + return tpr, fpr, accuracy, val, val_std, far + + @torch.no_grad() + def load_bin(path, image_size): ++ mx, nd = get_mxnet() + try: + with open(path, 'rb') as f: + bins, issame_list = pickle.load(f) # py2 +@@ -281,5 +297,3 @@ def test(data_set, backbone, batch_size, nfolds=10): + _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds) + acc2, std2 = np.mean(accuracy), np.std(accuracy) + return acc1, std1, acc2, std2, _xnorm, embeddings_list +- +- +diff --git a/cvlface/research/recognition/code/run_v1/models/base/__init__.py b/cvlface/research/recognition/code/run_v1/models/base/__init__.py +index 151ffa7..0fb6853 100644 +--- a/cvlface/research/recognition/code/run_v1/models/base/__init__.py ++++ b/cvlface/research/recognition/code/run_v1/models/base/__init__.py +@@ -100,11 +100,26 @@ class BaseModel(torch.nn.Module): + if 'net.vit' in list(self.state_dict().keys())[-1] and 'pretrained_models' in pretrained_model_path: + state_dict = {k.replace('net', 'net.vit'): v for k, v in state_dict.items()} + ++ current_state_dict = self.state_dict() ++ if not any(key in current_state_dict for key in state_dict) and any(key.startswith('net.') for key in current_state_dict): ++ state_dict = {f'net.{key}': value for key, value in state_dict.items()} ++ filtered_state_dict = {} ++ skipped_shape = [] ++ for key, value in state_dict.items(): ++ if key not in current_state_dict: ++ continue ++ if current_state_dict[key].shape != value.shape: ++ skipped_shape.append((key, tuple(value.shape), tuple(current_state_dict[key].shape))) ++ continue ++ filtered_state_dict[key] = value ++ + st_keys = list(state_dict.keys()) +- self_keys = list(self.state_dict().keys()) +- print('compatible keys in state_dict', len(set(st_keys).intersection(set(self_keys))), '/', len(st_keys)) ++ self_keys = list(current_state_dict.keys()) ++ print('compatible keys in state_dict', len(filtered_state_dict), '/', len(st_keys)) ++ if skipped_shape: ++ print('skipped shape-mismatched keys', skipped_shape[:10], 'total', len(skipped_shape)) + print('Check\n\n') +- result = self.load_state_dict(state_dict, strict=False) ++ result = self.load_state_dict(filtered_state_dict, strict=False) + print(result) + print(f"Loaded pretrained model from {pretrained_model_path}") + +diff --git a/cvlface/research/recognition/code/run_v1/models/swin/__init__.py b/cvlface/research/recognition/code/run_v1/models/swin/__init__.py +index 83e1c0c..ecf6f90 100644 +--- a/cvlface/research/recognition/code/run_v1/models/swin/__init__.py ++++ b/cvlface/research/recognition/code/run_v1/models/swin/__init__.py +@@ -1,6 +1,6 @@ + from ..base import BaseModel + from torchvision import transforms +-from .swin.names import swin_v2_b, swin_v2_s ++from .swin.names import swin_s, swin_v2_b, swin_v2_s + + class SWINModel(BaseModel): + +@@ -35,7 +35,9 @@ class SWINModel(BaseModel): + @classmethod + def from_config(cls, config): + +- if config.name == 'small': ++ if config.name == 'swin_s': ++ net = swin_s() ++ elif config.name == 'small': + net = swin_v2_s() + elif config.name == 'base': + net = swin_v2_b() +@@ -53,6 +55,7 @@ class SWINModel(BaseModel): + + def make_train_transform(self): + transform = transforms.Compose([ ++ transforms.Resize((112, 112)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) +@@ -60,11 +63,30 @@ class SWINModel(BaseModel): + + def make_test_transform(self): + transform = transforms.Compose([ ++ transforms.Resize((112, 112)), + transforms.ToTensor(), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), + ]) + return transform + ++ ++class TimmSWINModel(SWINModel): ++ @classmethod ++ def from_config(cls, config): ++ import timm ++ net = timm.create_model( ++ config.timm_name, ++ pretrained=False, ++ img_size=config.input_size[1], ++ num_classes=config.output_dim, ++ ) ++ model = cls(net, config) ++ model.eval() ++ return model ++ + def load_model(model_config): +- model = SWINModel.from_config(model_config) +- return model +\ No newline at end of file ++ if model_config.name == 'timm_swin_s': ++ model = TimmSWINModel.from_config(model_config) ++ else: ++ model = SWINModel.from_config(model_config) ++ return model +diff --git a/cvlface/research/recognition/code/run_v1/optims/lr_scheduler.py b/cvlface/research/recognition/code/run_v1/optims/lr_scheduler.py +index faf13c3..f9eb4a8 100644 +--- a/cvlface/research/recognition/code/run_v1/optims/lr_scheduler.py ++++ b/cvlface/research/recognition/code/run_v1/optims/lr_scheduler.py +@@ -1,4 +1,4 @@ +-from torch.optim.lr_scheduler import _LRScheduler ++from torch.optim.lr_scheduler import _LRScheduler, ReduceLROnPlateau + import torch + from timm.scheduler.cosine_lr import CosineLRScheduler + import numpy as np +@@ -66,6 +66,17 @@ def make_scheduler(cfg, opt): + warmup_steps=cfg.trainers.warmup_step, + lr_milestones=step_milestones, + lr_lambda=cfg.optims.lr_lambda) ++ elif cfg.optims.scheduler == 'plateau': ++ print('plateau scheduler') ++ lr_scheduler = ReduceLROnPlateau( ++ opt, ++ mode='max', ++ factor=cfg.optims.plateau_factor, ++ patience=cfg.optims.plateau_patience, ++ threshold=cfg.optims.plateau_threshold, ++ cooldown=cfg.optims.plateau_cooldown, ++ min_lr=cfg.optims.min_lr, ++ ) + + else: + raise ValueError('') +@@ -74,11 +85,17 @@ def make_scheduler(cfg, opt): + + + def scheduler_step(scheduler, global_step): ++ if isinstance(scheduler, ReduceLROnPlateau): ++ return + if isinstance(scheduler, _LRScheduler): + scheduler.step() + else: + scheduler.step(global_step) + ++def scheduler_step_on_metric(scheduler, metric): ++ if isinstance(scheduler, ReduceLROnPlateau): ++ scheduler.step(metric) ++ + def get_last_lr(optimizer): + lrs = [group['lr'] for group in optimizer.param_groups] + return float(np.mean(lrs)) +@@ -193,4 +210,4 @@ if __name__ == '__main__': + + from matplotlib import pyplot as plt + plt.plot(lrs) +- plt.show() +\ No newline at end of file ++ plt.show() +diff --git a/cvlface/research/recognition/code/run_v1/pefts/__init__.py b/cvlface/research/recognition/code/run_v1/pefts/__init__.py +index ccd5954..cb68267 100644 +--- a/cvlface/research/recognition/code/run_v1/pefts/__init__.py ++++ b/cvlface/research/recognition/code/run_v1/pefts/__init__.py +@@ -1,4 +1,3 @@ +-from peft import LoraConfig, LoraModel + import torch + import os + +@@ -25,6 +24,7 @@ def apply_peft_to_model(peft_config, model): + + + if peft_config.name == 'lora': ++ from peft import LoraConfig, LoraModel + target_modules_mapping = { + 'att_qkv': ['qkv'], + 'att_qkv_keypoint_linear': ['qkv', 'keypoint_linear'], +@@ -172,4 +172,4 @@ def print_trainable_parameters(model): + print( + f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}" + ) +- return trainable_names, untrainable_names +\ No newline at end of file ++ return trainable_names, untrainable_names +diff --git a/cvlface/research/recognition/code/run_v1/train.py b/cvlface/research/recognition/code/run_v1/train.py +index e3c77f5..cf388d3 100644 +--- a/cvlface/research/recognition/code/run_v1/train.py ++++ b/cvlface/research/recognition/code/run_v1/train.py +@@ -11,7 +11,7 @@ sys.path.append(os.path.join(root)) + import numpy as np + np.bool = np.bool_ # fix bug for mxnet 1.9.1 + np.object = np.object_ +-np.float = np.float_ ++np.float = np.float64 + + import pandas as pd + import torch +@@ -27,7 +27,7 @@ from general_utils import random_utils, os_utils + from optims.optims import make_optimizer + from lightning.fabric.loggers import CSVLogger + from lightning.pytorch.loggers import WandbLogger +-from optims.lr_scheduler import make_scheduler, scheduler_step, get_last_lr ++from optims.lr_scheduler import make_scheduler, scheduler_step, scheduler_step_on_metric, get_last_lr + from pipelines import pipeline_from_config, pipeline_from_name + import omegaconf + import lovely_tensors as lt +@@ -178,7 +178,17 @@ if __name__ == '__main__': + epoch = train_pipeline.start_epoch + for epoch in range(train_pipeline.start_epoch, cfg.optims.num_epoch): + epoch_start_time = time.time() ++ freeze_backbone_epochs = cfg.trainers.get('freeze_backbone_epochs', 0) ++ freeze_backbone = epoch < freeze_backbone_epochs ++ for param in model.parameters(): ++ param.requires_grad = not freeze_backbone ++ if freeze_backbone: ++ print(f'Backbone frozen for epoch {epoch} / {freeze_backbone_epochs}') ++ elif freeze_backbone_epochs > 0 and epoch == freeze_backbone_epochs: ++ print(f'Backbone unfrozen at epoch {epoch}') + train_pipeline.train() ++ if freeze_backbone: ++ model.eval() + set_epoch(dataloader, epoch, cfg) + batch_length = len(dataloader) if cfg.trainers.limit_num_batch <= 0 else cfg.trainers.limit_num_batch + pbar = tqdm(total=batch_length, disable=fabric.local_rank != 0) +@@ -193,8 +203,8 @@ if __name__ == '__main__': + loss = 0 + else: + is_accumulating = batch_idx % cfg.trainers.gradient_acc != 0 +- with fabric.no_backward_sync(model if model.has_trainable_params() else dummy_model, +- enabled=is_accumulating): ++ sync_module = model if 'dummy_model' not in locals() else dummy_model ++ with fabric.no_backward_sync(sync_module, enabled=is_accumulating): + with fabric.autocast(): + loss = train_pipeline(batch) + fabric.backward(loss) +@@ -257,6 +267,7 @@ if __name__ == '__main__': + else: + mean = -1.0 + is_best_tracker.set_is_best(mean) ++ scheduler_step_on_metric(lr_scheduler, mean) + if fabric.local_rank == 0: + fabric.log_dict({'is_best': float(is_best_tracker.is_best())}) + print(f'Epoch {epoch} | Step {step} | Best {is_best_tracker.is_best()}') +@@ -318,4 +329,4 @@ if __name__ == '__main__': + print('Skip final evaluation') + + # close +- print('done') +\ No newline at end of file ++ print('done') diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/git_status.txt b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/git_status.txt new file mode 100644 index 0000000000000000000000000000000000000000..8ab8a9ff5de26b2f8cf3d17c5b530deceb9aafb8 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/git_status.txt @@ -0,0 +1,23 @@ + M cvlface/research/recognition/code/run_v1/data_augs/__init__.py + M cvlface/research/recognition/code/run_v1/dataset/__init__.py + M cvlface/research/recognition/code/run_v1/dataset/base_dataset.py + M cvlface/research/recognition/code/run_v1/evaluations/__init__.py + M cvlface/research/recognition/code/run_v1/evaluations/verification_evaluator.py + M cvlface/research/recognition/code/run_v1/evaluations/verifications/verification.py + M cvlface/research/recognition/code/run_v1/models/base/__init__.py + M cvlface/research/recognition/code/run_v1/models/swin/__init__.py + M cvlface/research/recognition/code/run_v1/optims/lr_scheduler.py + M cvlface/research/recognition/code/run_v1/pefts/__init__.py + M cvlface/research/recognition/code/run_v1/train.py +?? AGENTS.md +?? AgeDB_aligned_224/ +?? cvlface/pretrained_models/model.safetensors +?? cvlface/research/recognition/code/run_v1/dataset/configs/agedb_80.yaml +?? cvlface/research/recognition/code/run_v1/evaluations/configs/agedb_30_1to1.yaml +?? cvlface/research/recognition/code/run_v1/evaluations/configs/final.yaml +?? cvlface/research/recognition/code/run_v1/models/swin/configs/v1_swin_s_pretrained.yaml +?? cvlface/research/recognition/code/run_v1/optims/configs/plateau_adamw_agedb.yaml +?? cvlface/research/recognition/code/run_v1/scripts/examples/run_swin_s_agedb.sh +?? cvlface/research/recognition/code/run_v1/scripts/prepare_agedb_protocol.py +?? cvlface/research/recognition/code/run_v1/trainers/configs/agedb_swin.yaml +?? data/ diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/hashes.txt b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/hashes.txt new file mode 100644 index 0000000000000000000000000000000000000000..2a36f718a0ba22468101b6e04669c46b9f396eaf --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/hashes.txt @@ -0,0 +1,24 @@ +e17adf843764761f138cb8cab081d9a37824db4a948c0082c7af7f0092d19581 /root/Adaface/cvlface/cvlface/pretrained_models/model.safetensors +ce85f68b7e0859d9362b660667f4a59b41f168a083e39fa947064d9308795817 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/config.yaml +81393539c31632c62904a38816d405755d4a5b0d03e3aa90638d22c4104c7111 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/checkpoints/best/model.pt +e06b747fa8a881bf4cd328ca089c7881e7d2a5b3393d9d6da71377ccd90d8d13 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/checkpoints/best/classifier_rank0.pt +3cbcebec6d5a45fd31d4fb017a9a53bc5c3d5b5fd535be35917c53cfbe44a6a2 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/checkpoints/epoch:49_step:10300/model.pt +b8e59f7522af0f1d337574ab84d58b4d1c9fcf396d52587f943053f234da1b6d /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_best.csv +c9120df1d32f6a5d5f6fc576dda9190a50cc7056f88e11e88aae51d536a4beec /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/README.md +867ea1ccb24cb9bf2c8573f19be57e949cda74593ceef352826e241ac2b98754 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/dataset_manifest.json +655a5807bed6f7083d7bb4a1fd8f280dd75222d38159648f134b3c836ead77c3 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/results_summary.csv +cdd0b368e77ad060e199ec44aed0346bf2c2b8eff5bb1f4140126422229dacb1 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/run_command.sh +f0a375fcfa80cadeee3ff5c2fd76e8590054e955c7eba21a823eac6012d20f36 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/prepare_data_command.sh +90fad6aff646c7bb16ca1a33666af4994305fb4f44185c371b3696bc7b8ebd22 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/environment.txt +cbb23194e9a8817b600862a0d9dc4a508ef10b7f606e5f25cbc0a484c7328b17 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/pip_freeze.txt +7bf0186dd79efadeb33fd7feadf2a2bb29e848befa93f07ca9521944831213ed /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/git_status.txt +202efb6c2bae63df91e34c07cefe08cbd000897608da19819cc7fea2a826697b /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/git_diff.patch +c5d121d779d61eba42d438d3b702eeaa868ce9226bd1793f7697b0113eb2c816 /root/Adaface/cvlface/logs/agedb_swin_s_train.log +aeea31390b540bdc23b2e7620438f21c76d6b2a23216bdb2512ceb282b369c79 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/dataset/configs/agedb_80.yaml +3ffce4c70add62708e14e20a96f525f8ac640350de28e5111e621ed95216ca9b /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/evaluations/configs/agedb_30_1to1.yaml +eff08c789cb83219053ba966513b34e2fd97afa9c6723de0d25c36b3133818e7 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/evaluations/configs/final.yaml +0ba2f2d96a801198216e3decb9fcb227082c04bd0b21b28bcc70da757149815b /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_swin_s_pretrained.yaml +183463ec74609e1b5d0dd8c88d8177f680cb473f3c5bb2768703a8af3347b459 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/optims/configs/plateau_adamw_agedb.yaml +215875cfd2a2ec930125cd68c5b005850b7457986819da0e8f71b84538fc4730 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/scripts/examples/run_swin_s_agedb.sh +bfcd1a2c4483230ea4fd962b282801c08dc809724837505eeffc71d4acefbae8 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/scripts/prepare_agedb_protocol.py +e14f225ea703f57a99a4d3d97ef438881b29afee517fac04960693ae1358c7b5 /root/Adaface/cvlface/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/trainers/configs/agedb_swin.yaml diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/pip_freeze.txt b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/pip_freeze.txt new file mode 100644 index 0000000000000000000000000000000000000000..da7dc7dae35e7a61611664067303c2c2bfeb25f6 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/pip_freeze.txt @@ -0,0 +1,132 @@ +aiohappyeyeballs==2.6.1 +aiohttp==3.13.5 +aiosignal==1.4.0 +annotated-doc==0.0.4 +annotated-types==0.7.0 +antlr4-python3-runtime==4.9.3 +anyio==4.13.0 +asttokens==3.0.1 +attrs==26.1.0 +certifi==2026.2.25 +charset-normalizer==3.4.7 +click==8.3.1 +comm==0.2.3 +contourpy==1.3.3 +cuda-bindings==13.2.0 +cuda-pathfinder==1.5.0 +cuda-toolkit==13.0.2 +cycler==0.12.1 +debugpy==1.8.20 +decorator==5.2.1 +executing==2.2.1 +fastcore==1.12.39 +filelock==3.25.2 +fonttools==4.62.1 +frozenlist==1.8.0 +fsspec==2026.2.0 +gitdb==4.0.12 +GitPython==3.1.46 +h11==0.16.0 +hf-xet==1.4.2 +httpcore==1.0.9 +httpx==0.28.1 +huggingface_hub==1.8.0 +hydra-core==1.3.2 +idna==3.11 +ipykernel==7.2.0 +ipython==9.11.0 +ipython_pygments_lexers==1.1.1 +ipywidgets==8.1.8 +jedi==0.19.2 +Jinja2==3.1.6 +joblib==1.5.3 +jupyter_client==8.8.0 +jupyter_core==5.9.1 +jupyterlab_widgets==3.0.16 +kiwisolver==1.5.0 +lightning==2.6.1 +lightning-utilities==0.15.3 +lovely-numpy==0.2.22 +lovely-tensors==0.1.22 +markdown-it-py==4.0.0 +MarkupSafe==3.0.3 +matplotlib==3.10.8 +matplotlib-inline==0.2.1 +mdurl==0.1.2 +mpmath==1.3.0 +multidict==6.7.1 +nest-asyncio==1.6.0 +networkx==3.6.1 +numpy==2.4.3 +nvidia-cublas==13.1.0.3 +nvidia-cuda-cupti==13.0.85 +nvidia-cuda-nvrtc==13.0.88 +nvidia-cuda-runtime==13.0.96 +nvidia-cudnn-cu13==9.19.0.56 +nvidia-cufft==12.0.0.61 +nvidia-cufile==1.15.1.6 +nvidia-curand==10.4.0.35 +nvidia-cusolver==12.0.4.66 +nvidia-cusparse==12.6.3.3 +nvidia-cusparselt-cu13==0.8.0 +nvidia-nccl-cu13==2.28.9 +nvidia-nvjitlink==13.0.88 +nvidia-nvshmem-cu13==3.4.5 +nvidia-nvtx==13.0.85 +omegaconf==2.3.0 +opencv-python-headless==4.13.0.92 +packaging @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_packaging_1769093650/work +pandas==3.0.2 +parso==0.8.6 +pexpect==4.9.0 +pillow==12.1.1 +platformdirs==4.9.4 +prompt_toolkit==3.0.52 +propcache==0.4.1 +protobuf==7.34.1 +psutil==7.2.2 +ptyprocess==0.7.0 +pure_eval==0.2.3 +pydantic==2.13.0 +pydantic_core==2.46.0 +Pygments==2.19.2 +pyparsing==3.3.2 +pyrootutils==1.0.4 +python-dateutil==2.9.0.post0 +python-dotenv==1.2.2 +pytorch-lightning==2.6.1 +PyYAML==6.0.3 +pyzmq==27.1.0 +requests==2.33.1 +rich==14.3.3 +safetensors==0.7.0 +scikit-learn==1.8.0 +scipy==1.17.1 +sentry-sdk==2.58.0 +setuptools==81.0.0 +shellingham==1.5.4 +six==1.17.0 +smmap==5.0.3 +stack-data==0.6.3 +sympy==1.14.0 +tabulate==0.10.0 +threadpoolctl==3.6.0 +timm==1.0.26 +torch==2.11.0+cu130 +torchaudio==2.11.0+cu130 +torchcodec==0.11.0+cu130 +torchmetrics==1.9.0 +torchvision==0.26.0+cu130 +tornado==6.5.5 +tqdm==4.67.3 +traitlets==5.14.3 +triton==3.6.0 +typer==0.24.1 +typing-inspection==0.4.2 +typing_extensions==4.15.0 +urllib3==2.6.3 +wandb==0.26.0 +wcwidth==0.6.0 +wheel==0.46.3 +widgetsnbextension==4.0.15 +yarl==1.23.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/prepare_data_command.sh b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/prepare_data_command.sh new file mode 100644 index 0000000000000000000000000000000000000000..2a6113ab5daddf3e5a10feb7c7bf323d1ed729ff --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/prepare_data_command.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd /root/Adaface/cvlface + +python cvlface/research/recognition/code/run_v1/scripts/prepare_agedb_protocol.py \ + --source /root/Adaface/cvlface/AgeDB_aligned_224 \ + --output /root/Adaface/cvlface/data/agedb_protocol \ + --train-ratio 0.8 \ + --seed 2048 \ + --max-pairs 3000 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/results_summary.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/results_summary.csv new file mode 100644 index 0000000000000000000000000000000000000000..82ee7d46b6c53022d4acf50d98f7a270490618c6 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/results_summary.csv @@ -0,0 +1,12 @@ +file,epoch,step,agedb_30_acc +eval_summary_0_206.csv,0,206,29.394812680115 +eval_summary_5_1236.csv,5,1236,29.394812680115 +eval_summary_10_2266.csv,10,2266,57.348703170029 +eval_summary_15_3296.csv,15,3296,69.106628242075 +eval_summary_20_4326.csv,20,4326,76.340057636888 +eval_summary_25_5356.csv,25,5356,75.216138328530 +eval_summary_30_6386.csv,30,6386,75.677233429395 +eval_summary_35_7416.csv,35,7416,76.109510086455 +eval_summary_40_8446.csv,40,8446,75.619596541787 +eval_summary_45_9476.csv,45,9476,76.051873198847 +eval_summary_49_10300.csv,49,10300,76.109510086455 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/run_command.sh b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/run_command.sh new file mode 100644 index 0000000000000000000000000000000000000000..c6b3c7624494d5b2ba293a42b1057566b20ecbdd --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/run_command.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +set -euo pipefail + +cd /root/Adaface/cvlface + +export AGEDB_PROTOCOL_ROOT=/root/Adaface/cvlface/data/agedb_protocol +export SWIN_S_PRETRAINED=/root/Adaface/cvlface/cvlface/pretrained_models/model.safetensors + +cd cvlface/research/recognition/code/run_v1 + +python train.py \ + trainers=configs/agedb_swin.yaml \ + dataset=configs/agedb_80.yaml \ + data_augs=configs/basic_v1.yaml \ + models=swin/configs/v1_swin_s_pretrained.yaml \ + pipelines=configs/train_model_cls.yaml \ + evaluations=configs/agedb_30_1to1.yaml \ + classifiers=configs/fc.yaml \ + optims=configs/plateau_adamw_agedb.yaml \ + losses=configs/adaface.yaml \ + aligners=configs/none.yaml \ + pefts=configs/none.yaml \ + 2>&1 | tee /root/Adaface/cvlface/logs/agedb_swin_s_train.log diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/dataset/configs/agedb_80.yaml b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/dataset/configs/agedb_80.yaml new file mode 100644 index 0000000000000000000000000000000000000000..430eed6c4072e606f189d767a63680931894b0dd --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/dataset/configs/agedb_80.yaml @@ -0,0 +1,8 @@ +data_root: ${oc.env:AGEDB_PROTOCOL_ROOT} +rec: 'agedb_train_80' +color_space: 'RGB' +num_classes: 567 +num_image: 13200 + +repeated_sampling_cfg: null +semi_sampling_cfg: null diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/evaluations/configs/agedb_30_1to1.yaml b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/evaluations/configs/agedb_30_1to1.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e79f49472808567ccf4a771e5f74db62fb304825 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/evaluations/configs/agedb_30_1to1.yaml @@ -0,0 +1,11 @@ +data_root: ${oc.env:AGEDB_PROTOCOL_ROOT} +eval_every_n_epochs: 5 +per_epoch_evaluations: { + "agedb_30": { + 'path': 'facerec_val/agedb_30_1to1', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, +} diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/evaluations/configs/final.yaml b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/evaluations/configs/final.yaml new file mode 100644 index 0000000000000000000000000000000000000000..88b415d9a600a4f4b899125d48eea0e360858e4e --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/evaluations/configs/final.yaml @@ -0,0 +1,11 @@ +data_root: ${oc.env:AGEDB_PROTOCOL_ROOT} +eval_every_n_epochs: 1 +per_epoch_evaluations: { + "agedb_30": { + 'path': 'facerec_val/agedb_30_1to1', + 'evaluation_type': 'verification', + 'color_space': 'RGB', + 'batch_size': 32, + 'num_workers': 4 + }, +} diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_swin_s_pretrained.yaml b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_swin_s_pretrained.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c653b9ffdaa80c9b09b9781ba4acdecf436ae737 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/models/swin/configs/v1_swin_s_pretrained.yaml @@ -0,0 +1,8 @@ +input_size: [3, 112, 112] +color_space: 'RGB' +name: 'timm_swin_s' +timm_name: 'swin_small_patch4_window7_224.ms_in22k_ft_in1k' +output_dim: 512 +start_from: ${oc.env:SWIN_S_PRETRAINED,/root/Adaface/cvlface/cvlface/pretrained_models/model.safetensors} +freeze: False +mask_ratio: 0.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/optims/configs/plateau_adamw_agedb.yaml b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/optims/configs/plateau_adamw_agedb.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f153c08739ce9f89630b7304e81a42d2a2b6ac3b --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/optims/configs/plateau_adamw_agedb.yaml @@ -0,0 +1,15 @@ +num_epoch: 50 +optimizer: 'adamw' +lr: 0.001 +momentum: 0.9 +weight_decay: 0.05 +scheduler: 'plateau' +filter_bias_and_bn: true +warmup_epoch: 0 +max_grad_norm: 5.0 +lr_milestones: [] +plateau_factor: 0.5 +plateau_patience: 2 +plateau_threshold: 0.01 +plateau_cooldown: 0 +min_lr: 0.000001 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/scripts/examples/run_swin_s_agedb.sh b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/scripts/examples/run_swin_s_agedb.sh new file mode 100644 index 0000000000000000000000000000000000000000..20a133301924065d8796e822e63001394e3de2d8 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/scripts/examples/run_swin_s_agedb.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash +set -euo pipefail + +REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/../../../../../../../" && pwd)" +RUN_DIR="$REPO_ROOT/cvlface/research/recognition/code/run_v1" + +export AGEDB_PROTOCOL_ROOT="$REPO_ROOT/data/agedb_protocol" +export SWIN_S_PRETRAINED="${SWIN_S_PRETRAINED:-$REPO_ROOT/cvlface/pretrained_models/model.safetensors}" + +cd "$RUN_DIR" + +python scripts/prepare_agedb_protocol.py \ + --source "$REPO_ROOT/AgeDB_aligned_224" \ + --output "$AGEDB_PROTOCOL_ROOT" \ + --train-ratio 0.8 \ + --seed 2048 \ + --max-pairs 3000 + +python train.py \ + trainers=configs/agedb_swin.yaml \ + dataset=configs/agedb_80.yaml \ + data_augs=configs/basic_v1.yaml \ + models=swin/configs/v1_swin_s_pretrained.yaml \ + pipelines=configs/train_model_cls.yaml \ + evaluations=configs/agedb_30_1to1.yaml \ + classifiers=configs/fc.yaml \ + optims=configs/plateau_adamw_agedb.yaml \ + losses=configs/adaface.yaml \ + aligners=configs/none.yaml \ + pefts=configs/none.yaml diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/scripts/prepare_agedb_protocol.py b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/scripts/prepare_agedb_protocol.py new file mode 100644 index 0000000000000000000000000000000000000000..c162c8aa13234cdd669d39b325da64c3344358f1 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/scripts/prepare_agedb_protocol.py @@ -0,0 +1,139 @@ +import argparse +import csv +import os +import random +import shutil +from collections import defaultdict +from pathlib import Path + + +def parse_agedb_name(path): + left, age, gender = path.stem.rsplit('_', 2) + _, identity = left.split('_', 1) + return identity, int(age), gender + + +def link_or_copy(src, dst, copy_files): + dst.parent.mkdir(parents=True, exist_ok=True) + if dst.exists() or dst.is_symlink(): + return + if copy_files: + shutil.copy2(src, dst) + else: + os.symlink(os.path.relpath(src, dst.parent), dst) + + +def build_positive_pairs(test_rows, max_pairs): + by_identity = defaultdict(list) + for row in test_rows: + by_identity[row['identity']].append(row) + + candidates = [] + for rows in by_identity.values(): + rows = sorted(rows, key=lambda x: (x['age'], x['path'].name)) + for i, left in enumerate(rows): + for right in rows[i + 1:]: + age_gap = abs(left['age'] - right['age']) + if age_gap >= 30: + candidates.append((age_gap, left, right)) + candidates.sort(key=lambda x: (-x[0], x[1]['path'].name, x[2]['path'].name)) + return [(left, right) for _, left, right in candidates[:max_pairs]] + + +def build_negative_pairs(test_rows, count, seed): + rng = random.Random(seed) + rows = sorted(test_rows, key=lambda x: x['path'].name) + by_identity = defaultdict(list) + for row in rows: + by_identity[row['identity']].append(row) + + identities = sorted(by_identity) + pairs = [] + used = set() + attempts = 0 + while len(pairs) < count and attempts < count * 100: + attempts += 1 + left_id, right_id = rng.sample(identities, 2) + left = rng.choice(by_identity[left_id]) + right = rng.choice(by_identity[right_id]) + key = tuple(sorted([left['path'].name, right['path'].name])) + if key in used: + continue + used.add(key) + pairs.append((left, right)) + if len(pairs) < count: + raise RuntimeError(f'Only built {len(pairs)} negative pairs, expected {count}') + return pairs + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--source', default='AgeDB_aligned_224') + parser.add_argument('--output', default='data/agedb_protocol') + parser.add_argument('--train-ratio', type=float, default=0.8) + parser.add_argument('--seed', type=int, default=2048) + parser.add_argument('--max-pairs', type=int, default=3000) + parser.add_argument('--copy-files', action='store_true') + args = parser.parse_args() + + source = Path(args.source).resolve() + output = Path(args.output).resolve() + train_dir = output / 'agedb_train_80' + val_dir = output / 'facerec_val' / 'agedb_30_1to1' + val_dir.mkdir(parents=True, exist_ok=True) + + by_identity = defaultdict(list) + for path in sorted(source.glob('*.jpg')): + identity, age, gender = parse_agedb_name(path) + by_identity[identity].append({'path': path, 'identity': identity, 'age': age, 'gender': gender}) + + rng = random.Random(args.seed) + train_rows = [] + test_rows = [] + for identity, rows in sorted(by_identity.items()): + rows = sorted(rows, key=lambda x: x['path'].name) + indices = list(range(len(rows))) + rng.shuffle(indices) + if len(rows) > 1: + n_test = max(1, int(round(len(rows) * (1 - args.train_ratio)))) + else: + n_test = 0 + test_indices = set(indices[:n_test]) + for idx, row in enumerate(rows): + if idx in test_indices: + test_rows.append(row) + else: + train_rows.append(row) + + for row in train_rows: + dst = train_dir / row['identity'] / row['path'].name + link_or_copy(row['path'], dst, args.copy_files) + + positive_pairs = build_positive_pairs(test_rows, args.max_pairs) + if not positive_pairs: + raise RuntimeError('No AgeDB-30 positive pairs found in the test split') + negative_pairs = build_negative_pairs(test_rows, len(positive_pairs), args.seed) + + pair_rows = [] + pair_index = 0 + for is_same, pairs in [(True, positive_pairs), (False, negative_pairs)]: + for left, right in pairs: + pair_rows.append({'path': str(left['path']), 'index': pair_index * 2, 'is_same': is_same}) + pair_rows.append({'path': str(right['path']), 'index': pair_index * 2 + 1, 'is_same': is_same}) + pair_index += 1 + + with open(val_dir / 'pairs.csv', 'w', newline='') as f: + writer = csv.DictWriter(f, fieldnames=['path', 'index', 'is_same']) + writer.writeheader() + writer.writerows(pair_rows) + + print(f'identities: {len(by_identity)}') + print(f'train images: {len(train_rows)}') + print(f'test images: {len(test_rows)}') + print(f'verification pairs: {len(pair_rows) // 2} ({len(positive_pairs)} positive, {len(negative_pairs)} negative)') + print(f'train_dir: {train_dir}') + print(f'val_dir: {val_dir}') + + +if __name__ == '__main__': + main() diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/trainers/configs/agedb_swin.yaml b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/trainers/configs/agedb_swin.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b7ce0b1d7b8090ef45b9c00ac45eaa2f8d2799ee --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/reproducibility/source_files/cvlface/research/recognition/code/run_v1/trainers/configs/agedb_swin.yaml @@ -0,0 +1,31 @@ +task: '' +prefix: 'agedb_swin_s' +debug: False +mock_lr_run: False +limit_num_batch: -1 +skip_final_eval: False + +batch_size: 64 +num_gpu: 1 +precision: '32-true' +float32_matmul_precision: high + +gradient_acc: 1 +seed: 2048 +num_workers: 4 +freeze_backbone_epochs: 10 + +wandb_key: '' +suffix_run_name: None +using_wandb: False +wandb_entity: '' +wandb_log_all: False + +resume: '' + +local_rank: -1 +world_size: -1 +total_batch_size: -1 +warmup_step: -1 +total_step: -1 +output_dir: '' diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_0_206.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_0_206.csv new file mode 100644 index 0000000000000000000000000000000000000000..d0ee1537cc375d06288d2025eddffc2f8a7c8fcc --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_0_206.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,29.394812680115272 +agedb_30/std,4.318922402151985 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_10_2266.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_10_2266.csv new file mode 100644 index 0000000000000000000000000000000000000000..bfeb8acffe5a8a499bbdd388d0bb0cb594151e29 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_10_2266.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,57.34870317002882 +agedb_30/std,6.954744958615956 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_15_3296.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_15_3296.csv new file mode 100644 index 0000000000000000000000000000000000000000..13a049a9d6b4ea0c61a6ae1cb631319c1f0852bb --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_15_3296.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,69.10662824207493 +agedb_30/std,5.857734177049752 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_20_4326.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_20_4326.csv new file mode 100644 index 0000000000000000000000000000000000000000..487ce7885cb150db26a0fb922bcc206cf3b73728 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_20_4326.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,76.34005763688761 +agedb_30/std,8.719615357523367 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_25_5356.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_25_5356.csv new file mode 100644 index 0000000000000000000000000000000000000000..f9c2217bd59a298ebb533e8229c279530d059d2e --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_25_5356.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,75.21613832853026 +agedb_30/std,8.110228979521779 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_30_6386.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_30_6386.csv new file mode 100644 index 0000000000000000000000000000000000000000..84564339a8823327a2450146d237335a99cfbc13 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_30_6386.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,75.67723342939483 +agedb_30/std,7.622682672965119 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_35_7416.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_35_7416.csv new file mode 100644 index 0000000000000000000000000000000000000000..7acbc6f308fd61824d1f62ed70d0d09a8161ff46 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_35_7416.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,76.10951008645533 +agedb_30/std,8.067568800688592 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_40_8446.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_40_8446.csv new file mode 100644 index 0000000000000000000000000000000000000000..a853a3b8429f07f749f0969743d1b8aa77f7b772 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_40_8446.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,75.61959654178676 +agedb_30/std,7.391467366699655 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_45_9476.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_45_9476.csv new file mode 100644 index 0000000000000000000000000000000000000000..659e654b2d57d56b393d8add8ef1029b444b3dc3 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_45_9476.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,76.05187319884726 +agedb_30/std,8.615172435302341 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_49_10300.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_49_10300.csv new file mode 100644 index 0000000000000000000000000000000000000000..361447d0aabb988e89d5bb83b044698b58339ee8 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_49_10300.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,76.10951008645533 +agedb_30/std,7.208586799222195 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_5_1236.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_5_1236.csv new file mode 100644 index 0000000000000000000000000000000000000000..d0ee1537cc375d06288d2025eddffc2f8a7c8fcc --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_5_1236.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,29.394812680115272 +agedb_30/std,4.318922402151985 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_best.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_best.csv new file mode 100644 index 0000000000000000000000000000000000000000..487ce7885cb150db26a0fb922bcc206cf3b73728 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_best.csv @@ -0,0 +1,3 @@ +,val +agedb_30/acc,76.34005763688761 +agedb_30/std,8.719615357523367 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_0_206.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_0_206.csv new file mode 100644 index 0000000000000000000000000000000000000000..8f9979ec777c9a3b7f93f11420cd61b80e277404 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_0_206.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,29.394812680115272 +epoch,0.0 +step,206.0 +n_images_seen,13184.0 +trainer/global_step,206.0 +trainer/epoch,0.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_10_2266.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_10_2266.csv new file mode 100644 index 0000000000000000000000000000000000000000..684d36e1e2e5fe32b41bfa45599a0e22e8972c9f --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_10_2266.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,57.34870317002882 +epoch,10.0 +step,2266.0 +n_images_seen,145024.0 +trainer/global_step,2266.0 +trainer/epoch,10.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_15_3296.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_15_3296.csv new file mode 100644 index 0000000000000000000000000000000000000000..fb88791c76c8c8e05bceacd782c20ac4c3370b65 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_15_3296.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,69.10662824207493 +epoch,15.0 +step,3296.0 +n_images_seen,210944.0 +trainer/global_step,3296.0 +trainer/epoch,15.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_20_4326.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_20_4326.csv new file mode 100644 index 0000000000000000000000000000000000000000..66ef292e99ec2c37c72b07ab0b220359e512c51d --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_20_4326.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,76.34005763688761 +epoch,20.0 +step,4326.0 +n_images_seen,276864.0 +trainer/global_step,4326.0 +trainer/epoch,20.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_25_5356.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_25_5356.csv new file mode 100644 index 0000000000000000000000000000000000000000..e2a7ba5c351f63acbaa62b0d29e6dcca8c8a2c1b --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_25_5356.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,75.21613832853026 +epoch,25.0 +step,5356.0 +n_images_seen,342784.0 +trainer/global_step,5356.0 +trainer/epoch,25.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_30_6386.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_30_6386.csv new file mode 100644 index 0000000000000000000000000000000000000000..b53261f6fef284ce4b06fd68a34e37b2f155e981 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_30_6386.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,75.67723342939483 +epoch,30.0 +step,6386.0 +n_images_seen,408704.0 +trainer/global_step,6386.0 +trainer/epoch,30.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_35_7416.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_35_7416.csv new file mode 100644 index 0000000000000000000000000000000000000000..94af338c5c878da95fbf3be313c1de4332040667 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_35_7416.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,76.10951008645533 +epoch,35.0 +step,7416.0 +n_images_seen,474624.0 +trainer/global_step,7416.0 +trainer/epoch,35.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_40_8446.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_40_8446.csv new file mode 100644 index 0000000000000000000000000000000000000000..c6f98f04dd8c7dd1751fd2ea2a5a0a90eefed68a --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_40_8446.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,75.61959654178676 +epoch,40.0 +step,8446.0 +n_images_seen,540544.0 +trainer/global_step,8446.0 +trainer/epoch,40.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_45_9476.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_45_9476.csv new file mode 100644 index 0000000000000000000000000000000000000000..24f6788bff9afa9ee84fafe414e3a9da730c95ba --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_45_9476.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,76.05187319884726 +epoch,45.0 +step,9476.0 +n_images_seen,606464.0 +trainer/global_step,9476.0 +trainer/epoch,45.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_49_10300.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_49_10300.csv new file mode 100644 index 0000000000000000000000000000000000000000..32f9bd9a744523d0f84d26442b3806e4306aa48e --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_49_10300.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,76.10951008645533 +epoch,49.0 +step,10300.0 +n_images_seen,659200.0 +trainer/global_step,10300.0 +trainer/epoch,49.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_5_1236.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_5_1236.csv new file mode 100644 index 0000000000000000000000000000000000000000..c34a85be0ce99078eb69d36f3b158f9497ad9005 --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_5_1236.csv @@ -0,0 +1,7 @@ +,val +summary/agedb_30_acc,29.394812680115272 +epoch,5.0 +step,1236.0 +n_images_seen,79104.0 +trainer/global_step,1236.0 +trainer/epoch,5.0 diff --git a/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_best.csv b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_best.csv new file mode 100644 index 0000000000000000000000000000000000000000..8ce0f0356c85eeb08a72eaaf44b19945f410621e --- /dev/null +++ b/cvlface/research/recognition/experiments/run_v1/agedb_swin_s_04-15_7/result/eval_summary_best.csv @@ -0,0 +1,7 @@ +,val +final/agedb_30_acc,76.34 +epoch,50.0 +step,10301.0 +n_images_seen,659201.0 +trainer/global_step,10301.0 +trainer/epoch,50.0