File size: 11,002 Bytes
aff3c6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
"""
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
"""
import os, time
import numpy as np
from tqdm import tqdm
from cellpose import utils, models, io, train
from .version import version_str
from cellpose.cli import get_arg_parser

try:
    from cellpose.gui import gui3d, gui
    GUI_ENABLED = True
except ImportError as err:
    GUI_ERROR = err
    GUI_ENABLED = False
    GUI_IMPORT = True
except Exception as err:
    GUI_ENABLED = False
    GUI_ERROR = err
    GUI_IMPORT = False
    raise

import logging


def main():
    """ Run cellpose from command line
    """

    args = get_arg_parser().parse_args()  # this has to be in a separate file for autodoc to work

    if args.version:
        print(version_str)
        return

    ######## if no image arguments are provided, run GUI or add model and exit ########
    if len(args.dir) == 0 and len(args.image_path) == 0:
        if args.add_model:
            io.add_model(args.add_model)
            return
        else:
            if not GUI_ENABLED:
                print("GUI ERROR: %s" % GUI_ERROR)
                if GUI_IMPORT:
                    print(
                        "GUI FAILED: GUI dependencies may not be installed, to install, run"
                    )
                    print("     pip install 'cellpose[gui]'")
            else:
                if args.Zstack:
                    gui3d.run()
                else:
                    gui.run()
            return

    ############################## run cellpose on images ##############################
    if args.verbose:
        from .io import logger_setup
        logger, log_file = logger_setup()
    else:
        print(
            ">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose")
        print("No --verbose => no progress or info printed")
        logger = logging.getLogger(__name__)


    # find images
    if len(args.img_filter) > 0:
        image_filter = args.img_filter
    else:
        image_filter = None

    device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu,
                                        device=args.gpu_device)

    if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0":
        pretrained_model = "cpsam"
        logger.warning("training from scratch is disabled, using 'cpsam' model")
    else:
        pretrained_model = args.pretrained_model

    # Warn users about old arguments from CP3:
    if args.pretrained_model_ortho:
        logger.warning(
            "the '--pretrained_model_ortho' flag is deprecated in v4.0.1+ and no longer used")
    if args.train_size:
        logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
    if args.chan or args.chan2:
        logger.warning('--chan and --chan2 are deprecated, all channels are used by default')
    if args.all_channels:
        logger.warning("the '--all_channels' flag is deprecated in v4.0.1+ and no longer used")
    if args.restore_type:
        logger.warning("the '--restore_type' flag is deprecated in v4.0.1+ and no longer used")
    if args.transformer:
        logger.warning("the '--tranformer' flag is deprecated in v4.0.1+ and no longer used")
    if args.invert:
        logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used")
    if args.chan2_restore:
        logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used")
    if args.diam_mean:
        logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used")
    if args.train_size:
        logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
    
    if args.norm_percentile is not None:
        value1, value2 = args.norm_percentile
        normalize = {'percentile': (float(value1), float(value2))} 
    else:
        normalize = (not args.no_norm)

    if args.save_each:
        if not args.save_every:
            raise ValueError("ERROR: --save_each requires --save_every")

    if len(args.image_path) > 0 and args.train:
        raise ValueError("ERROR: cannot train model with single image input")

    ## Run evaluation on images
    if not args.train:
        _evaluate_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
    
    ## Train a model ##
    else:
        _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)


