Spaces:
No application file
No application file
| import argparse | |
| import wandb | |
| import datetime | |
| import os | |
| from samhi.modeling.model import SamHI | |
| from samhi.data_processing.dataset import prepare_data | |
| import lightning.pytorch as pl | |
| from lightning.pytorch.loggers import WandbLogger | |
| import torch | |
| time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| # Define arguments | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--cluster", type=str, choices=["denbi", "helmholtz", "custom"], default="denbi") | |
| parser.add_argument("--num_workers", type=int, default=5) | |
| # W&B parameters | |
| parser.add_argument("--project_name", type=str, default="SAMHI") | |
| parser.add_argument("--entity", type=str, default="philippresearch") | |
| # Model Info | |
| parser.add_argument("--model_type", type=str, choices=["vit_b", "vit_l", "vit_h"], default="vit_b") | |
| parser.add_argument("--model_dir", type=str, default="/vol/data/models/") | |
| parser.add_argument("--base_model", type=str, default="sam_vit_b_01ec64.pth") | |
| parser.add_argument("--loss", type=str, choices=["diceCE", "diceFocal", "dice", "generalized_dice", "generalized_diceFocal", "tversky"], default="diceCE") | |
| parser.add_argument("--compile_model", type=bool, default=False) | |
| parser.add_argument("--save_model", type=bool, default=False) | |
| parser.add_argument("--resume_training", type=bool, default=False) | |
| parser.add_argument("--resume_ckpt", type=str, default="model-czpfnbsk:v0") | |
| # Finetuning approach | |
| parser.add_argument("--freeze", nargs='+', choices=["prompt", "mask", "image", ""], default=["prompt", "image"]) | |
| parser.add_argument("--image_encoder_size", type=int, default=1024) | |
| parser.add_argument("--p_tuning", type=bool, default=False) | |
| # Dataset type and location | |
| datasets = ["BCSS", "CAMELYON", "CellSeg", "CoCaHis", "CoNIC", "CPM", "CRAG", "CryoNuSeg", "GlaS", "ICIA2018", | |
| "Janowczyk", "KPI", "Kumar", "MoNuSAC", "MoNuSeg", "NuClick", "PAIP2023", "PanNuke", "SegPath", "SegPC", | |
| "TIGER", "TNBC", "WSSS4LUAD"] | |
| parser.add_argument("--datasets", nargs='+', choices=datasets, default=["Janowczyk"]) | |
| parser.add_argument("--use_holdout_testset", type=bool, default=False) | |
| parser.add_argument("--test_datasets", nargs='+', choices=datasets, default=["CPM"]) | |
| parser.add_argument("--data_directory", type=str, default="/vol/data/histo_datasets/") | |
| augmentations = ["AdvancedBlur", "Blur", "GaussianBlur", "ZoomBlur", "CLAHE", "Emboss", "GaussNoise", "IsoNoise", | |
| "ImageCompression", "Posterize", "RingingOvershoot", "Sharpen", "ToGray", "Downscale", | |
| "ChannelShuffle", "ChromaticAberration", "ColorJitter","HueSaturationValue", "MultiplicativeNoise", | |
| "PlanckianJitter", "RGBShift", "RandomBrightnessContrast", "RandomGamma","RandomToneCurve", | |
| "FancyPCA", | |
| "Affine", "CropNonEmptyMaskIfExists", "ElasticTransform", "GridDistortion", "OpticalDistortion", | |
| "RandomCrop", "RandomGridShuffle", "RandomResizedCrop", "RandomResizedCrop", "Rotate", | |
| "ShiftScaleRotate", "CropAndPad", "D4", "PadIfNeeded", "Perspective", "RandomScale", | |
| "NoOp"] | |
| parser.add_argument("--data_augmentations", nargs='+', default=["NoOp"]) | |
| parser.add_argument("--mask_augmentation_tries", type=int, default=5) | |
| parser.add_argument("--threshold_connected_components", type=int, default=2) | |
| #Training parameters | |
| parser.add_argument("--lr", type=float, default=1e-5) | |
| parser.add_argument("--epochs", type=int, default=10) | |
| parser.add_argument("--batch_size", type=int, default=4) | |
| parser.add_argument("--accumulate_grad_batches", type=int, default=1) | |
| parser.add_argument("--shuffle", type=bool, default=False) | |
| parser.add_argument('--seed', type=int, default=1) | |
| parser.add_argument("--mask_threshold", type=float, default=0.0) | |
| parser.add_argument("--bbox_shift", type=int, default=10) | |
| parser.add_argument("--run_directory", type=str, default="/vol/data/runs/") | |
| # Prompt parameters | |
| parser.add_argument("--random_prompt_type", type=bool, default=False) | |
| parser.add_argument("--prompt_type", type=str, choices=["points", "boxes", "both"], default="both") | |
| parser.add_argument("--prompt_batch_size", type=str, default=8) | |
| # Data split | |
| parser.add_argument("--data_split", type=bool, default=True) | |
| parser.add_argument("--train_split", type=float, default=0.6) | |
| parser.add_argument("--val_split", type=float, default=0.2) | |
| parser.add_argument("--test_split", type=float, default=0.2) | |
| # DataLoader parameters | |
| parser.add_argument("--drop_last", type=bool, default=True) | |
| # LoRA parameters | |
| parser.add_argument("--lora_rank", type=int, default=1) | |
| parser.add_argument("--lora_layer", nargs='+', default=None) #["-1"] | |
| # Interactive parameters | |
| parser.add_argument("--random_mode", type=bool, default=False) | |
| parser.add_argument("--mode", type=str, choices=["random", "interactive"], default="interactive") | |
| parser.add_argument("--random_nr_of_interactive_points", type=bool, default=False) | |
| parser.add_argument("--max_nr_of_interactive_points", type=int, default=1) | |
| parser.add_argument("--nr_of_interactive_points", type=int, default=0) | |
| # Point parameters | |
| parser.add_argument("--random_nr_of_points", type=bool, default=False) | |
| parser.add_argument("--max_nr_of_points", type=int, default=1) | |
| parser.add_argument("--nr_of_initial_points", type=int, default=1) | |
| # Positive point parameters | |
| parser.add_argument("--only_positive_points", type=bool, default=False) | |
| parser.add_argument("--random_nr_of_positive_points", type=bool, default=False) | |
| parser.add_argument("--max_nr_of_positive_points", type=int, default=1) | |
| parser.add_argument("--nr_of_initial_positive_points", type=int, default=1) | |
| parser.add_argument("--display_name", type=str, default="") | |
| args = parser.parse_args() | |
| # Choose display_name | |
| if args.display_name == "": | |
| display_name = f"{time}" | |
| else: | |
| display_name = args.display_name | |
| initial_config = { | |
| "compile_model": args.compile_model, | |
| "save_model": args.save_model, | |
| "display_name": display_name, | |
| "epochs": args.epochs, | |
| "time": time, | |
| "mask_threshold": args.mask_threshold, | |
| "accumulate_grad_batches": args.accumulate_grad_batches, | |
| "seed": args.seed, | |
| "run_directory": args.run_directory, | |
| "project_name": args.project_name, | |
| "entity": args.entity, | |
| "resume_training": args.resume_training, | |
| "model_dir": args.model_dir, | |
| "resume_ckpt": args.resume_ckpt, | |
| } | |
| model_config = { | |
| "base_model": args.base_model, | |
| "model_type": args.model_type, | |
| "lora_layer": args.lora_layer, | |
| "lora_rank": args.lora_rank, | |
| "model_mode": "train", | |
| "model_dir": args.model_dir, | |
| "p_tuning": args.p_tuning, | |
| } | |
| train_config = { | |
| "loss": args.loss, | |
| "learning_rate": args.lr, | |
| "freeze": args.freeze, | |
| "mode": args.mode, | |
| "nr_of_interactive_points": args.nr_of_interactive_points, | |
| "compile_model": args.compile_model, | |
| "mask_threshold": args.mask_threshold, | |
| "prompt_batch_size": args.prompt_batch_size, | |
| "prompt_type": args.prompt_type, | |
| "batch_size": args.batch_size, | |
| } | |
| data_config = { | |
| "datasets": args.datasets, | |
| "data_directory": args.data_directory, | |
| "cluster": args.cluster, | |
| "image_encoder_size": args.image_encoder_size, | |
| "batch_size": args.batch_size, | |
| "drop_last": args.drop_last, | |
| "num_workers": args.num_workers, | |
| "use_holdout_testset": args.use_holdout_testset, | |
| "test_datasets": args.test_datasets, | |
| "data_augmentations": args.data_augmentations, | |
| "mask_augmentation_tries": args.mask_augmentation_tries, | |
| "threshold_connected_components": args.threshold_connected_components, | |
| "data_split": args.data_split, | |
| "train_split": args.train_split, | |
| "val_split": args.val_split, | |
| "test_split": args.test_split, | |
| "shuffle": args.shuffle, | |
| "seed": args.seed, | |
| } | |
| prompt_config = { | |
| "prompt_batch_size": args.prompt_batch_size, | |
| "prompt_type": args.prompt_type, | |
| "nr_of_points": args.nr_of_initial_points, | |
| "nr_of_positive_points": args.nr_of_initial_positive_points, | |
| "bbox_shift": args.bbox_shift, | |
| } | |
| random_prompt_config = { | |
| "random_mode": args.random_mode, | |
| "random_prompt_type": args.random_prompt_type, | |
| "random_nr_of_points": args.random_nr_of_points, | |
| "random_nr_of_interactive_points": args.random_nr_of_interactive_points, | |
| "random_nr_of_positive_points": args.random_nr_of_positive_points, | |
| "max_nr_of_interactive_points": args.max_nr_of_interactive_points, | |
| "max_nr_of_points": args.max_nr_of_points, | |
| "max_nr_of_positive_points": args.max_nr_of_positive_points, | |
| "only_positive_points": args.only_positive_points, | |
| } | |
| config = { | |
| "initial_config": initial_config, | |
| "model_config": model_config, | |
| "training_config": train_config, | |
| "data_config": data_config, | |
| "prompt_config": prompt_config, | |
| "random_prompt_config": random_prompt_config, | |
| } | |
| print("CONFIG: ", config) | |
| ### TRAINING | |
| compile_model = initial_config["compile_model"] | |
| save_model = initial_config["save_model"] | |
| seed = initial_config["seed"] | |
| run_directory = initial_config["run_directory"] | |
| project_name = initial_config["project_name"] | |
| entity = initial_config["entity"] | |
| resume_training = initial_config["resume_training"] | |
| model_dir = initial_config["model_dir"] | |
| resume_ckpt = initial_config["resume_ckpt"] | |
| torch.set_float32_matmul_precision("high") | |
| pl.seed_everything(seed, workers=True) | |
| model = SamHI(config) | |
| if compile_model: | |
| model = torch.compile(model) | |
| if save_model: | |
| wandb_logger = WandbLogger(log_model="all", project=project_name, entity=entity, name=display_name, config=config, save_dir=run_directory) | |
| else: | |
| wandb_logger = WandbLogger(log_model=False, project=project_name, entity=entity, name=display_name, config=config, save_dir=run_directory) | |
| trainer = pl.Trainer(logger=wandb_logger, | |
| accelerator="gpu", | |
| devices=1, | |
| deterministic=False, | |
| fast_dev_run=False, | |
| max_epochs=initial_config["epochs"], | |
| accumulate_grad_batches=initial_config["accumulate_grad_batches"]) | |
| train_loader, val_loader, test_loader = prepare_data(data_config) | |
| if resume_training: | |
| model_path = os.path.join(model_dir, resume_ckpt) | |
| os.makedirs(model_path, exist_ok=True) | |
| api = wandb.Api() | |
| artifact = api.artifact("philippresearch/SAMHI/" + resume_ckpt, type="model") | |
| artifact_dir = artifact.download(root=model_path) | |
| trainer.fit(model, train_loader, val_loader, ckpt_path=os.path.join(model_path, "model.ckpt")) | |
| else: | |
| trainer.fit(model, train_loader, val_loader) | |
| pl.seed_everything(seed, workers=True) | |
| random_prompt_config = { | |
| "random_mode": False, | |
| "random_prompt_type": False, | |
| "random_nr_of_points": False, | |
| "random_nr_of_interactive_points": False, | |
| "random_nr_of_positive_points": False, | |
| "only_positive_points": True, | |
| } | |
| model.random_prompt_config = random_prompt_config | |
| prompt_config = { | |
| "prompt_batch_size": 1, | |
| "nr_of_points": 1, | |
| "prompt_type": "both", | |
| "nr_of_positive_points": 1, | |
| "bbox_shift": 10, | |
| } | |
| model.prompt_config = prompt_config | |
| model.prompt_type = "both" | |
| model.nr_of_interactive_points = 0 | |
| model.prompt_batch_size = 1 | |
| trainer.test(model, dataloaders=test_loader) |