CellPilot / SAMHI /scripts /train.py
philippendres's picture
Upload folder using huggingface_hub
907462b verified
Raw
History Blame Contribute Delete
11.1 kB
import argparse
import wandb
import datetime
import os
from samhi.modeling.model import SamHI
from samhi.data_processing.dataset import prepare_data
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
import torch
time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
# Define arguments
parser = argparse.ArgumentParser()
parser.add_argument("--cluster", type=str, choices=["denbi", "helmholtz", "custom"], default="denbi")
parser.add_argument("--num_workers", type=int, default=5)
# W&B parameters
parser.add_argument("--project_name", type=str, default="SAMHI")
parser.add_argument("--entity", type=str, default="philippresearch")
# Model Info
parser.add_argument("--model_type", type=str, choices=["vit_b", "vit_l", "vit_h"], default="vit_b")
parser.add_argument("--model_dir", type=str, default="/vol/data/models/")
parser.add_argument("--base_model", type=str, default="sam_vit_b_01ec64.pth")
parser.add_argument("--loss", type=str, choices=["diceCE", "diceFocal", "dice", "generalized_dice", "generalized_diceFocal", "tversky"], default="diceCE")
parser.add_argument("--compile_model", type=bool, default=False)
parser.add_argument("--save_model", type=bool, default=False)
parser.add_argument("--resume_training", type=bool, default=False)
parser.add_argument("--resume_ckpt", type=str, default="model-czpfnbsk:v0")
# Finetuning approach
parser.add_argument("--freeze", nargs='+', choices=["prompt", "mask", "image", ""], default=["prompt", "image"])
parser.add_argument("--image_encoder_size", type=int, default=1024)
parser.add_argument("--p_tuning", type=bool, default=False)
# Dataset type and location
datasets = ["BCSS", "CAMELYON", "CellSeg", "CoCaHis", "CoNIC", "CPM", "CRAG", "CryoNuSeg", "GlaS", "ICIA2018",
"Janowczyk", "KPI", "Kumar", "MoNuSAC", "MoNuSeg", "NuClick", "PAIP2023", "PanNuke", "SegPath", "SegPC",
"TIGER", "TNBC", "WSSS4LUAD"]
parser.add_argument("--datasets", nargs='+', choices=datasets, default=["Janowczyk"])
parser.add_argument("--use_holdout_testset", type=bool, default=False)
parser.add_argument("--test_datasets", nargs='+', choices=datasets, default=["CPM"])
parser.add_argument("--data_directory", type=str, default="/vol/data/histo_datasets/")
augmentations = ["AdvancedBlur", "Blur", "GaussianBlur", "ZoomBlur", "CLAHE", "Emboss", "GaussNoise", "IsoNoise",
"ImageCompression", "Posterize", "RingingOvershoot", "Sharpen", "ToGray", "Downscale",
"ChannelShuffle", "ChromaticAberration", "ColorJitter","HueSaturationValue", "MultiplicativeNoise",
"PlanckianJitter", "RGBShift", "RandomBrightnessContrast", "RandomGamma","RandomToneCurve",
"FancyPCA",
"Affine", "CropNonEmptyMaskIfExists", "ElasticTransform", "GridDistortion", "OpticalDistortion",
"RandomCrop", "RandomGridShuffle", "RandomResizedCrop", "RandomResizedCrop", "Rotate",
"ShiftScaleRotate", "CropAndPad", "D4", "PadIfNeeded", "Perspective", "RandomScale",
"NoOp"]
parser.add_argument("--data_augmentations", nargs='+', default=["NoOp"])
parser.add_argument("--mask_augmentation_tries", type=int, default=5)
parser.add_argument("--threshold_connected_components", type=int, default=2)
#Training parameters
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--epochs", type=int, default=10)
parser.add_argument("--batch_size", type=int, default=4)
parser.add_argument("--accumulate_grad_batches", type=int, default=1)
parser.add_argument("--shuffle", type=bool, default=False)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument("--mask_threshold", type=float, default=0.0)
parser.add_argument("--bbox_shift", type=int, default=10)
parser.add_argument("--run_directory", type=str, default="/vol/data/runs/")
# Prompt parameters
parser.add_argument("--random_prompt_type", type=bool, default=False)
parser.add_argument("--prompt_type", type=str, choices=["points", "boxes", "both"], default="both")
parser.add_argument("--prompt_batch_size", type=str, default=8)
# Data split
parser.add_argument("--data_split", type=bool, default=True)
parser.add_argument("--train_split", type=float, default=0.6)
parser.add_argument("--val_split", type=float, default=0.2)
parser.add_argument("--test_split", type=float, default=0.2)
# DataLoader parameters
parser.add_argument("--drop_last", type=bool, default=True)
# LoRA parameters
parser.add_argument("--lora_rank", type=int, default=1)
parser.add_argument("--lora_layer", nargs='+', default=None) #["-1"]
# Interactive parameters
parser.add_argument("--random_mode", type=bool, default=False)
parser.add_argument("--mode", type=str, choices=["random", "interactive"], default="interactive")
parser.add_argument("--random_nr_of_interactive_points", type=bool, default=False)
parser.add_argument("--max_nr_of_interactive_points", type=int, default=1)
parser.add_argument("--nr_of_interactive_points", type=int, default=0)
# Point parameters
parser.add_argument("--random_nr_of_points", type=bool, default=False)
parser.add_argument("--max_nr_of_points", type=int, default=1)
parser.add_argument("--nr_of_initial_points", type=int, default=1)
# Positive point parameters
parser.add_argument("--only_positive_points", type=bool, default=False)
parser.add_argument("--random_nr_of_positive_points", type=bool, default=False)
parser.add_argument("--max_nr_of_positive_points", type=int, default=1)
parser.add_argument("--nr_of_initial_positive_points", type=int, default=1)
parser.add_argument("--display_name", type=str, default="")
args = parser.parse_args()
# Choose display_name
if args.display_name == "":
display_name = f"{time}"
else:
display_name = args.display_name
initial_config = {
"compile_model": args.compile_model,
"save_model": args.save_model,
"display_name": display_name,
"epochs": args.epochs,
"time": time,
"mask_threshold": args.mask_threshold,
"accumulate_grad_batches": args.accumulate_grad_batches,
"seed": args.seed,
"run_directory": args.run_directory,
"project_name": args.project_name,
"entity": args.entity,
"resume_training": args.resume_training,
"model_dir": args.model_dir,
"resume_ckpt": args.resume_ckpt,
}
model_config = {
"base_model": args.base_model,
"model_type": args.model_type,
"lora_layer": args.lora_layer,
"lora_rank": args.lora_rank,
"model_mode": "train",
"model_dir": args.model_dir,
"p_tuning": args.p_tuning,
}
train_config = {
"loss": args.loss,
"learning_rate": args.lr,
"freeze": args.freeze,
"mode": args.mode,
"nr_of_interactive_points": args.nr_of_interactive_points,
"compile_model": args.compile_model,
"mask_threshold": args.mask_threshold,
"prompt_batch_size": args.prompt_batch_size,
"prompt_type": args.prompt_type,
"batch_size": args.batch_size,
}
data_config = {
"datasets": args.datasets,
"data_directory": args.data_directory,
"cluster": args.cluster,
"image_encoder_size": args.image_encoder_size,
"batch_size": args.batch_size,
"drop_last": args.drop_last,
"num_workers": args.num_workers,
"use_holdout_testset": args.use_holdout_testset,
"test_datasets": args.test_datasets,
"data_augmentations": args.data_augmentations,
"mask_augmentation_tries": args.mask_augmentation_tries,
"threshold_connected_components": args.threshold_connected_components,
"data_split": args.data_split,
"train_split": args.train_split,
"val_split": args.val_split,
"test_split": args.test_split,
"shuffle": args.shuffle,
"seed": args.seed,
}
prompt_config = {
"prompt_batch_size": args.prompt_batch_size,
"prompt_type": args.prompt_type,
"nr_of_points": args.nr_of_initial_points,
"nr_of_positive_points": args.nr_of_initial_positive_points,
"bbox_shift": args.bbox_shift,
}
random_prompt_config = {
"random_mode": args.random_mode,
"random_prompt_type": args.random_prompt_type,
"random_nr_of_points": args.random_nr_of_points,
"random_nr_of_interactive_points": args.random_nr_of_interactive_points,
"random_nr_of_positive_points": args.random_nr_of_positive_points,
"max_nr_of_interactive_points": args.max_nr_of_interactive_points,
"max_nr_of_points": args.max_nr_of_points,
"max_nr_of_positive_points": args.max_nr_of_positive_points,
"only_positive_points": args.only_positive_points,
}
config = {
"initial_config": initial_config,
"model_config": model_config,
"training_config": train_config,
"data_config": data_config,
"prompt_config": prompt_config,
"random_prompt_config": random_prompt_config,
}
print("CONFIG: ", config)
### TRAINING
compile_model = initial_config["compile_model"]
save_model = initial_config["save_model"]
seed = initial_config["seed"]
run_directory = initial_config["run_directory"]
project_name = initial_config["project_name"]
entity = initial_config["entity"]
resume_training = initial_config["resume_training"]
model_dir = initial_config["model_dir"]
resume_ckpt = initial_config["resume_ckpt"]
torch.set_float32_matmul_precision("high")
pl.seed_everything(seed, workers=True)
model = SamHI(config)
if compile_model:
model = torch.compile(model)
if save_model:
wandb_logger = WandbLogger(log_model="all", project=project_name, entity=entity, name=display_name, config=config, save_dir=run_directory)
else:
wandb_logger = WandbLogger(log_model=False, project=project_name, entity=entity, name=display_name, config=config, save_dir=run_directory)
trainer = pl.Trainer(logger=wandb_logger,
accelerator="gpu",
devices=1,
deterministic=False,
fast_dev_run=False,
max_epochs=initial_config["epochs"],
accumulate_grad_batches=initial_config["accumulate_grad_batches"])
train_loader, val_loader, test_loader = prepare_data(data_config)
if resume_training:
model_path = os.path.join(model_dir, resume_ckpt)
os.makedirs(model_path, exist_ok=True)
api = wandb.Api()
artifact = api.artifact("philippresearch/SAMHI/" + resume_ckpt, type="model")
artifact_dir = artifact.download(root=model_path)
trainer.fit(model, train_loader, val_loader, ckpt_path=os.path.join(model_path, "model.ckpt"))
else:
trainer.fit(model, train_loader, val_loader)
pl.seed_everything(seed, workers=True)
random_prompt_config = {
"random_mode": False,
"random_prompt_type": False,
"random_nr_of_points": False,
"random_nr_of_interactive_points": False,
"random_nr_of_positive_points": False,
"only_positive_points": True,
}
model.random_prompt_config = random_prompt_config
prompt_config = {
"prompt_batch_size": 1,
"nr_of_points": 1,
"prompt_type": "both",
"nr_of_positive_points": 1,
"bbox_shift": 10,
}
model.prompt_config = prompt_config
model.prompt_type = "both"
model.nr_of_interactive_points = 0
model.prompt_batch_size = 1
trainer.test(model, dataloaders=test_loader)