Spaces:
No application file
No application file
| # -*- coding: utf-8 -*- | |
| """ | |
| train the image encoder and mask decoder | |
| freeze prompt image encoder | |
| """ | |
| # %% setup environment | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import os | |
| join = os.path.join | |
| from tqdm import tqdm | |
| from skimage import transform | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader | |
| import monai | |
| from segment_anything import sam_model_registry | |
| import torch.nn.functional as F | |
| import argparse | |
| import random | |
| from datetime import datetime | |
| import shutil | |
| import glob | |
| # set seeds | |
| torch.manual_seed(2023) | |
| torch.cuda.empty_cache() | |
| # torch.distributed.init_process_group(backend="gloo") | |
| os.environ["OMP_NUM_THREADS"] = "4" # export OMP_NUM_THREADS=4 | |
| os.environ["OPENBLAS_NUM_THREADS"] = "4" # export OPENBLAS_NUM_THREADS=4 | |
| os.environ["MKL_NUM_THREADS"] = "6" # export MKL_NUM_THREADS=6 | |
| os.environ["VECLIB_MAXIMUM_THREADS"] = "4" # export VECLIB_MAXIMUM_THREADS=4 | |
| os.environ["NUMEXPR_NUM_THREADS"] = "6" # export NUMEXPR_NUM_THREADS=6 | |
| def show_mask(mask, ax, random_color=False): | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| ax.imshow(mask_image) | |
| def show_box(box, ax): | |
| x0, y0 = box[0], box[1] | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| ax.add_patch( | |
| plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2) | |
| ) | |
| class NpyDataset(Dataset): | |
| def __init__(self, data_root, bbox_shift=20): | |
| self.data_root = data_root | |
| self.gt_path = join(data_root, "gts") | |
| self.img_path = join(data_root, "imgs") | |
| self.gt_path_files = sorted( | |
| glob.glob(join(self.gt_path, "**/*.npy"), recursive=True) | |
| ) | |
| self.gt_path_files = [ | |
| file | |
| for file in self.gt_path_files | |
| if os.path.isfile(join(self.img_path, os.path.basename(file))) | |
| ] | |
| self.bbox_shift = bbox_shift | |
| print(f"number of images: {len(self.gt_path_files)}") | |
| def __len__(self): | |
| return len(self.gt_path_files) | |
| def __getitem__(self, index): | |
| # load npy image (1024, 1024, 3), [0,1] | |
| img_name = os.path.basename(self.gt_path_files[index]) | |
| img_1024 = np.load( | |
| join(self.img_path, img_name), "r", allow_pickle=True | |
| ) # (1024, 1024, 3) | |
| # convert the shape to (3, H, W) | |
| img_1024 = np.transpose(img_1024, (2, 0, 1)) | |
| assert ( | |
| np.max(img_1024) <= 1.0 and np.min(img_1024) >= 0.0 | |
| ), "image should be normalized to [0, 1]" | |
| gt = np.load( | |
| self.gt_path_files[index], "r", allow_pickle=True | |
| ) # multiple labels [0, 1,4,5...], (256,256) | |
| assert img_name == os.path.basename(self.gt_path_files[index]), ( | |
| "img gt name error" + self.gt_path_files[index] + self.npy_files[index] | |
| ) | |
| label_ids = np.unique(gt)[1:] | |
| gt2D = np.uint8( | |
| gt == random.choice(label_ids.tolist()) | |
| ) # only one label, (256, 256) | |
| assert np.max(gt2D) == 1 and np.min(gt2D) == 0.0, "ground truth should be 0, 1" | |
| y_indices, x_indices = np.where(gt2D > 0) | |
| x_min, x_max = np.min(x_indices), np.max(x_indices) | |
| y_min, y_max = np.min(y_indices), np.max(y_indices) | |
| # add perturbation to bounding box coordinates | |
| H, W = gt2D.shape | |
| x_min = max(0, x_min - random.randint(0, self.bbox_shift)) | |
| x_max = min(W, x_max + random.randint(0, self.bbox_shift)) | |
| y_min = max(0, y_min - random.randint(0, self.bbox_shift)) | |
| y_max = min(H, y_max + random.randint(0, self.bbox_shift)) | |
| bboxes = np.array([x_min, y_min, x_max, y_max]) | |
| return ( | |
| torch.tensor(img_1024).float(), | |
| torch.tensor(gt2D[None, :, :]).long(), | |
| torch.tensor(bboxes).float(), | |
| img_name, | |
| ) | |
| # %% sanity test of dataset class | |
| tr_dataset = NpyDataset("data/npy/CT_Abd") | |
| tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True) | |
| for step, (image, gt, bboxes, names_temp) in enumerate(tr_dataloader): | |
| print(image.shape, gt.shape, bboxes.shape) | |
| # show the example | |
| _, axs = plt.subplots(1, 2, figsize=(25, 25)) | |
| idx = random.randint(0, 7) | |
| axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy()) | |
| show_mask(gt[idx].cpu().numpy(), axs[0]) | |
| show_box(bboxes[idx].numpy(), axs[0]) | |
| axs[0].axis("off") | |
| # set title | |
| axs[0].set_title(names_temp[idx]) | |
| idx = random.randint(0, 7) | |
| axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy()) | |
| show_mask(gt[idx].cpu().numpy(), axs[1]) | |
| show_box(bboxes[idx].numpy(), axs[1]) | |
| axs[1].axis("off") | |
| # set title | |
| axs[1].set_title(names_temp[idx]) | |
| # plt.show() | |
| plt.subplots_adjust(wspace=0.01, hspace=0) | |
| plt.savefig("./data_sanitycheck.png", bbox_inches="tight", dpi=300) | |
| plt.close() | |
| break | |
| # %% set up parser | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "-i", | |
| "--tr_npy_path", | |
| type=str, | |
| default="data/npy/CT_Abd", | |
| help="path to training npy files; two subfolders: gts and imgs", | |
| ) | |
| parser.add_argument("-task_name", type=str, default="MedSAM-ViT-B") | |
| parser.add_argument("-model_type", type=str, default="vit_b") | |
| parser.add_argument( | |
| "-checkpoint", type=str, default="work_dir/SAM/sam_vit_b_01ec64.pth" | |
| ) | |
| # parser.add_argument('-device', type=str, default='cuda:0') | |
| parser.add_argument( | |
| "--load_pretrain", type=bool, default=True, help="load pretrain model" | |
| ) | |
| parser.add_argument("-pretrain_model_path", type=str, default="") | |
| parser.add_argument("-work_dir", type=str, default="./work_dir") | |
| # train | |
| parser.add_argument("-num_epochs", type=int, default=1000) | |
| parser.add_argument("-batch_size", type=int, default=2) | |
| parser.add_argument("-num_workers", type=int, default=0) | |
| # Optimizer parameters | |
| parser.add_argument( | |
| "-weight_decay", type=float, default=0.01, help="weight decay (default: 0.01)" | |
| ) | |
| parser.add_argument( | |
| "-lr", type=float, default=0.0001, metavar="LR", help="learning rate (absolute lr)" | |
| ) | |
| parser.add_argument( | |
| "-use_wandb", type=bool, default=False, help="use wandb to monitor training" | |
| ) | |
| parser.add_argument("-use_amp", action="store_true", default=False, help="use amp") | |
| parser.add_argument( | |
| "--resume", type=str, default="", help="Resuming training from checkpoint" | |
| ) | |
| parser.add_argument("--device", type=str, default="cuda:0") | |
| args = parser.parse_args() | |
| if args.use_wandb: | |
| import wandb | |
| wandb.login() | |
| wandb.init( | |
| project=args.task_name, | |
| config={ | |
| "lr": args.lr, | |
| "batch_size": args.batch_size, | |
| "data_path": args.tr_npy_path, | |
| "model_type": args.model_type, | |
| }, | |
| ) | |
| # %% set up model for training | |
| # device = args.device | |
| run_id = datetime.now().strftime("%Y%m%d-%H%M") | |
| model_save_path = join(args.work_dir, args.task_name + "-" + run_id) | |
| device = torch.device(args.device) | |
| # %% set up model | |
| class MedSAM(nn.Module): | |
| def __init__( | |
| self, | |
| image_encoder, | |
| mask_decoder, | |
| prompt_encoder, | |
| ): | |
| super().__init__() | |
| self.image_encoder = image_encoder | |
| self.mask_decoder = mask_decoder | |
| self.prompt_encoder = prompt_encoder | |
| # freeze prompt encoder | |
| for param in self.prompt_encoder.parameters(): | |
| param.requires_grad = False | |
| def forward(self, image, box): | |
| image_embedding = self.image_encoder(image) # (B, 256, 64, 64) | |
| # do not compute gradients for prompt encoder | |
| with torch.no_grad(): | |
| box_torch = torch.as_tensor(box, dtype=torch.float32, device=image.device) | |
| if len(box_torch.shape) == 2: | |
| box_torch = box_torch[:, None, :] # (B, 1, 4) | |
| sparse_embeddings, dense_embeddings = self.prompt_encoder( | |
| points=None, | |
| boxes=box_torch, | |
| masks=None, | |
| ) | |
| low_res_masks, _ = self.mask_decoder( | |
| image_embeddings=image_embedding, # (B, 256, 64, 64) | |
| image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) | |
| sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) | |
| dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) | |
| multimask_output=False, | |
| ) | |
| ori_res_masks = F.interpolate( | |
| low_res_masks, | |
| size=(image.shape[2], image.shape[3]), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) | |
| return ori_res_masks | |
| def main(): | |
| os.makedirs(model_save_path, exist_ok=True) | |
| shutil.copyfile( | |
| __file__, join(model_save_path, run_id + "_" + os.path.basename(__file__)) | |
| ) | |
| sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint) | |
| medsam_model = MedSAM( | |
| image_encoder=sam_model.image_encoder, | |
| mask_decoder=sam_model.mask_decoder, | |
| prompt_encoder=sam_model.prompt_encoder, | |
| ).to(device) | |
| medsam_model.train() | |
| print( | |
| "Number of total parameters: ", | |
| sum(p.numel() for p in medsam_model.parameters()), | |
| ) # 93735472 | |
| print( | |
| "Number of trainable parameters: ", | |
| sum(p.numel() for p in medsam_model.parameters() if p.requires_grad), | |
| ) # 93729252 | |
| img_mask_encdec_params = list(medsam_model.image_encoder.parameters()) + list( | |
| medsam_model.mask_decoder.parameters() | |
| ) | |
| optimizer = torch.optim.AdamW( | |
| img_mask_encdec_params, lr=args.lr, weight_decay=args.weight_decay | |
| ) | |
| print( | |
| "Number of image encoder and mask decoder parameters: ", | |
| sum(p.numel() for p in img_mask_encdec_params if p.requires_grad), | |
| ) # 93729252 | |
| seg_loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, reduction="mean") | |
| # cross entropy loss | |
| ce_loss = nn.BCEWithLogitsLoss(reduction="mean") | |
| # %% train | |
| num_epochs = args.num_epochs | |
| iter_num = 0 | |
| losses = [] | |
| best_loss = 1e10 | |
| train_dataset = NpyDataset(args.tr_npy_path) | |
| print("Number of training samples: ", len(train_dataset)) | |
| train_dataloader = DataLoader( | |
| train_dataset, | |
| batch_size=args.batch_size, | |
| shuffle=True, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| ) | |
| start_epoch = 0 | |
| if args.resume is not None: | |
| if os.path.isfile(args.resume): | |
| ## Map model to be loaded to specified single GPU | |
| checkpoint = torch.load(args.resume, map_location=device) | |
| start_epoch = checkpoint["epoch"] + 1 | |
| medsam_model.load_state_dict(checkpoint["model"]) | |
| optimizer.load_state_dict(checkpoint["optimizer"]) | |
| if args.use_amp: | |
| scaler = torch.cuda.amp.GradScaler() | |
| for epoch in range(start_epoch, num_epochs): | |
| epoch_loss = 0 | |
| for step, (image, gt2D, boxes, _) in enumerate(tqdm(train_dataloader)): | |
| optimizer.zero_grad() | |
| boxes_np = boxes.detach().cpu().numpy() | |
| image, gt2D = image.to(device), gt2D.to(device) | |
| if args.use_amp: | |
| ## AMP | |
| with torch.autocast(device_type="cuda", dtype=torch.float16): | |
| medsam_pred = medsam_model(image, boxes_np) | |
| loss = seg_loss(medsam_pred, gt2D) + ce_loss( | |
| medsam_pred, gt2D.float() | |
| ) | |
| scaler.scale(loss).backward() | |
| scaler.step(optimizer) | |
| scaler.update() | |
| optimizer.zero_grad() | |
| else: | |
| medsam_pred = medsam_model(image, boxes_np) | |
| loss = seg_loss(medsam_pred, gt2D) + ce_loss(medsam_pred, gt2D.float()) | |
| loss.backward() | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| epoch_loss += loss.item() | |
| iter_num += 1 | |
| epoch_loss /= step | |
| losses.append(epoch_loss) | |
| if args.use_wandb: | |
| wandb.log({"epoch_loss": epoch_loss}) | |
| print( | |
| f'Time: {datetime.now().strftime("%Y%m%d-%H%M")}, Epoch: {epoch}, Loss: {epoch_loss}' | |
| ) | |
| ## save the latest model | |
| checkpoint = { | |
| "model": medsam_model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "epoch": epoch, | |
| } | |
| torch.save(checkpoint, join(model_save_path, "medsam_model_latest.pth")) | |
| ## save the best model | |
| if epoch_loss < best_loss: | |
| best_loss = epoch_loss | |
| checkpoint = { | |
| "model": medsam_model.state_dict(), | |
| "optimizer": optimizer.state_dict(), | |
| "epoch": epoch, | |
| } | |
| torch.save(checkpoint, join(model_save_path, "medsam_model_best.pth")) | |
| # %% plot loss | |
| plt.plot(losses) | |
| plt.title("Dice + Cross Entropy Loss") | |
| plt.xlabel("Epoch") | |
| plt.ylabel("Loss") | |
| plt.savefig(join(model_save_path, args.task_name + "train_loss.png")) | |
| plt.close() | |
| if __name__ == "__main__": | |
| main() | |