phoebehxf
init
aff3c6f
"""
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
"""
import os, time
import numpy as np
from tqdm import tqdm
from cellpose import utils, models, io, train
from .version import version_str
from cellpose.cli import get_arg_parser
try:
from cellpose.gui import gui3d, gui
GUI_ENABLED = True
except ImportError as err:
GUI_ERROR = err
GUI_ENABLED = False
GUI_IMPORT = True
except Exception as err:
GUI_ENABLED = False
GUI_ERROR = err
GUI_IMPORT = False
raise
import logging
def main():
""" Run cellpose from command line
"""
args = get_arg_parser().parse_args() # this has to be in a separate file for autodoc to work
if args.version:
print(version_str)
return
######## if no image arguments are provided, run GUI or add model and exit ########
if len(args.dir) == 0 and len(args.image_path) == 0:
if args.add_model:
io.add_model(args.add_model)
return
else:
if not GUI_ENABLED:
print("GUI ERROR: %s" % GUI_ERROR)
if GUI_IMPORT:
print(
"GUI FAILED: GUI dependencies may not be installed, to install, run"
)
print(" pip install 'cellpose[gui]'")
else:
if args.Zstack:
gui3d.run()
else:
gui.run()
return
############################## run cellpose on images ##############################
if args.verbose:
from .io import logger_setup
logger, log_file = logger_setup()
else:
print(
">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose")
print("No --verbose => no progress or info printed")
logger = logging.getLogger(__name__)
# find images
if len(args.img_filter) > 0:
image_filter = args.img_filter
else:
image_filter = None
device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu,
device=args.gpu_device)
if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0":
pretrained_model = "cpsam"
logger.warning("training from scratch is disabled, using 'cpsam' model")
else:
pretrained_model = args.pretrained_model
# Warn users about old arguments from CP3:
if args.pretrained_model_ortho:
logger.warning(
"the '--pretrained_model_ortho' flag is deprecated in v4.0.1+ and no longer used")
if args.train_size:
logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
if args.chan or args.chan2:
logger.warning('--chan and --chan2 are deprecated, all channels are used by default')
if args.all_channels:
logger.warning("the '--all_channels' flag is deprecated in v4.0.1+ and no longer used")
if args.restore_type:
logger.warning("the '--restore_type' flag is deprecated in v4.0.1+ and no longer used")
if args.transformer:
logger.warning("the '--tranformer' flag is deprecated in v4.0.1+ and no longer used")
if args.invert:
logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used")
if args.chan2_restore:
logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used")
if args.diam_mean:
logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used")
if args.train_size:
logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
if args.norm_percentile is not None:
value1, value2 = args.norm_percentile
normalize = {'percentile': (float(value1), float(value2))}
else:
normalize = (not args.no_norm)
if args.save_each:
if not args.save_every:
raise ValueError("ERROR: --save_each requires --save_every")
if len(args.image_path) > 0 and args.train:
raise ValueError("ERROR: cannot train model with single image input")
## Run evaluation on images
if not args.train:
_evaluate_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
## Train a model ##
else:
_train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize):
test_dir = None if len(args.test_dir) == 0 else args.test_dir
images, labels, image_names, train_probs = None, None, None, None
test_images, test_labels, image_names_test, test_probs = None, None, None, None
compute_flows = False
if len(args.file_list) > 0:
if os.path.exists(args.file_list):
dat = np.load(args.file_list, allow_pickle=True).item()
image_names = dat["train_files"]
image_names_test = dat.get("test_files", None)
train_probs = dat.get("train_probs", None)
test_probs = dat.get("test_probs", None)
compute_flows = dat.get("compute_flows", False)
load_files = False
else:
logger.critical(f"ERROR: {args.file_list} does not exist")
else:
output = io.load_train_test_data(args.dir, test_dir, image_filter,
args.mask_filter,
args.look_one_level_down)
images, labels, image_names, test_images, test_labels, image_names_test = output
load_files = True
# initialize model
model = models.CellposeModel(device=device, pretrained_model=pretrained_model)
# train segmentation model
cpmodel_path = train.train_seg(
model.net, images, labels, train_files=image_names,
test_data=test_images, test_labels=test_labels,
test_files=image_names_test, train_probs=train_probs,
test_probs=test_probs, compute_flows=compute_flows,
load_files=load_files, normalize=normalize,
channel_axis=args.channel_axis,
learning_rate=args.learning_rate, weight_decay=args.weight_decay,
SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.train_batch_size,
min_train_masks=args.min_train_masks,
nimg_per_epoch=args.nimg_per_epoch,
nimg_test_per_epoch=args.nimg_test_per_epoch,
save_path=os.path.realpath(args.dir),
save_every=args.save_every,
save_each=args.save_each,
model_name=args.model_name_out)[0]
model.pretrained_model = cpmodel_path
logger.info(">>>> model trained and saved to %s" % cpmodel_path)
return model
def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, normalize):
# Check with user if they REALLY mean to run without saving anything
if not args.train:
saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt
tic = time.time()
if len(args.dir) > 0:
image_names = io.get_image_files(
args.dir, args.mask_filter, imf=imf,
look_one_level_down=args.look_one_level_down)
else:
if os.path.exists(args.image_path):
image_names = [args.image_path]
else:
raise ValueError(f"ERROR: no file found at {args.image_path}")
nimg = len(image_names)
if args.savedir:
if not os.path.exists(args.savedir):
raise FileExistsError(f"--savedir {args.savedir} does not exist")
logger.info(
">>>> running cellpose on %d images using all channels" % nimg)
# handle built-in model exceptions
model = models.CellposeModel(device=device, pretrained_model=pretrained_model,)
tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
channel_axis = args.channel_axis
z_axis = args.z_axis
for image_name in tqdm(image_names, file=tqdm_out):
if args.do_3D or args.stitch_threshold > 0.:
logger.info('loading image as 3D zstack')
image = io.imread_3D(image_name)
if channel_axis is None:
channel_axis = 3
if z_axis is None:
z_axis = 0
else:
image = io.imread_2D(image_name)
out = model.eval(
image,
diameter=args.diameter,
do_3D=args.do_3D,
augment=args.augment,
flow_threshold=args.flow_threshold,
cellprob_threshold=args.cellprob_threshold,
stitch_threshold=args.stitch_threshold,
min_size=args.min_size,
batch_size=args.batch_size,
bsize=args.bsize,
resample=not args.no_resample,
normalize=normalize,
channel_axis=channel_axis,
z_axis=z_axis,
anisotropy=args.anisotropy,
niter=args.niter,
flow3D_smooth=args.flow3D_smooth)
masks, flows = out[:2]
if args.exclude_on_edges:
masks = utils.remove_edge_masks(masks)
if not args.no_npy:
io.masks_flows_to_seg(image, masks, flows, image_name,
imgs_restore=None,
restore_type=None,
ratio=1.)
if saving_something:
suffix = "_cp_masks"
if args.output_name is not None:
# (1) If `savedir` is not defined, then must have a non-zero `suffix`
if args.savedir is None and len(args.output_name) > 0:
suffix = args.output_name
elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir):
# (2) If `savedir` is defined, and different from `dir` then
# takes the value passed as a param. (which can be empty string)
suffix = args.output_name
io.save_masks(image, masks, flows, image_name,
suffix=suffix, png=args.save_png,
tif=args.save_tif, save_flows=args.save_flows,
save_outlines=args.save_outlines,
dir_above=args.dir_above, savedir=args.savedir,
save_txt=args.save_txt, in_folders=args.in_folders,
save_mpl=args.save_mpl)
if args.save_rois:
io.save_rois(masks, image_name)
logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))
return model
if __name__ == "__main__":
main()