def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize):
    test_dir = None if len(args.test_dir) == 0 else args.test_dir
    images, labels, image_names, train_probs = None, None, None, None
    test_images, test_labels, image_names_test, test_probs = None, None, None, None
    compute_flows = False
    if len(args.file_list) > 0:
        if os.path.exists(args.file_list):
            dat = np.load(args.file_list, allow_pickle=True).item()
            image_names = dat["train_files"]
            image_names_test = dat.get("test_files", None)
            train_probs = dat.get("train_probs", None)
            test_probs = dat.get("test_probs", None)
            compute_flows = dat.get("compute_flows", False)
            load_files = False
        else:
            logger.critical(f"ERROR: {args.file_list} does not exist")
    else:
        output = io.load_train_test_data(args.dir, test_dir, image_filter,
                                                args.mask_filter,
                                                args.look_one_level_down)
        images, labels, image_names, test_images, test_labels, image_names_test = output
        load_files = True

    # initialize model
    model = models.CellposeModel(device=device, pretrained_model=pretrained_model)

    # train segmentation model
    cpmodel_path = train.train_seg(
            model.net, images, labels, train_files=image_names,
            test_data=test_images, test_labels=test_labels,
            test_files=image_names_test, train_probs=train_probs,
            test_probs=test_probs, compute_flows=compute_flows,
            load_files=load_files, normalize=normalize,
            channel_axis=args.channel_axis, 
            learning_rate=args.learning_rate, weight_decay=args.weight_decay,
            SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.train_batch_size,
            min_train_masks=args.min_train_masks,
            nimg_per_epoch=args.nimg_per_epoch,
            nimg_test_per_epoch=args.nimg_test_per_epoch,
            save_path=os.path.realpath(args.dir), 
            save_every=args.save_every,
            save_each=args.save_each,
            model_name=args.model_name_out)[0]
    model.pretrained_model = cpmodel_path
    logger.info(">>>> model trained and saved to %s" % cpmodel_path)
    return model


def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, normalize):
    # Check with user if they REALLY mean to run without saving anything
    if not args.train:
        saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt

    tic = time.time()
    if len(args.dir) > 0:
        image_names = io.get_image_files(
                args.dir, args.mask_filter, imf=imf,
                look_one_level_down=args.look_one_level_down)
    else:
        if os.path.exists(args.image_path):
            image_names = [args.image_path]
        else:
            raise ValueError(f"ERROR: no file found at {args.image_path}")
    nimg = len(image_names)

    if args.savedir:
        if not os.path.exists(args.savedir):
            raise FileExistsError(f"--savedir {args.savedir} does not exist")
        
    logger.info(
            ">>>> running cellpose on %d images using all channels" % nimg)

    # handle built-in model exceptions
    model = models.CellposeModel(device=device, pretrained_model=pretrained_model,)

    tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)

    channel_axis = args.channel_axis
    z_axis = args.z_axis

    for image_name in tqdm(image_names, file=tqdm_out):
        if args.do_3D or args.stitch_threshold > 0.:
            logger.info('loading image as 3D zstack')
            image = io.imread_3D(image_name)
            if channel_axis is None:
                channel_axis = 3
            if z_axis is None:
                z_axis = 0
                
        else:
            image = io.imread_2D(image_name)
        out = model.eval(
                image, 
                diameter=args.diameter, 
                do_3D=args.do_3D,
                augment=args.augment, 
                flow_threshold=args.flow_threshold,
                cellprob_threshold=args.cellprob_threshold,
                stitch_threshold=args.stitch_threshold, 
                min_size=args.min_size,
                batch_size=args.batch_size,
                bsize=args.bsize,
                resample=not args.no_resample,
                normalize=normalize,
                channel_axis=channel_axis, 
                z_axis=z_axis,
                anisotropy=args.anisotropy, 
                niter=args.niter,
                flow3D_smooth=args.flow3D_smooth)
        masks, flows = out[:2]

        if args.exclude_on_edges:
            masks = utils.remove_edge_masks(masks)
        if not args.no_npy:
            io.masks_flows_to_seg(image, masks, flows, image_name,
                                        imgs_restore=None, 
                                        restore_type=None,
                                        ratio=1.)
        if saving_something:
            suffix = "_cp_masks"
            if args.output_name is not None: 
                    # (1) If `savedir` is not defined, then must have a non-zero `suffix`
                if args.savedir is None and len(args.output_name) > 0:
                    suffix = args.output_name
                elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir):
                        # (2) If `savedir` is defined, and different from `dir` then                              
                        # takes the value passed as a param. (which can be empty string)
                    suffix = args.output_name

            io.save_masks(image, masks, flows, image_name,
                                suffix=suffix, png=args.save_png,
                                tif=args.save_tif, save_flows=args.save_flows,
                                save_outlines=args.save_outlines,
                                dir_above=args.dir_above, savedir=args.savedir,
                                save_txt=args.save_txt, in_folders=args.in_folders,
                                save_mpl=args.save_mpl)
        if args.save_rois:
            io.save_rois(masks, image_name)
    logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))

    return model


if __name__ == "__main__":
    main()