Spaces:
Sleeping
Sleeping
| """ | |
| Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. | |
| """ | |
| import os, time | |
| import numpy as np | |
| from tqdm import tqdm | |
| from cellpose import utils, models, io, train | |
| from .version import version_str | |
| from cellpose.cli import get_arg_parser | |
| try: | |
| from cellpose.gui import gui3d, gui | |
| GUI_ENABLED = True | |
| except ImportError as err: | |
| GUI_ERROR = err | |
| GUI_ENABLED = False | |
| GUI_IMPORT = True | |
| except Exception as err: | |
| GUI_ENABLED = False | |
| GUI_ERROR = err | |
| GUI_IMPORT = False | |
| raise | |
| import logging | |
| def main(): | |
| """ Run cellpose from command line | |
| """ | |
| args = get_arg_parser().parse_args() # this has to be in a separate file for autodoc to work | |
| if args.version: | |
| print(version_str) | |
| return | |
| ######## if no image arguments are provided, run GUI or add model and exit ######## | |
| if len(args.dir) == 0 and len(args.image_path) == 0: | |
| if args.add_model: | |
| io.add_model(args.add_model) | |
| return | |
| else: | |
| if not GUI_ENABLED: | |
| print("GUI ERROR: %s" % GUI_ERROR) | |
| if GUI_IMPORT: | |
| print( | |
| "GUI FAILED: GUI dependencies may not be installed, to install, run" | |
| ) | |
| print(" pip install 'cellpose[gui]'") | |
| else: | |
| if args.Zstack: | |
| gui3d.run() | |
| else: | |
| gui.run() | |
| return | |
| ############################## run cellpose on images ############################## | |
| if args.verbose: | |
| from .io import logger_setup | |
| logger, log_file = logger_setup() | |
| else: | |
| print( | |
| ">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose") | |
| print("No --verbose => no progress or info printed") | |
| logger = logging.getLogger(__name__) | |
| # find images | |
| if len(args.img_filter) > 0: | |
| image_filter = args.img_filter | |
| else: | |
| image_filter = None | |
| device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu, | |
| device=args.gpu_device) | |
| if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0": | |
| pretrained_model = "cpsam" | |
| logger.warning("training from scratch is disabled, using 'cpsam' model") | |
| else: | |
| pretrained_model = args.pretrained_model | |
| # Warn users about old arguments from CP3: | |
| if args.pretrained_model_ortho: | |
| logger.warning( | |
| "the '--pretrained_model_ortho' flag is deprecated in v4.0.1+ and no longer used") | |
| if args.train_size: | |
| logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used") | |
| if args.chan or args.chan2: | |
| logger.warning('--chan and --chan2 are deprecated, all channels are used by default') | |
| if args.all_channels: | |
| logger.warning("the '--all_channels' flag is deprecated in v4.0.1+ and no longer used") | |
| if args.restore_type: | |
| logger.warning("the '--restore_type' flag is deprecated in v4.0.1+ and no longer used") | |
| if args.transformer: | |
| logger.warning("the '--tranformer' flag is deprecated in v4.0.1+ and no longer used") | |
| if args.invert: | |
| logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used") | |
| if args.chan2_restore: | |
| logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used") | |
| if args.diam_mean: | |
| logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used") | |
| if args.train_size: | |
| logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used") | |
| if args.norm_percentile is not None: | |
| value1, value2 = args.norm_percentile | |
| normalize = {'percentile': (float(value1), float(value2))} | |
| else: | |
| normalize = (not args.no_norm) | |
| if args.save_each: | |
| if not args.save_every: | |
| raise ValueError("ERROR: --save_each requires --save_every") | |
| if len(args.image_path) > 0 and args.train: | |
| raise ValueError("ERROR: cannot train model with single image input") | |
| ## Run evaluation on images | |
| if not args.train: | |
| _evaluate_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize) | |
| ## Train a model ## | |
| else: | |
| _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize) | |
| def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize): | |
| test_dir = None if len(args.test_dir) == 0 else args.test_dir | |
| images, labels, image_names, train_probs = None, None, None, None | |
| test_images, test_labels, image_names_test, test_probs = None, None, None, None | |
| compute_flows = False | |
| if len(args.file_list) > 0: | |
| if os.path.exists(args.file_list): | |
| dat = np.load(args.file_list, allow_pickle=True).item() | |
| image_names = dat["train_files"] | |
| image_names_test = dat.get("test_files", None) | |
| train_probs = dat.get("train_probs", None) | |
| test_probs = dat.get("test_probs", None) | |
| compute_flows = dat.get("compute_flows", False) | |
| load_files = False | |
| else: | |
| logger.critical(f"ERROR: {args.file_list} does not exist") | |
| else: | |
| output = io.load_train_test_data(args.dir, test_dir, image_filter, | |
| args.mask_filter, | |
| args.look_one_level_down) | |
| images, labels, image_names, test_images, test_labels, image_names_test = output | |
| load_files = True | |
| # initialize model | |
| model = models.CellposeModel(device=device, pretrained_model=pretrained_model) | |
| # train segmentation model | |
| cpmodel_path = train.train_seg( | |
| model.net, images, labels, train_files=image_names, | |
| test_data=test_images, test_labels=test_labels, | |
| test_files=image_names_test, train_probs=train_probs, | |
| test_probs=test_probs, compute_flows=compute_flows, | |
| load_files=load_files, normalize=normalize, | |
| channel_axis=args.channel_axis, | |
| learning_rate=args.learning_rate, weight_decay=args.weight_decay, | |
| SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.train_batch_size, | |
| min_train_masks=args.min_train_masks, | |
| nimg_per_epoch=args.nimg_per_epoch, | |
| nimg_test_per_epoch=args.nimg_test_per_epoch, | |
| save_path=os.path.realpath(args.dir), | |
| save_every=args.save_every, | |
| save_each=args.save_each, | |
| model_name=args.model_name_out)[0] | |
| model.pretrained_model = cpmodel_path | |
| logger.info(">>>> model trained and saved to %s" % cpmodel_path) | |
| return model | |
| def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, normalize): | |
| # Check with user if they REALLY mean to run without saving anything | |
| if not args.train: | |
| saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt | |
| tic = time.time() | |
| if len(args.dir) > 0: | |
| image_names = io.get_image_files( | |
| args.dir, args.mask_filter, imf=imf, | |
| look_one_level_down=args.look_one_level_down) | |
| else: | |
| if os.path.exists(args.image_path): | |
| image_names = [args.image_path] | |
| else: | |
| raise ValueError(f"ERROR: no file found at {args.image_path}") | |
| nimg = len(image_names) | |
| if args.savedir: | |
| if not os.path.exists(args.savedir): | |
| raise FileExistsError(f"--savedir {args.savedir} does not exist") | |
| logger.info( | |
| ">>>> running cellpose on %d images using all channels" % nimg) | |
| # handle built-in model exceptions | |
| model = models.CellposeModel(device=device, pretrained_model=pretrained_model,) | |
| tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO) | |
| channel_axis = args.channel_axis | |
| z_axis = args.z_axis | |
| for image_name in tqdm(image_names, file=tqdm_out): | |
| if args.do_3D or args.stitch_threshold > 0.: | |
| logger.info('loading image as 3D zstack') | |
| image = io.imread_3D(image_name) | |
| if channel_axis is None: | |
| channel_axis = 3 | |
| if z_axis is None: | |
| z_axis = 0 | |
| else: | |
| image = io.imread_2D(image_name) | |
| out = model.eval( | |
| image, | |
| diameter=args.diameter, | |
| do_3D=args.do_3D, | |
| augment=args.augment, | |
| flow_threshold=args.flow_threshold, | |
| cellprob_threshold=args.cellprob_threshold, | |
| stitch_threshold=args.stitch_threshold, | |
| min_size=args.min_size, | |
| batch_size=args.batch_size, | |
| bsize=args.bsize, | |
| resample=not args.no_resample, | |
| normalize=normalize, | |
| channel_axis=channel_axis, | |
| z_axis=z_axis, | |
| anisotropy=args.anisotropy, | |
| niter=args.niter, | |
| flow3D_smooth=args.flow3D_smooth) | |
| masks, flows = out[:2] | |
| if args.exclude_on_edges: | |
| masks = utils.remove_edge_masks(masks) | |
| if not args.no_npy: | |
| io.masks_flows_to_seg(image, masks, flows, image_name, | |
| imgs_restore=None, | |
| restore_type=None, | |
| ratio=1.) | |
| if saving_something: | |
| suffix = "_cp_masks" | |
| if args.output_name is not None: | |
| # (1) If `savedir` is not defined, then must have a non-zero `suffix` | |
| if args.savedir is None and len(args.output_name) > 0: | |
| suffix = args.output_name | |
| elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir): | |
| # (2) If `savedir` is defined, and different from `dir` then | |
| # takes the value passed as a param. (which can be empty string) | |
| suffix = args.output_name | |
| io.save_masks(image, masks, flows, image_name, | |
| suffix=suffix, png=args.save_png, | |
| tif=args.save_tif, save_flows=args.save_flows, | |
| save_outlines=args.save_outlines, | |
| dir_above=args.dir_above, savedir=args.savedir, | |
| save_txt=args.save_txt, in_folders=args.in_folders, | |
| save_mpl=args.save_mpl) | |
| if args.save_rois: | |
| io.save_rois(masks, image_name) | |
| logger.info(">>>> completed in %0.3f sec" % (time.time() - tic)) | |
| return model | |
| if __name__ == "__main__": | |
| main() | |