BoSAM / train.py
ziyanlu's picture
Upload folder using huggingface_hub
9859ea2 verified
# set up environment
import numpy as np
import random
import datetime
import logging
import matplotlib.pyplot as plt
import os
join = os.path.join
from tqdm import tqdm
from torch.backends import cudnn
import torch
import torch.distributed as dist
import torch.nn.functional as F
import torchio as tio
from torch.utils.data.distributed import DistributedSampler
from segment_anything.build_sam3D import sam_model_registry3D
import argparse
from torch.cuda import amp
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from monai.losses import DiceCELoss
from contextlib import nullcontext
from utils.click_method import get_next_click3D_torch_2
from utils.data_loader import Dataset_Union_ALL, Union_Dataloader
from utils.data_paths import img_datas
# %% set up parser
parser = argparse.ArgumentParser()
parser.add_argument('--task_name', type=str, default='union_train')
parser.add_argument('--click_type', type=str, default='random')
parser.add_argument('--multi_click', action='store_true', default=False)
parser.add_argument('--model_type', type=str, default='vit_b_ori')
parser.add_argument('--checkpoint', type=str, default='ckpt/sam_med3d.pth')
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--work_dir', type=str, default='work_dir')
# train
parser.add_argument('--num_workers', type=int, default=24)
parser.add_argument('--gpu_ids', type=int, nargs='+', default=[0,1])
parser.add_argument('--multi_gpu', action='store_true', default=False)
parser.add_argument('--resume', action='store_true', default=False)
parser.add_argument('--allow_partial_weight', action='store_true', default=False)
# lr_scheduler
parser.add_argument('--lr_scheduler', type=str, default='multisteplr')
parser.add_argument('--step_size', type=list, default=[120, 180])
parser.add_argument('--gamma', type=float, default=0.1)
parser.add_argument('--num_epochs', type=int, default=200)
parser.add_argument('--img_size', type=int, default=128)
parser.add_argument('--batch_size', type=int, default=12)
parser.add_argument('--accumulation_steps', type=int, default=20)
parser.add_argument('--lr', type=float, default=8e-4)
parser.add_argument('--weight_decay', type=float, default=0.1)
parser.add_argument('--port', type=int, default=12361)
args = parser.parse_args()
device = args.device
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(i) for i in args.gpu_ids])
logger = logging.getLogger(__name__)
LOG_OUT_DIR = join(args.work_dir, args.task_name)
click_methods = {
'random': get_next_click3D_torch_2,
}
MODEL_SAVE_PATH = join(args.work_dir, args.task_name)
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)
def build_model(args):
sam_model = sam_model_registry3D[args.model_type](checkpoint=None).to(device)
if args.multi_gpu:
sam_model = DDP(sam_model, device_ids=[args.rank], output_device=args.rank)
return sam_model
def get_dataloaders(args):
train_dataset = Dataset_Union_ALL(paths=img_datas, transform=tio.Compose([
tio.ToCanonical(),
tio.CropOrPad(mask_name='label', target_shape=(args.img_size,args.img_size,args.img_size)), # crop only object region
tio.RandomFlip(axes=(0, 1, 2)),
]),
threshold=1000)
if args.multi_gpu:
train_sampler = DistributedSampler(train_dataset)
shuffle = False
else:
train_sampler = None
shuffle = True
#train_dataloader = tio.SubjectsLoader(
train_dataloader = Union_Dataloader(
dataset=train_dataset,
sampler=train_sampler,
batch_size=args.batch_size,
shuffle=shuffle,
num_workers=args.num_workers,
pin_memory=True,
)
return train_dataloader
class BaseTrainer:
def __init__(self, model, dataloaders, args):
self.model = model
self.dataloaders = dataloaders
self.args = args
self.best_loss = np.inf
self.best_dice = 0.0
self.step_best_loss = np.inf
self.step_best_dice = 0.0
self.losses = []
self.dices = []
self.ious = []
self.set_loss_fn()
self.set_optimizer()
self.set_lr_scheduler()
if(args.resume):
self.init_checkpoint(join(self.args.work_dir, self.args.task_name, 'sam_model_latest.pth'))
else:
self.init_checkpoint(self.args.checkpoint)
self.norm_transform = tio.ZNormalization(masking_method=lambda x: x > 0)
def set_loss_fn(self):
self.seg_loss = DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')
def set_optimizer(self):
if self.args.multi_gpu:
sam_model = self.model.module
else:
sam_model = self.model
self.optimizer = torch.optim.AdamW([
{'params': sam_model.image_encoder.parameters()}, # , 'lr': self.args.lr * 0.1},
{'params': sam_model.prompt_encoder.parameters() , 'lr': self.args.lr * 0.1},
{'params': sam_model.mask_decoder.parameters(), 'lr': self.args.lr * 0.1},
], lr=self.args.lr, betas=(0.9,0.999), weight_decay=self.args.weight_decay)
def set_lr_scheduler(self):
if self.args.lr_scheduler == "multisteplr":
self.lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,
self.args.step_size,
self.args.gamma)
elif self.args.lr_scheduler == "steplr":
self.lr_scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer,
self.args.step_size[0],
self.args.gamma)
elif self.args.lr_scheduler == 'coswarm':
self.lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer)
else:
self.lr_scheduler = torch.optim.lr_scheduler.LinearLR(self.optimizer, 0.1)
def init_checkpoint(self, ckp_path):
last_ckpt = None
if os.path.exists(ckp_path):
if self.args.multi_gpu:
dist.barrier()
last_ckpt = torch.load(ckp_path, map_location=self.args.device)
else:
last_ckpt = torch.load(ckp_path, map_location=self.args.device)
if last_ckpt:
if(self.args.allow_partial_weight):
if self.args.multi_gpu:
self.model.module.load_state_dict(last_ckpt['model_state_dict'], strict=False)
else:
self.model.load_state_dict(last_ckpt['model_state_dict'], strict=False)
else:
if self.args.multi_gpu:
self.model.module.load_state_dict(last_ckpt['model_state_dict'])
else:
self.model.load_state_dict(last_ckpt['model_state_dict'])
if not self.args.resume:
self.start_epoch = 0
else:
self.start_epoch = last_ckpt['epoch']
self.optimizer.load_state_dict(last_ckpt['optimizer_state_dict'])
self.lr_scheduler.load_state_dict(last_ckpt['lr_scheduler_state_dict'])
self.losses = last_ckpt['losses']
self.dices = last_ckpt['dices']
self.best_loss = last_ckpt['best_loss']
self.best_dice = last_ckpt['best_dice']
print(f"Loaded checkpoint from {ckp_path} (epoch {self.start_epoch})")
else:
self.start_epoch = 0
print(f"No checkpoint found at {ckp_path}, start training from scratch")
def save_checkpoint(self, epoch, state_dict, describe="last"):
torch.save({
"epoch": epoch + 1,
"model_state_dict": state_dict,
"optimizer_state_dict": self.optimizer.state_dict(),
"lr_scheduler_state_dict": self.lr_scheduler.state_dict(),
"losses": self.losses,
"dices": self.dices,
"best_loss": self.best_loss,
"best_dice": self.best_dice,
"args": self.args,
"used_datas": img_datas,
}, join(MODEL_SAVE_PATH, f"sam_model_{describe}.pth"))
def batch_forward(self, sam_model, image_embedding, gt3D, low_res_masks, points=None):
sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
points=points,
boxes=None,
masks=low_res_masks,
)
low_res_masks, iou_predictions = sam_model.mask_decoder(
image_embeddings=image_embedding.to(device), # (B, 256, 64, 64)
image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
prev_masks = F.interpolate(low_res_masks, size=gt3D.shape[-3:], mode='trilinear', align_corners=False)
return low_res_masks, prev_masks
def get_points(self, prev_masks, gt3D):
batch_points, batch_labels = click_methods[self.args.click_type](prev_masks, gt3D)
points_co = torch.cat(batch_points, dim=0).to(device)
points_la = torch.cat(batch_labels, dim=0).to(device)
self.click_points.append(points_co)
self.click_labels.append(points_la)
points_multi = torch.cat(self.click_points, dim=1).to(device)
labels_multi = torch.cat(self.click_labels, dim=1).to(device)
if self.args.multi_click:
points_input = points_multi
labels_input = labels_multi
else:
points_input = points_co
labels_input = points_la
return points_input, labels_input
def interaction(self, sam_model, image_embedding, gt3D, num_clicks):
return_loss = 0
prev_masks = torch.zeros_like(gt3D).to(gt3D.device)
low_res_masks = F.interpolate(prev_masks.float(), size=(args.img_size//4,args.img_size//4,args.img_size//4))
random_insert = np.random.randint(2, 9)
for num_click in range(num_clicks):
points_input, labels_input = self.get_points(prev_masks, gt3D)
if num_click == random_insert or num_click == num_clicks - 1:
low_res_masks, prev_masks = self.batch_forward(sam_model, image_embedding, gt3D, low_res_masks, points=None)
else:
low_res_masks, prev_masks = self.batch_forward(sam_model, image_embedding, gt3D, low_res_masks, points=[points_input, labels_input])
loss = self.seg_loss(prev_masks, gt3D)
return_loss += loss
return prev_masks, return_loss
def get_dice_score(self, prev_masks, gt3D):
def compute_dice(mask_pred, mask_gt):
mask_threshold = 0.5
mask_pred = (mask_pred > mask_threshold)
mask_gt = (mask_gt > 0)
volume_sum = mask_gt.sum() + mask_pred.sum()
if volume_sum == 0:
return np.NaN
volume_intersect = (mask_gt & mask_pred).sum()
return 2*volume_intersect / volume_sum
pred_masks = (prev_masks > 0.5)
true_masks = (gt3D > 0)
dice_list = []
for i in range(true_masks.shape[0]):
dice_list.append(compute_dice(pred_masks[i], true_masks[i]))
return (sum(dice_list)/len(dice_list)).item()
def train_epoch(self, epoch, num_clicks):
epoch_loss = 0
epoch_iou = 0
self.model.train()
if self.args.multi_gpu:
sam_model = self.model.module
else:
sam_model = self.model
self.args.rank = -1
if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0):
tbar = tqdm(self.dataloaders)
else:
tbar = self.dataloaders
self.optimizer.zero_grad()
step_loss = 0
epoch_dice = 0
for step, data3D in enumerate(tbar):
try:
image3D, gt3D = data3D["image"], data3D["label"]
except Exception as e:
print(f"Error processing batch at step {step}: {e}")
#import pdb; pdb.set_trace()
my_context = self.model.no_sync if self.args.rank != -1 and step % self.args.accumulation_steps != 0 else nullcontext
with my_context():
image3D = self.norm_transform(image3D.squeeze(dim=1)) # (N, C, W, H, D)
image3D = image3D.unsqueeze(dim=1)
image3D = image3D.to(device)
gt3D = gt3D.to(device).type(torch.long)
with torch.amp.autocast("cuda"):
image_embedding = sam_model.image_encoder(image3D)
self.click_points = []
self.click_labels = []
pred_list = []
prev_masks, loss = self.interaction(sam_model, image_embedding, gt3D, num_clicks=11)
epoch_loss += loss.item()
epoch_dice += self.get_dice_score(prev_masks,gt3D)
cur_loss = loss.item()
loss /= self.args.accumulation_steps
self.scaler.scale(loss).backward()
if step % self.args.accumulation_steps == 0 and step != 0:
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad()
print_loss = step_loss / self.args.accumulation_steps
step_loss = 0
print_dice = self.get_dice_score(prev_masks, gt3D)
else:
step_loss += cur_loss
if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0):
if step % self.args.accumulation_steps == 0 and step != 0:
print(f'Epoch: {epoch}, Step: {step}, Loss: {print_loss}, Dice: {print_dice}')
if print_dice > self.step_best_dice:
self.step_best_dice = print_dice
if print_dice > 0.9:
self.save_checkpoint(
epoch,
sam_model.state_dict(),
describe=f'{epoch}_step_dice:{print_dice}_best'
)
if print_loss < self.step_best_loss:
self.step_best_loss = print_loss
epoch_loss /= step+1
epoch_dice /= step+1
return epoch_loss, epoch_iou, epoch_dice, pred_list
def eval_epoch(self, epoch, num_clicks):
return 0
def plot_result(self, plot_data, description, save_name):
plt.plot(plot_data)
plt.title(description)
plt.xlabel('Epoch')
plt.ylabel(f'{save_name}')
plt.savefig(join(MODEL_SAVE_PATH, f'{save_name}.png'))
plt.close()
def train(self):
self.scaler = torch.amp.GradScaler("cuda")
for epoch in range(self.start_epoch, self.args.num_epochs):
print(f'Epoch: {epoch}/{self.args.num_epochs - 1}')
if self.args.multi_gpu:
dist.barrier()
self.dataloaders.sampler.set_epoch(epoch)
num_clicks = np.random.randint(1, 21)
epoch_loss, epoch_iou, epoch_dice, pred_list = self.train_epoch(epoch, num_clicks)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
if self.args.multi_gpu:
dist.barrier()
if not self.args.multi_gpu or (self.args.multi_gpu and self.args.rank == 0):
self.losses.append(epoch_loss)
self.dices.append(epoch_dice)
print(f'EPOCH: {epoch}, Loss: {epoch_loss}')
print(f'EPOCH: {epoch}, Dice: {epoch_dice}')
logger.info(f'Epoch\t {epoch}\t : loss: {epoch_loss}, dice: {epoch_dice}')
if self.args.multi_gpu:
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
# save latest checkpoint
self.save_checkpoint(
epoch,
state_dict,
describe='latest'
)
# save train loss best checkpoint
if epoch_loss < self.best_loss:
self.best_loss = epoch_loss
self.save_checkpoint(
epoch,
state_dict,
describe='loss_best'
)
# save train dice best checkpoint
if epoch_dice > self.best_dice:
self.best_dice = epoch_dice
self.save_checkpoint(
epoch,
state_dict,
describe='dice_best'
)
self.plot_result(self.losses, 'Dice + Cross Entropy Loss', 'Loss')
self.plot_result(self.dices, 'Dice', 'Dice')
logger.info('=====================================================================')
logger.info(f'Best loss: {self.best_loss}')
logger.info(f'Best dice: {self.best_dice}')
logger.info(f'Total loss: {self.losses}')
logger.info(f'Total dice: {self.dices}')
logger.info('=====================================================================')
logger.info(f'args : {self.args}')
logger.info(f'Used datasets : {img_datas}')
logger.info('=====================================================================')
def init_seeds(seed=0, cuda_deterministic=True):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# Speed-reproducibility tradeoff https://pytorch.org/docs/stable/notes/randomness.html
if cuda_deterministic: # slower, more reproducible
cudnn.deterministic = True
cudnn.benchmark = False
else: # faster, less reproducible
cudnn.deterministic = False
cudnn.benchmark = True
def device_config(args):
try:
if not args.multi_gpu:
# Single GPU
if args.device == 'mps':
args.device = torch.device('mps')
else:
args.device = torch.device(f"cuda:{args.gpu_ids[0]}")
else:
args.nodes = 1
args.ngpus_per_node = len(args.gpu_ids)
args.world_size = args.nodes * args.ngpus_per_node
except RuntimeError as e:
print(e)
def main():
mp.set_sharing_strategy('file_system')
device_config(args)
if args.multi_gpu:
mp.spawn(
main_worker,
nprocs=args.world_size,
args=(args, )
)
else:
random.seed(2023)
np.random.seed(2023)
torch.manual_seed(2023)
# Load datasets
dataloaders = get_dataloaders(args)
# Build model
model = build_model(args)
# Create trainer
trainer = BaseTrainer(model, dataloaders, args)
# Train
trainer.train()
def main_worker(rank, args):
setup(rank, args.world_size)
torch.cuda.set_device(rank)
args.num_workers = int(args.num_workers / args.ngpus_per_node)
args.device = torch.device(f"cuda:{rank}")
args.rank = rank
init_seeds(2023 + rank)
cur_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
logging.basicConfig(
format='[%(asctime)s] - %(message)s',
datefmt='%Y/%m/%d %H:%M:%S',
level=logging.INFO if rank in [-1, 0] else logging.WARN,
filemode='w',
filename=os.path.join(LOG_OUT_DIR, f'output_{cur_time}.log'))
dataloaders = get_dataloaders(args)
model = build_model(args)
trainer = BaseTrainer(model, dataloaders, args)
trainer.train()
cleanup()
def setup(rank, world_size):
# initialize the process group
dist.init_process_group(
backend='nccl',
init_method=f'tcp://127.0.0.1:{args.port}',
world_size=world_size,
rank=rank
)
def cleanup():
dist.destroy_process_group()
if __name__ == '__main__':
main()