""" Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. """ import os, gc import numpy as np import cv2 import fastremap from ..io import imread, imread_2D, imread_3D, imsave, outlines_to_text, add_model, remove_model, save_rois from ..models import normalize_default, MODEL_DIR, MODEL_LIST_PATH, get_user_models from ..utils import masks_to_outlines, outlines_list try: import qtpy from qtpy.QtWidgets import QFileDialog GUI = True except: GUI = False try: import matplotlib.pyplot as plt MATPLOTLIB = True except: MATPLOTLIB = False def _init_model_list(parent): MODEL_DIR.mkdir(parents=True, exist_ok=True) parent.model_list_path = MODEL_LIST_PATH parent.model_strings = get_user_models() def _add_model(parent, filename=None, load_model=True): if filename is None: name = QFileDialog.getOpenFileName(parent, "Add model to GUI") filename = name[0] add_model(filename) fname = os.path.split(filename)[-1] parent.ModelChooseC.addItems([fname]) parent.model_strings.append(fname) for ind, model_string in enumerate(parent.model_strings[:-1]): if model_string == fname: _remove_model(parent, ind=ind + 1, verbose=False) parent.ModelChooseC.setCurrentIndex(len(parent.model_strings)) if load_model: parent.model_choose(custom=True) def _remove_model(parent, ind=None, verbose=True): if ind is None: ind = parent.ModelChooseC.currentIndex() if ind > 0: ind -= 1 parent.ModelChooseC.removeItem(ind + 1) del parent.model_strings[ind] # remove model from txt path modelstr = parent.ModelChooseC.currentText() remove_model(modelstr) if len(parent.model_strings) > 0: parent.ModelChooseC.setCurrentIndex(len(parent.model_strings)) else: parent.ModelChooseC.setCurrentIndex(0) else: print("ERROR: no model selected to delete") def _get_train_set(image_names): """ get training data and labels for images in current folder image_names""" train_data, train_labels, train_files = [], [], [] restore = None normalize_params = normalize_default for image_name_full in image_names: image_name = os.path.splitext(image_name_full)[0] label_name = None if os.path.exists(image_name + "_seg.npy"): dat = np.load(image_name + "_seg.npy", allow_pickle=True).item() masks = dat["masks"].squeeze() if masks.ndim == 2: fastremap.renumber(masks, in_place=True) label_name = image_name + "_seg.npy" else: print(f"GUI_INFO: _seg.npy found for {image_name} but masks.ndim!=2") if "img_restore" in dat: data = dat["img_restore"].squeeze() restore = dat["restore"] else: data = imread(image_name_full) normalize_params = dat[ "normalize_params"] if "normalize_params" in dat else normalize_default if label_name is not None: train_files.append(image_name_full) train_data.append(data) train_labels.append(masks) if restore: print(f"GUI_INFO: using {restore} images (dat['img_restore'])") return train_data, train_labels, train_files, restore, normalize_params def _load_image(parent, filename=None, load_seg=True, load_3D=False): """ load image with filename; if None, open QFileDialog if image is grey change view to default to grey scale """ if parent.load_3D: load_3D = True if filename is None: name = QFileDialog.getOpenFileName(parent, "Load image") filename = name[0] if filename == "": return manual_file = os.path.splitext(filename)[0] + "_seg.npy" load_mask = False if load_seg: if os.path.isfile(manual_file) and not parent.autoloadMasks.isChecked(): if filename is not None: image = (imread_2D(filename) if not load_3D else imread_3D(filename)) else: image = None _load_seg(parent, manual_file, image=image, image_file=filename, load_3D=load_3D) return elif parent.autoloadMasks.isChecked(): mask_file = os.path.splitext(filename)[0] + "_masks" + os.path.splitext( filename)[-1] mask_file = os.path.splitext(filename)[ 0] + "_masks.tif" if not os.path.isfile(mask_file) else mask_file load_mask = True if os.path.isfile(mask_file) else False try: print(f"GUI_INFO: loading image: {filename}") if not load_3D: image = imread_2D(filename) else: image = imread_3D(filename) parent.loaded = True except Exception as e: print("ERROR: images not compatible") print(f"ERROR: {e}") if parent.loaded: parent.reset() parent.filename = filename filename = os.path.split(parent.filename)[-1] _initialize_images(parent, image, load_3D=load_3D) parent.loaded = True parent.enable_buttons() if load_mask: _load_masks(parent, filename=mask_file) # check if gray and adjust viewer: if len(np.unique(image[..., 1:])) == 1: parent.color = 4 parent.RGBDropDown.setCurrentIndex(4) # gray parent.update_plot() def _initialize_images(parent, image, load_3D=False): """ format image for GUI assumes image is Z x W x H x C """ load_3D = parent.load_3D if load_3D is False else load_3D parent.stack = image print(f"GUI_INFO: image shape: {image.shape}") if load_3D: parent.NZ = len(parent.stack) parent.scroll.setMaximum(parent.NZ - 1) else: parent.NZ = 1 parent.stack = parent.stack[np.newaxis, ...] img_min = image.min() img_max = image.max() parent.stack = parent.stack.astype(np.float32) parent.stack -= img_min if img_max > img_min + 1e-3: parent.stack /= (img_max - img_min) parent.stack *= 255 if load_3D: print("GUI_INFO: converted to float and normalized values to 0.0->255.0") del image gc.collect() parent.imask = 0 parent.Ly, parent.Lx = parent.stack.shape[-3:-1] parent.Ly0, parent.Lx0 = parent.stack.shape[-3:-1] parent.layerz = 255 * np.ones((parent.Ly, parent.Lx, 4), "uint8") if hasattr(parent, "stack_filtered"): parent.Lyr, parent.Lxr = parent.stack_filtered.shape[-3:-1] elif parent.restore and "upsample" in parent.restore: parent.Lyr, parent.Lxr = int(parent.Ly * parent.ratio), int(parent.Lx * parent.ratio) else: parent.Lyr, parent.Lxr = parent.Ly, parent.Lx parent.clear_all() if not hasattr(parent, "stack_filtered") and parent.restore: print("GUI_INFO: no 'img_restore' found, applying current settings") parent.compute_restore() if parent.autobtn.isChecked(): if parent.restore is None or parent.restore != "filter": print( "GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)" ) parent.compute_saturation() # elif len(parent.saturation) != parent.NZ: # parent.saturation = [] # for r in range(3): # parent.saturation.append([]) # for n in range(parent.NZ): # parent.saturation[-1].append([0, 255]) # parent.sliders[r].setValue([0, 255]) parent.compute_scale() parent.track_changes = [] if load_3D: parent.currentZ = int(np.floor(parent.NZ / 2)) parent.scroll.setValue(parent.currentZ) parent.zpos.setText(str(parent.currentZ)) else: parent.currentZ = 0 def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False): """ load *_seg.npy with filename; if None, open QFileDialog """ if filename is None: name = QFileDialog.getOpenFileName(parent, "Load labelled data", filter="*.npy") filename = name[0] try: dat = np.load(filename, allow_pickle=True).item() # check if there are keys in filename dat["outlines"] parent.loaded = True except: parent.loaded = False print("ERROR: not NPY") return parent.reset() if image is None: found_image = False if "filename" in dat: parent.filename = dat["filename"] if os.path.isfile(parent.filename): parent.filename = dat["filename"] found_image = True else: imgname = os.path.split(parent.filename)[1] root = os.path.split(filename)[0] parent.filename = root + "/" + imgname if os.path.isfile(parent.filename): found_image = True if found_image: try: print(parent.filename) image = (imread_2D(parent.filename) if not load_3D else imread_3D(parent.filename)) except: parent.loaded = False found_image = False print("ERROR: cannot find image file, loading from npy") if not found_image: parent.filename = filename[:-8] print(parent.filename) if "img" in dat: image = dat["img"] else: print("ERROR: no image file found and no image in npy") return else: parent.filename = image_file parent.restore = None parent.ratio = 1. if "normalize_params" in dat: parent.set_normalize_params(dat["normalize_params"]) _initialize_images(parent, image, load_3D=load_3D) print(parent.stack.shape) if "outlines" in dat: if isinstance(dat["outlines"], list): # old way of saving files dat["outlines"] = dat["outlines"][::-1] for k, outline in enumerate(dat["outlines"]): if "colors" in dat: color = dat["colors"][k] else: col_rand = np.random.randint(1000) color = parent.colormap[col_rand, :3] median = parent.add_mask(points=outline, color=color) if median is not None: parent.cellcolors = np.append(parent.cellcolors, color[np.newaxis, :], axis=0) parent.ncells += 1 else: if dat["masks"].min() == -1: dat["masks"] += 1 dat["outlines"] += 1 parent.ncells.set(dat["masks"].max()) if "colors" in dat and len(dat["colors"]) == dat["masks"].max(): colors = dat["colors"] else: colors = parent.colormap[:parent.ncells.get(), :3] _masks_to_gui(parent, dat["masks"], outlines=dat["outlines"], colors=colors) parent.draw_layer() if "manual_changes" in dat: parent.track_changes = dat["manual_changes"] print("GUI_INFO: loaded in previous changes") if "zdraw" in dat: parent.zdraw = dat["zdraw"] else: parent.zdraw = [None for n in range(parent.ncells.get())] parent.loaded = True else: parent.clear_all() parent.ismanual = np.zeros(parent.ncells.get(), bool) if "ismanual" in dat: if len(dat["ismanual"]) == parent.ncells: parent.ismanual = dat["ismanual"] if "current_channel" in dat: parent.color = (dat["current_channel"] + 2) % 5 parent.RGBDropDown.setCurrentIndex(parent.color) if "flows" in dat: parent.flows = dat["flows"] try: if parent.flows[0].shape[-3] != dat["masks"].shape[-2]: Ly, Lx = dat["masks"].shape[-2:] for i in range(len(parent.flows)): parent.flows[i] = cv2.resize( parent.flows[i].squeeze(), (Lx, Ly), interpolation=cv2.INTER_NEAREST)[np.newaxis, ...] if parent.NZ == 1: parent.recompute_masks = True else: parent.recompute_masks = False except: try: if len(parent.flows[0]) > 0: parent.flows = parent.flows[0] except: parent.flows = [[], [], [], [], [[]]] parent.recompute_masks = False parent.enable_buttons() parent.update_layer() del dat gc.collect() def _load_masks(parent, filename=None): """ load zeros-based masks (0=no cell, 1=cell 1, ...) """ if filename is None: name = QFileDialog.getOpenFileName(parent, "Load masks (PNG or TIFF)") filename = name[0] print(f"GUI_INFO: loading masks: {filename}") masks = imread(filename) outlines = None if masks.ndim > 3: # Z x nchannels x Ly x Lx if masks.shape[-1] > 5: parent.flows = list(np.transpose(masks[:, :, :, 2:], (3, 0, 1, 2))) outlines = masks[..., 1] masks = masks[..., 0] else: parent.flows = list(np.transpose(masks[:, :, :, 1:], (3, 0, 1, 2))) masks = masks[..., 0] elif masks.ndim == 3: if masks.shape[-1] < 5: masks = masks[np.newaxis, :, :, 0] elif masks.ndim < 3: masks = masks[np.newaxis, :, :] # masks should be Z x Ly x Lx if masks.shape[0] != parent.NZ: print("ERROR: masks are not same depth (number of planes) as image stack") return _masks_to_gui(parent, masks, outlines) if parent.ncells > 0: parent.draw_layer() parent.toggle_mask_ops() del masks gc.collect() parent.update_layer() parent.update_plot() def _masks_to_gui(parent, masks, outlines=None, colors=None): """ masks loaded into GUI """ # get unique values shape = masks.shape if len(fastremap.unique(masks)) != masks.max() + 1: print("GUI_INFO: renumbering masks") fastremap.renumber(masks, in_place=True) outlines = None masks = masks.reshape(shape) if masks.ndim == 2: outlines = None masks = masks.astype(np.uint16) if masks.max() < 2**16 - 1 else masks.astype( np.uint32) if parent.restore and "upsample" in parent.restore: parent.cellpix_resize = masks.copy() parent.cellpix = parent.cellpix_resize.copy() parent.cellpix_orig = cv2.resize( masks.squeeze(), (parent.Lx0, parent.Ly0), interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :] parent.resize = True else: parent.cellpix = masks if parent.cellpix.ndim == 2: parent.cellpix = parent.cellpix[np.newaxis, :, :] if parent.restore and "upsample" in parent.restore: if parent.cellpix_resize.ndim == 2: parent.cellpix_resize = parent.cellpix_resize[np.newaxis, :, :] if parent.cellpix_orig.ndim == 2: parent.cellpix_orig = parent.cellpix_orig[np.newaxis, :, :] print(f"GUI_INFO: {masks.max()} masks found") # get outlines if outlines is None: # parent.outlinesOn parent.outpix = np.zeros_like(parent.cellpix) if parent.restore and "upsample" in parent.restore: parent.outpix_orig = np.zeros_like(parent.cellpix_orig) for z in range(parent.NZ): outlines = masks_to_outlines(parent.cellpix[z]) parent.outpix[z] = outlines * parent.cellpix[z] if parent.restore and "upsample" in parent.restore: outlines = masks_to_outlines(parent.cellpix_orig[z]) parent.outpix_orig[z] = outlines * parent.cellpix_orig[z] if z % 50 == 0 and parent.NZ > 1: print("GUI_INFO: plane %d outlines processed" % z) if parent.restore and "upsample" in parent.restore: parent.outpix_resize = parent.outpix.copy() else: parent.outpix = outlines if parent.restore and "upsample" in parent.restore: parent.outpix_resize = parent.outpix.copy() parent.outpix_orig = np.zeros_like(parent.cellpix_orig) for z in range(parent.NZ): outlines = masks_to_outlines(parent.cellpix_orig[z]) parent.outpix_orig[z] = outlines * parent.cellpix_orig[z] if z % 50 == 0 and parent.NZ > 1: print("GUI_INFO: plane %d outlines processed" % z) if parent.outpix.ndim == 2: parent.outpix = parent.outpix[np.newaxis, :, :] if parent.restore and "upsample" in parent.restore: if parent.outpix_resize.ndim == 2: parent.outpix_resize = parent.outpix_resize[np.newaxis, :, :] if parent.outpix_orig.ndim == 2: parent.outpix_orig = parent.outpix_orig[np.newaxis, :, :] parent.ncells.set(parent.cellpix.max()) colors = parent.colormap[:parent.ncells.get(), :3] if colors is None else colors print("GUI_INFO: creating cellcolors and drawing masks") parent.cellcolors = np.concatenate((np.array([[255, 255, 255]]), colors), axis=0).astype(np.uint8) if parent.ncells > 0: parent.draw_layer() parent.toggle_mask_ops() parent.ismanual = np.zeros(parent.ncells.get(), bool) parent.zdraw = list(-1 * np.ones(parent.ncells.get(), np.int16)) if hasattr(parent, "stack_filtered"): parent.ViewDropDown.setCurrentIndex(parent.ViewDropDown.count() - 1) print("set denoised/filtered view") else: parent.ViewDropDown.setCurrentIndex(0) def _save_png(parent): """ save masks to png or tiff (if 3D) """ filename = parent.filename base = os.path.splitext(filename)[0] if parent.NZ == 1: if parent.cellpix[0].max() > 65534: print("GUI_INFO: saving 2D masks to tif (too many masks for PNG)") imsave(base + "_cp_masks.tif", parent.cellpix[0]) else: print("GUI_INFO: saving 2D masks to png") imsave(base + "_cp_masks.png", parent.cellpix[0].astype(np.uint16)) else: print("GUI_INFO: saving 3D masks to tiff") imsave(base + "_cp_masks.tif", parent.cellpix) def _save_flows(parent): """ save flows and cellprob to tiff """ filename = parent.filename base = os.path.splitext(filename)[0] print("GUI_INFO: saving flows and cellprob to tiff") if len(parent.flows) > 0: imsave(base + "_cp_cellprob.tif", parent.flows[1]) for i in range(3): imsave(base + f"_cp_flows_{i}.tif", parent.flows[0][..., i]) if len(parent.flows) > 2: imsave(base + "_cp_flows.tif", parent.flows[2]) print("GUI_INFO: saved flows and cellprob") else: print("ERROR: no flows or cellprob found") def _save_rois(parent): """ save masks as rois in .zip file for ImageJ """ filename = parent.filename if parent.NZ == 1: print( f"GUI_INFO: saving {parent.cellpix[0].max()} ImageJ ROIs to .zip archive.") save_rois(parent.cellpix[0], parent.filename) else: print("ERROR: cannot save 3D outlines") def _save_outlines(parent): filename = parent.filename base = os.path.splitext(filename)[0] if parent.NZ == 1: print( "GUI_INFO: saving 2D outlines to text file, see docs for info to load into ImageJ" ) outlines = outlines_list(parent.cellpix[0]) outlines_to_text(base, outlines) else: print("ERROR: cannot save 3D outlines") def _save_sets_with_check(parent): """ Save masks and update *_seg.npy file. Use this function when saving should be optional based on the disableAutosave checkbox. Otherwise, use _save_sets """ if not parent.disableAutosave.isChecked(): _save_sets(parent) def _save_sets(parent): """ save masks to *_seg.npy. This function should be used when saving is forced, e.g. when clicking the save button. Otherwise, use _save_sets_with_check """ filename = parent.filename base = os.path.splitext(filename)[0] flow_threshold = parent.segmentation_settings.flow_threshold cellprob_threshold = parent.segmentation_settings.cellprob_threshold if parent.NZ > 1: dat = { "outlines": parent.outpix, "colors": parent.cellcolors[1:], "masks": parent.cellpix, "current_channel": (parent.color - 2) % 5, "filename": parent.filename, "flows": parent.flows, "zdraw": parent.zdraw, "model_path": parent.current_model_path if hasattr(parent, "current_model_path") else 0, "flow_threshold": flow_threshold, "cellprob_threshold": cellprob_threshold, "normalize_params": parent.get_normalize_params(), "restore": parent.restore, "ratio": parent.ratio, "diameter": parent.segmentation_settings.diameter } if parent.restore is not None: dat["img_restore"] = parent.stack_filtered else: dat = { "outlines": parent.outpix.squeeze() if parent.restore is None or not "upsample" in parent.restore else parent.outpix_resize.squeeze(), "colors": parent.cellcolors[1:], "masks": parent.cellpix.squeeze() if parent.restore is None or not "upsample" in parent.restore else parent.cellpix_resize.squeeze(), "filename": parent.filename, "flows": parent.flows, "ismanual": parent.ismanual, "manual_changes": parent.track_changes, "model_path": parent.current_model_path if hasattr(parent, "current_model_path") else 0, "flow_threshold": flow_threshold, "cellprob_threshold": cellprob_threshold, "normalize_params": parent.get_normalize_params(), "restore": parent.restore, "ratio": parent.ratio, "diameter": parent.segmentation_settings.diameter } if parent.restore is not None: dat["img_restore"] = parent.stack_filtered try: np.save(base + "_seg.npy", dat) print("GUI_INFO: %d ROIs saved to %s" % (parent.ncells.get(), base + "_seg.npy")) except Exception as e: print(f"ERROR: {e}") del dat