PanCancerSeg-Specialist-Inference / trainers /nnUNetTrainerWandb2000.py
WorkTimer's picture
Add PanCancerSeg Gradio inference app
ce1057b verified
Raw
History Blame Contribute Delete
2.86 kB
"""nnUNet trainer with Weights & Biases logging and 2000 epochs.
Place this file (or symlink it) into the nnUNet trainer variants directory
so that nnUNet can discover it via the -tr flag:
VARIANTS_DIR=$(python -c "import nnunetv2; print(nnunetv2.__path__[0])")/training/nnUNetTrainer/variants
ln -sf $(realpath nnUNetTrainerWandb2000.py) "$VARIANTS_DIR/nnUNetTrainerWandb2000.py"
Then train with:
nnUNetv2_train DATASET CONFIG FOLD -tr nnUNetTrainerWandb2000 ...
"""
import os
import torch
import wandb
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
class nnUNetTrainerWandb2000(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int,
dataset_json: dict,
device: torch.device = torch.device("cuda")):
super().__init__(plans, configuration, fold, dataset_json, device)
self.num_epochs = 2000
def on_train_start(self):
super().on_train_start()
wandb.init(
project=os.environ.get("WANDB_PROJECT", "CVPR2026-PanCancerSeg"),
name=f"{self.plans_manager.dataset_name}_fold{self.fold}",
config={
"dataset": self.plans_manager.dataset_name,
"configuration": self.configuration_name,
"fold": self.fold,
"num_epochs": self.num_epochs,
"batch_size": self.batch_size,
"patch_size": list(self.configuration_manager.patch_size),
},
resume="allow",
)
def on_epoch_end(self):
super().on_epoch_end()
# Save periodic checkpoint every 200 epochs after epoch 1000
if self.current_epoch > 1000 and self.current_epoch % 200 == 0:
self.save_checkpoint(
os.path.join(self.output_folder, f"checkpoint_epoch{self.current_epoch}.pth")
)
logs = self.logger.my_fantastic_logging
metrics = {"epoch": self.current_epoch}
if logs["train_losses"]:
metrics["train_loss"] = logs["train_losses"][-1]
if logs["val_losses"]:
metrics["val_loss"] = logs["val_losses"][-1]
if logs["ema_fg_dice"]:
metrics["ema_fg_dice"] = logs["ema_fg_dice"][-1]
if logs["dice_per_class_or_region"]:
latest = logs["dice_per_class_or_region"][-1]
for i, d in enumerate(latest):
metrics[f"dice_class_{i}"] = d
metrics["learning_rate"] = self.optimizer.param_groups[0]["lr"]
metrics["epoch_time"] = self.logger.my_fantastic_logging.get(
"epoch_end_timestamps", [0])[-1] - self.logger.my_fantastic_logging.get(
"epoch_start_timestamps", [0])[-1]
wandb.log(metrics, step=self.current_epoch)
def on_train_end(self):
super().on_train_end()
wandb.finish()