Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import random | |
| from collections import defaultdict | |
| from copy import deepcopy | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from isegm.utils.distributed import (get_dp_wrapper, get_sampler, | |
| reduce_loss_dict) | |
| from isegm.utils.log import SummaryWriterAvg, TqdmToLogger, logger | |
| from isegm.utils.misc import save_checkpoint | |
| from isegm.utils.serialization import get_config_repr | |
| from isegm.utils.vis import draw_points, draw_probmap | |
| from .optimizer import get_optimizer | |
| class ISTrainer(object): | |
| def __init__( | |
| self, | |
| model, | |
| cfg, | |
| model_cfg, | |
| loss_cfg, | |
| trainset, | |
| valset, | |
| optimizer="adam", | |
| optimizer_params=None, | |
| image_dump_interval=200, | |
| checkpoint_interval=10, | |
| tb_dump_period=25, | |
| max_interactive_points=0, | |
| lr_scheduler=None, | |
| metrics=None, | |
| additional_val_metrics=None, | |
| net_inputs=("images", "points"), | |
| max_num_next_clicks=0, | |
| click_models=None, | |
| prev_mask_drop_prob=0.0, | |
| ): | |
| self.cfg = cfg | |
| self.model_cfg = model_cfg | |
| self.max_interactive_points = max_interactive_points | |
| self.loss_cfg = loss_cfg | |
| self.val_loss_cfg = deepcopy(loss_cfg) | |
| self.tb_dump_period = tb_dump_period | |
| self.net_inputs = net_inputs | |
| self.max_num_next_clicks = max_num_next_clicks | |
| self.click_models = click_models | |
| self.prev_mask_drop_prob = prev_mask_drop_prob | |
| if cfg.distributed: | |
| cfg.batch_size //= cfg.ngpus | |
| cfg.val_batch_size //= cfg.ngpus | |
| if metrics is None: | |
| metrics = [] | |
| self.train_metrics = metrics | |
| self.val_metrics = deepcopy(metrics) | |
| if additional_val_metrics is not None: | |
| self.val_metrics.extend(additional_val_metrics) | |
| self.checkpoint_interval = checkpoint_interval | |
| self.image_dump_interval = image_dump_interval | |
| self.task_prefix = "" | |
| self.sw = None | |
| self.trainset = trainset | |
| self.valset = valset | |
| logger.info( | |
| f"Dataset of {trainset.get_samples_number()} samples was loaded for training." | |
| ) | |
| logger.info( | |
| f"Dataset of {valset.get_samples_number()} samples was loaded for validation." | |
| ) | |
| self.train_data = DataLoader( | |
| trainset, | |
| cfg.batch_size, | |
| sampler=get_sampler(trainset, shuffle=True, distributed=cfg.distributed), | |
| drop_last=True, | |
| pin_memory=True, | |
| num_workers=cfg.workers, | |
| ) | |
| self.val_data = DataLoader( | |
| valset, | |
| cfg.val_batch_size, | |
| sampler=get_sampler(valset, shuffle=False, distributed=cfg.distributed), | |
| drop_last=True, | |
| pin_memory=True, | |
| num_workers=cfg.workers, | |
| ) | |
| self.optim = get_optimizer(model, optimizer, optimizer_params) | |
| model = self._load_weights(model) | |
| if cfg.multi_gpu: | |
| model = get_dp_wrapper(cfg.distributed)( | |
| model, device_ids=cfg.gpu_ids, output_device=cfg.gpu_ids[0] | |
| ) | |
| if self.is_master: | |
| logger.info(model) | |
| logger.info(get_config_repr(model._config)) | |
| self.device = cfg.device | |
| self.net = model.to(self.device) | |
| self.lr = optimizer_params["lr"] | |
| if lr_scheduler is not None: | |
| self.lr_scheduler = lr_scheduler(optimizer=self.optim) | |
| if cfg.start_epoch > 0: | |
| for _ in range(cfg.start_epoch): | |
| self.lr_scheduler.step() | |
| self.tqdm_out = TqdmToLogger(logger, level=logging.INFO) | |
| if self.click_models is not None: | |
| for click_model in self.click_models: | |
| for param in click_model.parameters(): | |
| param.requires_grad = False | |
| click_model.to(self.device) | |
| click_model.eval() | |
| def run(self, num_epochs, start_epoch=None, validation=True): | |
| if start_epoch is None: | |
| start_epoch = self.cfg.start_epoch | |
| logger.info(f"Starting Epoch: {start_epoch}") | |
| logger.info(f"Total Epochs: {num_epochs}") | |
| for epoch in range(start_epoch, num_epochs): | |
| self.training(epoch) | |
| if validation: | |
| self.validation(epoch) | |
| def training(self, epoch): | |
| if self.sw is None and self.is_master: | |
| self.sw = SummaryWriterAvg( | |
| log_dir=str(self.cfg.LOGS_PATH), | |
| flush_secs=10, | |
| dump_period=self.tb_dump_period, | |
| ) | |
| if self.cfg.distributed: | |
| self.train_data.sampler.set_epoch(epoch) | |
| log_prefix = "Train" + self.task_prefix.capitalize() | |
| tbar = ( | |
| tqdm(self.train_data, file=self.tqdm_out, ncols=100) | |
| if self.is_master | |
| else self.train_data | |
| ) | |
| for metric in self.train_metrics: | |
| metric.reset_epoch_stats() | |
| self.net.train() | |
| train_loss = 0.0 | |
| for i, batch_data in enumerate(tbar): | |
| global_step = epoch * len(self.train_data) + i | |
| loss, losses_logging, splitted_batch_data, outputs = self.batch_forward( | |
| batch_data | |
| ) | |
| self.optim.zero_grad() | |
| loss.backward() | |
| self.optim.step() | |
| losses_logging["overall"] = loss | |
| reduce_loss_dict(losses_logging) | |
| train_loss += losses_logging["overall"].item() | |
| if self.is_master: | |
| for loss_name, loss_value in losses_logging.items(): | |
| self.sw.add_scalar( | |
| tag=f"{log_prefix}Losses/{loss_name}", | |
| value=loss_value.item(), | |
| global_step=global_step, | |
| ) | |
| for k, v in self.loss_cfg.items(): | |
| if ( | |
| "_loss" in k | |
| and hasattr(v, "log_states") | |
| and self.loss_cfg.get(k + "_weight", 0.0) > 0 | |
| ): | |
| v.log_states(self.sw, f"{log_prefix}Losses/{k}", global_step) | |
| if ( | |
| self.image_dump_interval > 0 | |
| and global_step % self.image_dump_interval == 0 | |
| ): | |
| self.save_visualization( | |
| splitted_batch_data, outputs, global_step, prefix="train" | |
| ) | |
| self.sw.add_scalar( | |
| tag=f"{log_prefix}States/learning_rate", | |
| value=self.lr | |
| if not hasattr(self, "lr_scheduler") | |
| else self.lr_scheduler.get_lr()[-1], | |
| global_step=global_step, | |
| ) | |
| tbar.set_description( | |
| f"Epoch {epoch}, training loss {train_loss/(i+1):.4f}" | |
| ) | |
| for metric in self.train_metrics: | |
| metric.log_states( | |
| self.sw, f"{log_prefix}Metrics/{metric.name}", global_step | |
| ) | |
| if self.is_master: | |
| for metric in self.train_metrics: | |
| self.sw.add_scalar( | |
| tag=f"{log_prefix}Metrics/{metric.name}", | |
| value=metric.get_epoch_value(), | |
| global_step=epoch, | |
| disable_avg=True, | |
| ) | |
| save_checkpoint( | |
| self.net, | |
| self.cfg.CHECKPOINTS_PATH, | |
| prefix=self.task_prefix, | |
| epoch=None, | |
| multi_gpu=self.cfg.multi_gpu, | |
| ) | |
| if isinstance(self.checkpoint_interval, (list, tuple)): | |
| checkpoint_interval = [ | |
| x for x in self.checkpoint_interval if x[0] <= epoch | |
| ][-1][1] | |
| else: | |
| checkpoint_interval = self.checkpoint_interval | |
| if epoch % checkpoint_interval == 0: | |
| save_checkpoint( | |
| self.net, | |
| self.cfg.CHECKPOINTS_PATH, | |
| prefix=self.task_prefix, | |
| epoch=epoch, | |
| multi_gpu=self.cfg.multi_gpu, | |
| ) | |
| if hasattr(self, "lr_scheduler"): | |
| self.lr_scheduler.step() | |
| def validation(self, epoch): | |
| if self.sw is None and self.is_master: | |
| self.sw = SummaryWriterAvg( | |
| log_dir=str(self.cfg.LOGS_PATH), | |
| flush_secs=10, | |
| dump_period=self.tb_dump_period, | |
| ) | |
| log_prefix = "Val" + self.task_prefix.capitalize() | |
| tbar = ( | |
| tqdm(self.val_data, file=self.tqdm_out, ncols=100) | |
| if self.is_master | |
| else self.val_data | |
| ) | |
| for metric in self.val_metrics: | |
| metric.reset_epoch_stats() | |
| val_loss = 0 | |
| losses_logging = defaultdict(list) | |
| self.net.eval() | |
| for i, batch_data in enumerate(tbar): | |
| global_step = epoch * len(self.val_data) + i | |
| ( | |
| loss, | |
| batch_losses_logging, | |
| splitted_batch_data, | |
| outputs, | |
| ) = self.batch_forward(batch_data, validation=True) | |
| batch_losses_logging["overall"] = loss | |
| reduce_loss_dict(batch_losses_logging) | |
| for loss_name, loss_value in batch_losses_logging.items(): | |
| losses_logging[loss_name].append(loss_value.item()) | |
| val_loss += batch_losses_logging["overall"].item() | |
| if self.is_master: | |
| tbar.set_description( | |
| f"Epoch {epoch}, validation loss: {val_loss/(i + 1):.4f}" | |
| ) | |
| for metric in self.val_metrics: | |
| metric.log_states( | |
| self.sw, f"{log_prefix}Metrics/{metric.name}", global_step | |
| ) | |
| if self.is_master: | |
| for loss_name, loss_values in losses_logging.items(): | |
| self.sw.add_scalar( | |
| tag=f"{log_prefix}Losses/{loss_name}", | |
| value=np.array(loss_values).mean(), | |
| global_step=epoch, | |
| disable_avg=True, | |
| ) | |
| for metric in self.val_metrics: | |
| self.sw.add_scalar( | |
| tag=f"{log_prefix}Metrics/{metric.name}", | |
| value=metric.get_epoch_value(), | |
| global_step=epoch, | |
| disable_avg=True, | |
| ) | |
| def batch_forward(self, batch_data, validation=False): | |
| metrics = self.val_metrics if validation else self.train_metrics | |
| losses_logging = dict() | |
| with torch.set_grad_enabled(not validation): | |
| batch_data = {k: v.to(self.device) for k, v in batch_data.items()} | |
| image, gt_mask, points = ( | |
| batch_data["images"], | |
| batch_data["instances"], | |
| batch_data["points"], | |
| ) | |
| orig_image, orig_gt_mask, orig_points = ( | |
| image.clone(), | |
| gt_mask.clone(), | |
| points.clone(), | |
| ) | |
| prev_output = torch.zeros_like(image, dtype=torch.float32)[:, :1, :, :] | |
| last_click_indx = None | |
| with torch.no_grad(): | |
| num_iters = random.randint(0, self.max_num_next_clicks) | |
| for click_indx in range(num_iters): | |
| last_click_indx = click_indx | |
| if not validation: | |
| self.net.eval() | |
| if self.click_models is None or click_indx >= len( | |
| self.click_models | |
| ): | |
| eval_model = self.net | |
| else: | |
| eval_model = self.click_models[click_indx] | |
| net_input = ( | |
| torch.cat((image, prev_output), dim=1) | |
| if self.net.with_prev_mask | |
| else image | |
| ) | |
| prev_output = torch.sigmoid( | |
| eval_model(net_input, points)["instances"] | |
| ) | |
| points = get_next_points( | |
| prev_output, orig_gt_mask, points, click_indx + 1 | |
| ) | |
| if not validation: | |
| self.net.train() | |
| if ( | |
| self.net.with_prev_mask | |
| and self.prev_mask_drop_prob > 0 | |
| and last_click_indx is not None | |
| ): | |
| zero_mask = ( | |
| np.random.random(size=prev_output.size(0)) | |
| < self.prev_mask_drop_prob | |
| ) | |
| prev_output[zero_mask] = torch.zeros_like(prev_output[zero_mask]) | |
| batch_data["points"] = points | |
| net_input = ( | |
| torch.cat((image, prev_output), dim=1) | |
| if self.net.with_prev_mask | |
| else image | |
| ) | |
| output = self.net(net_input, points) | |
| loss = 0.0 | |
| loss = self.add_loss( | |
| "instance_loss", | |
| loss, | |
| losses_logging, | |
| validation, | |
| lambda: (output["instances"], batch_data["instances"]), | |
| ) | |
| loss = self.add_loss( | |
| "instance_aux_loss", | |
| loss, | |
| losses_logging, | |
| validation, | |
| lambda: (output["instances_aux"], batch_data["instances"]), | |
| ) | |
| if self.is_master: | |
| with torch.no_grad(): | |
| for m in metrics: | |
| m.update( | |
| *(output.get(x) for x in m.pred_outputs), | |
| *(batch_data[x] for x in m.gt_outputs), | |
| ) | |
| return loss, losses_logging, batch_data, output | |
| def add_loss( | |
| self, loss_name, total_loss, losses_logging, validation, lambda_loss_inputs | |
| ): | |
| loss_cfg = self.loss_cfg if not validation else self.val_loss_cfg | |
| loss_weight = loss_cfg.get(loss_name + "_weight", 0.0) | |
| if loss_weight > 0.0: | |
| loss_criterion = loss_cfg.get(loss_name) | |
| loss = loss_criterion(*lambda_loss_inputs()) | |
| loss = torch.mean(loss) | |
| losses_logging[loss_name] = loss | |
| loss = loss_weight * loss | |
| total_loss = total_loss + loss | |
| return total_loss | |
| def save_visualization(self, splitted_batch_data, outputs, global_step, prefix): | |
| output_images_path = self.cfg.VIS_PATH / prefix | |
| if self.task_prefix: | |
| output_images_path /= self.task_prefix | |
| if not output_images_path.exists(): | |
| output_images_path.mkdir(parents=True) | |
| image_name_prefix = f"{global_step:06d}" | |
| def _save_image(suffix, image): | |
| cv2.imwrite( | |
| str(output_images_path / f"{image_name_prefix}_{suffix}.jpg"), | |
| image, | |
| [cv2.IMWRITE_JPEG_QUALITY, 85], | |
| ) | |
| images = splitted_batch_data["images"] | |
| points = splitted_batch_data["points"] | |
| instance_masks = splitted_batch_data["instances"] | |
| gt_instance_masks = instance_masks.cpu().numpy() | |
| predicted_instance_masks = ( | |
| torch.sigmoid(outputs["instances"]).detach().cpu().numpy() | |
| ) | |
| points = points.detach().cpu().numpy() | |
| image_blob, points = images[0], points[0] | |
| gt_mask = np.squeeze(gt_instance_masks[0], axis=0) | |
| predicted_mask = np.squeeze(predicted_instance_masks[0], axis=0) | |
| image = image_blob.cpu().numpy() * 255 | |
| image = image.transpose((1, 2, 0)) | |
| image_with_points = draw_points( | |
| image, points[: self.max_interactive_points], (0, 255, 0) | |
| ) | |
| image_with_points = draw_points( | |
| image_with_points, points[self.max_interactive_points :], (0, 0, 255) | |
| ) | |
| gt_mask[gt_mask < 0] = 0.25 | |
| gt_mask = draw_probmap(gt_mask) | |
| predicted_mask = draw_probmap(predicted_mask) | |
| viz_image = np.hstack((image_with_points, gt_mask, predicted_mask)).astype( | |
| np.uint8 | |
| ) | |
| _save_image("instance_segmentation", viz_image[:, :, ::-1]) | |
| def _load_weights(self, net): | |
| if self.cfg.weights is not None: | |
| if os.path.isfile(self.cfg.weights): | |
| load_weights(net, self.cfg.weights) | |
| self.cfg.weights = None | |
| else: | |
| raise RuntimeError(f"=> no checkpoint found at '{self.cfg.weights}'") | |
| elif self.cfg.resume_exp is not None: | |
| checkpoints = list( | |
| self.cfg.CHECKPOINTS_PATH.glob(f"{self.cfg.resume_prefix}*.pth") | |
| ) | |
| assert len(checkpoints) == 1 | |
| checkpoint_path = checkpoints[0] | |
| logger.info(f"Load checkpoint from path: {checkpoint_path}") | |
| load_weights(net, str(checkpoint_path)) | |
| return net | |
| def is_master(self): | |
| return self.cfg.local_rank == 0 | |
| def get_next_points(pred, gt, points, click_indx, pred_thresh=0.49): | |
| assert click_indx > 0 | |
| pred = pred.cpu().numpy()[:, 0, :, :] | |
| gt = gt.cpu().numpy()[:, 0, :, :] > 0.5 | |
| fn_mask = np.logical_and(gt, pred < pred_thresh) | |
| fp_mask = np.logical_and(np.logical_not(gt), pred > pred_thresh) | |
| fn_mask = np.pad(fn_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8) | |
| fp_mask = np.pad(fp_mask, ((0, 0), (1, 1), (1, 1)), "constant").astype(np.uint8) | |
| num_points = points.size(1) // 2 | |
| points = points.clone() | |
| for bindx in range(fn_mask.shape[0]): | |
| fn_mask_dt = cv2.distanceTransform(fn_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] | |
| fp_mask_dt = cv2.distanceTransform(fp_mask[bindx], cv2.DIST_L2, 5)[1:-1, 1:-1] | |
| fn_max_dist = np.max(fn_mask_dt) | |
| fp_max_dist = np.max(fp_mask_dt) | |
| is_positive = fn_max_dist > fp_max_dist | |
| dt = fn_mask_dt if is_positive else fp_mask_dt | |
| inner_mask = dt > max(fn_max_dist, fp_max_dist) / 2.0 | |
| indices = np.argwhere(inner_mask) | |
| if len(indices) > 0: | |
| coords = indices[np.random.randint(0, len(indices))] | |
| if is_positive: | |
| points[bindx, num_points - click_indx, 0] = float(coords[0]) | |
| points[bindx, num_points - click_indx, 1] = float(coords[1]) | |
| points[bindx, num_points - click_indx, 2] = float(click_indx) | |
| else: | |
| points[bindx, 2 * num_points - click_indx, 0] = float(coords[0]) | |
| points[bindx, 2 * num_points - click_indx, 1] = float(coords[1]) | |
| points[bindx, 2 * num_points - click_indx, 2] = float(click_indx) | |
| return points | |
| def load_weights(model, path_to_weights): | |
| current_state_dict = model.state_dict() | |
| new_state_dict = torch.load(path_to_weights, map_location="cpu")["state_dict"] | |
| current_state_dict.update(new_state_dict) | |
| model.load_state_dict(current_state_dict) | |