""" Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. """ import logging import numpy as np from tqdm import trange from . import transforms, utils import torch TORCH_ENABLED = True core_logger = logging.getLogger(__name__) tqdm_out = utils.TqdmToLogger(core_logger, level=logging.INFO) def use_gpu(gpu_number=0, use_torch=True): """ Check if GPU is available for use. Args: gpu_number (int): The index of the GPU to be used. Default is 0. use_torch (bool): Whether to use PyTorch for GPU check. Default is True. Returns: bool: True if GPU is available, False otherwise. Raises: ValueError: If use_torch is False, as cellpose only runs with PyTorch now. """ if use_torch: return _use_gpu_torch(gpu_number) else: raise ValueError("cellpose only runs with PyTorch now") def _use_gpu_torch(gpu_number=0): """ Checks if CUDA or MPS is available and working with PyTorch. Args: gpu_number (int): The GPU device number to use (default is 0). Returns: bool: True if CUDA or MPS is available and working, False otherwise. """ try: device = torch.device("cuda:" + str(gpu_number)) _ = torch.zeros((1,1)).to(device) core_logger.info("** TORCH CUDA version installed and working. **") return True except: pass try: device = torch.device('mps:' + str(gpu_number)) _ = torch.zeros((1,1)).to(device) core_logger.info('** TORCH MPS version installed and working. **') return True except: core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.') return False def assign_device(use_torch=True, gpu=False, device=0): """ Assigns the device (CPU or GPU or mps) to be used for computation. Args: use_torch (bool, optional): Whether to use torch for GPU detection. Defaults to True. gpu (bool, optional): Whether to use GPU for computation. Defaults to False. device (int or str, optional): The device index or name to be used. Defaults to 0. Returns: torch.device, bool (True if GPU is used, False otherwise) """ if isinstance(device, str): if device != "mps" or not(gpu and torch.backends.mps.is_available()): device = int(device) if gpu and use_gpu(use_torch=True): try: if torch.cuda.is_available(): device = torch.device(f'cuda:{device}') core_logger.info(">>>> using GPU (CUDA)") gpu = True cpu = False except: gpu = False cpu = True try: if torch.backends.mps.is_available(): device = torch.device('mps') core_logger.info(">>>> using GPU (MPS)") gpu = True cpu = False except: gpu = False cpu = True else: device = torch.device('cpu') core_logger.info('>>>> using CPU') gpu = False cpu = True if cpu: device = torch.device("cpu") core_logger.info(">>>> using CPU") gpu = False return device, gpu def _to_device(x, device, dtype=torch.float32): """ Converts the input tensor or numpy array to the specified device. Args: x (torch.Tensor or numpy.ndarray): The input tensor or numpy array. device (torch.device): The target device. Returns: torch.Tensor: The converted tensor on the specified device. """ if not isinstance(x, torch.Tensor): X = torch.from_numpy(x).to(device, dtype=dtype) return X else: return x def _from_device(X): """ Converts a PyTorch tensor from the device to a NumPy array on the CPU. Args: X (torch.Tensor): The input PyTorch tensor. Returns: numpy.ndarray: The converted NumPy array. """ # The cast is so numpy conversion always works x = X.detach().cpu().to(torch.float32).numpy() return x def _forward(net, x, feat=None): """Converts images to torch tensors, runs the network model, and returns numpy arrays. Args: net (torch.nn.Module): The network model. x (numpy.ndarray): The input images. Returns: Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features. """ X = _to_device(x, device=net.device, dtype=net.dtype) if feat is not None: feat = _to_device(feat, device=net.device, dtype=net.dtype) net.eval() with torch.no_grad(): y, style = net(X, feat=feat)[:2] del X y = _from_device(y) style = _from_device(style) return y, style def run_net(net, imgi, feat=None, batch_size=8, augment=False, tile_overlap=0.1, bsize=224, rsz=None): """ Run network on stack of images. (faster if augment is False) Args: net (class): cellpose network (model.net) imgi (np.ndarray): The input image or stack of images of size [Lz x Ly x Lx x nchan]. batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8. rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0. augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False. tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1. bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224. Returns: Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3]. y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability. style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles. """ # run network Lz, Ly0, Lx0, nchan = imgi.shape if rsz is not None: if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray): rsz = [rsz, rsz] Lyr, Lxr = int(Ly0 * rsz[0]), int(Lx0 * rsz[1]) else: Lyr, Lxr = Ly0, Lx0 # 512, 512 ly, lx = bsize, bsize # 256, 256 ypad1, ypad2, xpad1, xpad2 = transforms.get_pad_yx(Lyr, Lxr, min_size=(bsize, bsize)) # 8 Ly, Lx = Lyr + ypad1 + ypad2, Lxr + xpad1 + xpad2 # 528, 528 pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]]) if augment: ny = max(2, int(np.ceil(2. * Ly / bsize))) nx = max(2, int(np.ceil(2. * Lx / bsize))) else: ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize)) # 3 nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize)) # 3 # run multiple slices at the same time ntiles = ny * nx nimgs = max(1, batch_size // ntiles) # number of imgs to run in the same batch, 1 niter = int(np.ceil(Lz / nimgs)) # 1 ziterator = (trange(niter, file=tqdm_out, mininterval=30) if niter > 10 or Lz > 1 else range(niter)) for k in ziterator: inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs)) IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 3, 256, 256 if feat is not None: FEATa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 256 else: FEATa = None for i, b in enumerate(inds): # pad image for net so Ly and Lx are divisible by 4 imgb = transforms.resize_image(imgi[b], rsz=rsz) if rsz is not None else imgi[b].copy() imgb = np.pad(imgb.transpose(2,0,1), pads, mode="constant") # 3, 528, 528 IMG, ysub, xsub, Lyt, Lxt = transforms.make_tiles( imgb, bsize=bsize, augment=augment, tile_overlap=tile_overlap) # IMG: 3, 3, 3, 256, 256 IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG, (ny * nx, nchan, ly, lx)) if feat is not None: featb = transforms.resize_image(feat[b], rsz=rsz) if rsz is not None else feat[b].copy() featb = np.pad(featb.transpose(2,0,1), pads, mode="constant") FEAT, ysub, xsub, Lyt, Lxt = transforms.make_tiles( featb, bsize=bsize, augment=augment, tile_overlap=tile_overlap) FEATa[i * ntiles : (i+1) * ntiles] = np.reshape(FEAT, (ny * nx, nchan, ly, lx)) # run network for j in range(0, IMGa.shape[0], batch_size): bslc = slice(j, min(j + batch_size, IMGa.shape[0])) ya0, stylea0 = _forward(net, IMGa[bslc], feat=FEATa[bslc] if FEATa is not None else None) if j == 0: nout = ya0.shape[1] ya = np.zeros((IMGa.shape[0], nout, ly, lx), "float32") stylea = np.zeros((IMGa.shape[0], 256), "float32") ya[bslc] = ya0 stylea[bslc] = stylea0 # average tiles for i, b in enumerate(inds): if i==0 and k==0: yf = np.zeros((Lz, nout, Ly, Lx), "float32") styles = np.zeros((Lz, 256), "float32") y = ya[i * ntiles : (i + 1) * ntiles] if augment: y = np.reshape(y, (ny, nx, 3, ly, lx)) y = transforms.unaugment_tiles(y) y = np.reshape(y, (-1, 3, ly, lx)) yfi = transforms.average_tiles(y, ysub, xsub, Lyt, Lxt) yf[b] = yfi[:, :imgb.shape[-2], :imgb.shape[-1]] stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0) stylei /= (stylei**2).sum()**0.5 styles[b] = stylei # slices from padding yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2] yf = yf.transpose(0,2,3,1) return yf, np.array(styles) def run_3D(net, imgs, batch_size=8, augment=False, tile_overlap=0.1, bsize=224, net_ortho=None, progress=None): """ Run network on image z-stack. (faster if augment is False) Args: imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan]. batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8. rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0. anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None. augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False. tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1. bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224. net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None. progress (QProgressBar, optional): pyqt progress bar. Defaults to None. Returns: Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3]. y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability. style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles. """ sstr = ["YX", "ZY", "ZX"] pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)] ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)] cp = [(1, 2), (0, 2), (0, 1)] cpy = [(0, 1), (0, 1), (0, 1)] shape = imgs.shape[:-1] yf = np.zeros((*shape, 4), "float32") for p in range(3): xsl = imgs.transpose(pm[p]) # per image core_logger.info("running %s: %d planes of size (%d, %d)" % (sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]])) y, style = run_net(net, xsl, batch_size=batch_size, augment=augment, bsize=bsize, tile_overlap=tile_overlap, rsz=None) yf[..., -1] += y[..., -1].transpose(ipm[p]) for j in range(2): yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p]) y = None; del y if progress is not None: progress.setValue(25 + 15 * p) return yf, style