phoebehxf
init
aff3c6f
"""
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
"""
import os, gc
import numpy as np
import cv2
import fastremap
from ..io import imread, imread_2D, imread_3D, imsave, outlines_to_text, add_model, remove_model, save_rois
from ..models import normalize_default, MODEL_DIR, MODEL_LIST_PATH, get_user_models
from ..utils import masks_to_outlines, outlines_list
try:
import qtpy
from qtpy.QtWidgets import QFileDialog
GUI = True
except:
GUI = False
try:
import matplotlib.pyplot as plt
MATPLOTLIB = True
except:
MATPLOTLIB = False
def _init_model_list(parent):
MODEL_DIR.mkdir(parents=True, exist_ok=True)
parent.model_list_path = MODEL_LIST_PATH
parent.model_strings = get_user_models()
def _add_model(parent, filename=None, load_model=True):
if filename is None:
name = QFileDialog.getOpenFileName(parent, "Add model to GUI")
filename = name[0]
add_model(filename)
fname = os.path.split(filename)[-1]
parent.ModelChooseC.addItems([fname])
parent.model_strings.append(fname)
for ind, model_string in enumerate(parent.model_strings[:-1]):
if model_string == fname:
_remove_model(parent, ind=ind + 1, verbose=False)
parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
if load_model:
parent.model_choose(custom=True)
def _remove_model(parent, ind=None, verbose=True):
if ind is None:
ind = parent.ModelChooseC.currentIndex()
if ind > 0:
ind -= 1
parent.ModelChooseC.removeItem(ind + 1)
del parent.model_strings[ind]
# remove model from txt path
modelstr = parent.ModelChooseC.currentText()
remove_model(modelstr)
if len(parent.model_strings) > 0:
parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
else:
parent.ModelChooseC.setCurrentIndex(0)
else:
print("ERROR: no model selected to delete")
def _get_train_set(image_names):
""" get training data and labels for images in current folder image_names"""
train_data, train_labels, train_files = [], [], []
restore = None
normalize_params = normalize_default
for image_name_full in image_names:
image_name = os.path.splitext(image_name_full)[0]
label_name = None
if os.path.exists(image_name + "_seg.npy"):
dat = np.load(image_name + "_seg.npy", allow_pickle=True).item()
masks = dat["masks"].squeeze()
if masks.ndim == 2:
fastremap.renumber(masks, in_place=True)
label_name = image_name + "_seg.npy"
else:
print(f"GUI_INFO: _seg.npy found for {image_name} but masks.ndim!=2")
if "img_restore" in dat:
data = dat["img_restore"].squeeze()
restore = dat["restore"]
else:
data = imread(image_name_full)
normalize_params = dat[
"normalize_params"] if "normalize_params" in dat else normalize_default
if label_name is not None:
train_files.append(image_name_full)
train_data.append(data)
train_labels.append(masks)
if restore:
print(f"GUI_INFO: using {restore} images (dat['img_restore'])")
return train_data, train_labels, train_files, restore, normalize_params
def _load_image(parent, filename=None, load_seg=True, load_3D=False):
""" load image with filename; if None, open QFileDialog
if image is grey change view to default to grey scale
"""
if parent.load_3D:
load_3D = True
if filename is None:
name = QFileDialog.getOpenFileName(parent, "Load image")
filename = name[0]
if filename == "":
return
manual_file = os.path.splitext(filename)[0] + "_seg.npy"
load_mask = False
if load_seg:
if os.path.isfile(manual_file) and not parent.autoloadMasks.isChecked():
if filename is not None:
image = (imread_2D(filename) if not load_3D else
imread_3D(filename))
else:
image = None
_load_seg(parent, manual_file, image=image, image_file=filename,
load_3D=load_3D)
return
elif parent.autoloadMasks.isChecked():
mask_file = os.path.splitext(filename)[0] + "_masks" + os.path.splitext(
filename)[-1]
mask_file = os.path.splitext(filename)[
0] + "_masks.tif" if not os.path.isfile(mask_file) else mask_file
load_mask = True if os.path.isfile(mask_file) else False
try:
print(f"GUI_INFO: loading image: {filename}")
if not load_3D:
image = imread_2D(filename)
else:
image = imread_3D(filename)
parent.loaded = True
except Exception as e:
print("ERROR: images not compatible")
print(f"ERROR: {e}")
if parent.loaded:
parent.reset()
parent.filename = filename
filename = os.path.split(parent.filename)[-1]
_initialize_images(parent, image, load_3D=load_3D)
parent.loaded = True
parent.enable_buttons()
if load_mask:
_load_masks(parent, filename=mask_file)
# check if gray and adjust viewer:
if len(np.unique(image[..., 1:])) == 1:
parent.color = 4
parent.RGBDropDown.setCurrentIndex(4) # gray
parent.update_plot()
def _initialize_images(parent, image, load_3D=False):
""" format image for GUI
assumes image is Z x W x H x C
"""
load_3D = parent.load_3D if load_3D is False else load_3D
parent.stack = image
print(f"GUI_INFO: image shape: {image.shape}")
if load_3D:
parent.NZ = len(parent.stack)
parent.scroll.setMaximum(parent.NZ - 1)
else:
parent.NZ = 1
parent.stack = parent.stack[np.newaxis, ...]
img_min = image.min()
img_max = image.max()
parent.stack = parent.stack.astype(np.float32)
parent.stack -= img_min
if img_max > img_min + 1e-3:
parent.stack /= (img_max - img_min)
parent.stack *= 255
if load_3D:
print("GUI_INFO: converted to float and normalized values to 0.0->255.0")
del image
gc.collect()
parent.imask = 0
parent.Ly, parent.Lx = parent.stack.shape[-3:-1]
parent.Ly0, parent.Lx0 = parent.stack.shape[-3:-1]
parent.layerz = 255 * np.ones((parent.Ly, parent.Lx, 4), "uint8")
if hasattr(parent, "stack_filtered"):
parent.Lyr, parent.Lxr = parent.stack_filtered.shape[-3:-1]
elif parent.restore and "upsample" in parent.restore:
parent.Lyr, parent.Lxr = int(parent.Ly * parent.ratio), int(parent.Lx *
parent.ratio)
else:
parent.Lyr, parent.Lxr = parent.Ly, parent.Lx
parent.clear_all()
if not hasattr(parent, "stack_filtered") and parent.restore:
print("GUI_INFO: no 'img_restore' found, applying current settings")
parent.compute_restore()
if parent.autobtn.isChecked():
if parent.restore is None or parent.restore != "filter":
print(
"GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)"
)
parent.compute_saturation()
# elif len(parent.saturation) != parent.NZ:
# parent.saturation = []
# for r in range(3):
# parent.saturation.append([])
# for n in range(parent.NZ):
# parent.saturation[-1].append([0, 255])
# parent.sliders[r].setValue([0, 255])
parent.compute_scale()
parent.track_changes = []
if load_3D:
parent.currentZ = int(np.floor(parent.NZ / 2))
parent.scroll.setValue(parent.currentZ)
parent.zpos.setText(str(parent.currentZ))
else:
parent.currentZ = 0
def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False):
""" load *_seg.npy with filename; if None, open QFileDialog """
if filename is None:
name = QFileDialog.getOpenFileName(parent, "Load labelled data", filter="*.npy")
filename = name[0]
try:
dat = np.load(filename, allow_pickle=True).item()
# check if there are keys in filename
dat["outlines"]
parent.loaded = True
except:
parent.loaded = False
print("ERROR: not NPY")
return
parent.reset()
if image is None:
found_image = False
if "filename" in dat:
parent.filename = dat["filename"]
if os.path.isfile(parent.filename):
parent.filename = dat["filename"]
found_image = True
else:
imgname = os.path.split(parent.filename)[1]
root = os.path.split(filename)[0]
parent.filename = root + "/" + imgname
if os.path.isfile(parent.filename):
found_image = True
if found_image:
try:
print(parent.filename)
image = (imread_2D(parent.filename) if not load_3D else
imread_3D(parent.filename))
except:
parent.loaded = False
found_image = False
print("ERROR: cannot find image file, loading from npy")
if not found_image:
parent.filename = filename[:-8]
print(parent.filename)
if "img" in dat:
image = dat["img"]
else:
print("ERROR: no image file found and no image in npy")
return
else:
parent.filename = image_file
parent.restore = None
parent.ratio = 1.
if "normalize_params" in dat:
parent.set_normalize_params(dat["normalize_params"])
_initialize_images(parent, image, load_3D=load_3D)
print(parent.stack.shape)
if "outlines" in dat:
if isinstance(dat["outlines"], list):
# old way of saving files
dat["outlines"] = dat["outlines"][::-1]
for k, outline in enumerate(dat["outlines"]):
if "colors" in dat:
color = dat["colors"][k]
else:
col_rand = np.random.randint(1000)
color = parent.colormap[col_rand, :3]
median = parent.add_mask(points=outline, color=color)
if median is not None:
parent.cellcolors = np.append(parent.cellcolors,
color[np.newaxis, :], axis=0)
parent.ncells += 1
else:
if dat["masks"].min() == -1:
dat["masks"] += 1
dat["outlines"] += 1
parent.ncells.set(dat["masks"].max())
if "colors" in dat and len(dat["colors"]) == dat["masks"].max():
colors = dat["colors"]
else:
colors = parent.colormap[:parent.ncells.get(), :3]
_masks_to_gui(parent, dat["masks"], outlines=dat["outlines"], colors=colors)
parent.draw_layer()
if "manual_changes" in dat:
parent.track_changes = dat["manual_changes"]
print("GUI_INFO: loaded in previous changes")
if "zdraw" in dat:
parent.zdraw = dat["zdraw"]
else:
parent.zdraw = [None for n in range(parent.ncells.get())]
parent.loaded = True
else:
parent.clear_all()
parent.ismanual = np.zeros(parent.ncells.get(), bool)
if "ismanual" in dat:
if len(dat["ismanual"]) == parent.ncells:
parent.ismanual = dat["ismanual"]
if "current_channel" in dat:
parent.color = (dat["current_channel"] + 2) % 5
parent.RGBDropDown.setCurrentIndex(parent.color)
if "flows" in dat:
parent.flows = dat["flows"]
try:
if parent.flows[0].shape[-3] != dat["masks"].shape[-2]:
Ly, Lx = dat["masks"].shape[-2:]
for i in range(len(parent.flows)):
parent.flows[i] = cv2.resize(
parent.flows[i].squeeze(), (Lx, Ly),
interpolation=cv2.INTER_NEAREST)[np.newaxis, ...]
if parent.NZ == 1:
parent.recompute_masks = True
else:
parent.recompute_masks = False
except:
try:
if len(parent.flows[0]) > 0:
parent.flows = parent.flows[0]
except:
parent.flows = [[], [], [], [], [[]]]
parent.recompute_masks = False
parent.enable_buttons()
parent.update_layer()
del dat
gc.collect()
def _load_masks(parent, filename=None):
""" load zeros-based masks (0=no cell, 1=cell 1, ...) """
if filename is None:
name = QFileDialog.getOpenFileName(parent, "Load masks (PNG or TIFF)")
filename = name[0]
print(f"GUI_INFO: loading masks: {filename}")
masks = imread(filename)
outlines = None
if masks.ndim > 3:
# Z x nchannels x Ly x Lx
if masks.shape[-1] > 5:
parent.flows = list(np.transpose(masks[:, :, :, 2:], (3, 0, 1, 2)))
outlines = masks[..., 1]
masks = masks[..., 0]
else:
parent.flows = list(np.transpose(masks[:, :, :, 1:], (3, 0, 1, 2)))
masks = masks[..., 0]
elif masks.ndim == 3:
if masks.shape[-1] < 5:
masks = masks[np.newaxis, :, :, 0]
elif masks.ndim < 3:
masks = masks[np.newaxis, :, :]
# masks should be Z x Ly x Lx
if masks.shape[0] != parent.NZ:
print("ERROR: masks are not same depth (number of planes) as image stack")
return
_masks_to_gui(parent, masks, outlines)
if parent.ncells > 0:
parent.draw_layer()
parent.toggle_mask_ops()
del masks
gc.collect()
parent.update_layer()
parent.update_plot()
def _masks_to_gui(parent, masks, outlines=None, colors=None):
""" masks loaded into GUI """
# get unique values
shape = masks.shape
if len(fastremap.unique(masks)) != masks.max() + 1:
print("GUI_INFO: renumbering masks")
fastremap.renumber(masks, in_place=True)
outlines = None
masks = masks.reshape(shape)
if masks.ndim == 2:
outlines = None
masks = masks.astype(np.uint16) if masks.max() < 2**16 - 1 else masks.astype(
np.uint32)
if parent.restore and "upsample" in parent.restore:
parent.cellpix_resize = masks.copy()
parent.cellpix = parent.cellpix_resize.copy()
parent.cellpix_orig = cv2.resize(
masks.squeeze(), (parent.Lx0, parent.Ly0),
interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :]
parent.resize = True
else:
parent.cellpix = masks
if parent.cellpix.ndim == 2:
parent.cellpix = parent.cellpix[np.newaxis, :, :]
if parent.restore and "upsample" in parent.restore:
if parent.cellpix_resize.ndim == 2:
parent.cellpix_resize = parent.cellpix_resize[np.newaxis, :, :]
if parent.cellpix_orig.ndim == 2:
parent.cellpix_orig = parent.cellpix_orig[np.newaxis, :, :]
print(f"GUI_INFO: {masks.max()} masks found")
# get outlines
if outlines is None: # parent.outlinesOn
parent.outpix = np.zeros_like(parent.cellpix)
if parent.restore and "upsample" in parent.restore:
parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
for z in range(parent.NZ):
outlines = masks_to_outlines(parent.cellpix[z])
parent.outpix[z] = outlines * parent.cellpix[z]
if parent.restore and "upsample" in parent.restore:
outlines = masks_to_outlines(parent.cellpix_orig[z])
parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
if z % 50 == 0 and parent.NZ > 1:
print("GUI_INFO: plane %d outlines processed" % z)
if parent.restore and "upsample" in parent.restore:
parent.outpix_resize = parent.outpix.copy()
else:
parent.outpix = outlines
if parent.restore and "upsample" in parent.restore:
parent.outpix_resize = parent.outpix.copy()
parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
for z in range(parent.NZ):
outlines = masks_to_outlines(parent.cellpix_orig[z])
parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
if z % 50 == 0 and parent.NZ > 1:
print("GUI_INFO: plane %d outlines processed" % z)
if parent.outpix.ndim == 2:
parent.outpix = parent.outpix[np.newaxis, :, :]
if parent.restore and "upsample" in parent.restore:
if parent.outpix_resize.ndim == 2:
parent.outpix_resize = parent.outpix_resize[np.newaxis, :, :]
if parent.outpix_orig.ndim == 2:
parent.outpix_orig = parent.outpix_orig[np.newaxis, :, :]
parent.ncells.set(parent.cellpix.max())
colors = parent.colormap[:parent.ncells.get(), :3] if colors is None else colors
print("GUI_INFO: creating cellcolors and drawing masks")
parent.cellcolors = np.concatenate((np.array([[255, 255, 255]]), colors),
axis=0).astype(np.uint8)
if parent.ncells > 0:
parent.draw_layer()
parent.toggle_mask_ops()
parent.ismanual = np.zeros(parent.ncells.get(), bool)
parent.zdraw = list(-1 * np.ones(parent.ncells.get(), np.int16))
if hasattr(parent, "stack_filtered"):
parent.ViewDropDown.setCurrentIndex(parent.ViewDropDown.count() - 1)
print("set denoised/filtered view")
else:
parent.ViewDropDown.setCurrentIndex(0)
def _save_png(parent):
""" save masks to png or tiff (if 3D) """
filename = parent.filename
base = os.path.splitext(filename)[0]
if parent.NZ == 1:
if parent.cellpix[0].max() > 65534:
print("GUI_INFO: saving 2D masks to tif (too many masks for PNG)")
imsave(base + "_cp_masks.tif", parent.cellpix[0])
else:
print("GUI_INFO: saving 2D masks to png")
imsave(base + "_cp_masks.png", parent.cellpix[0].astype(np.uint16))
else:
print("GUI_INFO: saving 3D masks to tiff")
imsave(base + "_cp_masks.tif", parent.cellpix)
def _save_flows(parent):
""" save flows and cellprob to tiff """
filename = parent.filename
base = os.path.splitext(filename)[0]
print("GUI_INFO: saving flows and cellprob to tiff")
if len(parent.flows) > 0:
imsave(base + "_cp_cellprob.tif", parent.flows[1])
for i in range(3):
imsave(base + f"_cp_flows_{i}.tif", parent.flows[0][..., i])
if len(parent.flows) > 2:
imsave(base + "_cp_flows.tif", parent.flows[2])
print("GUI_INFO: saved flows and cellprob")
else:
print("ERROR: no flows or cellprob found")
def _save_rois(parent):
""" save masks as rois in .zip file for ImageJ """
filename = parent.filename
if parent.NZ == 1:
print(
f"GUI_INFO: saving {parent.cellpix[0].max()} ImageJ ROIs to .zip archive.")
save_rois(parent.cellpix[0], parent.filename)
else:
print("ERROR: cannot save 3D outlines")
def _save_outlines(parent):
filename = parent.filename
base = os.path.splitext(filename)[0]
if parent.NZ == 1:
print(
"GUI_INFO: saving 2D outlines to text file, see docs for info to load into ImageJ"
)
outlines = outlines_list(parent.cellpix[0])
outlines_to_text(base, outlines)
else:
print("ERROR: cannot save 3D outlines")
def _save_sets_with_check(parent):
""" Save masks and update *_seg.npy file. Use this function when saving should be optional
based on the disableAutosave checkbox. Otherwise, use _save_sets """
if not parent.disableAutosave.isChecked():
_save_sets(parent)
def _save_sets(parent):
""" save masks to *_seg.npy. This function should be used when saving
is forced, e.g. when clicking the save button. Otherwise, use _save_sets_with_check
"""
filename = parent.filename
base = os.path.splitext(filename)[0]
flow_threshold = parent.segmentation_settings.flow_threshold
cellprob_threshold = parent.segmentation_settings.cellprob_threshold
if parent.NZ > 1:
dat = {
"outlines":
parent.outpix,
"colors":
parent.cellcolors[1:],
"masks":
parent.cellpix,
"current_channel": (parent.color - 2) % 5,
"filename":
parent.filename,
"flows":
parent.flows,
"zdraw":
parent.zdraw,
"model_path":
parent.current_model_path
if hasattr(parent, "current_model_path") else 0,
"flow_threshold":
flow_threshold,
"cellprob_threshold":
cellprob_threshold,
"normalize_params":
parent.get_normalize_params(),
"restore":
parent.restore,
"ratio":
parent.ratio,
"diameter":
parent.segmentation_settings.diameter
}
if parent.restore is not None:
dat["img_restore"] = parent.stack_filtered
else:
dat = {
"outlines":
parent.outpix.squeeze() if parent.restore is None or
not "upsample" in parent.restore else parent.outpix_resize.squeeze(),
"colors":
parent.cellcolors[1:],
"masks":
parent.cellpix.squeeze() if parent.restore is None or
not "upsample" in parent.restore else parent.cellpix_resize.squeeze(),
"filename":
parent.filename,
"flows":
parent.flows,
"ismanual":
parent.ismanual,
"manual_changes":
parent.track_changes,
"model_path":
parent.current_model_path
if hasattr(parent, "current_model_path") else 0,
"flow_threshold":
flow_threshold,
"cellprob_threshold":
cellprob_threshold,
"normalize_params":
parent.get_normalize_params(),
"restore":
parent.restore,
"ratio":
parent.ratio,
"diameter":
parent.segmentation_settings.diameter
}
if parent.restore is not None:
dat["img_restore"] = parent.stack_filtered
try:
np.save(base + "_seg.npy", dat)
print("GUI_INFO: %d ROIs saved to %s" % (parent.ncells.get(), base + "_seg.npy"))
except Exception as e:
print(f"ERROR: {e}")
del dat