MCP-MedSAM / train.py
Leo-Lyu's picture
Upload 13 files
1715fda verified
import os
import random
import monai
from os import listdir, makedirs
from os.path import join, exists, isfile, isdir, basename
from tqdm import tqdm
from time import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from datetime import datetime
from shutil import copyfile
from models import PromptEncoder, TwoWayTransformer, TinyViT, MaskDecoder_F4
import torch.nn.functional as F
import gc
from matplotlib import pyplot as plt
import argparse
from modality_npz_dataset import ModalityNpzDataset
torch.cuda.empty_cache()
os.environ["OMP_NUM_THREADS"] = "4" # export OMP_NUM_THREADS=4
os.environ["OPENBLAS_NUM_THREADS"] = "4" # export OPENBLAS_NUM_THREADS=4
os.environ["MKL_NUM_THREADS"] = "6" # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "4" # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "6" # export NUMEXPR_NUM_THREADS=6
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
setup_seed(2024)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--data_root",
type=str,
default="",
help="Path to the npy data root.")
parser.add_argument('--task_name', type=str, default='MedSAM-Lite-All')
parser.add_argument("--pretrained_checkpoint",
type=str,
default=None,
help="Path to the pretrained Lite-MedSAM checkpoint.")
parser.add_argument("--resume",
type=str,
default=None,
help="Path to the checkpoint to continue training.")
parser.add_argument(
"--work_dir",
type=str,
default="./work_dir",
help=
"Path to the working directory where checkpoints and logs will be saved."
)
parser.add_argument('--data_aug',
action='store_true',
default=False,
help='use data augmentation during training')
parser.add_argument("--num_epochs",
type=int,
default=25,
help="Number of epochs to train.")
parser.add_argument("--batch_size",
type=int,
default=16,
help="Batch size.")
parser.add_argument("--num_workers",
type=int,
default=8,
help="Number of workers for dataloader.")
parser.add_argument(
"--bbox_shift",
type=int,
default=5,
help="Perturbation to bounding box coordinates during training.")
parser.add_argument("-lr", type=float, default=2e-4, help="Learning rate.")
parser.add_argument("-weight_decay",
type=float,
default=0.001,
help="Weight decay.")
parser.add_argument("-iou_loss_weight",
type=float,
default=1.0,
help="Weight of IoU loss.")
parser.add_argument("-seg_loss_weight",
type=float,
default=1.0,
help="Weight of segmentation loss.")
parser.add_argument("-ce_loss_weight",
type=float,
default=1.0,
help="Weight of cross entropy loss.")
parser.add_argument("--sanity_check",
action="store_true",
default=True,
help="Whether to do sanity check for dataloading.")
args = parser.parse_args()
return args
def show_mask(mask, ax, random_color=True):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.45])], axis=0)
else:
color = np.array([251 / 255, 252 / 255, 30 / 255, 0.45])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle((x0, y0),
w,
h,
edgecolor='blue',
facecolor=(0, 0, 0, 0),
lw=2))
def show_points(points, ax):
for i, (x, y) in enumerate(points):
ax.scatter(x, y, color='red', s=10)
def cal_iou(result, reference):
intersection = torch.count_nonzero(torch.logical_and(result, reference),
dim=[i for i in range(1, result.ndim)])
union = torch.count_nonzero(torch.logical_or(result, reference),
dim=[i for i in range(1, result.ndim)])
iou = intersection.float() / union.float()
return iou.unsqueeze(1)
def sanity_check_dataset(args):
tr_dataset = ModalityNpzDataset(args.data_root, data_aug=True)
tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True)
for step, batch in enumerate(tr_dataloader):
# show the example
_, axs = plt.subplots(1, 2, figsize=(10, 10))
idx = random.randint(0, 4)
image = batch["image"]
gt = batch["gt2D"]
bboxes = batch["bboxes"]
names_temp = batch["image_name"]
axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
show_mask(gt[idx].cpu().squeeze().numpy(), axs[0])
show_box(bboxes[idx].numpy().squeeze(), axs[0])
axs[0].axis('off')
# set title
axs[0].set_title(names_temp[idx])
idx = random.randint(4, 7)
axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
show_mask(gt[idx].cpu().squeeze().numpy(), axs[1])
show_box(bboxes[idx].numpy().squeeze(), axs[1])
axs[1].axis('off')
# set title
axs[1].set_title(names_temp[idx])
plt.subplots_adjust(wspace=0.01, hspace=0)
plt.savefig(join(args.work_dir, 'Sanitycheck_DA.png'),
bbox_inches='tight',
dpi=300)
plt.close()
break
class MedSAM_Lite(nn.Module):
def __init__(
self,
image_encoder,
mask_decoder,
prompt_encoder,
):
super().__init__()
self.image_encoder = image_encoder
self.mask_decoder = mask_decoder
self.prompt_encoder = prompt_encoder
encoder_weight_file = "" # path for vision encoder (tiny vit) weights
self.image_encoder.load_state_dict(torch.load(encoder_weight_file))
def forward(self, image, points, boxes, masks, features, crops,
text_features, category_idx):
image_embedding = self.image_encoder(image)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=points,
boxes=boxes,
masks=masks,
features=features,
crops=crops,
text_features=text_features,
category_idx=category_idx)
low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec = self.mask_decoder(
image_embeddings=image_embedding, # (B, 256, 64, 64)
image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
) # (B, 1, 256, 256)
return low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec
@torch.no_grad()
def postprocess_masks(self, masks, new_size, original_size):
"""
Do cropping and resizing
"""
# Crop
masks = masks[:, :, :new_size[0], :new_size[1]]
# Resize
masks = F.interpolate(
masks,
size=(original_size[0], original_size[1]),
mode="bilinear",
align_corners=False,
)
return masks
def collate_fn(batch):
"""
Collate function for PyTorch DataLoader.
"""
batch_dict = {}
for key in batch[0].keys():
if key == "image_name" or key == "category_idx":
batch_dict[key] = [sample[key] for sample in batch]
else:
batch_dict[key] = torch.stack([sample[key] for sample in batch],
dim=0)
return batch_dict
if __name__ == "__main__":
args = get_args()
sanity_check_dataset(args)
run_id = datetime.now().strftime("%Y%m%d-%H%M")
print(f"Run ID: {run_id}")
model_save_path = join(args.work_dir, args.task_name + "-" + run_id)
makedirs(model_save_path, exist_ok=True)
copyfile(__file__,
join(model_save_path, run_id + "_" + os.path.basename(__file__)))
device = torch.device("cuda")
num_epochs = args.num_epochs
batch_size = args.batch_size
num_workers = args.num_workers
medsam_lite_image_encoder = TinyViT(
img_size=256,
in_chans=3,
embed_dims=[
64, ## (64, 256, 256)
128, ## (128, 128, 128)
160, ## (160, 64, 64)
320 ## (320, 64, 64)
],
depths=[2, 2, 6, 2],
num_heads=[2, 4, 5, 10],
window_sizes=[7, 7, 14, 7],
mlp_ratio=4.,
drop_rate=0.,
drop_path_rate=0.0,
use_checkpoint=False,
mbconv_expand_ratio=4.0,
local_conv_size=3,
layer_lr_decay=0.8)
medsam_lite_prompt_encoder = PromptEncoder(embed_dim=256,
image_embedding_size=(64, 64),
input_image_size=(256, 256),
mask_in_chans=16)
medsam_lite_mask_decoder = MaskDecoder_F4(
num_multimask_outputs=3,
transformer=TwoWayTransformer(
depth=2,
embedding_dim=256,
mlp_dim=2048,
num_heads=8,
),
modality=True,
contents=True,
transformer_dim=256,
iou_head_depth=3,
iou_head_hidden_dim=256,
)
medsam_lite_model = MedSAM_Lite(image_encoder=medsam_lite_image_encoder,
mask_decoder=medsam_lite_mask_decoder,
prompt_encoder=medsam_lite_prompt_encoder)
if args.resume is None and args.pretrained_checkpoint is not None:
## Load pretrained checkpoint if there's no checkpoint to resume from and there's a pretrained checkpoint
print(
f"Loading pretrained checkpoint from {args.pretrained_checkpoint}")
medsam_lite_checkpoint = torch.load(args.pretrained_checkpoint,
map_location="cpu")
medsam_lite_model.load_state_dict(medsam_lite_checkpoint["model"],
strict=True)
medsam_lite_model = medsam_lite_model.to(device)
medsam_lite_model.train()
print(
f"MedSAM Lite size: {sum(p.numel() for p in medsam_lite_model.parameters())}"
)
print('lr:', args.lr)
optimizer = optim.AdamW(
medsam_lite_model.parameters(),
lr=args.lr,
betas=(0.9, 0.999),
eps=1e-08,
weight_decay=args.weight_decay,
)
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
mode='min',
factor=0.9,
patience=5,
cooldown=0)
seg_loss = monai.losses.DiceLoss(sigmoid=True,
squared_pred=True,
reduction='mean')
bce_loss = nn.BCEWithLogitsLoss(reduction='mean')
iou_loss = nn.MSELoss(reduction='mean')
ce_loss = nn.CrossEntropyLoss(reduction='mean')
train_dataset = ModalityNpzDataset(data_root=args.data_root, data_aug=True)
train_loader = DataLoader(train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=True)
if args.resume is not None:
ckpt_folders = sorted(listdir(args.resume))
ckpt_folders = [
f for f in ckpt_folders
if (f.startswith(args.task_name)
and isfile(join(args.resume, f, 'medsam_lite_latest.pth')))
]
print('*' * 20)
print('existing ckpts in', args.resume, ckpt_folders)
# find the latest ckpt folders
time_strings = [
f.split(args.task_name + '-')[-1] for f in ckpt_folders
]
dates = [datetime.strptime(f, '%Y%m%d-%H%M') for f in time_strings]
latest_date = max(dates)
latest_ckpt = join(
args.work_dir,
args.task_name + '-' + latest_date.strftime('%Y%m%d-%H%M'),
'medsam_lite_latest.pth')
print('Loading from', latest_ckpt)
checkpoint = torch.load(latest_ckpt, map_location=device)
medsam_lite_model.module.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
start_epoch = checkpoint["epoch"] + 1
best_loss = checkpoint["loss"]
print(f"Loaded checkpoint from epoch {start_epoch}")
else:
start_epoch = 0
best_loss = 1e10
train_losses = []
epoch_times = []
print("Training")
for epoch in range(start_epoch, num_epochs):
if epoch == num_epochs - 1:
for param_group in optimizer.param_groups:
param_group['lr'] = 5e-5
epoch_loss = [1e10 for _ in range(len(train_loader))]
epoch_start_time = time()
pbar = tqdm(train_loader)
for step, batch in enumerate(pbar):
gc.collect()
torch.cuda.empty_cache()
image = batch["image"]
gt2D = batch["gt2D"]
boxes = batch["bboxes"]
coords = batch["coords"]
crops = batch["image_crop"]
features = batch["image_feature"]
text_features = batch["text_feature"]
class_idx = batch["category_idx"]
class_idx = torch.tensor(class_idx)
optimizer.zero_grad()
image, gt2D, boxes, coords, crops, features, text_features, class_idx = image.to(
device), gt2D.to(device), boxes.to(device), coords.to(
device), crops.to(device), features.to(
device), text_features.to(device), class_idx.to(device)
labels_torch = torch.ones(coords.shape[0]).long()
labels_torch = labels_torch.unsqueeze(1).expand(-1, 4)
labels_torch = labels_torch.to(device)
point_prompt = (coords, labels_torch)
logits_pred, iou_pred, category_predictions, clip_vec, img_vec = medsam_lite_model(
image, None, boxes, None, features, crops, text_features, class_idx)
clip_img_features = clip_vec / clip_vec.norm(dim=-1, keepdim=True)
img_features = img_vec / img_vec.norm(dim=-1, keepdim=True)
similarity1 = torch.matmul(clip_img_features, img_features.T)
similarity2 = torch.matmul(img_features, clip_img_features.T)
sim_labels = torch.arange(similarity1.shape[0]).to(image.device)
l_seg = seg_loss(logits_pred, gt2D)
l_bce = bce_loss(logits_pred, gt2D.float())
l_ce_sim = 0.5 * (ce_loss(similarity1, sim_labels.long()) +
ce_loss(similarity2, sim_labels.long()))
l_ce = ce_loss(category_predictions, class_idx.long())
mask_loss = l_seg + l_bce
with torch.no_grad():
iou_gt = cal_iou(torch.sigmoid(logits_pred) > 0.5, gt2D.bool())
l_iou = iou_loss(iou_pred, iou_gt)
loss = mask_loss + l_iou + 0.01 * l_ce_sim + 0.01 * l_ce
epoch_loss[step] = loss.item()
loss.backward()
optimizer.step()
optimizer.zero_grad()
pbar.set_description(
f"Epoch {epoch} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, loss: {loss.item():.4f}"
)
epoch_end_time = time()
epoch_duration = epoch_end_time - epoch_start_time
epoch_times.append(epoch_duration)
epoch_loss_reduced = sum(epoch_loss) / len(epoch_loss)
train_losses.append(epoch_loss_reduced)
lr_scheduler.step(epoch_loss_reduced)
model_weights = medsam_lite_model.state_dict()
checkpoint = {
"model": model_weights,
"epoch": epoch,
"optimizer": optimizer.state_dict(),
"loss": epoch_loss_reduced,
"best_loss": best_loss,
}
torch.save(checkpoint, join(model_save_path, "medsam_lite_latest.pth"))
if epoch_loss_reduced < best_loss:
print(
f"New best loss: {best_loss:.4f} -> {epoch_loss_reduced:.4f}")
best_loss = epoch_loss_reduced
checkpoint["best_loss"] = best_loss
torch.save(checkpoint, join(model_save_path,
"medsam_lite_best.pth"))
epoch_loss_reduced = 1e10
fig, axes = plt.subplots(2, 1, figsize=(10, 8))
axes[0].title.set_text("Dice + Binary Cross Entropy + IoU Loss")
axes[0].plot(train_losses)
axes[0].set_ylabel("Loss")
axes[1].plot(epoch_times)
axes[1].title.set_text("Epoch Duration")
axes[1].set_ylabel("Duration (s)")
axes[1].set_xlabel("Epoch")
plt.tight_layout()
plt.savefig(join(model_save_path, "log.png"))
plt.close()