import random import sys import torch import torch.nn as nn import torch.nn.functional as F import wandb from PIL import Image from tqdm import tqdm from efficientvit.apps.trainer import Trainer from efficientvit.apps.utils import AverageMeter, get_dist_local_rank, get_dist_size, is_master, sync_tensor from efficientvit.models.utils import list_join from efficientvit.samcore.data_provider import SAMDataProvider from efficientvit.samcore.trainer import SAMRunConfig from efficientvit.samcore.trainer.utils import ( compute_boundary_iou, compute_iou, loss_masks, mask_iou_batch, masks_sample_points, ) __all__ = ["SAMTrainer"] class SAMTrainer(Trainer): def __init__( self, path: str, model: nn.Module, data_provider: SAMDataProvider, ) -> None: super().__init__( path=path, model=model, data_provider=data_provider, ) if is_master(): self.wandb_log = wandb.init(project="efficientvit-sam") def _validate(self, model, data_loader, epoch: int, sub_epoch: int) -> dict[str, any]: val_loss = AverageMeter() val_iou = AverageMeter() val_iou_boundary = AverageMeter() with torch.no_grad(): with tqdm( total=len(data_loader), desc=f"Validate Epoch #{epoch + 1}, Sub Epoch #{sub_epoch+1}", disable=not is_master(), file=sys.stdout, ) as t: for i, data in enumerate(data_loader): image = data["image"].cuda() masks = data["masks"].cuda() bboxs = data["bboxs"].cuda() * 2 if image.shape[2] == 512 else data["bboxs"].cuda() points = data["points"].cuda() * 2 if image.shape[2] == 512 else data["points"].cuda() bboxs[..., 2] = bboxs[..., 0] + bboxs[..., 2] bboxs[..., 3] = bboxs[..., 1] + bboxs[..., 3] batched_input = [] for b_i in range(len(image)): dict_input = dict() dict_input["image"] = image[b_i] dict_input["boxes"] = bboxs[b_i] batched_input.append(dict_input) output, iou_predictions = model(batched_input, True) B, M, N, H, W = output.shape output = torch.stack( [ output[k][torch.arange(M), iou_predictions[k].argmax(-1).squeeze()] for k in range(len(output)) ], dim=0, ) output = ( F.interpolate(output, size=(image.shape[2], image.shape[3]), mode="bilinear") .reshape(-1, image.shape[2], image.shape[3]) .unsqueeze(1) ) masks = masks.reshape(-1, image.shape[2], image.shape[3]).unsqueeze(1) loss_mask, loss_dice = loss_masks(output, masks, len(output)) loss = loss_mask * 20 + loss_dice iou = compute_iou(output, masks * 255) boundary_iou = compute_boundary_iou(output, masks * 255) loss = sync_tensor(loss) iou = sync_tensor(iou) boundary_iou = sync_tensor(boundary_iou) val_loss.update(loss, image.shape[0] * get_dist_size()) val_iou.update(iou, image.shape[0] * get_dist_size()) val_iou_boundary.update(boundary_iou, image.shape[0] * get_dist_size()) t.set_postfix( { "loss": val_loss.avg, "iou": val_iou.avg, "boundary_iou": val_iou_boundary.avg, "bs": image.shape[0] * get_dist_size(), } ) t.update() if is_master(): self.wandb_log.log( {"val_loss": val_loss.avg, "val_iou": val_iou.avg, "val_boundary_iou": val_iou_boundary.avg} ) return { "val_loss": val_loss.avg, "val_iou": val_iou.avg, "val_boundary_iou": val_iou_boundary.avg, } def validate(self, model=None, data_loader=None, epoch=0, sub_epoch=0) -> dict[str, any]: model = model or self.eval_network if data_loader is None: data_loader = self.data_provider.valid model.eval() return self._validate(model, data_loader, epoch, sub_epoch) def before_step(self, feed_dict: dict[str, any]) -> dict[str, any]: image = feed_dict["image"].cuda() masks = feed_dict["masks"].cuda() bboxs = feed_dict["bboxs"].cuda() * 2 if image.shape[2] == 512 else feed_dict["bboxs"].cuda() points = feed_dict["points"].cuda() * 2 if image.shape[2] == 512 else feed_dict["points"].cuda() bboxs[..., 2] = bboxs[..., 0] + bboxs[..., 2] bboxs[..., 3] = bboxs[..., 1] + bboxs[..., 3] return { "image": image, "masks": masks, "points": points, "bboxs": bboxs, } def run_step(self, feed_dict: dict[str, any]) -> dict[str, any]: image = feed_dict["image"] masks = feed_dict["masks"] bboxs = feed_dict["bboxs"] points = feed_dict["points"] batched_input = [] for b_i in range(len(image)): dict_input = dict() dict_input["image"] = image[b_i] if random.random() >= 0.5: dict_input["boxes"] = bboxs[b_i] else: try: n_p = int(random.random() * 10 + 1) dict_input["point_coords"] = masks_sample_points(masks[b_i], k=n_p) if image.shape[2] == 512: dict_input["point_coords"] = dict_input["point_coords"] * 2 dict_input["point_labels"] = torch.ones((points[b_i].shape[0], n_p), device=image.device) except: dict_input["boxes"] = bboxs[b_i] batched_input.append(dict_input) with torch.autocast(device_type="cuda", dtype=self.amp_dtype, enabled=self.enable_amp): if random.random() >= 0.5: output, iou_predictions = self.model(batched_input, multimask_output=True) else: output, iou_predictions = self.model(batched_input, multimask_output=False) masks = masks.reshape(-1, image.shape[2], image.shape[3]).unsqueeze(1) loss_list = [] for i in range(output.shape[2]): output_i = ( F.interpolate(output[:, :, i], size=(image.shape[2], image.shape[3]), mode="bilinear") .reshape(-1, image.shape[2], image.shape[3]) .unsqueeze(1) ) loss_mask_i, loss_dice_i = loss_masks(output_i, masks, len(output_i), mode="none") loss_i = loss_mask_i * 20 + loss_dice_i loss_list.append(loss_i) loss = torch.stack(loss_list, -1) min_indices = torch.argmin(loss, dim=1) mask = torch.zeros_like(loss, device=loss.device) mask.scatter_(1, min_indices.unsqueeze(1), 1) loss = (loss * mask).mean() * loss.shape[-1] self.scaler.scale(loss).backward() return {"loss": loss, "output": output} def _train_one_sub_epoch(self, epoch: int, sub_epoch: int) -> dict[str, any]: train_loss = AverageMeter() with tqdm( total=len(self.data_provider.train), desc=f"Train Epoch #{epoch + 1}, Sub Epoch #{sub_epoch + 1}", disable=not is_master(), file=sys.stdout, ) as t: for i, data in enumerate(self.data_provider.train): feed_dict = data # preprocessing feed_dict = self.before_step(feed_dict) # clear gradient self.optimizer.zero_grad() # forward & backward output_dict = self.run_step(feed_dict) # update: optimizer, lr_scheduler self.after_step() loss = output_dict["loss"] loss = sync_tensor(loss) train_loss.update(loss, data["image"].shape[0] * get_dist_size()) if is_master(): self.wandb_log.log( { "train_loss": train_loss.avg, "epoch": epoch, "sub_epoch": sub_epoch, "learning_rate": sorted(set([group["lr"] for group in self.optimizer.param_groups]))[0], } ) t.set_postfix( { "loss": train_loss.avg, "bs": data["image"].shape[0] * get_dist_size(), "res": data["image"].shape[2], "lr": list_join( sorted(set([group["lr"] for group in self.optimizer.param_groups])), "#", "%.1E", ), "progress": self.run_config.progress, } ) t.update() return { "train_loss": train_loss.avg, } def train_one_sub_epoch(self, epoch: int, sub_epoch: int) -> dict[str, any]: self.model.train() self.data_provider.set_epoch_and_sub_epoch(epoch, sub_epoch) train_info_dict = self._train_one_sub_epoch(epoch, sub_epoch) return train_info_dict def train(self) -> None: for sub_epoch in range(self.start_epoch, self.run_config.n_epochs): epoch = sub_epoch // self.data_provider.sub_epochs_per_epoch train_info_dict = self.train_one_sub_epoch(epoch, sub_epoch) val_info_dict = self.validate(epoch=epoch, sub_epoch=sub_epoch) val_iou = val_info_dict["val_iou"] is_best = val_iou > self.best_val self.best_val = max(val_iou, self.best_val) self.save_model( only_state_dict=False, epoch=sub_epoch, model_name=f"checkpoint_{epoch}_{sub_epoch}.pt", ) def prep_for_training(self, run_config: SAMRunConfig, amp="fp32") -> None: self.run_config = run_config self.model = nn.parallel.DistributedDataParallel( self.model.cuda(), device_ids=[get_dist_local_rank()], find_unused_parameters=True, ) self.run_config.global_step = 0 self.run_config.batch_per_epoch = len(self.data_provider.train) assert self.run_config.batch_per_epoch > 0, "Training set is empty" # build optimizer self.optimizer, self.lr_scheduler = self.run_config.build_optimizer(self.model) # amp self.amp = amp self.scaler = torch.cuda.amp.GradScaler(enabled=self.enable_amp)