| | |
| | |
| | |
| |
|
| | import torch |
| | from torch.nn.parallel import DataParallel, DistributedDataParallel |
| |
|
| | import numpy as np |
| | import time |
| | import torch |
| | from torch.nn.parallel import DataParallel, DistributedDataParallel |
| | import copy |
| | import random |
| | import torch.nn.functional as F |
| | from detectron2.structures.instances import Instances |
| | from detectron2.structures import BitMasks |
| |
|
| | from detectron2.engine import SimpleTrainer |
| |
|
| | __all__ = ["CustomSimpleTrainer", "CustomAMPTrainer"] |
| |
|
| | class CustomSimpleTrainer(SimpleTrainer): |
| | """ |
| | A simple trainer for the most common type of task: |
| | single-cost single-optimizer single-data-source iterative optimization, |
| | optionally using data-parallelism. |
| | It assumes that every step, you: |
| | |
| | 1. Compute the loss with a data from the data_loader. |
| | 2. Compute the gradients with the above loss. |
| | 3. Update the model with the optimizer. |
| | |
| | All other tasks during training (checkpointing, logging, evaluation, LR schedule) |
| | are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`. |
| | |
| | If you want to do anything fancier than this, |
| | either subclass TrainerBase and implement your own `run_step`, |
| | or write your own training loop. |
| | """ |
| |
|
| | def __init__(self, model, data_loader, optimizer, cfg=None, use_copy_paste=False, |
| | copy_paste_rate=-1, copy_paste_random_num=None, copy_paste_min_ratio=-1, |
| | copy_paste_max_ratio=-1, visualize_copy_paste=False): |
| | """ |
| | Args: |
| | model: a torch Module. Takes a data from data_loader and returns a |
| | dict of losses. |
| | data_loader: an iterable. Contains data to be used to call model. |
| | optimizer: a torch optimizer. |
| | """ |
| | super().__init__(model, data_loader, optimizer) |
| |
|
| | """ |
| | We set the model to training mode in the trainer. |
| | However it's valid to train a model that's in eval mode. |
| | If you want your model (or a submodule of it) to behave |
| | like evaluation during training, you can overwrite its train() method. |
| | """ |
| | self.cfg = cfg |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.use_copy_paste = use_copy_paste if self.cfg is None else self.cfg.DATALOADER.COPY_PASTE |
| | self.cfg_COPY_PASTE_RATE = copy_paste_rate if self.cfg is None else self.cfg.DATALOADER.COPY_PASTE_RATE |
| | self.cfg_COPY_PASTE_RANDOM_NUM = copy_paste_random_num if self.cfg is None else self.cfg.DATALOADER.COPY_PASTE_RANDOM_NUM |
| | self.cfg_COPY_PASTE_MIN_RATIO = copy_paste_min_ratio if self.cfg is None else self.cfg.DATALOADER.COPY_PASTE_MIN_RATIO |
| | self.cfg_COPY_PASTE_MAX_RATIO = copy_paste_max_ratio if self.cfg is None else self.cfg.DATALOADER.COPY_PASTE_MAX_RATIO |
| | self.cfg_VISUALIZE_COPY_PASTE = visualize_copy_paste if self.cfg is None else self.cfg.DATALOADER.VISUALIZE_COPY_PASTE |
| |
|
| | def IoU(self, mask1, mask2): |
| | mask1, mask2 = (mask1>0.5).to(torch.bool), (mask2>0.5).to(torch.bool) |
| | intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze() |
| | union = torch.sum(mask1 + mask2, dim=[-1, -2]).squeeze() |
| | return (intersection.to(torch.float) / union).mean().view(1, -1) |
| |
|
| | def IoY(self, mask1, mask2): |
| | |
| | mask1, mask2 = mask1.squeeze(), mask2.squeeze() |
| | mask1, mask2 = (mask1>0.5).to(torch.bool), (mask2>0.5).to(torch.bool) |
| | intersection = torch.sum(mask1 * (mask1 == mask2), dim=[-1, -2]).squeeze() |
| | union = torch.sum(mask2, dim=[-1, -2]).squeeze() |
| | return (intersection.to(torch.float) / union).mean().view(1, -1) |
| |
|
| | def copy_and_paste(self, labeled_data, unlabeled_data): |
| | new_unlabeled_data = [] |
| | def mask_iou_matrix(x, y, mode='iou'): |
| | x = x.reshape(x.shape[0], -1).float() |
| | y = y.reshape(y.shape[0], -1).float() |
| | inter_matrix = x @ y.transpose(1, 0) |
| | sum_x = x.sum(1)[:, None].expand(x.shape[0], y.shape[0]) |
| | sum_y = y.sum(1)[None, :].expand(x.shape[0], y.shape[0]) |
| | if mode == 'ioy': |
| | iou_matrix = inter_matrix / (sum_y) |
| | else: |
| | iou_matrix = inter_matrix / (sum_x + sum_y - inter_matrix) |
| | return iou_matrix |
| |
|
| | def visualize_data(data, save_path = './sample.jpg'): |
| | from data import detection_utils as utils |
| | from detectron2.data import DatasetCatalog, MetadataCatalog |
| | from detectron2.utils.visualizer import Visualizer |
| | data["instances"] = data["instances"].to(device='cpu') |
| | img = data["image"].permute(1, 2, 0).cpu().detach().numpy() |
| | img = utils.convert_image_to_rgb(img, 'RGB') |
| | metadata = MetadataCatalog.get('imagenet_train_tau0.15') |
| | visualizer = Visualizer(img, metadata=metadata, scale=1.0) |
| | target_fields = data["instances"].get_fields() |
| | labels = [metadata.thing_classes[i] for i in target_fields["gt_classes"]] |
| | vis = visualizer.overlay_instances( |
| | labels=labels, |
| | boxes=target_fields.get("gt_boxes"), |
| | masks=target_fields.get("gt_masks"), |
| | keypoints=target_fields.get("gt_keypoints", None), |
| | ) |
| | print("Saving to {} ...".format(save_path)) |
| | vis.save(save_path) |
| |
|
| | for cur_labeled_data, cur_unlabeled_data in zip(labeled_data, unlabeled_data): |
| | cur_labeled_instances = cur_labeled_data["instances"] |
| | cur_labeled_image = cur_labeled_data["image"] |
| | cur_unlabeled_instances = cur_unlabeled_data["instances"] |
| | cur_unlabeled_image = cur_unlabeled_data["image"] |
| |
|
| | num_labeled_instances = len(cur_labeled_instances) |
| | copy_paste_rate = random.random() |
| | |
| | if self.cfg_COPY_PASTE_RATE >= copy_paste_rate and num_labeled_instances > 0: |
| | if self.cfg_COPY_PASTE_RANDOM_NUM: |
| | num_copy = 1 if num_labeled_instances == 1 else np.random.randint(1, max(1, num_labeled_instances)) |
| | else: |
| | num_copy = num_labeled_instances |
| | else: |
| | num_copy = 0 |
| | if num_labeled_instances == 0 or num_copy == 0: |
| | new_unlabeled_data.append(cur_unlabeled_data) |
| | else: |
| | |
| | choice = np.random.choice(num_labeled_instances, num_copy, replace=False) |
| | copied_instances = cur_labeled_instances[choice].to(device=cur_unlabeled_instances.gt_boxes.device) |
| | copied_masks = copied_instances.gt_masks |
| | copied_boxes = copied_instances.gt_boxes |
| | _, labeled_h, labeled_w = cur_labeled_image.shape |
| | _, unlabeled_h, unlabeled_w = cur_unlabeled_image.shape |
| |
|
| | |
| | if isinstance(copied_masks, torch.Tensor): |
| | masks_new = copied_masks[None, ...].float() |
| | else: |
| | masks_new = copied_masks.tensor[None, ...].float() |
| | |
| | resize_ratio = random.uniform(self.cfg_COPY_PASTE_MIN_RATIO, self.cfg_COPY_PASTE_MAX_RATIO) |
| | w_new = int(resize_ratio * unlabeled_w) |
| | h_new = int(resize_ratio * unlabeled_h) |
| |
|
| | w_shift = random.randint(0, unlabeled_w - w_new) |
| | h_shift = random.randint(0, unlabeled_h - h_new) |
| |
|
| | cur_labeled_image_new = F.interpolate(cur_labeled_image[None, ...].float(), size=(h_new, w_new), mode="bilinear", align_corners=False).byte().squeeze(0) |
| | if isinstance(copied_masks, torch.Tensor): |
| | masks_new = F.interpolate(copied_masks[None, ...].float(), size=(h_new, w_new), mode="bilinear", align_corners=False).bool().squeeze(0) |
| | else: |
| | masks_new = F.interpolate(copied_masks.tensor[None, ...].float(), size=(h_new, w_new), mode="bilinear", align_corners=False).bool().squeeze(0) |
| | copied_boxes.scale(1. * unlabeled_w / labeled_w * resize_ratio, 1. * unlabeled_h / labeled_h * resize_ratio) |
| |
|
| | if isinstance(cur_unlabeled_instances.gt_masks, torch.Tensor): |
| | _, mask_w, mask_h = cur_unlabeled_instances.gt_masks.size() |
| | else: |
| | _, mask_w, mask_h = cur_unlabeled_instances.gt_masks.tensor.size() |
| | masks_new_all = torch.zeros(num_copy, mask_w, mask_h) |
| | image_new_all = torch.zeros_like(cur_unlabeled_image) |
| |
|
| | image_new_all[:, h_shift:h_shift+h_new, w_shift:w_shift+w_new] += cur_labeled_image_new |
| | masks_new_all[:, h_shift:h_shift+h_new, w_shift:w_shift+w_new] += masks_new |
| |
|
| | cur_labeled_image = image_new_all.byte() |
| | if isinstance(copied_masks, torch.Tensor): |
| | copied_masks = masks_new_all.bool() |
| | else: |
| | copied_masks.tensor = masks_new_all.bool() |
| | copied_boxes.tensor[:, 0] += h_shift |
| | copied_boxes.tensor[:, 2] += h_shift |
| | copied_boxes.tensor[:, 1] += w_shift |
| | copied_boxes.tensor[:, 3] += w_shift |
| |
|
| | copied_instances.gt_masks = copied_masks |
| | copied_instances.gt_boxes = copied_boxes |
| | copied_instances._image_size = (unlabeled_h, unlabeled_w) |
| | if len(cur_unlabeled_instances) == 0: |
| | if isinstance(copied_instances.gt_masks, torch.Tensor): |
| | alpha = copied_instances.gt_masks.sum(0) > 0 |
| | else: |
| | alpha = copied_instances.gt_masks.tensor.sum(0) > 0 |
| | |
| | alpha = alpha.cpu() |
| | composited_image = (alpha * cur_labeled_image) + (~alpha * cur_unlabeled_image) |
| | cur_unlabeled_data["image"] = composited_image |
| | cur_unlabeled_data["instances"] = copied_instances |
| | else: |
| | |
| | if isinstance(copied_masks, torch.Tensor): |
| | iou_matrix = mask_iou_matrix(copied_masks, cur_unlabeled_instances.gt_masks, mode='ioy') |
| | else: |
| | iou_matrix = mask_iou_matrix(copied_masks.tensor, cur_unlabeled_instances.gt_masks.tensor, mode='ioy') |
| |
|
| | keep = iou_matrix.max(1)[0] < 0.5 |
| | if keep.sum() == 0: |
| | new_unlabeled_data.append(cur_unlabeled_data) |
| | continue |
| | copied_instances = copied_instances[keep] |
| | |
| | if isinstance(copied_instances.gt_masks, torch.Tensor): |
| | alpha = copied_instances.gt_masks.sum(0) > 0 |
| | cur_unlabeled_instances.gt_masks = ~alpha * cur_unlabeled_instances.gt_masks |
| | areas_unlabeled = cur_unlabeled_instances.gt_masks.sum((1,2)) |
| | else: |
| | alpha = copied_instances.gt_masks.tensor.sum(0) > 0 |
| | cur_unlabeled_instances.gt_masks.tensor = ~alpha * cur_unlabeled_instances.gt_masks.tensor |
| | areas_unlabeled = cur_unlabeled_instances.gt_masks.tensor.sum((1,2)) |
| | |
| | alpha = alpha.cpu() |
| | composited_image = (alpha * cur_labeled_image) + (~alpha * cur_unlabeled_image) |
| | |
| | merged_instances = Instances.cat([cur_unlabeled_instances[areas_unlabeled > 0], copied_instances]) |
| | |
| | if isinstance(merged_instances.gt_masks, torch.Tensor): |
| | merged_instances.gt_boxes = BitMasks(merged_instances.gt_masks).get_bounding_boxes() |
| | |
| | else: |
| | merged_instances.gt_boxes = merged_instances.gt_masks.get_bounding_boxes() |
| |
|
| | cur_unlabeled_data["image"] = composited_image |
| | cur_unlabeled_data["instances"] = merged_instances |
| | if self.cfg_VISUALIZE_COPY_PASTE: |
| | visualize_data(cur_unlabeled_data, save_path = 'sample_{}.jpg'.format(np.random.randint(5))) |
| | new_unlabeled_data.append(cur_unlabeled_data) |
| | return new_unlabeled_data |
| |
|
| | def run_step(self): |
| | """ |
| | Implement the standard training logic described above. |
| | """ |
| | assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" |
| | start = time.perf_counter() |
| | """ |
| | If you want to do something with the data, you can wrap the dataloader. |
| | """ |
| | data = next(self._data_loader_iter) |
| | |
| | if self.use_copy_paste: |
| | |
| | data = self.copy_and_paste(copy.deepcopy(data[::-1]), data) |
| | data_time = time.perf_counter() - start |
| |
|
| | """ |
| | If you want to do something with the losses, you can wrap the model. |
| | """ |
| | loss_dict = self.model(data) |
| | if isinstance(loss_dict, torch.Tensor): |
| | losses = loss_dict |
| | loss_dict = {"total_loss": loss_dict} |
| | else: |
| | losses = sum(loss_dict.values()) |
| |
|
| | """ |
| | If you need to accumulate gradients or do something similar, you can |
| | wrap the optimizer with your custom `zero_grad()` method. |
| | """ |
| | if not torch.isnan(losses): |
| | self.optimizer.zero_grad() |
| | losses.backward() |
| | else: |
| | print('Nan loss. Skipped.') |
| |
|
| | self._write_metrics(loss_dict, data_time) |
| |
|
| | """ |
| | If you need gradient clipping/scaling or other processing, you can |
| | wrap the optimizer with your custom `step()` method. But it is |
| | suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4 |
| | """ |
| | self.optimizer.step() |
| |
|
| |
|
| | class CustomAMPTrainer(CustomSimpleTrainer): |
| | """ |
| | Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision |
| | in the training loop. |
| | """ |
| |
|
| | def __init__(self, model, data_loader, optimizer, cfg=None, grad_scaler=None, use_copy_paste=False, |
| | copy_paste_rate=-1, copy_paste_random_num=None, copy_paste_min_ratio=-1, |
| | copy_paste_max_ratio=-1, visualize_copy_paste=False): |
| | """ |
| | Args: |
| | model, data_loader, optimizer: same as in :class:`SimpleTrainer`. |
| | grad_scaler: torch GradScaler to automatically scale gradients. |
| | """ |
| | unsupported = "AMPTrainer does not support single-process multi-device training!" |
| | if isinstance(model, DistributedDataParallel): |
| | assert not (model.device_ids and len(model.device_ids) > 1), unsupported |
| | assert not isinstance(model, DataParallel), unsupported |
| |
|
| | super().__init__(model, data_loader, optimizer, cfg=cfg, use_copy_paste=use_copy_paste, \ |
| | copy_paste_rate=copy_paste_rate, copy_paste_random_num=copy_paste_random_num, \ |
| | copy_paste_min_ratio=copy_paste_min_ratio, copy_paste_max_ratio=copy_paste_max_ratio, \ |
| | visualize_copy_paste=visualize_copy_paste) |
| |
|
| | if grad_scaler is None: |
| | from torch.cuda.amp import GradScaler |
| |
|
| | grad_scaler = GradScaler() |
| | self.grad_scaler = grad_scaler |
| |
|
| | def run_step(self): |
| | """ |
| | Implement the AMP training logic. |
| | """ |
| | assert self.model.training, "[AMPTrainer] model was changed to eval mode!" |
| | assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" |
| | from torch.cuda.amp import autocast |
| |
|
| | start = time.perf_counter() |
| | data = next(self._data_loader_iter) |
| | if self.use_copy_paste: |
| | |
| | data = self.copy_and_paste(copy.deepcopy(data[::-1]), data) |
| | data_time = time.perf_counter() - start |
| |
|
| | with autocast(): |
| | loss_dict = self.model(data) |
| | if isinstance(loss_dict, torch.Tensor): |
| | losses = loss_dict |
| | loss_dict = {"total_loss": loss_dict} |
| | else: |
| | losses = sum(loss_dict.values()) |
| |
|
| | if not torch.isnan(losses): |
| | self.optimizer.zero_grad() |
| | self.grad_scaler.scale(losses).backward() |
| | else: |
| | print('Nan loss.') |
| |
|
| | self._write_metrics(loss_dict, data_time) |
| |
|
| | self.grad_scaler.step(self.optimizer) |
| | self.grad_scaler.update() |
| |
|
| | def state_dict(self): |
| | ret = super().state_dict() |
| | ret["grad_scaler"] = self.grad_scaler.state_dict() |
| | return ret |
| |
|
| | def load_state_dict(self, state_dict): |
| | super().load_state_dict(state_dict) |
| | self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) |
| |
|