MMedAgent_demo / src /MedSAM /train_one_gpu.py
YPan0's picture
Upload folder using huggingface_hub
b6deff2 verified
# -*- coding: utf-8 -*-
"""
train the image encoder and mask decoder
freeze prompt image encoder
"""
# %% setup environment
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
from tqdm import tqdm
from skimage import transform
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import monai
from segment_anything import sam_model_registry
import torch.nn.functional as F
import argparse
import random
from datetime import datetime
import shutil
import glob
# set seeds
torch.manual_seed(2023)
torch.cuda.empty_cache()
# torch.distributed.init_process_group(backend="gloo")
os.environ["OMP_NUM_THREADS"] = "4" # export OMP_NUM_THREADS=4
os.environ["OPENBLAS_NUM_THREADS"] = "4" # export OPENBLAS_NUM_THREADS=4
os.environ["MKL_NUM_THREADS"] = "6" # export MKL_NUM_THREADS=6
os.environ["VECLIB_MAXIMUM_THREADS"] = "4" # export VECLIB_MAXIMUM_THREADS=4
os.environ["NUMEXPR_NUM_THREADS"] = "6" # export NUMEXPR_NUM_THREADS=6
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([251 / 255, 252 / 255, 30 / 255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(
plt.Rectangle((x0, y0), w, h, edgecolor="blue", facecolor=(0, 0, 0, 0), lw=2)
)
class NpyDataset(Dataset):
def __init__(self, data_root, bbox_shift=20):
self.data_root = data_root
self.gt_path = join(data_root, "gts")
self.img_path = join(data_root, "imgs")
self.gt_path_files = sorted(
glob.glob(join(self.gt_path, "**/*.npy"), recursive=True)
)
self.gt_path_files = [
file
for file in self.gt_path_files
if os.path.isfile(join(self.img_path, os.path.basename(file)))
]
self.bbox_shift = bbox_shift
print(f"number of images: {len(self.gt_path_files)}")
def __len__(self):
return len(self.gt_path_files)
def __getitem__(self, index):
# load npy image (1024, 1024, 3), [0,1]
img_name = os.path.basename(self.gt_path_files[index])
img_1024 = np.load(
join(self.img_path, img_name), "r", allow_pickle=True
) # (1024, 1024, 3)
# convert the shape to (3, H, W)
img_1024 = np.transpose(img_1024, (2, 0, 1))
assert (
np.max(img_1024) <= 1.0 and np.min(img_1024) >= 0.0
), "image should be normalized to [0, 1]"
gt = np.load(
self.gt_path_files[index], "r", allow_pickle=True
) # multiple labels [0, 1,4,5...], (256,256)
assert img_name == os.path.basename(self.gt_path_files[index]), (
"img gt name error" + self.gt_path_files[index] + self.npy_files[index]
)
label_ids = np.unique(gt)[1:]
gt2D = np.uint8(
gt == random.choice(label_ids.tolist())
) # only one label, (256, 256)
assert np.max(gt2D) == 1 and np.min(gt2D) == 0.0, "ground truth should be 0, 1"
y_indices, x_indices = np.where(gt2D > 0)
x_min, x_max = np.min(x_indices), np.max(x_indices)
y_min, y_max = np.min(y_indices), np.max(y_indices)
# add perturbation to bounding box coordinates
H, W = gt2D.shape
x_min = max(0, x_min - random.randint(0, self.bbox_shift))
x_max = min(W, x_max + random.randint(0, self.bbox_shift))
y_min = max(0, y_min - random.randint(0, self.bbox_shift))
y_max = min(H, y_max + random.randint(0, self.bbox_shift))
bboxes = np.array([x_min, y_min, x_max, y_max])
return (
torch.tensor(img_1024).float(),
torch.tensor(gt2D[None, :, :]).long(),
torch.tensor(bboxes).float(),
img_name,
)
# %% sanity test of dataset class
tr_dataset = NpyDataset("data/npy/CT_Abd")
tr_dataloader = DataLoader(tr_dataset, batch_size=8, shuffle=True)
for step, (image, gt, bboxes, names_temp) in enumerate(tr_dataloader):
print(image.shape, gt.shape, bboxes.shape)
# show the example
_, axs = plt.subplots(1, 2, figsize=(25, 25))
idx = random.randint(0, 7)
axs[0].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
show_mask(gt[idx].cpu().numpy(), axs[0])
show_box(bboxes[idx].numpy(), axs[0])
axs[0].axis("off")
# set title
axs[0].set_title(names_temp[idx])
idx = random.randint(0, 7)
axs[1].imshow(image[idx].cpu().permute(1, 2, 0).numpy())
show_mask(gt[idx].cpu().numpy(), axs[1])
show_box(bboxes[idx].numpy(), axs[1])
axs[1].axis("off")
# set title
axs[1].set_title(names_temp[idx])
# plt.show()
plt.subplots_adjust(wspace=0.01, hspace=0)
plt.savefig("./data_sanitycheck.png", bbox_inches="tight", dpi=300)
plt.close()
break
# %% set up parser
parser = argparse.ArgumentParser()
parser.add_argument(
"-i",
"--tr_npy_path",
type=str,
default="data/npy/CT_Abd",
help="path to training npy files; two subfolders: gts and imgs",
)
parser.add_argument("-task_name", type=str, default="MedSAM-ViT-B")
parser.add_argument("-model_type", type=str, default="vit_b")
parser.add_argument(
"-checkpoint", type=str, default="work_dir/SAM/sam_vit_b_01ec64.pth"
)
# parser.add_argument('-device', type=str, default='cuda:0')
parser.add_argument(
"--load_pretrain", type=bool, default=True, help="load pretrain model"
)
parser.add_argument("-pretrain_model_path", type=str, default="")
parser.add_argument("-work_dir", type=str, default="./work_dir")
# train
parser.add_argument("-num_epochs", type=int, default=1000)
parser.add_argument("-batch_size", type=int, default=2)
parser.add_argument("-num_workers", type=int, default=0)
# Optimizer parameters
parser.add_argument(
"-weight_decay", type=float, default=0.01, help="weight decay (default: 0.01)"
)
parser.add_argument(
"-lr", type=float, default=0.0001, metavar="LR", help="learning rate (absolute lr)"
)
parser.add_argument(
"-use_wandb", type=bool, default=False, help="use wandb to monitor training"
)
parser.add_argument("-use_amp", action="store_true", default=False, help="use amp")
parser.add_argument(
"--resume", type=str, default="", help="Resuming training from checkpoint"
)
parser.add_argument("--device", type=str, default="cuda:0")
args = parser.parse_args()
if args.use_wandb:
import wandb
wandb.login()
wandb.init(
project=args.task_name,
config={
"lr": args.lr,
"batch_size": args.batch_size,
"data_path": args.tr_npy_path,
"model_type": args.model_type,
},
)
# %% set up model for training
# device = args.device
run_id = datetime.now().strftime("%Y%m%d-%H%M")
model_save_path = join(args.work_dir, args.task_name + "-" + run_id)
device = torch.device(args.device)
# %% set up model
class MedSAM(nn.Module):
def __init__(
self,
image_encoder,
mask_decoder,
prompt_encoder,
):
super().__init__()
self.image_encoder = image_encoder
self.mask_decoder = mask_decoder
self.prompt_encoder = prompt_encoder
# freeze prompt encoder
for param in self.prompt_encoder.parameters():
param.requires_grad = False
def forward(self, image, box):
image_embedding = self.image_encoder(image) # (B, 256, 64, 64)
# do not compute gradients for prompt encoder
with torch.no_grad():
box_torch = torch.as_tensor(box, dtype=torch.float32, device=image.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
sparse_embeddings, dense_embeddings = self.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_masks, _ = self.mask_decoder(
image_embeddings=image_embedding, # (B, 256, 64, 64)
image_pe=self.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
ori_res_masks = F.interpolate(
low_res_masks,
size=(image.shape[2], image.shape[3]),
mode="bilinear",
align_corners=False,
)
return ori_res_masks
def main():
os.makedirs(model_save_path, exist_ok=True)
shutil.copyfile(
__file__, join(model_save_path, run_id + "_" + os.path.basename(__file__))
)
sam_model = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
medsam_model = MedSAM(
image_encoder=sam_model.image_encoder,
mask_decoder=sam_model.mask_decoder,
prompt_encoder=sam_model.prompt_encoder,
).to(device)
medsam_model.train()
print(
"Number of total parameters: ",
sum(p.numel() for p in medsam_model.parameters()),
) # 93735472
print(
"Number of trainable parameters: ",
sum(p.numel() for p in medsam_model.parameters() if p.requires_grad),
) # 93729252
img_mask_encdec_params = list(medsam_model.image_encoder.parameters()) + list(
medsam_model.mask_decoder.parameters()
)
optimizer = torch.optim.AdamW(
img_mask_encdec_params, lr=args.lr, weight_decay=args.weight_decay
)
print(
"Number of image encoder and mask decoder parameters: ",
sum(p.numel() for p in img_mask_encdec_params if p.requires_grad),
) # 93729252
seg_loss = monai.losses.DiceLoss(sigmoid=True, squared_pred=True, reduction="mean")
# cross entropy loss
ce_loss = nn.BCEWithLogitsLoss(reduction="mean")
# %% train
num_epochs = args.num_epochs
iter_num = 0
losses = []
best_loss = 1e10
train_dataset = NpyDataset(args.tr_npy_path)
print("Number of training samples: ", len(train_dataset))
train_dataloader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
)
start_epoch = 0
if args.resume is not None:
if os.path.isfile(args.resume):
## Map model to be loaded to specified single GPU
checkpoint = torch.load(args.resume, map_location=device)
start_epoch = checkpoint["epoch"] + 1
medsam_model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
if args.use_amp:
scaler = torch.cuda.amp.GradScaler()
for epoch in range(start_epoch, num_epochs):
epoch_loss = 0
for step, (image, gt2D, boxes, _) in enumerate(tqdm(train_dataloader)):
optimizer.zero_grad()
boxes_np = boxes.detach().cpu().numpy()
image, gt2D = image.to(device), gt2D.to(device)
if args.use_amp:
## AMP
with torch.autocast(device_type="cuda", dtype=torch.float16):
medsam_pred = medsam_model(image, boxes_np)
loss = seg_loss(medsam_pred, gt2D) + ce_loss(
medsam_pred, gt2D.float()
)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
else:
medsam_pred = medsam_model(image, boxes_np)
loss = seg_loss(medsam_pred, gt2D) + ce_loss(medsam_pred, gt2D.float())
loss.backward()
optimizer.step()
optimizer.zero_grad()
epoch_loss += loss.item()
iter_num += 1
epoch_loss /= step
losses.append(epoch_loss)
if args.use_wandb:
wandb.log({"epoch_loss": epoch_loss})
print(
f'Time: {datetime.now().strftime("%Y%m%d-%H%M")}, Epoch: {epoch}, Loss: {epoch_loss}'
)
## save the latest model
checkpoint = {
"model": medsam_model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}
torch.save(checkpoint, join(model_save_path, "medsam_model_latest.pth"))
## save the best model
if epoch_loss < best_loss:
best_loss = epoch_loss
checkpoint = {
"model": medsam_model.state_dict(),
"optimizer": optimizer.state_dict(),
"epoch": epoch,
}
torch.save(checkpoint, join(model_save_path, "medsam_model_best.pth"))
# %% plot loss
plt.plot(losses)
plt.title("Dice + Cross Entropy Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.savefig(join(model_save_path, args.task_name + "train_loss.png"))
plt.close()
if __name__ == "__main__":
main()