Spaces:
Running
on
Zero
Running
on
Zero
File size: 35,307 Bytes
9060565 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 |
# stdlib imports
import os, getpass, datetime, pathlib, tempfile, functools, glob
# non-stdlib core dependencies
import numpy as np
import scipy
import cellpose.io
import cellpose.models
import tifffile
import imagecodecs
# distributed dependencies
import dask
import distributed
import dask_image.ndmeasure
import yaml
import zarr
import dask_jobqueue
######################## File format functions ################################
def numpy_array_to_zarr(write_path, array, chunks):
"""
Store an in memory numpy array to disk as a chunked Zarr array
Parameters
----------
write_path : string
Filepath where Zarr array will be created
array : numpy.ndarray
The already loaded in-memory numpy array to store as zarr
chunks : tuple, must be array.ndim length
How the array will be chunked in the Zarr array
Returns
-------
zarr.core.Array
A read+write reference to the zarr array on disk
"""
zarr_array = zarr.open(
write_path,
mode='w',
shape=array.shape,
chunks=chunks,
dtype=array.dtype,
)
zarr_array[...] = array
return zarr_array
def wrap_folder_of_tiffs(
filename_pattern,
block_index_pattern=r'_(Z)(\d+)(Y)(\d+)(X)(\d+)',
):
"""
Wrap a folder of tiff files with a zarr array without duplicating data.
Tiff files must all contain images with the same shape and data type.
Tiff file names must contain a pattern indicating where individual files
lie in the block grid.
Distributed computing requires parallel access to small regions of your
image from different processes. This is best accomplished with chunked
file formats like Zarr and N5. This function can accommodate a folder of
tiff files, but it is not equivalent to reformating your data as Zarr or N5.
If your individual tiff files/tiles are huge, distributed performance will
be poor or not work at all.
It does not make sense to use this function if you have only one tiff file.
That tiff file will become the only chunk in the zarr array, which means all
workers will have to load the entire image to fetch their crop of data anyway.
If you have a single tiff image, you should just reformat it with the
numpy_array_to_zarr function. Single tiff files too large to fit into system
memory are not be supported.
Parameters
----------
filename_pattern : string
A glob pattern that will match all needed tif files
block_index_pattern : regular expression string (default: r'_(Z)(\d+)(Y)(\d+)(X)(\d+)')
A regular expression pattern that indicates how to parse tiff filenames
to determine where each tiff file lies in the overall block grid
The default pattern assumes filenames like the following:
{any_prefix}_Z000Y000X000{any_suffix}
{any_prefix}_Z000Y000X001{any_suffix}
... and so on
Returns
-------
zarr.core.Array
"""
# define function to read individual files
def imread(fname):
with open(fname, 'rb') as fh:
return imagecodecs.tiff_decode(fh.read(), index=None)
# create zarr store, open it as zarr array and return
store = tifffile.imread(
filename_pattern,
aszarr=True,
imread=imread,
pattern=block_index_pattern,
axestiled={x:x for x in range(3)},
)
return zarr.open(store=store)
######################## Cluster related functions ############################
#----------------------- config stuff ----------------------------------------#
DEFAULT_CONFIG_FILENAME = 'distributed_cellpose_dask_config.yaml'
def _config_path(config_name):
"""Add config directory path to config filename"""
return str(pathlib.Path.home()) + '/.config/dask/' + config_name
def _modify_dask_config(
config,
config_name=DEFAULT_CONFIG_FILENAME,
):
"""
Modifies dask config dictionary, but also dumps modified
config to disk as a yaml file in ~/.config/dask/. This
ensures that workers inherit config options.
"""
dask.config.set(config)
with open(_config_path(config_name), 'w') as f:
yaml.dump(dask.config.config, f, default_flow_style=False)
def _remove_config_file(
config_name=DEFAULT_CONFIG_FILENAME,
):
"""Removes a config file from disk"""
config_path = _config_path(config_name)
if os.path.exists(config_path): os.remove(config_path)
#----------------------- clusters --------------------------------------------#
class myLocalCluster(distributed.LocalCluster):
"""
This is a thin wrapper extending dask.distributed.LocalCluster to set
configs before the cluster or workers are initialized.
For a list of full arguments (how to specify your worker resources) see:
https://distributed.dask.org/en/latest/api.html#distributed.LocalCluster
You need to know how many cpu cores and how much RAM your machine has.
Most users will only need to specify:
n_workers
ncpus (number of physical cpu cores per worker)
memory_limit (which is the limit per worker, should be a string like '16GB')
threads_per_worker (for most workflows this should be 1)
You can also modify any dask configuration option through the
config argument.
If your workstation has a GPU, one of the workers will have exclusive
access to it by default. That worker will be much faster than the others.
You may want to consider creating only one worker (which will have access
to the GPU) and letting that worker process all blocks serially.
"""
def __init__(
self,
ncpus,
config={},
config_name=DEFAULT_CONFIG_FILENAME,
persist_config=False,
**kwargs,
):
# config
self.config_name = config_name
self.persist_config = persist_config
scratch_dir = f"{os.getcwd()}/"
scratch_dir += f".{getpass.getuser()}_distributed_cellpose/"
config_defaults = {'temporary-directory':scratch_dir}
config = {**config_defaults, **config}
_modify_dask_config(config, config_name)
# construct
if "host" not in kwargs: kwargs["host"] = ""
super().__init__(**kwargs)
self.client = distributed.Client(self)
# set environment variables for workers (threading)
environment_vars = {
'MKL_NUM_THREADS':str(2*ncpus),
'NUM_MKL_THREADS':str(2*ncpus),
'OPENBLAS_NUM_THREADS':str(2*ncpus),
'OPENMP_NUM_THREADS':str(2*ncpus),
'OMP_NUM_THREADS':str(2*ncpus),
}
def set_environment_vars():
for k, v in environment_vars.items():
os.environ[k] = v
self.client.run(set_environment_vars)
print("Cluster dashboard link: ", self.dashboard_link)
def __enter__(self): return self
def __exit__(self, exc_type, exc_value, traceback):
if not self.persist_config:
_remove_config_file(self.config_name)
self.client.close()
super().__exit__(exc_type, exc_value, traceback)
class janeliaLSFCluster(dask_jobqueue.LSFCluster):
"""
This is a thin wrapper extending dask_jobqueue.LSFCluster,
which in turn extends dask.distributed.SpecCluster. This wrapper
sets configs before the cluster or workers are initialized. This is
an adaptive cluster and will scale the number of workers, between user
specified limits, based on the number of pending tasks. This wrapper
also enforces conventions specific to the Janelia LSF cluster.
For a full list of arguments see
https://jobqueue.dask.org/en/latest/generated/dask_jobqueue.LSFCluster.html
Most users will only need to specify:
ncpus (the number of cpu cores per worker)
min_workers
max_workers
"""
def __init__(
self,
ncpus,
min_workers,
max_workers,
config={},
config_name=DEFAULT_CONFIG_FILENAME,
persist_config=False,
**kwargs
):
# store all args in case needed later
self.locals_store = {**locals()}
# config
self.config_name = config_name
self.persist_config = persist_config
scratch_dir = f"/scratch/{getpass.getuser()}/"
config_defaults = {
'temporary-directory':scratch_dir,
'distributed.comm.timeouts.connect':'180s',
'distributed.comm.timeouts.tcp':'360s',
}
config = {**config_defaults, **config}
_modify_dask_config(config, config_name)
# threading is best in low level libraries
job_script_prologue = [
f"export MKL_NUM_THREADS={2*ncpus}",
f"export NUM_MKL_THREADS={2*ncpus}",
f"export OPENBLAS_NUM_THREADS={2*ncpus}",
f"export OPENMP_NUM_THREADS={2*ncpus}",
f"export OMP_NUM_THREADS={2*ncpus}",
]
# set scratch and log directories
if "local_directory" not in kwargs:
kwargs["local_directory"] = scratch_dir
if "log_directory" not in kwargs:
log_dir = f"{os.getcwd()}/dask_worker_logs_{os.getpid()}/"
pathlib.Path(log_dir).mkdir(parents=False, exist_ok=True)
kwargs["log_directory"] = log_dir
# graceful exit for lsf jobs (adds -d flag)
class quietLSFJob(dask_jobqueue.lsf.LSFJob):
cancel_command = "bkill -d"
# construct
super().__init__(
ncpus=ncpus,
processes=1,
cores=1,
memory=str(15*ncpus)+'GB',
mem=int(15e9*ncpus),
job_script_prologue=job_script_prologue,
job_cls=quietLSFJob,
**kwargs,
)
self.client = distributed.Client(self)
print("Cluster dashboard link: ", self.dashboard_link)
# set adaptive cluster bounds
self.adapt_cluster(min_workers, max_workers)
def __enter__(self): return self
def __exit__(self, exc_type, exc_value, traceback):
if not self.persist_config:
_remove_config_file(self.config_name)
self.client.close()
super().__exit__(exc_type, exc_value, traceback)
def adapt_cluster(self, min_workers, max_workers):
_ = self.adapt(
minimum_jobs=min_workers,
maximum_jobs=max_workers,
interval='10s',
wait_count=6,
)
def change_worker_attributes(
self,
min_workers,
max_workers,
**kwargs,
):
"""WARNING: this function is dangerous if you don't know what
you're doing. Don't call this unless you know exactly what
this does."""
self.scale(0)
for k, v in kwargs.items():
self.new_spec['options'][k] = v
self.adapt_cluster(min_workers, max_workers)
#----------------------- decorator -------------------------------------------#
def cluster(func):
"""
This decorator ensures a function will run inside a cluster
as a context manager. The decorated function, "func", must
accept "cluster" and "cluster_kwargs" as parameters. If
"cluster" is not None then the user has provided an existing
cluster and we just run func. If "cluster" is None then
"cluster_kwargs" are used to construct a new cluster, and
the function is run inside that cluster context.
"""
@functools.wraps(func)
def create_or_pass_cluster(*args, **kwargs):
# TODO: this only checks if args are explicitly present in function call
# it does not check if they are set correctly in any way
assert 'cluster' in kwargs or 'cluster_kwargs' in kwargs, \
"Either cluster or cluster_kwargs must be defined"
if not 'cluster' in kwargs:
cluster_constructor = myLocalCluster
F = lambda x: x in kwargs['cluster_kwargs']
if F('ncpus') and F('min_workers') and F('max_workers'):
cluster_constructor = janeliaLSFCluster
with cluster_constructor(**kwargs['cluster_kwargs']) as cluster:
kwargs['cluster'] = cluster
return func(*args, **kwargs)
return func(*args, **kwargs)
return create_or_pass_cluster
######################## the function to run on each block ####################
#----------------------- The main function -----------------------------------#
def process_block(
block_index,
crop,
input_zarr,
model_kwargs,
eval_kwargs,
blocksize,
overlap,
output_zarr,
preprocessing_steps=[],
worker_logs_directory=None,
test_mode=False,
):
"""
Preprocess and segment one block, of many, with eventual merger
of all blocks in mind. The block is processed as follows:
(1) Read block from disk, preprocess, and segment.
(2) Remove overlaps.
(3) Get bounding boxes for every segment.
(4) Remap segment IDs to globally unique values.
(5) Write segments to disk.
(6) Get segmented block faces.
A user may want to test this function on one block before running
the distributed function. When test_mode=True, steps (5) and (6)
are omitted and replaced with:
(5) return remapped segments as a numpy array, boxes, and box_ids
Parameters
----------
block_index : tuple
The (i, j, k, ...) index of the block in the overall block grid
crop : tuple of slice objects
The bounding box of the data to read from the input_zarr array
input_zarr : zarr.core.Array
The image data we want to segment
preprocessing_steps : list of tuples (default: the empty list)
Optionally apply an arbitrary pipeline of preprocessing steps
to the image block before running cellpose.
Must be in the following format:
[(f, {'arg1':val1, ...}), ...]
That is, each tuple must contain only two elements, a function
and a dictionary. The function must have the following signature:
def F(image, ..., crop=None)
That is, the first argument must be a numpy array, which will later
be populated by the image data. The function must also take a keyword
argument called crop, even if it is not used in the function itself.
All other arguments to the function are passed using the dictionary.
Here is an example:
def F(image, sigma, crop=None):
return gaussian_filter(image, sigma)
def G(image, radius, crop=None):
return median_filter(image, radius)
preprocessing_steps = [(F, {'sigma':2.0}), (G, {'radius':4})]
model_kwargs : dict
Arguments passed to cellpose.models.Cellpose
This is how you select and parameterize a model.
eval_kwargs : dict
Arguments passed to the eval function of the Cellpose model
This is how you parameterize model evaluation.
blocksize : iterable (list, tuple, np.ndarray)
The number of voxels (the shape) of blocks without overlaps
overlap : int
The number of voxels added to the blocksize to provide context
at the edges
output_zarr : zarr.core.Array
A location where segments can be stored temporarily before
merger is complete
worker_logs_directory : string (default: None)
A directory path where log files for each worker can be created
The directory must exist
test_mode : bool (default: False)
The primary use case of this function is to be called by
distributed_eval (defined later in this same module). However
you may want to call this function manually to test what
happens to an individual block; this is a good idea before
ramping up to process big data and also useful for debugging.
When test_mode is False (default) this function stores
the segments and returns objects needed for merging between
blocks.
When test_mode is True this function does not store the
segments, and instead returns them to the caller as a numpy
array. The boxes and box IDs are also returned. When test_mode
is True, you can supply dummy values for many of the inputs,
such as:
block_index = (0, 0, 0)
output_zarr=None
Returns
-------
If test_mode == False (the default), three things are returned:
faces : a list of numpy arrays - the faces of the block segments
boxes : a list of crops (tuples of slices), bounding boxes of segments
box_ids : 1D numpy array, parallel to boxes, the segment IDs of the
boxes
If test_mode == True, three things are returned:
segments : np.ndarray containing the segments with globally unique IDs
boxes : a list of crops (tuples of slices), bounding boxes of segments
box_ids : 1D numpy array, parallel to boxes, the segment IDs of the
boxes
"""
print('RUNNING BLOCK: ', block_index, '\tREGION: ', crop, flush=True)
segmentation = read_preprocess_and_segment(
input_zarr, crop, preprocessing_steps, model_kwargs, eval_kwargs,
worker_logs_directory,
)
segmentation, crop = remove_overlaps(
segmentation, crop, overlap, blocksize,
)
boxes = bounding_boxes_in_global_coordinates(segmentation, crop)
nblocks = get_nblocks(input_zarr.shape, blocksize)
segmentation, remap = global_segment_ids(segmentation, block_index, nblocks)
if remap[0] == 0: remap = remap[1:]
if test_mode: return segmentation, boxes, remap
output_zarr[tuple(crop)] = segmentation
faces = block_faces(segmentation)
return faces, boxes, remap
#----------------------- component functions ---------------------------------#
def read_preprocess_and_segment(
input_zarr,
crop,
preprocessing_steps,
model_kwargs,
eval_kwargs,
worker_logs_directory,
):
"""Read block from zarr array, run all preprocessing steps, run cellpose"""
image = input_zarr[crop]
for pp_step in preprocessing_steps:
pp_step[1]['crop'] = crop
image = pp_step[0](image, **pp_step[1])
log_file=None
if worker_logs_directory is not None:
log_file = f'dask_worker_{distributed.get_worker().name}.log'
log_file = pathlib.Path(worker_logs_directory).joinpath(log_file)
cellpose.io.logger_setup(stdout_file_replacement=log_file)
model = cellpose.models.CellposeModel(**model_kwargs)
return model.eval(image, **eval_kwargs)[0].astype(np.uint32)
def remove_overlaps(array, crop, overlap, blocksize):
"""overlaps only there to provide context for boundary voxels
and can be removed after segmentation is complete
reslice array to remove the overlaps"""
crop_trimmed = list(crop)
for axis in range(array.ndim):
if crop[axis].start != 0:
slc = [slice(None),]*array.ndim
slc[axis] = slice(overlap, None)
array = array[tuple(slc)]
a, b = crop[axis].start, crop[axis].stop
crop_trimmed[axis] = slice(a + overlap, b)
if array.shape[axis] > blocksize[axis]:
slc = [slice(None),]*array.ndim
slc[axis] = slice(None, blocksize[axis])
array = array[tuple(slc)]
a = crop_trimmed[axis].start
crop_trimmed[axis] = slice(a, a + blocksize[axis])
return array, crop_trimmed
def bounding_boxes_in_global_coordinates(segmentation, crop):
"""bounding boxes (tuples of slices) are super useful later
best to compute them now while things are distributed"""
boxes = scipy.ndimage.find_objects(segmentation)
boxes = [b for b in boxes if b is not None]
translate = lambda a, b: slice(a.start+b.start, a.start+b.stop)
for iii, box in enumerate(boxes):
boxes[iii] = tuple(translate(a, b) for a, b in zip(crop, box))
return boxes
def get_nblocks(shape, blocksize):
"""Given a shape and blocksize determine the number of blocks per axis"""
return np.ceil(np.array(shape) / blocksize).astype(int)
def global_segment_ids(segmentation, block_index, nblocks):
"""pack the block index into the segment IDs so they are
globally unique. Everything gets remapped to [1..N] later.
A uint32 is split into 5 digits on left and 5 digits on right.
This creates limits: 42950 maximum number of blocks and
99999 maximum number of segments per block"""
unique, unique_inverse = np.unique(segmentation, return_inverse=True)
p = str(np.ravel_multi_index(block_index, nblocks))
remap = [np.uint32(p+str(x).zfill(5)) for x in unique]
if unique[0] == 0: remap[0] = np.uint32(0) # 0 should just always be 0
segmentation = np.array(remap)[unique_inverse.reshape(segmentation.shape)]
return segmentation, remap
def block_faces(segmentation):
"""slice faces along every axis"""
faces = []
for iii in range(segmentation.ndim):
a = [slice(None),] * segmentation.ndim
a[iii] = slice(0, 1)
faces.append(segmentation[tuple(a)])
a = [slice(None),] * segmentation.ndim
a[iii] = slice(-1, None)
faces.append(segmentation[tuple(a)])
return faces
######################## Distributed Cellpose #################################
#----------------------- The main function -----------------------------------#
@cluster
def distributed_eval(
input_zarr,
blocksize,
write_path,
mask=None,
preprocessing_steps=[],
model_kwargs={},
eval_kwargs={},
cluster=None,
cluster_kwargs={},
temporary_directory=None,
):
"""
Evaluate a cellpose model on overlapping blocks of a big image.
Distributed over workstation or cluster resources with Dask.
Optionally run preprocessing steps on the blocks before running cellpose.
Optionally use a mask to ignore background regions in image.
Either cluster or cluster_kwargs parameter must be set to a
non-default value; please read these parameter descriptions below.
If using cluster_kwargs, the workstation and Janelia LSF cluster cases
are distinguished by the arguments present in the dictionary.
PC/Mac/Linux workstations and the Janelia LSF cluster are supported;
running on a different institute cluster will require implementing your
own dask cluster class. Look at the JaneliaLSFCluster class in this
module as an example, also look at the dask_jobqueue library. A PR with
a solid start is the right way to get help running this on your own
institute cluster.
If running on a workstation, please read the docstring for the
LocalCluster class defined in this module. That will tell you what to
put in the cluster_kwargs dictionary. If using the Janelia cluster,
please read the docstring for the JaneliaLSFCluster class.
Parameters
----------
input_zarr : zarr.core.Array
A zarr.core.Array instance containing the image data you want to
segment.
blocksize : iterable
The size of blocks in voxels. E.g. [128, 256, 256]
write_path : string
The location of a zarr file on disk where you'd like to write your results
mask : numpy.ndarray (default: None)
A foreground mask for the image data; may be at a different resolution
(e.g. lower) than the image data. If given, only blocks that contain
foreground will be processed. This can save considerable time and
expense. It is assumed that the domain of the input_zarr image data
and the mask is the same in physical units, but they may be on
different sampling/voxel grids.
preprocessing_steps : list of tuples (default: the empty list)
Optionally apply an arbitrary pipeline of preprocessing steps
to the image blocks before running cellpose.
Must be in the following format:
[(f, {'arg1':val1, ...}), ...]
That is, each tuple must contain only two elements, a function
and a dictionary. The function must have the following signature:
def F(image, ..., crop=None)
That is, the first argument must be a numpy array, which will later
be populated by the image data. The function must also take a keyword
argument called crop, even if it is not used in the function itself.
All other arguments to the function are passed using the dictionary.
Here is an example:
def F(image, sigma, crop=None):
return gaussian_filter(image, sigma)
def G(image, radius, crop=None):
return median_filter(image, radius)
preprocessing_steps = [(F, {'sigma':2.0}), (G, {'radius':4})]
model_kwargs : dict (default: {})
Arguments passed to cellpose.models.Cellpose
eval_kwargs : dict (default: {})
Arguments passed to cellpose.models.Cellpose.eval
cluster : A dask cluster object (default: None)
Only set if you have constructed your own static cluster. The default
behavior is to construct a dask cluster for the duration of this function,
then close it when the function is finished.
cluster_kwargs : dict (default: {})
Arguments used to parameterize your cluster.
If you are running locally, see the docstring for the myLocalCluster
class in this module. If you are running on the Janelia LSF cluster, see
the docstring for the janeliaLSFCluster class in this module. If you are
running on a different institute cluster, you may need to implement
a dask cluster object that conforms to the requirements of your cluster.
temporary_directory : string (default: None)
Temporary files are created during segmentation. The temporary files
will be in their own folder within the temporary_directory. The default
is the current directory. Temporary files are removed if the function
completes successfully.
Returns
-------
Two values are returned:
(1) A reference to the zarr array on disk containing the stitched cellpose
segments for your entire image
(2) Bounding boxes for every segment. This is a list of tuples of slices:
[(slice(z1, z2), slice(y1, y2), slice(x1, x2)), ...]
The list is sorted according to segment ID. That is the smallest segment
ID is the first tuple in the list, the largest segment ID is the last
tuple in the list.
"""
timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
worker_logs_dirname = f'dask_worker_logs_{timestamp}'
worker_logs_dir = pathlib.Path().absolute().joinpath(worker_logs_dirname)
worker_logs_dir.mkdir()
if 'diameter' not in eval_kwargs.keys():
eval_kwargs['diameter'] = 30
overlap = eval_kwargs['diameter'] * 2
block_indices, block_crops = get_block_crops(
input_zarr.shape, blocksize, overlap, mask,
)
# I hate indenting all that code just for the tempdir
# but context manager is the only way to really guarantee that
# the tempdir gets cleaned up even after unhandled exceptions
with tempfile.TemporaryDirectory(
prefix='.', suffix='_distributed_cellpose_tempdir',
dir=temporary_directory or os.getcwd(),
) as temporary_directory:
temp_zarr_path = temporary_directory + '/segmentation_unstitched.zarr'
temp_zarr = zarr.open(
temp_zarr_path, 'w',
shape=input_zarr.shape,
chunks=blocksize,
dtype=np.uint32,
)
futures = cluster.client.map(
process_block,
block_indices,
block_crops,
input_zarr=input_zarr,
preprocessing_steps=preprocessing_steps,
model_kwargs=model_kwargs,
eval_kwargs=eval_kwargs,
blocksize=blocksize,
overlap=overlap,
output_zarr=temp_zarr,
worker_logs_directory=str(worker_logs_dir),
)
results = cluster.client.gather(futures)
if isinstance(cluster, dask_jobqueue.core.JobQueueCluster):
cluster.scale(0)
faces, boxes_, box_ids_ = list(zip(*results))
boxes = [box for sublist in boxes_ for box in sublist]
box_ids = np.concatenate(box_ids_).astype(int) # unsure how but without cast these are float64
new_labeling = determine_merge_relabeling(block_indices, faces, box_ids)
debug_unique = np.unique(new_labeling)
new_labeling_path = temporary_directory + '/new_labeling.npy'
np.save(new_labeling_path, new_labeling)
# stitching step is cheap, we should release gpus and use small workers
if isinstance(cluster, dask_jobqueue.core.JobQueueCluster):
cluster.change_worker_attributes(
min_workers=cluster.locals_store['min_workers'],
max_workers=cluster.locals_store['max_workers'],
ncpus=1,
memory="15GB",
mem=int(15e9),
queue=None,
job_extra_directives=[],
)
segmentation_da = dask.array.from_zarr(temp_zarr)
relabeled = dask.array.map_blocks(
lambda block: np.load(new_labeling_path)[block],
segmentation_da,
dtype=np.uint32,
chunks=segmentation_da.chunks,
)
dask.array.to_zarr(relabeled, write_path, overwrite=True)
merged_boxes = merge_all_boxes(boxes, new_labeling[box_ids])
return zarr.open(write_path, mode='r'), merged_boxes
#----------------------- component functions ---------------------------------#
def get_block_crops(shape, blocksize, overlap, mask):
"""Given a voxel grid shape, blocksize, and overlap size, construct
tuples of slices for every block; optionally only include blocks
that contain foreground in the mask. Returns parallel lists,
the block indices and the slice tuples."""
blocksize = np.array(blocksize)
if mask is not None:
ratio = np.array(mask.shape) / shape
mask_blocksize = np.round(ratio * blocksize).astype(int)
indices, crops = [], []
nblocks = get_nblocks(shape, blocksize)
for index in np.ndindex(*nblocks):
start = blocksize * index - overlap
stop = start + blocksize + 2 * overlap
start = np.maximum(0, start)
stop = np.minimum(shape, stop)
crop = tuple(slice(x, y) for x, y in zip(start, stop))
foreground = True
if mask is not None:
start = mask_blocksize * index
stop = start + mask_blocksize
stop = np.minimum(mask.shape, stop)
mask_crop = tuple(slice(x, y) for x, y in zip(start, stop))
if not np.any(mask[mask_crop]): foreground = False
if foreground:
indices.append(index)
crops.append(crop)
return indices, crops
def determine_merge_relabeling(block_indices, faces, used_labels):
"""Determine boundary segment mergers, remap all label IDs to merge
and put all label IDs in range [1..N] for N global segments found"""
faces = adjacent_faces(block_indices, faces)
# FIX float parameters
# print("Used labels:", used_labels, "Type:", type(used_labels))
used_labels = used_labels.astype(int)
# print("Used labels:", used_labels, "Type:", type(used_labels))
label_range = int(np.max(used_labels))
label_groups = block_face_adjacency_graph(faces, label_range)
new_labeling = scipy.sparse.csgraph.connected_components(
label_groups, directed=False)[1]
# XXX: new_labeling is returned as int32. Loses half range. Potentially a problem.
unused_labels = np.ones(label_range + 1, dtype=bool)
unused_labels[used_labels] = 0
new_labeling[unused_labels] = 0
unique, unique_inverse = np.unique(new_labeling, return_inverse=True)
new_labeling = np.arange(len(unique), dtype=np.uint32)[unique_inverse]
return new_labeling
def adjacent_faces(block_indices, faces):
"""Find faces which touch and pair them together in new data structure"""
face_pairs = []
faces_index_lookup = {a:b for a, b in zip(block_indices, faces)}
for block_index in block_indices:
for ax in range(len(block_index)):
neighbor_index = np.array(block_index)
neighbor_index[ax] += 1
neighbor_index = tuple(neighbor_index)
try:
a = faces_index_lookup[block_index][2*ax + 1]
b = faces_index_lookup[neighbor_index][2*ax]
face_pairs.append( np.concatenate((a, b), axis=ax) )
except KeyError:
continue
return face_pairs
def block_face_adjacency_graph(faces, nlabels):
"""Shrink labels in face plane, then find which labels touch across the
face boundary"""
# FIX float parameters
# print("Initial nlabels:", nlabels, "Type:", type(nlabels))
nlabels = int(nlabels)
# print("Final nlabels:", nlabels, "Type:", type(nlabels))
all_mappings = []
structure = scipy.ndimage.generate_binary_structure(3, 1)
for face in faces:
sl0 = tuple(slice(0, 1) if d==2 else slice(None) for d in face.shape)
sl1 = tuple(slice(1, 2) if d==2 else slice(None) for d in face.shape)
a = shrink_labels(face[sl0], 1.0)
b = shrink_labels(face[sl1], 1.0)
face = np.concatenate((a, b), axis=np.argmin(a.shape))
mapped = dask_image.ndmeasure._utils._label._across_block_label_grouping(face, structure)
all_mappings.append(mapped)
i, j = np.concatenate(all_mappings, axis=1)
v = np.ones_like(i)
return scipy.sparse.coo_matrix((v, (i, j)), shape=(nlabels+1, nlabels+1)).tocsr()
def shrink_labels(plane, threshold):
"""Shrink labels in plane by some distance from their boundary"""
gradmag = np.linalg.norm(np.gradient(plane.squeeze()), axis=0)
shrunk_labels = np.copy(plane.squeeze())
shrunk_labels[gradmag > 0] = 0
distances = scipy.ndimage.distance_transform_edt(shrunk_labels)
shrunk_labels[distances <= threshold] = 0
return shrunk_labels.reshape(plane.shape)
def merge_all_boxes(boxes, box_ids):
"""Merge all boxes that map to the same box_ids"""
merged_boxes = []
boxes_array = np.array(boxes, dtype=object)
# FIX float parameters
# print("Box IDs:", box_ids, "Type:", type(box_ids))
box_ids = box_ids.astype(int)
# print("Box IDs:", box_ids, "Type:", type(box_ids))
for iii in np.unique(box_ids):
merge_indices = np.argwhere(box_ids == iii).squeeze()
if merge_indices.shape:
merged_box = merge_boxes(boxes_array[merge_indices])
else:
merged_box = boxes_array[merge_indices]
merged_boxes.append(merged_box)
return merged_boxes
def merge_boxes(boxes):
"""Take union of two or more parallelpipeds"""
box_union = boxes[0]
for iii in range(1, len(boxes)):
local_union = []
for s1, s2 in zip(box_union, boxes[iii]):
start = min(s1.start, s2.start)
stop = max(s1.stop, s2.stop)
local_union.append(slice(start, stop))
box_union = tuple(local_union)
return box_union
|