|
|
import os |
|
|
import random |
|
|
import monai |
|
|
from os import listdir, makedirs |
|
|
from os.path import join, exists, isfile, isdir, basename |
|
|
from tqdm import tqdm |
|
|
from time import time |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import DataLoader |
|
|
from datetime import datetime |
|
|
from shutil import copyfile |
|
|
from models import PromptEncoder, TwoWayTransformer, TinyViT, MaskDecoder_F4 |
|
|
import torch.nn.functional as F |
|
|
import gc |
|
|
from matplotlib import pyplot as plt |
|
|
import argparse |
|
|
from modality_npz_dataset import ModalityNpzDataset |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
os.environ["OMP_NUM_THREADS"] = "4" |
|
|
os.environ["OPENBLAS_NUM_THREADS"] = "4" |
|
|
os.environ["MKL_NUM_THREADS"] = "6" |
|
|
os.environ["VECLIB_MAXIMUM_THREADS"] = "4" |
|
|
os.environ["NUMEXPR_NUM_THREADS"] = "6" |
|
|
|
|
|
def setup_seed(seed): |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
np.random.seed(seed) |
|
|
random.seed(seed) |
|
|
setup_seed(2024) |
|
|
|
|
|
def get_args(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--data_root", |
|
|
type=str, |
|
|
default="", |
|
|
help="Path to the npy data root.") |
|
|
|
|
|
parser.add_argument('--task_name', type=str, default='MedSAM-Lite-All') |
|
|
|
|
|
parser.add_argument("--pretrained_checkpoint", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to the pretrained Lite-MedSAM checkpoint.") |
|
|
|
|
|
parser.add_argument("--resume", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Path to the checkpoint to continue training.") |
|
|
parser.add_argument( |
|
|
"--work_dir", |
|
|
type=str, |
|
|
default="./work_dir", |
|
|
help= |
|
|
"Path to the working directory where checkpoints and logs will be saved." |
|
|
) |
|
|
|
|
|
parser.add_argument('--data_aug', |
|
|
action='store_true', |
|
|
default=False, |
|
|
help='use data augmentation during training') |
|
|
|
|
|
parser.add_argument("--num_epochs", |
|
|
type=int, |
|
|
default=25, |
|
|
help="Number of epochs to train.") |
|
|
parser.add_argument("--batch_size", |
|
|
type=int, |
|
|
default=16, |
|
|
help="Batch size.") |
|
|
parser.add_argument("--num_workers", |
|
|
type=int, |
|
|
default=8, |
|
|
help="Number of workers for dataloader.") |
|
|
|
|
|
parser.add_argument( |
|
|
"--bbox_shift", |
|
|
type=int, |
|
|
default=5, |
|
|
help="Perturbation to bounding box coordinates during training.") |
|
|
|
|
|
parser.add_argument("-lr", type=float, default=2e-4, help="Learning rate.") |
|
|
|
|
|
parser.add_argument("-weight_decay", |
|
|
type=float, |
|
|
default=0.001, |
|
|
help="Weight decay.") |
|
|
|
|
|
parser.add_argument("-iou_loss_weight", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="Weight of IoU loss.") |
|
|
|
|
|
parser.add_argument("-seg_loss_weight", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="Weight of segmentation loss.") |
|
|
parser.add_argument("-ce_loss_weight", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="Weight of cross entropy loss.") |
|
|
|
|
|
parser.add_argument("--sanity_check", |
|
|
action="store_true", |
|
|
default=True, |
|
|
help="Whether to do sanity check for dataloading.") |
|
|
|
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
def show_mask(mask, ax, random_color=True): |
|
|
if random_color: |
|
|
color = np.concatenate([np.random.random(3), np.array([0.45])], axis=0) |
|
|
else: |
|
|
color = np.array([251 / 255, 252 / 255, 30 / 255, 0.45]) |
|
|
h, w = mask.shape[-2:] |
|
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
|
ax.imshow(mask_image) |
|
|
|
|
|
|
|
|
def show_box(box, ax): |
|
|
x0, y0 = box[0], box[1] |
|
|
w, h = box[2] - box[0], box[3] - box[1] |
|
|
ax.add_patch( |
|
|
plt.Rectangle((x0, y0), |
|
|
w, |
|
|
h, |
|
|
edgecolor='blue', |
|
|
facecolor=(0, 0, 0, 0), |
|
|
lw=2)) |
|
|
|
|
|
|
|
|
def show_points(points, ax): |
|
|
for i, (x, y) in enumerate(points): |
|
|
ax.scatter(x, y, color='red', s=10) |
|
|
|
|
|
|
|
|
def cal_iou(result, reference): |
|
|
|
|
|
intersection = torch.count_nonzero(torch.logical_and(result, reference), |
|
|
dim=[i for i in range(1, result.ndim)]) |
|
|
union = torch.count_nonzero(torch.logical_or(result, reference), |
|
|
dim=[i for i in range(1, result.ndim)]) |
|
|
|
|
|
iou = intersection.float() / union.float() |
|
|
|
|
|
return iou.unsqueeze(1) |
|
|
|
|
|
|
|
|
def sanity_check_dataset(args): |
|
|
|
|
|
tr_dataset = ModalityNpzDataset(args.data_root, data_aug=True) |
|
|
tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True) |
|
|
|
|
|
for step, batch in enumerate(tr_dataloader): |
|
|
|
|
|
_, axs = plt.subplots(1, 2, figsize=(10, 10)) |
|
|
idx = random.randint(0, 4) |
|
|
|
|
|
image = batch["image"] |
|
|
gt = batch["gt2D"] |
|
|
bboxes = batch["bboxes"] |
|
|
names_temp = batch["image_name"] |
|
|
|
|
|
axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy()) |
|
|
show_mask(gt[idx].cpu().squeeze().numpy(), axs[0]) |
|
|
show_box(bboxes[idx].numpy().squeeze(), axs[0]) |
|
|
axs[0].axis('off') |
|
|
|
|
|
axs[0].set_title(names_temp[idx]) |
|
|
idx = random.randint(4, 7) |
|
|
axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy()) |
|
|
show_mask(gt[idx].cpu().squeeze().numpy(), axs[1]) |
|
|
show_box(bboxes[idx].numpy().squeeze(), axs[1]) |
|
|
axs[1].axis('off') |
|
|
|
|
|
axs[1].set_title(names_temp[idx]) |
|
|
plt.subplots_adjust(wspace=0.01, hspace=0) |
|
|
plt.savefig(join(args.work_dir, 'Sanitycheck_DA.png'), |
|
|
bbox_inches='tight', |
|
|
dpi=300) |
|
|
plt.close() |
|
|
break |
|
|
|
|
|
|
|
|
class MedSAM_Lite(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_encoder, |
|
|
mask_decoder, |
|
|
prompt_encoder, |
|
|
): |
|
|
super().__init__() |
|
|
self.image_encoder = image_encoder |
|
|
self.mask_decoder = mask_decoder |
|
|
self.prompt_encoder = prompt_encoder |
|
|
encoder_weight_file = "" |
|
|
|
|
|
self.image_encoder.load_state_dict(torch.load(encoder_weight_file)) |
|
|
|
|
|
def forward(self, image, points, boxes, masks, features, crops, |
|
|
text_features, category_idx): |
|
|
image_embedding = self.image_encoder(image) |
|
|
|
|
|
sparse_embeddings, dense_embeddings = self.prompt_encoder( |
|
|
points=points, |
|
|
boxes=boxes, |
|
|
masks=masks, |
|
|
features=features, |
|
|
crops=crops, |
|
|
text_features=text_features, |
|
|
category_idx=category_idx) |
|
|
|
|
|
low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec = self.mask_decoder( |
|
|
image_embeddings=image_embedding, |
|
|
image_pe=self.prompt_encoder.get_dense_pe(), |
|
|
sparse_prompt_embeddings=sparse_embeddings, |
|
|
dense_prompt_embeddings=dense_embeddings, |
|
|
multimask_output=False, |
|
|
) |
|
|
|
|
|
return low_res_masks, iou_predictions, category_predictions, clip_vec, img_vec |
|
|
|
|
|
@torch.no_grad() |
|
|
def postprocess_masks(self, masks, new_size, original_size): |
|
|
""" |
|
|
Do cropping and resizing |
|
|
""" |
|
|
|
|
|
masks = masks[:, :, :new_size[0], :new_size[1]] |
|
|
|
|
|
masks = F.interpolate( |
|
|
masks, |
|
|
size=(original_size[0], original_size[1]), |
|
|
mode="bilinear", |
|
|
align_corners=False, |
|
|
) |
|
|
|
|
|
return masks |
|
|
|
|
|
|
|
|
def collate_fn(batch): |
|
|
""" |
|
|
Collate function for PyTorch DataLoader. |
|
|
""" |
|
|
batch_dict = {} |
|
|
for key in batch[0].keys(): |
|
|
if key == "image_name" or key == "category_idx": |
|
|
batch_dict[key] = [sample[key] for sample in batch] |
|
|
else: |
|
|
batch_dict[key] = torch.stack([sample[key] for sample in batch], |
|
|
dim=0) |
|
|
|
|
|
return batch_dict |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
args = get_args() |
|
|
sanity_check_dataset(args) |
|
|
|
|
|
run_id = datetime.now().strftime("%Y%m%d-%H%M") |
|
|
print(f"Run ID: {run_id}") |
|
|
|
|
|
model_save_path = join(args.work_dir, args.task_name + "-" + run_id) |
|
|
makedirs(model_save_path, exist_ok=True) |
|
|
copyfile(__file__, |
|
|
join(model_save_path, run_id + "_" + os.path.basename(__file__))) |
|
|
|
|
|
device = torch.device("cuda") |
|
|
|
|
|
num_epochs = args.num_epochs |
|
|
batch_size = args.batch_size |
|
|
num_workers = args.num_workers |
|
|
|
|
|
medsam_lite_image_encoder = TinyViT( |
|
|
img_size=256, |
|
|
in_chans=3, |
|
|
embed_dims=[ |
|
|
64, |
|
|
128, |
|
|
160, |
|
|
320 |
|
|
], |
|
|
depths=[2, 2, 6, 2], |
|
|
num_heads=[2, 4, 5, 10], |
|
|
window_sizes=[7, 7, 14, 7], |
|
|
mlp_ratio=4., |
|
|
drop_rate=0., |
|
|
drop_path_rate=0.0, |
|
|
use_checkpoint=False, |
|
|
mbconv_expand_ratio=4.0, |
|
|
local_conv_size=3, |
|
|
layer_lr_decay=0.8) |
|
|
|
|
|
medsam_lite_prompt_encoder = PromptEncoder(embed_dim=256, |
|
|
image_embedding_size=(64, 64), |
|
|
input_image_size=(256, 256), |
|
|
mask_in_chans=16) |
|
|
|
|
|
medsam_lite_mask_decoder = MaskDecoder_F4( |
|
|
num_multimask_outputs=3, |
|
|
transformer=TwoWayTransformer( |
|
|
depth=2, |
|
|
embedding_dim=256, |
|
|
mlp_dim=2048, |
|
|
num_heads=8, |
|
|
), |
|
|
modality=True, |
|
|
contents=True, |
|
|
transformer_dim=256, |
|
|
iou_head_depth=3, |
|
|
iou_head_hidden_dim=256, |
|
|
) |
|
|
|
|
|
medsam_lite_model = MedSAM_Lite(image_encoder=medsam_lite_image_encoder, |
|
|
mask_decoder=medsam_lite_mask_decoder, |
|
|
prompt_encoder=medsam_lite_prompt_encoder) |
|
|
|
|
|
if args.resume is None and args.pretrained_checkpoint is not None: |
|
|
|
|
|
print( |
|
|
f"Loading pretrained checkpoint from {args.pretrained_checkpoint}") |
|
|
medsam_lite_checkpoint = torch.load(args.pretrained_checkpoint, |
|
|
map_location="cpu") |
|
|
medsam_lite_model.load_state_dict(medsam_lite_checkpoint["model"], |
|
|
strict=True) |
|
|
|
|
|
medsam_lite_model = medsam_lite_model.to(device) |
|
|
|
|
|
medsam_lite_model.train() |
|
|
|
|
|
print( |
|
|
f"MedSAM Lite size: {sum(p.numel() for p in medsam_lite_model.parameters())}" |
|
|
) |
|
|
|
|
|
print('lr:', args.lr) |
|
|
|
|
|
optimizer = optim.AdamW( |
|
|
medsam_lite_model.parameters(), |
|
|
lr=args.lr, |
|
|
betas=(0.9, 0.999), |
|
|
eps=1e-08, |
|
|
weight_decay=args.weight_decay, |
|
|
) |
|
|
lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, |
|
|
mode='min', |
|
|
factor=0.9, |
|
|
patience=5, |
|
|
cooldown=0) |
|
|
seg_loss = monai.losses.DiceLoss(sigmoid=True, |
|
|
squared_pred=True, |
|
|
reduction='mean') |
|
|
bce_loss = nn.BCEWithLogitsLoss(reduction='mean') |
|
|
iou_loss = nn.MSELoss(reduction='mean') |
|
|
ce_loss = nn.CrossEntropyLoss(reduction='mean') |
|
|
|
|
|
train_dataset = ModalityNpzDataset(data_root=args.data_root, data_aug=True) |
|
|
|
|
|
train_loader = DataLoader(train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=num_workers, |
|
|
pin_memory=True) |
|
|
|
|
|
if args.resume is not None: |
|
|
ckpt_folders = sorted(listdir(args.resume)) |
|
|
ckpt_folders = [ |
|
|
f for f in ckpt_folders |
|
|
if (f.startswith(args.task_name) |
|
|
and isfile(join(args.resume, f, 'medsam_lite_latest.pth'))) |
|
|
] |
|
|
print('*' * 20) |
|
|
print('existing ckpts in', args.resume, ckpt_folders) |
|
|
|
|
|
time_strings = [ |
|
|
f.split(args.task_name + '-')[-1] for f in ckpt_folders |
|
|
] |
|
|
dates = [datetime.strptime(f, '%Y%m%d-%H%M') for f in time_strings] |
|
|
latest_date = max(dates) |
|
|
latest_ckpt = join( |
|
|
args.work_dir, |
|
|
args.task_name + '-' + latest_date.strftime('%Y%m%d-%H%M'), |
|
|
'medsam_lite_latest.pth') |
|
|
print('Loading from', latest_ckpt) |
|
|
checkpoint = torch.load(latest_ckpt, map_location=device) |
|
|
medsam_lite_model.module.load_state_dict(checkpoint["model"]) |
|
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
|
start_epoch = checkpoint["epoch"] + 1 |
|
|
best_loss = checkpoint["loss"] |
|
|
print(f"Loaded checkpoint from epoch {start_epoch}") |
|
|
else: |
|
|
start_epoch = 0 |
|
|
best_loss = 1e10 |
|
|
|
|
|
train_losses = [] |
|
|
epoch_times = [] |
|
|
|
|
|
print("Training") |
|
|
for epoch in range(start_epoch, num_epochs): |
|
|
if epoch == num_epochs - 1: |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = 5e-5 |
|
|
|
|
|
epoch_loss = [1e10 for _ in range(len(train_loader))] |
|
|
epoch_start_time = time() |
|
|
pbar = tqdm(train_loader) |
|
|
for step, batch in enumerate(pbar): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
image = batch["image"] |
|
|
gt2D = batch["gt2D"] |
|
|
boxes = batch["bboxes"] |
|
|
coords = batch["coords"] |
|
|
crops = batch["image_crop"] |
|
|
features = batch["image_feature"] |
|
|
text_features = batch["text_feature"] |
|
|
class_idx = batch["category_idx"] |
|
|
class_idx = torch.tensor(class_idx) |
|
|
|
|
|
optimizer.zero_grad() |
|
|
image, gt2D, boxes, coords, crops, features, text_features, class_idx = image.to( |
|
|
device), gt2D.to(device), boxes.to(device), coords.to( |
|
|
device), crops.to(device), features.to( |
|
|
device), text_features.to(device), class_idx.to(device) |
|
|
labels_torch = torch.ones(coords.shape[0]).long() |
|
|
labels_torch = labels_torch.unsqueeze(1).expand(-1, 4) |
|
|
labels_torch = labels_torch.to(device) |
|
|
point_prompt = (coords, labels_torch) |
|
|
logits_pred, iou_pred, category_predictions, clip_vec, img_vec = medsam_lite_model( |
|
|
image, None, boxes, None, features, crops, text_features, class_idx) |
|
|
|
|
|
clip_img_features = clip_vec / clip_vec.norm(dim=-1, keepdim=True) |
|
|
img_features = img_vec / img_vec.norm(dim=-1, keepdim=True) |
|
|
similarity1 = torch.matmul(clip_img_features, img_features.T) |
|
|
similarity2 = torch.matmul(img_features, clip_img_features.T) |
|
|
sim_labels = torch.arange(similarity1.shape[0]).to(image.device) |
|
|
|
|
|
l_seg = seg_loss(logits_pred, gt2D) |
|
|
l_bce = bce_loss(logits_pred, gt2D.float()) |
|
|
l_ce_sim = 0.5 * (ce_loss(similarity1, sim_labels.long()) + |
|
|
ce_loss(similarity2, sim_labels.long())) |
|
|
l_ce = ce_loss(category_predictions, class_idx.long()) |
|
|
mask_loss = l_seg + l_bce |
|
|
with torch.no_grad(): |
|
|
iou_gt = cal_iou(torch.sigmoid(logits_pred) > 0.5, gt2D.bool()) |
|
|
l_iou = iou_loss(iou_pred, iou_gt) |
|
|
loss = mask_loss + l_iou + 0.01 * l_ce_sim + 0.01 * l_ce |
|
|
epoch_loss[step] = loss.item() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
pbar.set_description( |
|
|
f"Epoch {epoch} at {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}, loss: {loss.item():.4f}" |
|
|
) |
|
|
|
|
|
epoch_end_time = time() |
|
|
epoch_duration = epoch_end_time - epoch_start_time |
|
|
epoch_times.append(epoch_duration) |
|
|
|
|
|
epoch_loss_reduced = sum(epoch_loss) / len(epoch_loss) |
|
|
|
|
|
train_losses.append(epoch_loss_reduced) |
|
|
lr_scheduler.step(epoch_loss_reduced) |
|
|
|
|
|
model_weights = medsam_lite_model.state_dict() |
|
|
|
|
|
checkpoint = { |
|
|
"model": model_weights, |
|
|
"epoch": epoch, |
|
|
"optimizer": optimizer.state_dict(), |
|
|
"loss": epoch_loss_reduced, |
|
|
"best_loss": best_loss, |
|
|
} |
|
|
torch.save(checkpoint, join(model_save_path, "medsam_lite_latest.pth")) |
|
|
|
|
|
if epoch_loss_reduced < best_loss: |
|
|
print( |
|
|
f"New best loss: {best_loss:.4f} -> {epoch_loss_reduced:.4f}") |
|
|
best_loss = epoch_loss_reduced |
|
|
checkpoint["best_loss"] = best_loss |
|
|
torch.save(checkpoint, join(model_save_path, |
|
|
"medsam_lite_best.pth")) |
|
|
epoch_loss_reduced = 1e10 |
|
|
|
|
|
fig, axes = plt.subplots(2, 1, figsize=(10, 8)) |
|
|
axes[0].title.set_text("Dice + Binary Cross Entropy + IoU Loss") |
|
|
axes[0].plot(train_losses) |
|
|
axes[0].set_ylabel("Loss") |
|
|
axes[1].plot(epoch_times) |
|
|
axes[1].title.set_text("Epoch Duration") |
|
|
axes[1].set_ylabel("Duration (s)") |
|
|
axes[1].set_xlabel("Epoch") |
|
|
plt.tight_layout() |
|
|
plt.savefig(join(model_save_path, "log.png")) |
|
|
plt.close() |
|
|
|