| import os |
| import random |
| import monai |
| from os import listdir, makedirs |
| from os.path import join, exists, isfile, isdir, basename |
| from tqdm import tqdm |
| from time import time |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader |
| from datetime import datetime |
| from shutil import copyfile |
| from models import PromptEncoder, TwoWayTransformer, TinyViT, MaskDecoder_F4 |
| import torch.nn.functional as F |
| import gc |
| from matplotlib import pyplot as plt |
| import argparse |
| from modality_npz_dataset import ModalityNpzDataset |
|
|
| torch.cuda.empty_cache() |
| os.environ["OMP_NUM_THREADS"] = "4" |
| os.environ["OPENBLAS_NUM_THREADS"] = "4" |
| os.environ["MKL_NUM_THREADS"] = "6" |
| os.environ["VECLIB_MAXIMUM_THREADS"] = "4" |
| os.environ["NUMEXPR_NUM_THREADS"] = "6" |
|
|
| def setup_seed(seed): |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| setup_seed(2024) |
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--data_root", |
| type=str, |
| default="", |
| help="Path to the npy data root.") |
|
|
| parser.add_argument('--task_name', type=str, default='MedSAM-Lite-All') |
|
|
| parser.add_argument("--pretrained_checkpoint", |
| type=str, |
| default=None, |
| help="Path to the pretrained Lite-MedSAM checkpoint.") |
|
|
| parser.add_argument("--resume", |
| type=str, |
| default=None, |
| help="Path to the checkpoint to continue training.") |
| parser.add_argument( |
| "--work_dir", |
| type=str, |
| default="./work_dir", |
| help= |
| "Path to the working directory where checkpoints and logs will be saved." |
| ) |
|
|
| parser.add_argument('--data_aug', |
| action='store_true', |
| default=False, |
| help='use data augmentation during training') |
|
|
| parser.add_argument("--num_epochs", |
| type=int, |
| default=25, |
| help="Number of epochs to train.") |
| parser.add_argument("--batch_size", |
| type=int, |
| default=16, |
| help="Batch size.") |
| parser.add_argument("--num_workers", |
| type=int, |
| default=8, |
| help="Number of workers for dataloader.") |
|
|
| parser.add_argument( |
| "--bbox_shift", |
| type=int, |
| default=5, |
| help="Perturbation to bounding box coordinates during training.") |
|
|
| parser.add_argument("-lr", type=float, default=2e-4, help="Learning rate.") |
|
|
| parser.add_argument("-weight_decay", |
| type=float, |
| default=0.001, |
| help="Weight decay.") |
|
|
| parser.add_argument("-iou_loss_weight", |
| type=float, |
| default=1.0, |
| help="Weight of IoU loss.") |
|
|
| parser.add_argument("-seg_loss_weight", |
| type=float, |
| default=1.0, |
| help="Weight of segmentation loss.") |
| parser.add_argument("-ce_loss_weight", |
| type=float, |
| default=1.0, |
| help="Weight of cross entropy loss.") |
|
|
| parser.add_argument("--sanity_check", |
| action="store_true", |
| default=True, |
| help="Whether to do sanity check for dataloading.") |
|
|
| args = parser.parse_args() |
| return args |
|
|
|
|
| def show_mask(mask, ax, random_color=True): |
| if random_color: |
| color = np.concatenate([np.random.random(3), np.array([0.45])], axis=0) |
| else: |
| color = np.array([251 / 255, 252 / 255, 30 / 255, 0.45]) |
| h, w = mask.shape[-2:] |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| ax.imshow(mask_image) |
|
|
|
|
| def show_box(box, ax): |
| x0, y0 = box[0], box[1] |
| w, h = box[2] - box[0], box[3] - box[1] |
| ax.add_patch( |
| plt.Rectangle((x0, y0), |
| w, |
| h, |
| edgecolor='blue', |
| facecolor=(0, 0, 0, 0), |
| lw=2)) |
|
|
|
|
| def show_points(points, ax): |
| for i, (x, y) in enumerate(points): |
| ax.scatter(x, y, color='red', s=10) |
|
|
|
|
| def cal_iou(result, reference): |
|
|
| intersection = torch.count_nonzero(torch.logical_and(result, reference), |
| dim=[i for i in range(1, result.ndim)]) |
| union = torch.count_nonzero(torch.logical_or(result, reference), |
| dim=[i for i in range(1, result.ndim)]) |
|
|
| iou = intersection.float() / union.float() |
|
|
| return iou.unsqueeze(1) |
|
|
|
|
| def sanity_check_dataset(args): |
|
|
| tr_dataset = ModalityNpzDataset(args.data_root, data_aug=True) |
| tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True) |
|
|
| for step, batch in enumerate(tr_dataloader): |
| |
| _, axs = plt.subplots(1, 2, figsize=(10, 10)) |
| idx = random.randint(0, 4) |
|
|
| image = batch["image"] |
| gt = batch["gt2D"] |
| bboxes = batch["bboxes"] |
| names_temp = batch["image_name"] |
|
|
| axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy()) |
| show_mask(gt[idx].cpu().squeeze().numpy(), axs[0]) |
| show_box(bboxes[idx].numpy().squeeze(), axs[0]) |
| axs[0].axis('off') |
| |
| axs[0].set_title(names_temp[idx]) |
| idx = random.randint(4, 7) |
| axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy()) |
| show_mask(gt[idx].cpu().squeeze().numpy(), axs[1]) |
| show_box(bboxes[idx].numpy().squeeze(), axs[1]) |
| axs[1].axis('off') |
| |
| axs[1].set_title(names_temp[idx]) |
| plt.subplots_adjust(wspace=0.01, hspace=0) |
| plt.savefig(join(args.work_dir, 'Sanitycheck_DA.png'), |
| bbox_inches='tight', |
| dpi=300) |
| plt.close() |
| break |
|
|
|
|
| class MedSAM_Lite(nn.Module): |
|
|
| def __init__( |
| self, |
| image_encoder, |
| mask_decoder, |
| prompt_encoder, |
| ): |
| super().__init__() |
| self.image_encoder = image_encoder |
| self.mask_decoder = mask_decoder |
| self.prompt_encoder = prompt_encoder |
| encoder_weight_file = "" |
|
|
| self.image_encoder.load_state_dict(torch.load(encoder_weight_file)) |
|
|
| def forward(self, image, points, boxes, masks, features, crops, |
| text_features, category_idx): |
| image_embedding = self.image_encoder(image) |
|
|
| sparse_embeddings, dense_embeddings = self.prompt_encoder( |
| points=points, |
| boxes=boxes, |
| masks=masks, |
| features=features, |
| crops=crops, |
| text_features=text_features, |
| category_idx=category_idx) |
|
|
| low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec = self.mask_decoder( |
| image_embeddings=image_embedding, |
| image_pe=self.prompt_encoder.get_dense_pe(), |
| sparse_prompt_embeddings=sparse_embeddings, |
| dense_prompt_embeddings=dense_embeddings, |
| multimask_output=False, |
| ) |
|
|
| return low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec |
|
|
| @torch.no_grad() |
| def postprocess_masks(self, masks, new_size, original_size): |
| """ |
| Do cropping and resizing |
| """ |
| |
| masks = masks[:, :, :new_size[0], :new_size[1]] |
| |
| masks = F.interpolate( |
| masks, |
| size=(original_size[0], original_size[1]), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| return masks |
|
|
|
|
| def collate_fn(batch): |
| """ |
| Collate function for PyTorch DataLoader. |
| """ |
| batch_dict = {} |
| for key in batch[0].keys(): |
| if key == "image_name" or key == "category_idx": |
| batch_dict[key] = [sample[key] for sample in batch] |
| else: |
| batch_dict[key] = torch.stack([sample[key] for sample in batch], |
| dim=0) |
|
|
| return batch_dict |
|
|
|
|
| if __name__ == "__main__": |
|
|
| args = get_args() |
| sanity_check_dataset(args) |
|
|
| run_id = datetime.now().strftime("%Y%m%d-%H%M") |
| print(f"Run ID: {run_id}") |
|
|
| model_save_path = join(args.work_dir, args.task_name + "-" + run_id) |
| makedirs(model_save_path, exist_ok=True) |
| copyfile(__file__, |
| join(model_save_path, run_id + "_" + os.path.basename(__file__))) |
|
|
| device = torch.device("cuda") |
|
|
| num_epochs = args.num_epochs |
| batch_size = args.batch_size |
| num_workers = args.num_workers |
|
|
| medsam_lite_image_encoder = TinyViT( |
| img_size=256, |
| in_chans=3, |
| embed_dims=[ |
| 64, |
| 128, |
| 160, |
| 320 |
| ], |
| depths=[2, 2, 6, 2], |
| num_heads=[2, 4, 5, 10], |
| window_sizes=[7, 7, 14, 7], |
| mlp_ratio=4., |
| drop_rate=0., |
| drop_path_rate=0.0, |
| use_checkpoint=False, |
| mbconv_expand_ratio=4.0, |
| local_conv_size=3, |
| layer_lr_decay=0.8) |
|
|
| medsam_lite_prompt_encoder = PromptEncoder(embed_dim=256, |
| image_embedding_size=(64, 64), |
| input_image_size=(256, 256), |
| mask_in_chans=16) |
|
|
| medsam_lite_mask_decoder = MaskDecoder_F4( |
| num_multimask_outputs=3, |
| transformer=TwoWayTransformer( |
| depth=2, |
| embedding_dim=256, |
| mlp_dim=2048, |
| num_heads=8, |
| ), |
| modality=True, |
| contents=True, |
| transformer_dim=256, |
| iou_head_depth=3, |
| iou_head_hidden_dim=256, |
| ) |
|
|
| medsam_lite_model = MedSAM_Lite(image_encoder=medsam_lite_image_encoder, |
| mask_decoder=medsam_lite_mask_decoder, |
| prompt_encoder=medsam_lite_prompt_encoder) |
|
|
| if args.resume is None and args.pretrained_checkpoint is not None: |
| |
| print( |
| f"Loading pretrained checkpoint from {args.pretrained_checkpoint}") |
| medsam_lite_checkpoint = torch.load(args.pretrained_checkpoint, |
| map_location="cpu") |
| medsam_lite_model.load_state_dict(medsam_lite_checkpoint["model"], |
| strict=True) |
|
|
| medsam_lite_model = medsam_lite_model.to(device) |
|
|
| medsam_lite_model.train() |
|
|
| print( |
| f"MedSAM Lite size: {sum(p.numel() for p in medsam_lite_model.parameters())}" |
| ) |
|
|
| print('lr:', args.lr) |
|
|
| optimizer = optim.AdamW( |
| medsam_lite_model.parameters(), |
| lr=args.lr, |
| betas=(0.9, 0.999), |
| eps=1e-08, |
| weight_decay=args.weight_decay, |
| ) |
| lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, |
| mode='min', |
| factor=0.9, |
| patience=5, |
| cooldown=0) |
| seg_loss = monai.losses.DiceLoss(sigmoid=True, |
| squared_pred=True, |
| reduction='mean') |
| bce_loss = nn.BCEWithLogitsLoss(reduction='mean') |
| iou_loss = nn.MSELoss(reduction='mean') |
| ce_loss = nn.CrossEntropyLoss(reduction='mean') |
|
|
| train_dataset = ModalityNpzDataset(data_root=args.data_root, data_aug=True) |
|
|
| train_loader = DataLoader(train_dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=num_workers, |
| pin_memory=True) |
|
|
| if args.resume is not None: |
| ckpt_folders = sorted(listdir(args.resume)) |
| ckpt_folders = [ |
| f for f in ckpt_folders |
| if (f.startswith(args.task_name) |
| and isfile(join(args.resume, f, 'medsam_lite_latest.pth'))) |
| ] |
| print('*' * 20) |
| print('existing ckpts in', args.resume, ckpt_folders) |
| |
| time_strings = [ |
| f.split(args.task_name + '-')[-1] for f in ckpt_folders |
| ] |
| dates = [datetime.strptime(f, '%Y%m%d-%H%M') for f in time_strings] |
| latest_date = max(dates) |
| latest_ckpt = join( |
| args.work_dir, |
| args.task_name + '-' + latest_date.strftime('%Y%m%d-%H%M'), |
| 'medsam_lite_latest.pth') |
| print('Loading from', latest_ckpt) |
| checkpoint = torch.load(latest_ckpt, map_location=device) |
| medsam_lite_model.module.load_state_dict(checkpoint["model"]) |
| optimizer.load_state_dict(checkpoint["optimizer"]) |
| start_epoch = checkpoint["epoch"] + 1 |
| best_loss = checkpoint["loss"] |
| print(f"Loaded checkpoint from epoch {start_epoch}") |
| else: |
| start_epoch = 0 |
| best_loss = 1e10 |
|
|
| train_losses = [] |
| epoch_times = [] |
|
|
| print("Training") |
| for epoch in range(start_epoch, num_epochs): |
| if epoch == num_epochs - 1: |
| for param_group in optimizer.param_groups: |
| param_group['lr'] = 5e-5 |
|
|
| epoch_loss = [1e10 for _ in range(len(train_loader))] |
| epoch_start_time = time() |
| pbar = tqdm(train_loader) |
| for step, batch in enumerate(pbar): |
| gc.collect() |
| torch.cuda.empty_cache() |
| image = batch["image"] |
| gt2D = batch["gt2D"] |
| boxes = batch["bboxes"] |
| coords = batch["coords"] |
| crops = batch["image_crop"] |
| features = batch["image_feature"] |
| text_features = batch["text_feature"] |
| class_idx = batch["category_idx"] |
| class_idx = torch.tensor(class_idx) |
|
|
| optimizer.zero_grad() |
| image, gt2D, boxes, coords, crops, features, text_features, class_idx = image.to( |
| device), gt2D.to(device), boxes.to(device), coords.to( |
| device), crops.to(device), features.to( |
| device), text_features.to(device), class_idx.to(device) |
| labels_torch = torch.ones(coords.shape[0]).long() |
| labels_torch = labels_torch.unsqueeze(1).expand(-1, 4) |
| labels_torch = labels_torch.to(device) |
| point_prompt = (coords, labels_torch) |
| logits_pred, iou_pred, category_predictions, clip_vec, img_vec = medsam_lite_model( |
| image, None, boxes, None, features, crops, text_features, class_idx) |
| |
| clip_img_features = clip_vec / clip_vec.norm(dim=-1, keepdim=True) |
| img_features = img_vec / img_vec.norm(dim=-1, keepdim=True) |
| similarity1 = torch.matmul(clip_img_features, img_features.T) |
| similarity2 = torch.matmul(img_features, clip_img_features.T) |
| sim_labels = torch.arange(similarity1.shape[0]).to(image.device) |
|
|
| l_seg = seg_loss(logits_pred, gt2D) |
| l_bce = bce_loss(logits_pred, gt2D.float()) |
| l_ce_sim = 0.5 * (ce_loss(similarity1, sim_labels.long()) + |
| ce_loss(similarity2, sim_labels.long())) |
| l_ce = ce_loss(category_predictions, class_idx.long()) |
| mask_loss = l_seg + l_bce |
| with torch.no_grad(): |
| iou_gt = cal_iou(torch.sigmoid(logits_pred) > 0.5, gt2D.bool()) |
| l_iou = iou_loss(iou_pred, iou_gt) |
| loss = mask_loss + l_iou + 0.01 * l_ce_sim + 0.01 * l_ce |
| epoch_loss[step] = loss.item() |
| loss.backward() |
| optimizer.step() |
| optimizer.zero_grad() |
| pbar.set_description( |
| f"Epoch {epoch} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, loss: {loss.item():.4f}" |
| ) |
|
|
| epoch_end_time = time() |
| epoch_duration = epoch_end_time - epoch_start_time |
| epoch_times.append(epoch_duration) |
|
|
| epoch_loss_reduced = sum(epoch_loss) / len(epoch_loss) |
|
|
| train_losses.append(epoch_loss_reduced) |
| lr_scheduler.step(epoch_loss_reduced) |
|
|
| model_weights = medsam_lite_model.state_dict() |
|
|
| checkpoint = { |
| "model": model_weights, |
| "epoch": epoch, |
| "optimizer": optimizer.state_dict(), |
| "loss": epoch_loss_reduced, |
| "best_loss": best_loss, |
| } |
| torch.save(checkpoint, join(model_save_path, "medsam_lite_latest.pth")) |
|
|
| if epoch_loss_reduced < best_loss: |
| print( |
| f"New best loss: {best_loss:.4f} -> {epoch_loss_reduced:.4f}") |
| best_loss = epoch_loss_reduced |
| checkpoint["best_loss"] = best_loss |
| torch.save(checkpoint, join(model_save_path, |
| "medsam_lite_best.pth")) |
| epoch_loss_reduced = 1e10 |
|
|
| fig, axes = plt.subplots(2, 1, figsize=(10, 8)) |
| axes[0].title.set_text("Dice + Binary Cross Entropy + IoU Loss") |
| axes[0].plot(train_losses) |
| axes[0].set_ylabel("Loss") |
| axes[1].plot(epoch_times) |
| axes[1].title.set_text("Epoch Duration") |
| axes[1].set_ylabel("Duration (s)") |
| axes[1].set_xlabel("Epoch") |
| plt.tight_layout() |
| plt.savefig(join(model_save_path, "log.png")) |
| plt.close() |
|
|