Spaces:
Running
on
Zero
Running
on
Zero
| # stdlib imports | |
| import os, getpass, datetime, pathlib, tempfile, functools, glob | |
| # non-stdlib core dependencies | |
| import numpy as np | |
| import scipy | |
| import cellpose.io | |
| import cellpose.models | |
| import tifffile | |
| import imagecodecs | |
| # distributed dependencies | |
| import dask | |
| import distributed | |
| import dask_image.ndmeasure | |
| import yaml | |
| import zarr | |
| import dask_jobqueue | |
| ######################## File format functions ################################ | |
| def numpy_array_to_zarr(write_path, array, chunks): | |
| """ | |
| Store an in memory numpy array to disk as a chunked Zarr array | |
| Parameters | |
| ---------- | |
| write_path : string | |
| Filepath where Zarr array will be created | |
| array : numpy.ndarray | |
| The already loaded in-memory numpy array to store as zarr | |
| chunks : tuple, must be array.ndim length | |
| How the array will be chunked in the Zarr array | |
| Returns | |
| ------- | |
| zarr.core.Array | |
| A read+write reference to the zarr array on disk | |
| """ | |
| zarr_array = zarr.open( | |
| write_path, | |
| mode='w', | |
| shape=array.shape, | |
| chunks=chunks, | |
| dtype=array.dtype, | |
| ) | |
| zarr_array[...] = array | |
| return zarr_array | |
| def wrap_folder_of_tiffs( | |
| filename_pattern, | |
| block_index_pattern=r'_(Z)(\d+)(Y)(\d+)(X)(\d+)', | |
| ): | |
| """ | |
| Wrap a folder of tiff files with a zarr array without duplicating data. | |
| Tiff files must all contain images with the same shape and data type. | |
| Tiff file names must contain a pattern indicating where individual files | |
| lie in the block grid. | |
| Distributed computing requires parallel access to small regions of your | |
| image from different processes. This is best accomplished with chunked | |
| file formats like Zarr and N5. This function can accommodate a folder of | |
| tiff files, but it is not equivalent to reformating your data as Zarr or N5. | |
| If your individual tiff files/tiles are huge, distributed performance will | |
| be poor or not work at all. | |
| It does not make sense to use this function if you have only one tiff file. | |
| That tiff file will become the only chunk in the zarr array, which means all | |
| workers will have to load the entire image to fetch their crop of data anyway. | |
| If you have a single tiff image, you should just reformat it with the | |
| numpy_array_to_zarr function. Single tiff files too large to fit into system | |
| memory are not be supported. | |
| Parameters | |
| ---------- | |
| filename_pattern : string | |
| A glob pattern that will match all needed tif files | |
| block_index_pattern : regular expression string (default: r'_(Z)(\d+)(Y)(\d+)(X)(\d+)') | |
| A regular expression pattern that indicates how to parse tiff filenames | |
| to determine where each tiff file lies in the overall block grid | |
| The default pattern assumes filenames like the following: | |
| {any_prefix}_Z000Y000X000{any_suffix} | |
| {any_prefix}_Z000Y000X001{any_suffix} | |
| ... and so on | |
| Returns | |
| ------- | |
| zarr.core.Array | |
| """ | |
| # define function to read individual files | |
| def imread(fname): | |
| with open(fname, 'rb') as fh: | |
| return imagecodecs.tiff_decode(fh.read(), index=None) | |
| # create zarr store, open it as zarr array and return | |
| store = tifffile.imread( | |
| filename_pattern, | |
| aszarr=True, | |
| imread=imread, | |
| pattern=block_index_pattern, | |
| axestiled={x:x for x in range(3)}, | |
| ) | |
| return zarr.open(store=store) | |
| ######################## Cluster related functions ############################ | |
| #----------------------- config stuff ----------------------------------------# | |
| DEFAULT_CONFIG_FILENAME = 'distributed_cellpose_dask_config.yaml' | |
| def _config_path(config_name): | |
| """Add config directory path to config filename""" | |
| return str(pathlib.Path.home()) + '/.config/dask/' + config_name | |
| def _modify_dask_config( | |
| config, | |
| config_name=DEFAULT_CONFIG_FILENAME, | |
| ): | |
| """ | |
| Modifies dask config dictionary, but also dumps modified | |
| config to disk as a yaml file in ~/.config/dask/. This | |
| ensures that workers inherit config options. | |
| """ | |
| dask.config.set(config) | |
| with open(_config_path(config_name), 'w') as f: | |
| yaml.dump(dask.config.config, f, default_flow_style=False) | |
| def _remove_config_file( | |
| config_name=DEFAULT_CONFIG_FILENAME, | |
| ): | |
| """Removes a config file from disk""" | |
| config_path = _config_path(config_name) | |
| if os.path.exists(config_path): os.remove(config_path) | |
| #----------------------- clusters --------------------------------------------# | |
| class myLocalCluster(distributed.LocalCluster): | |
| """ | |
| This is a thin wrapper extending dask.distributed.LocalCluster to set | |
| configs before the cluster or workers are initialized. | |
| For a list of full arguments (how to specify your worker resources) see: | |
| https://distributed.dask.org/en/latest/api.html#distributed.LocalCluster | |
| You need to know how many cpu cores and how much RAM your machine has. | |
| Most users will only need to specify: | |
| n_workers | |
| ncpus (number of physical cpu cores per worker) | |
| memory_limit (which is the limit per worker, should be a string like '16GB') | |
| threads_per_worker (for most workflows this should be 1) | |
| You can also modify any dask configuration option through the | |
| config argument. | |
| If your workstation has a GPU, one of the workers will have exclusive | |
| access to it by default. That worker will be much faster than the others. | |
| You may want to consider creating only one worker (which will have access | |
| to the GPU) and letting that worker process all blocks serially. | |
| """ | |
| def __init__( | |
| self, | |
| ncpus, | |
| config={}, | |
| config_name=DEFAULT_CONFIG_FILENAME, | |
| persist_config=False, | |
| **kwargs, | |
| ): | |
| # config | |
| self.config_name = config_name | |
| self.persist_config = persist_config | |
| scratch_dir = f"{os.getcwd()}/" | |
| scratch_dir += f".{getpass.getuser()}_distributed_cellpose/" | |
| config_defaults = {'temporary-directory':scratch_dir} | |
| config = {**config_defaults, **config} | |
| _modify_dask_config(config, config_name) | |
| # construct | |
| if "host" not in kwargs: kwargs["host"] = "" | |
| super().__init__(**kwargs) | |
| self.client = distributed.Client(self) | |
| # set environment variables for workers (threading) | |
| environment_vars = { | |
| 'MKL_NUM_THREADS':str(2*ncpus), | |
| 'NUM_MKL_THREADS':str(2*ncpus), | |
| 'OPENBLAS_NUM_THREADS':str(2*ncpus), | |
| 'OPENMP_NUM_THREADS':str(2*ncpus), | |
| 'OMP_NUM_THREADS':str(2*ncpus), | |
| } | |
| def set_environment_vars(): | |
| for k, v in environment_vars.items(): | |
| os.environ[k] = v | |
| self.client.run(set_environment_vars) | |
| print("Cluster dashboard link: ", self.dashboard_link) | |
| def __enter__(self): return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| if not self.persist_config: | |
| _remove_config_file(self.config_name) | |
| self.client.close() | |
| super().__exit__(exc_type, exc_value, traceback) | |
| class janeliaLSFCluster(dask_jobqueue.LSFCluster): | |
| """ | |
| This is a thin wrapper extending dask_jobqueue.LSFCluster, | |
| which in turn extends dask.distributed.SpecCluster. This wrapper | |
| sets configs before the cluster or workers are initialized. This is | |
| an adaptive cluster and will scale the number of workers, between user | |
| specified limits, based on the number of pending tasks. This wrapper | |
| also enforces conventions specific to the Janelia LSF cluster. | |
| For a full list of arguments see | |
| https://jobqueue.dask.org/en/latest/generated/dask_jobqueue.LSFCluster.html | |
| Most users will only need to specify: | |
| ncpus (the number of cpu cores per worker) | |
| min_workers | |
| max_workers | |
| """ | |
| def __init__( | |
| self, | |
| ncpus, | |
| min_workers, | |
| max_workers, | |
| config={}, | |
| config_name=DEFAULT_CONFIG_FILENAME, | |
| persist_config=False, | |
| **kwargs | |
| ): | |
| # store all args in case needed later | |
| self.locals_store = {**locals()} | |
| # config | |
| self.config_name = config_name | |
| self.persist_config = persist_config | |
| scratch_dir = f"/scratch/{getpass.getuser()}/" | |
| config_defaults = { | |
| 'temporary-directory':scratch_dir, | |
| 'distributed.comm.timeouts.connect':'180s', | |
| 'distributed.comm.timeouts.tcp':'360s', | |
| } | |
| config = {**config_defaults, **config} | |
| _modify_dask_config(config, config_name) | |
| # threading is best in low level libraries | |
| job_script_prologue = [ | |
| f"export MKL_NUM_THREADS={2*ncpus}", | |
| f"export NUM_MKL_THREADS={2*ncpus}", | |
| f"export OPENBLAS_NUM_THREADS={2*ncpus}", | |
| f"export OPENMP_NUM_THREADS={2*ncpus}", | |
| f"export OMP_NUM_THREADS={2*ncpus}", | |
| ] | |
| # set scratch and log directories | |
| if "local_directory" not in kwargs: | |
| kwargs["local_directory"] = scratch_dir | |
| if "log_directory" not in kwargs: | |
| log_dir = f"{os.getcwd()}/dask_worker_logs_{os.getpid()}/" | |
| pathlib.Path(log_dir).mkdir(parents=False, exist_ok=True) | |
| kwargs["log_directory"] = log_dir | |
| # graceful exit for lsf jobs (adds -d flag) | |
| class quietLSFJob(dask_jobqueue.lsf.LSFJob): | |
| cancel_command = "bkill -d" | |
| # construct | |
| super().__init__( | |
| ncpus=ncpus, | |
| processes=1, | |
| cores=1, | |
| memory=str(15*ncpus)+'GB', | |
| mem=int(15e9*ncpus), | |
| job_script_prologue=job_script_prologue, | |
| job_cls=quietLSFJob, | |
| **kwargs, | |
| ) | |
| self.client = distributed.Client(self) | |
| print("Cluster dashboard link: ", self.dashboard_link) | |
| # set adaptive cluster bounds | |
| self.adapt_cluster(min_workers, max_workers) | |
| def __enter__(self): return self | |
| def __exit__(self, exc_type, exc_value, traceback): | |
| if not self.persist_config: | |
| _remove_config_file(self.config_name) | |
| self.client.close() | |
| super().__exit__(exc_type, exc_value, traceback) | |
| def adapt_cluster(self, min_workers, max_workers): | |
| _ = self.adapt( | |
| minimum_jobs=min_workers, | |
| maximum_jobs=max_workers, | |
| interval='10s', | |
| wait_count=6, | |
| ) | |
| def change_worker_attributes( | |
| self, | |
| min_workers, | |
| max_workers, | |
| **kwargs, | |
| ): | |
| """WARNING: this function is dangerous if you don't know what | |
| you're doing. Don't call this unless you know exactly what | |
| this does.""" | |
| self.scale(0) | |
| for k, v in kwargs.items(): | |
| self.new_spec['options'][k] = v | |
| self.adapt_cluster(min_workers, max_workers) | |
| #----------------------- decorator -------------------------------------------# | |
| def cluster(func): | |
| """ | |
| This decorator ensures a function will run inside a cluster | |
| as a context manager. The decorated function, "func", must | |
| accept "cluster" and "cluster_kwargs" as parameters. If | |
| "cluster" is not None then the user has provided an existing | |
| cluster and we just run func. If "cluster" is None then | |
| "cluster_kwargs" are used to construct a new cluster, and | |
| the function is run inside that cluster context. | |
| """ | |
| def create_or_pass_cluster(*args, **kwargs): | |
| # TODO: this only checks if args are explicitly present in function call | |
| # it does not check if they are set correctly in any way | |
| assert 'cluster' in kwargs or 'cluster_kwargs' in kwargs, \ | |
| "Either cluster or cluster_kwargs must be defined" | |
| if not 'cluster' in kwargs: | |
| cluster_constructor = myLocalCluster | |
| F = lambda x: x in kwargs['cluster_kwargs'] | |
| if F('ncpus') and F('min_workers') and F('max_workers'): | |
| cluster_constructor = janeliaLSFCluster | |
| with cluster_constructor(**kwargs['cluster_kwargs']) as cluster: | |
| kwargs['cluster'] = cluster | |
| return func(*args, **kwargs) | |
| return func(*args, **kwargs) | |
| return create_or_pass_cluster | |
| ######################## the function to run on each block #################### | |
| #----------------------- The main function -----------------------------------# | |
| def process_block( | |
| block_index, | |
| crop, | |
| input_zarr, | |
| model_kwargs, | |
| eval_kwargs, | |
| blocksize, | |
| overlap, | |
| output_zarr, | |
| preprocessing_steps=[], | |
| worker_logs_directory=None, | |
| test_mode=False, | |
| ): | |
| """ | |
| Preprocess and segment one block, of many, with eventual merger | |
| of all blocks in mind. The block is processed as follows: | |
| (1) Read block from disk, preprocess, and segment. | |
| (2) Remove overlaps. | |
| (3) Get bounding boxes for every segment. | |
| (4) Remap segment IDs to globally unique values. | |
| (5) Write segments to disk. | |
| (6) Get segmented block faces. | |
| A user may want to test this function on one block before running | |
| the distributed function. When test_mode=True, steps (5) and (6) | |
| are omitted and replaced with: | |
| (5) return remapped segments as a numpy array, boxes, and box_ids | |
| Parameters | |
| ---------- | |
| block_index : tuple | |
| The (i, j, k, ...) index of the block in the overall block grid | |
| crop : tuple of slice objects | |
| The bounding box of the data to read from the input_zarr array | |
| input_zarr : zarr.core.Array | |
| The image data we want to segment | |
| preprocessing_steps : list of tuples (default: the empty list) | |
| Optionally apply an arbitrary pipeline of preprocessing steps | |
| to the image block before running cellpose. | |
| Must be in the following format: | |
| [(f, {'arg1':val1, ...}), ...] | |
| That is, each tuple must contain only two elements, a function | |
| and a dictionary. The function must have the following signature: | |
| def F(image, ..., crop=None) | |
| That is, the first argument must be a numpy array, which will later | |
| be populated by the image data. The function must also take a keyword | |
| argument called crop, even if it is not used in the function itself. | |
| All other arguments to the function are passed using the dictionary. | |
| Here is an example: | |
| def F(image, sigma, crop=None): | |
| return gaussian_filter(image, sigma) | |
| def G(image, radius, crop=None): | |
| return median_filter(image, radius) | |
| preprocessing_steps = [(F, {'sigma':2.0}), (G, {'radius':4})] | |
| model_kwargs : dict | |
| Arguments passed to cellpose.models.Cellpose | |
| This is how you select and parameterize a model. | |
| eval_kwargs : dict | |
| Arguments passed to the eval function of the Cellpose model | |
| This is how you parameterize model evaluation. | |
| blocksize : iterable (list, tuple, np.ndarray) | |
| The number of voxels (the shape) of blocks without overlaps | |
| overlap : int | |
| The number of voxels added to the blocksize to provide context | |
| at the edges | |
| output_zarr : zarr.core.Array | |
| A location where segments can be stored temporarily before | |
| merger is complete | |
| worker_logs_directory : string (default: None) | |
| A directory path where log files for each worker can be created | |
| The directory must exist | |
| test_mode : bool (default: False) | |
| The primary use case of this function is to be called by | |
| distributed_eval (defined later in this same module). However | |
| you may want to call this function manually to test what | |
| happens to an individual block; this is a good idea before | |
| ramping up to process big data and also useful for debugging. | |
| When test_mode is False (default) this function stores | |
| the segments and returns objects needed for merging between | |
| blocks. | |
| When test_mode is True this function does not store the | |
| segments, and instead returns them to the caller as a numpy | |
| array. The boxes and box IDs are also returned. When test_mode | |
| is True, you can supply dummy values for many of the inputs, | |
| such as: | |
| block_index = (0, 0, 0) | |
| output_zarr=None | |
| Returns | |
| ------- | |
| If test_mode == False (the default), three things are returned: | |
| faces : a list of numpy arrays - the faces of the block segments | |
| boxes : a list of crops (tuples of slices), bounding boxes of segments | |
| box_ids : 1D numpy array, parallel to boxes, the segment IDs of the | |
| boxes | |
| If test_mode == True, three things are returned: | |
| segments : np.ndarray containing the segments with globally unique IDs | |
| boxes : a list of crops (tuples of slices), bounding boxes of segments | |
| box_ids : 1D numpy array, parallel to boxes, the segment IDs of the | |
| boxes | |
| """ | |
| print('RUNNING BLOCK: ', block_index, '\tREGION: ', crop, flush=True) | |
| segmentation = read_preprocess_and_segment( | |
| input_zarr, crop, preprocessing_steps, model_kwargs, eval_kwargs, | |
| worker_logs_directory, | |
| ) | |
| segmentation, crop = remove_overlaps( | |
| segmentation, crop, overlap, blocksize, | |
| ) | |
| boxes = bounding_boxes_in_global_coordinates(segmentation, crop) | |
| nblocks = get_nblocks(input_zarr.shape, blocksize) | |
| segmentation, remap = global_segment_ids(segmentation, block_index, nblocks) | |
| if remap[0] == 0: remap = remap[1:] | |
| if test_mode: return segmentation, boxes, remap | |
| output_zarr[tuple(crop)] = segmentation | |
| faces = block_faces(segmentation) | |
| return faces, boxes, remap | |
| #----------------------- component functions ---------------------------------# | |
| def read_preprocess_and_segment( | |
| input_zarr, | |
| crop, | |
| preprocessing_steps, | |
| model_kwargs, | |
| eval_kwargs, | |
| worker_logs_directory, | |
| ): | |
| """Read block from zarr array, run all preprocessing steps, run cellpose""" | |
| image = input_zarr[crop] | |
| for pp_step in preprocessing_steps: | |
| pp_step[1]['crop'] = crop | |
| image = pp_step[0](image, **pp_step[1]) | |
| log_file=None | |
| if worker_logs_directory is not None: | |
| log_file = f'dask_worker_{distributed.get_worker().name}.log' | |
| log_file = pathlib.Path(worker_logs_directory).joinpath(log_file) | |
| cellpose.io.logger_setup(stdout_file_replacement=log_file) | |
| model = cellpose.models.CellposeModel(**model_kwargs) | |
| return model.eval(image, **eval_kwargs)[0].astype(np.uint32) | |
| def remove_overlaps(array, crop, overlap, blocksize): | |
| """overlaps only there to provide context for boundary voxels | |
| and can be removed after segmentation is complete | |
| reslice array to remove the overlaps""" | |
| crop_trimmed = list(crop) | |
| for axis in range(array.ndim): | |
| if crop[axis].start != 0: | |
| slc = [slice(None),]*array.ndim | |
| slc[axis] = slice(overlap, None) | |
| array = array[tuple(slc)] | |
| a, b = crop[axis].start, crop[axis].stop | |
| crop_trimmed[axis] = slice(a + overlap, b) | |
| if array.shape[axis] > blocksize[axis]: | |
| slc = [slice(None),]*array.ndim | |
| slc[axis] = slice(None, blocksize[axis]) | |
| array = array[tuple(slc)] | |
| a = crop_trimmed[axis].start | |
| crop_trimmed[axis] = slice(a, a + blocksize[axis]) | |
| return array, crop_trimmed | |
| def bounding_boxes_in_global_coordinates(segmentation, crop): | |
| """bounding boxes (tuples of slices) are super useful later | |
| best to compute them now while things are distributed""" | |
| boxes = scipy.ndimage.find_objects(segmentation) | |
| boxes = [b for b in boxes if b is not None] | |
| translate = lambda a, b: slice(a.start+b.start, a.start+b.stop) | |
| for iii, box in enumerate(boxes): | |
| boxes[iii] = tuple(translate(a, b) for a, b in zip(crop, box)) | |
| return boxes | |
| def get_nblocks(shape, blocksize): | |
| """Given a shape and blocksize determine the number of blocks per axis""" | |
| return np.ceil(np.array(shape) / blocksize).astype(int) | |
| def global_segment_ids(segmentation, block_index, nblocks): | |
| """pack the block index into the segment IDs so they are | |
| globally unique. Everything gets remapped to [1..N] later. | |
| A uint32 is split into 5 digits on left and 5 digits on right. | |
| This creates limits: 42950 maximum number of blocks and | |
| 99999 maximum number of segments per block""" | |
| unique, unique_inverse = np.unique(segmentation, return_inverse=True) | |
| p = str(np.ravel_multi_index(block_index, nblocks)) | |
| remap = [np.uint32(p+str(x).zfill(5)) for x in unique] | |
| if unique[0] == 0: remap[0] = np.uint32(0) # 0 should just always be 0 | |
| segmentation = np.array(remap)[unique_inverse.reshape(segmentation.shape)] | |
| return segmentation, remap | |
| def block_faces(segmentation): | |
| """slice faces along every axis""" | |
| faces = [] | |
| for iii in range(segmentation.ndim): | |
| a = [slice(None),] * segmentation.ndim | |
| a[iii] = slice(0, 1) | |
| faces.append(segmentation[tuple(a)]) | |
| a = [slice(None),] * segmentation.ndim | |
| a[iii] = slice(-1, None) | |
| faces.append(segmentation[tuple(a)]) | |
| return faces | |
| ######################## Distributed Cellpose ################################# | |
| #----------------------- The main function -----------------------------------# | |
| def distributed_eval( | |
| input_zarr, | |
| blocksize, | |
| write_path, | |
| mask=None, | |
| preprocessing_steps=[], | |
| model_kwargs={}, | |
| eval_kwargs={}, | |
| cluster=None, | |
| cluster_kwargs={}, | |
| temporary_directory=None, | |
| ): | |
| """ | |
| Evaluate a cellpose model on overlapping blocks of a big image. | |
| Distributed over workstation or cluster resources with Dask. | |
| Optionally run preprocessing steps on the blocks before running cellpose. | |
| Optionally use a mask to ignore background regions in image. | |
| Either cluster or cluster_kwargs parameter must be set to a | |
| non-default value; please read these parameter descriptions below. | |
| If using cluster_kwargs, the workstation and Janelia LSF cluster cases | |
| are distinguished by the arguments present in the dictionary. | |
| PC/Mac/Linux workstations and the Janelia LSF cluster are supported; | |
| running on a different institute cluster will require implementing your | |
| own dask cluster class. Look at the JaneliaLSFCluster class in this | |
| module as an example, also look at the dask_jobqueue library. A PR with | |
| a solid start is the right way to get help running this on your own | |
| institute cluster. | |
| If running on a workstation, please read the docstring for the | |
| LocalCluster class defined in this module. That will tell you what to | |
| put in the cluster_kwargs dictionary. If using the Janelia cluster, | |
| please read the docstring for the JaneliaLSFCluster class. | |
| Parameters | |
| ---------- | |
| input_zarr : zarr.core.Array | |
| A zarr.core.Array instance containing the image data you want to | |
| segment. | |
| blocksize : iterable | |
| The size of blocks in voxels. E.g. [128, 256, 256] | |
| write_path : string | |
| The location of a zarr file on disk where you'd like to write your results | |
| mask : numpy.ndarray (default: None) | |
| A foreground mask for the image data; may be at a different resolution | |
| (e.g. lower) than the image data. If given, only blocks that contain | |
| foreground will be processed. This can save considerable time and | |
| expense. It is assumed that the domain of the input_zarr image data | |
| and the mask is the same in physical units, but they may be on | |
| different sampling/voxel grids. | |
| preprocessing_steps : list of tuples (default: the empty list) | |
| Optionally apply an arbitrary pipeline of preprocessing steps | |
| to the image blocks before running cellpose. | |
| Must be in the following format: | |
| [(f, {'arg1':val1, ...}), ...] | |
| That is, each tuple must contain only two elements, a function | |
| and a dictionary. The function must have the following signature: | |
| def F(image, ..., crop=None) | |
| That is, the first argument must be a numpy array, which will later | |
| be populated by the image data. The function must also take a keyword | |
| argument called crop, even if it is not used in the function itself. | |
| All other arguments to the function are passed using the dictionary. | |
| Here is an example: | |
| def F(image, sigma, crop=None): | |
| return gaussian_filter(image, sigma) | |
| def G(image, radius, crop=None): | |
| return median_filter(image, radius) | |
| preprocessing_steps = [(F, {'sigma':2.0}), (G, {'radius':4})] | |
| model_kwargs : dict (default: {}) | |
| Arguments passed to cellpose.models.Cellpose | |
| eval_kwargs : dict (default: {}) | |
| Arguments passed to cellpose.models.Cellpose.eval | |
| cluster : A dask cluster object (default: None) | |
| Only set if you have constructed your own static cluster. The default | |
| behavior is to construct a dask cluster for the duration of this function, | |
| then close it when the function is finished. | |
| cluster_kwargs : dict (default: {}) | |
| Arguments used to parameterize your cluster. | |
| If you are running locally, see the docstring for the myLocalCluster | |
| class in this module. If you are running on the Janelia LSF cluster, see | |
| the docstring for the janeliaLSFCluster class in this module. If you are | |
| running on a different institute cluster, you may need to implement | |
| a dask cluster object that conforms to the requirements of your cluster. | |
| temporary_directory : string (default: None) | |
| Temporary files are created during segmentation. The temporary files | |
| will be in their own folder within the temporary_directory. The default | |
| is the current directory. Temporary files are removed if the function | |
| completes successfully. | |
| Returns | |
| ------- | |
| Two values are returned: | |
| (1) A reference to the zarr array on disk containing the stitched cellpose | |
| segments for your entire image | |
| (2) Bounding boxes for every segment. This is a list of tuples of slices: | |
| [(slice(z1, z2), slice(y1, y2), slice(x1, x2)), ...] | |
| The list is sorted according to segment ID. That is the smallest segment | |
| ID is the first tuple in the list, the largest segment ID is the last | |
| tuple in the list. | |
| """ | |
| timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S') | |
| worker_logs_dirname = f'dask_worker_logs_{timestamp}' | |
| worker_logs_dir = pathlib.Path().absolute().joinpath(worker_logs_dirname) | |
| worker_logs_dir.mkdir() | |
| if 'diameter' not in eval_kwargs.keys(): | |
| eval_kwargs['diameter'] = 30 | |
| overlap = eval_kwargs['diameter'] * 2 | |
| block_indices, block_crops = get_block_crops( | |
| input_zarr.shape, blocksize, overlap, mask, | |
| ) | |
| # I hate indenting all that code just for the tempdir | |
| # but context manager is the only way to really guarantee that | |
| # the tempdir gets cleaned up even after unhandled exceptions | |
| with tempfile.TemporaryDirectory( | |
| prefix='.', suffix='_distributed_cellpose_tempdir', | |
| dir=temporary_directory or os.getcwd(), | |
| ) as temporary_directory: | |
| temp_zarr_path = temporary_directory + '/segmentation_unstitched.zarr' | |
| temp_zarr = zarr.open( | |
| temp_zarr_path, 'w', | |
| shape=input_zarr.shape, | |
| chunks=blocksize, | |
| dtype=np.uint32, | |
| ) | |
| futures = cluster.client.map( | |
| process_block, | |
| block_indices, | |
| block_crops, | |
| input_zarr=input_zarr, | |
| preprocessing_steps=preprocessing_steps, | |
| model_kwargs=model_kwargs, | |
| eval_kwargs=eval_kwargs, | |
| blocksize=blocksize, | |
| overlap=overlap, | |
| output_zarr=temp_zarr, | |
| worker_logs_directory=str(worker_logs_dir), | |
| ) | |
| results = cluster.client.gather(futures) | |
| if isinstance(cluster, dask_jobqueue.core.JobQueueCluster): | |
| cluster.scale(0) | |
| faces, boxes_, box_ids_ = list(zip(*results)) | |
| boxes = [box for sublist in boxes_ for box in sublist] | |
| box_ids = np.concatenate(box_ids_).astype(int) # unsure how but without cast these are float64 | |
| new_labeling = determine_merge_relabeling(block_indices, faces, box_ids) | |
| debug_unique = np.unique(new_labeling) | |
| new_labeling_path = temporary_directory + '/new_labeling.npy' | |
| np.save(new_labeling_path, new_labeling) | |
| # stitching step is cheap, we should release gpus and use small workers | |
| if isinstance(cluster, dask_jobqueue.core.JobQueueCluster): | |
| cluster.change_worker_attributes( | |
| min_workers=cluster.locals_store['min_workers'], | |
| max_workers=cluster.locals_store['max_workers'], | |
| ncpus=1, | |
| memory="15GB", | |
| mem=int(15e9), | |
| queue=None, | |
| job_extra_directives=[], | |
| ) | |
| segmentation_da = dask.array.from_zarr(temp_zarr) | |
| relabeled = dask.array.map_blocks( | |
| lambda block: np.load(new_labeling_path)[block], | |
| segmentation_da, | |
| dtype=np.uint32, | |
| chunks=segmentation_da.chunks, | |
| ) | |
| dask.array.to_zarr(relabeled, write_path, overwrite=True) | |
| merged_boxes = merge_all_boxes(boxes, new_labeling[box_ids]) | |
| return zarr.open(write_path, mode='r'), merged_boxes | |
| #----------------------- component functions ---------------------------------# | |
| def get_block_crops(shape, blocksize, overlap, mask): | |
| """Given a voxel grid shape, blocksize, and overlap size, construct | |
| tuples of slices for every block; optionally only include blocks | |
| that contain foreground in the mask. Returns parallel lists, | |
| the block indices and the slice tuples.""" | |
| blocksize = np.array(blocksize) | |
| if mask is not None: | |
| ratio = np.array(mask.shape) / shape | |
| mask_blocksize = np.round(ratio * blocksize).astype(int) | |
| indices, crops = [], [] | |
| nblocks = get_nblocks(shape, blocksize) | |
| for index in np.ndindex(*nblocks): | |
| start = blocksize * index - overlap | |
| stop = start + blocksize + 2 * overlap | |
| start = np.maximum(0, start) | |
| stop = np.minimum(shape, stop) | |
| crop = tuple(slice(x, y) for x, y in zip(start, stop)) | |
| foreground = True | |
| if mask is not None: | |
| start = mask_blocksize * index | |
| stop = start + mask_blocksize | |
| stop = np.minimum(mask.shape, stop) | |
| mask_crop = tuple(slice(x, y) for x, y in zip(start, stop)) | |
| if not np.any(mask[mask_crop]): foreground = False | |
| if foreground: | |
| indices.append(index) | |
| crops.append(crop) | |
| return indices, crops | |
| def determine_merge_relabeling(block_indices, faces, used_labels): | |
| """Determine boundary segment mergers, remap all label IDs to merge | |
| and put all label IDs in range [1..N] for N global segments found""" | |
| faces = adjacent_faces(block_indices, faces) | |
| # FIX float parameters | |
| # print("Used labels:", used_labels, "Type:", type(used_labels)) | |
| used_labels = used_labels.astype(int) | |
| # print("Used labels:", used_labels, "Type:", type(used_labels)) | |
| label_range = int(np.max(used_labels)) | |
| label_groups = block_face_adjacency_graph(faces, label_range) | |
| new_labeling = scipy.sparse.csgraph.connected_components( | |
| label_groups, directed=False)[1] | |
| # XXX: new_labeling is returned as int32. Loses half range. Potentially a problem. | |
| unused_labels = np.ones(label_range + 1, dtype=bool) | |
| unused_labels[used_labels] = 0 | |
| new_labeling[unused_labels] = 0 | |
| unique, unique_inverse = np.unique(new_labeling, return_inverse=True) | |
| new_labeling = np.arange(len(unique), dtype=np.uint32)[unique_inverse] | |
| return new_labeling | |
| def adjacent_faces(block_indices, faces): | |
| """Find faces which touch and pair them together in new data structure""" | |
| face_pairs = [] | |
| faces_index_lookup = {a:b for a, b in zip(block_indices, faces)} | |
| for block_index in block_indices: | |
| for ax in range(len(block_index)): | |
| neighbor_index = np.array(block_index) | |
| neighbor_index[ax] += 1 | |
| neighbor_index = tuple(neighbor_index) | |
| try: | |
| a = faces_index_lookup[block_index][2*ax + 1] | |
| b = faces_index_lookup[neighbor_index][2*ax] | |
| face_pairs.append( np.concatenate((a, b), axis=ax) ) | |
| except KeyError: | |
| continue | |
| return face_pairs | |
| def block_face_adjacency_graph(faces, nlabels): | |
| """Shrink labels in face plane, then find which labels touch across the | |
| face boundary""" | |
| # FIX float parameters | |
| # print("Initial nlabels:", nlabels, "Type:", type(nlabels)) | |
| nlabels = int(nlabels) | |
| # print("Final nlabels:", nlabels, "Type:", type(nlabels)) | |
| all_mappings = [] | |
| structure = scipy.ndimage.generate_binary_structure(3, 1) | |
| for face in faces: | |
| sl0 = tuple(slice(0, 1) if d==2 else slice(None) for d in face.shape) | |
| sl1 = tuple(slice(1, 2) if d==2 else slice(None) for d in face.shape) | |
| a = shrink_labels(face[sl0], 1.0) | |
| b = shrink_labels(face[sl1], 1.0) | |
| face = np.concatenate((a, b), axis=np.argmin(a.shape)) | |
| mapped = dask_image.ndmeasure._utils._label._across_block_label_grouping(face, structure) | |
| all_mappings.append(mapped) | |
| i, j = np.concatenate(all_mappings, axis=1) | |
| v = np.ones_like(i) | |
| return scipy.sparse.coo_matrix((v, (i, j)), shape=(nlabels+1, nlabels+1)).tocsr() | |
| def shrink_labels(plane, threshold): | |
| """Shrink labels in plane by some distance from their boundary""" | |
| gradmag = np.linalg.norm(np.gradient(plane.squeeze()), axis=0) | |
| shrunk_labels = np.copy(plane.squeeze()) | |
| shrunk_labels[gradmag > 0] = 0 | |
| distances = scipy.ndimage.distance_transform_edt(shrunk_labels) | |
| shrunk_labels[distances <= threshold] = 0 | |
| return shrunk_labels.reshape(plane.shape) | |
| def merge_all_boxes(boxes, box_ids): | |
| """Merge all boxes that map to the same box_ids""" | |
| merged_boxes = [] | |
| boxes_array = np.array(boxes, dtype=object) | |
| # FIX float parameters | |
| # print("Box IDs:", box_ids, "Type:", type(box_ids)) | |
| box_ids = box_ids.astype(int) | |
| # print("Box IDs:", box_ids, "Type:", type(box_ids)) | |
| for iii in np.unique(box_ids): | |
| merge_indices = np.argwhere(box_ids == iii).squeeze() | |
| if merge_indices.shape: | |
| merged_box = merge_boxes(boxes_array[merge_indices]) | |
| else: | |
| merged_box = boxes_array[merge_indices] | |
| merged_boxes.append(merged_box) | |
| return merged_boxes | |
| def merge_boxes(boxes): | |
| """Take union of two or more parallelpipeds""" | |
| box_union = boxes[0] | |
| for iii in range(1, len(boxes)): | |
| local_union = [] | |
| for s1, s2 in zip(box_union, boxes[iii]): | |
| start = min(s1.start, s2.start) | |
| stop = max(s1.stop, s2.stop) | |
| local_union.append(slice(start, stop)) | |
| box_union = tuple(local_union) | |
| return box_union | |