""" Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu. """ import sys, os, pathlib, warnings, datetime, time, copy from qtpy import QtGui, QtCore from superqt import QRangeSlider, QCollapsible from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, \ QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, \ QLineEdit, QMessageBox, QGroupBox, QMenu, QAction import pyqtgraph as pg import numpy as np from scipy.stats import mode import cv2 from . import guiparts, menus, io from .. import models, core, dynamics, version, train from ..utils import download_url_to_file, masks_to_outlines, diameters from ..io import get_image_files, imsave, imread from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img from ..models import normalize_default from ..plot import disk try: import matplotlib.pyplot as plt MATPLOTLIB = True except: MATPLOTLIB = False Horizontal = QtCore.Qt.Orientation.Horizontal class Slider(QRangeSlider): def __init__(self, parent, name, color): super().__init__(Horizontal) self.setEnabled(False) self.valueChanged.connect(lambda: self.levelChanged(parent)) self.name = name self.setStyleSheet(""" QSlider{ background-color: transparent; } """) self.show() def levelChanged(self, parent): parent.level_change(self.name) class QHLine(QFrame): def __init__(self): super(QHLine, self).__init__() self.setFrameShape(QFrame.HLine) self.setLineWidth(8) def make_bwr(): # make a bwr colormap b = np.append(255 * np.ones(128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis] r = np.append(np.linspace(0, 255, 128), 255 * np.ones(128))[:, np.newaxis] g = np.append(np.linspace(0, 255, 128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis] color = np.concatenate((r, g, b), axis=-1).astype(np.uint8) bwr = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) return bwr def make_spectral(): # make spectral colormap r = np.array([ 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 120, 112, 104, 96, 88, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 7, 11, 15, 19, 23, 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103, 107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167, 171, 175, 179, 183, 187, 191, 195, 199, 203, 207, 211, 215, 219, 223, 227, 231, 235, 239, 243, 247, 251, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 ]) g = np.array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 247, 239, 231, 223, 215, 207, 199, 191, 183, 175, 167, 159, 151, 143, 135, 128, 129, 131, 132, 134, 135, 137, 139, 140, 142, 143, 145, 147, 148, 150, 151, 153, 154, 156, 158, 159, 161, 162, 164, 166, 167, 169, 170, 172, 174, 175, 177, 178, 180, 181, 183, 185, 186, 188, 189, 191, 193, 194, 196, 197, 199, 201, 202, 204, 205, 207, 208, 210, 212, 213, 215, 216, 218, 220, 221, 223, 224, 226, 228, 229, 231, 232, 234, 235, 237, 239, 240, 242, 243, 245, 247, 248, 250, 251, 253, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 127, 123, 119, 115, 111, 107, 103, 99, 95, 91, 87, 83, 79, 75, 71, 67, 63, 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205, 213, 222, 230, 238, 246, 254 ]) b = np.array([ 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 128, 126, 124, 122, 120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100, 98, 96, 94, 92, 90, 88, 86, 84, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205, 213, 222, 230, 238, 246, 254 ]) color = (np.vstack((r, g, b)).T).astype(np.uint8) spectral = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) return spectral def make_cmap(cm=0): # make a single channel colormap r = np.arange(0, 256) color = np.zeros((256, 3)) color[:, cm] = r color = color.astype(np.uint8) cmap = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) return cmap def run(image=None): from ..io import logger_setup logger, log_file = logger_setup() # Always start by initializing Qt (only once per application) warnings.filterwarnings("ignore") app = QApplication(sys.argv) icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png") guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png") if not icon_path.is_file(): cp_dir = pathlib.Path.home().joinpath(".cellpose") cp_dir.mkdir(exist_ok=True) print("downloading logo") download_url_to_file( "https://www.cellpose.org/static/images/cellpose_transparent.png", icon_path, progress=True) if not guip_path.is_file(): print("downloading help window image") download_url_to_file("https://www.cellpose.org/static/images/cellposeSAM_gui.png", guip_path, progress=True) icon_path = str(icon_path.resolve()) app_icon = QtGui.QIcon() app_icon.addFile(icon_path, QtCore.QSize(16, 16)) app_icon.addFile(icon_path, QtCore.QSize(24, 24)) app_icon.addFile(icon_path, QtCore.QSize(32, 32)) app_icon.addFile(icon_path, QtCore.QSize(48, 48)) app_icon.addFile(icon_path, QtCore.QSize(64, 64)) app_icon.addFile(icon_path, QtCore.QSize(256, 256)) app.setWindowIcon(app_icon) app.setStyle("Fusion") app.setPalette(guiparts.DarkPalette()) MainW(image=image, logger=logger) ret = app.exec_() sys.exit(ret) class MainW(QMainWindow): def __init__(self, image=None, logger=None): super(MainW, self).__init__() self.logger = logger pg.setConfigOptions(imageAxisOrder="row-major") self.setGeometry(50, 50, 1200, 1000) self.setWindowTitle(f"cellpose v{version}") self.cp_path = os.path.dirname(os.path.realpath(__file__)) app_icon = QtGui.QIcon() icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png") icon_path = str(icon_path.resolve()) app_icon.addFile(icon_path, QtCore.QSize(16, 16)) app_icon.addFile(icon_path, QtCore.QSize(24, 24)) app_icon.addFile(icon_path, QtCore.QSize(32, 32)) app_icon.addFile(icon_path, QtCore.QSize(48, 48)) app_icon.addFile(icon_path, QtCore.QSize(64, 64)) app_icon.addFile(icon_path, QtCore.QSize(256, 256)) self.setWindowIcon(app_icon) # rgb(150,255,150) self.setStyleSheet(guiparts.stylesheet()) menus.mainmenu(self) menus.editmenu(self) menus.modelmenu(self) menus.helpmenu(self) self.stylePressed = """QPushButton {Text-align: center; background-color: rgb(150,50,150); border-color: white; color:white;} QToolTip { background-color: black; color: white; border: black solid 1px }""" self.styleUnpressed = """QPushButton {Text-align: center; background-color: rgb(50,50,50); border-color: white; color:white;} QToolTip { background-color: black; color: white; border: black solid 1px }""" self.loaded = False # ---- MAIN WIDGET LAYOUT ---- # self.cwidget = QWidget(self) self.lmain = QGridLayout() self.cwidget.setLayout(self.lmain) self.setCentralWidget(self.cwidget) self.lmain.setVerticalSpacing(0) self.lmain.setContentsMargins(0, 0, 0, 10) self.imask = 0 self.scrollarea = QScrollArea() self.scrollarea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn) self.scrollarea.setStyleSheet("""QScrollArea { border: none }""") self.scrollarea.setWidgetResizable(True) self.swidget = QWidget(self) self.scrollarea.setWidget(self.swidget) self.l0 = QGridLayout() self.swidget.setLayout(self.l0) b = self.make_buttons() self.lmain.addWidget(self.scrollarea, 0, 0, 39, 9) # ---- drawing area ---- # self.win = pg.GraphicsLayoutWidget() self.lmain.addWidget(self.win, 0, 9, 40, 30) self.win.scene().sigMouseClicked.connect(self.plot_clicked) self.win.scene().sigMouseMoved.connect(self.mouse_moved) self.make_viewbox() self.lmain.setColumnStretch(10, 1) bwrmap = make_bwr() self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False) self.cmap = [] # spectral colormap self.cmap.append(make_spectral().getLookupTable(start=0.0, stop=255.0, alpha=False)) # single channel colormaps for i in range(3): self.cmap.append( make_cmap(i).getLookupTable(start=0.0, stop=255.0, alpha=False)) if MATPLOTLIB: self.colormap = (plt.get_cmap("gist_ncar")(np.linspace(0.0, .9, 1000000)) * 255).astype(np.uint8) np.random.seed(42) # make colors stable self.colormap = self.colormap[np.random.permutation(1000000)] else: np.random.seed(42) # make colors stable self.colormap = ((np.random.rand(1000000, 3) * 0.8 + 0.1) * 255).astype( np.uint8) self.NZ = 1 self.restore = None self.ratio = 1. self.reset() # This needs to go after .reset() is called to get state fully set up: self.autobtn.checkStateChanged.connect(self.compute_saturation_if_checked) self.load_3D = False # if called with image, load it if image is not None: self.filename = image io._load_image(self, self.filename) # training settings d = datetime.datetime.now() self.training_params = { "model_index": 0, "learning_rate": 1e-5, "weight_decay": 0.1, "n_epochs": 100, "model_name": "cpsam" + d.strftime("_%Y%m%d_%H%M%S"), } self.stitch_threshold = 0. self.flow3D_smooth = 0. self.anisotropy = 1. self.min_size = 15 self.setAcceptDrops(True) self.win.show() self.show() def help_window(self): HW = guiparts.HelpWindow(self) HW.show() def train_help_window(self): THW = guiparts.TrainHelpWindow(self) THW.show() def gui_window(self): EG = guiparts.ExampleGUI(self) EG.show() def make_buttons(self): self.boldfont = QtGui.QFont("Arial", 11, QtGui.QFont.Bold) self.boldmedfont = QtGui.QFont("Arial", 9, QtGui.QFont.Bold) self.medfont = QtGui.QFont("Arial", 9) self.smallfont = QtGui.QFont("Arial", 8) b = 0 self.satBox = QGroupBox("Views") self.satBox.setFont(self.boldfont) self.satBoxG = QGridLayout() self.satBox.setLayout(self.satBoxG) self.l0.addWidget(self.satBox, b, 0, 1, 9) widget_row = 0 self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B self.RGBDropDown = QComboBox() self.RGBDropDown.addItems( ["RGB", "red=R", "green=G", "blue=B", "gray", "spectral"]) self.RGBDropDown.setFont(self.medfont) self.RGBDropDown.currentIndexChanged.connect(self.color_choose) self.satBoxG.addWidget(self.RGBDropDown, widget_row, 0, 1, 3) label = QLabel("

[↑ / ↓ or W/S]

") label.setFont(self.smallfont) self.satBoxG.addWidget(label, widget_row, 3, 1, 3) label = QLabel("[R / G / B \n toggles color ]") label.setFont(self.smallfont) self.satBoxG.addWidget(label, widget_row, 6, 1, 3) widget_row += 1 self.ViewDropDown = QComboBox() self.ViewDropDown.addItems(["image", "gradXY", "cellprob", "restored"]) self.ViewDropDown.setFont(self.medfont) self.ViewDropDown.model().item(3).setEnabled(False) self.ViewDropDown.currentIndexChanged.connect(self.update_plot) self.satBoxG.addWidget(self.ViewDropDown, widget_row, 0, 2, 3) label = QLabel("[pageup / pagedown]") label.setFont(self.smallfont) self.satBoxG.addWidget(label, widget_row, 3, 1, 5) widget_row += 2 label = QLabel("") label.setToolTip( "NOTE: manually changing the saturation bars does not affect normalization in segmentation" ) self.satBoxG.addWidget(label, widget_row, 0, 1, 5) self.autobtn = QCheckBox("auto-adjust saturation") self.autobtn.setToolTip("sets scale-bars as normalized for segmentation") self.autobtn.setFont(self.medfont) self.autobtn.setChecked(True) self.satBoxG.addWidget(self.autobtn, widget_row, 1, 1, 8) widget_row += 1 self.sliders = [] colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [100, 100, 100]] colornames = ["red", "Chartreuse", "DodgerBlue"] names = ["red", "green", "blue"] for r in range(3): widget_row += 1 if r == 0: label = QLabel('gray/
red') else: label = QLabel(names[r] + ":") label.setStyleSheet(f"color: {colornames[r]}") label.setFont(self.boldmedfont) self.satBoxG.addWidget(label, widget_row, 0, 1, 2) self.sliders.append(Slider(self, names[r], colors[r])) self.sliders[-1].setMinimum(-.1) self.sliders[-1].setMaximum(255.1) self.sliders[-1].setValue([0, 255]) self.sliders[-1].setToolTip( "NOTE: manually changing the saturation bars does not affect normalization in segmentation" ) self.satBoxG.addWidget(self.sliders[-1], widget_row, 2, 1, 7) b += 1 self.drawBox = QGroupBox("Drawing") self.drawBox.setFont(self.boldfont) self.drawBoxG = QGridLayout() self.drawBox.setLayout(self.drawBoxG) self.l0.addWidget(self.drawBox, b, 0, 1, 9) self.autosave = True widget_row = 0 self.brush_size = 3 self.BrushChoose = QComboBox() self.BrushChoose.addItems(["1", "3", "5", "7", "9"]) self.BrushChoose.currentIndexChanged.connect(self.brush_choose) self.BrushChoose.setFixedWidth(40) self.BrushChoose.setFont(self.medfont) self.drawBoxG.addWidget(self.BrushChoose, widget_row, 3, 1, 2) label = QLabel("brush size:") label.setFont(self.medfont) self.drawBoxG.addWidget(label, widget_row, 0, 1, 3) widget_row += 1 # turn off masks self.layer_off = False self.masksOn = True self.MCheckBox = QCheckBox("MASKS ON [X]") self.MCheckBox.setFont(self.medfont) self.MCheckBox.setChecked(True) self.MCheckBox.toggled.connect(self.toggle_masks) self.drawBoxG.addWidget(self.MCheckBox, widget_row, 0, 1, 5) widget_row += 1 # turn off outlines self.outlinesOn = False # turn off by default self.OCheckBox = QCheckBox("outlines on [Z]") self.OCheckBox.setFont(self.medfont) self.drawBoxG.addWidget(self.OCheckBox, widget_row, 0, 1, 5) self.OCheckBox.setChecked(False) self.OCheckBox.toggled.connect(self.toggle_masks) widget_row += 1 self.SCheckBox = QCheckBox("single stroke") self.SCheckBox.setFont(self.medfont) self.SCheckBox.setChecked(True) self.SCheckBox.toggled.connect(self.autosave_on) self.SCheckBox.setEnabled(True) self.drawBoxG.addWidget(self.SCheckBox, widget_row, 0, 1, 5) # buttons for deleting multiple cells self.deleteBox = QGroupBox("delete multiple ROIs") self.deleteBox.setStyleSheet("color: rgb(200, 200, 200)") self.deleteBox.setFont(self.medfont) self.deleteBoxG = QGridLayout() self.deleteBox.setLayout(self.deleteBoxG) self.drawBoxG.addWidget(self.deleteBox, 0, 5, 4, 4) self.MakeDeletionRegionButton = QPushButton("region-select") self.MakeDeletionRegionButton.clicked.connect(self.remove_region_cells) self.deleteBoxG.addWidget(self.MakeDeletionRegionButton, 0, 0, 1, 4) self.MakeDeletionRegionButton.setFont(self.smallfont) self.MakeDeletionRegionButton.setFixedWidth(70) self.DeleteMultipleROIButton = QPushButton("click-select") self.DeleteMultipleROIButton.clicked.connect(self.delete_multiple_cells) self.deleteBoxG.addWidget(self.DeleteMultipleROIButton, 1, 0, 1, 4) self.DeleteMultipleROIButton.setFont(self.smallfont) self.DeleteMultipleROIButton.setFixedWidth(70) self.DoneDeleteMultipleROIButton = QPushButton("done") self.DoneDeleteMultipleROIButton.clicked.connect( self.done_remove_multiple_cells) self.deleteBoxG.addWidget(self.DoneDeleteMultipleROIButton, 2, 0, 1, 2) self.DoneDeleteMultipleROIButton.setFont(self.smallfont) self.DoneDeleteMultipleROIButton.setFixedWidth(35) self.CancelDeleteMultipleROIButton = QPushButton("cancel") self.CancelDeleteMultipleROIButton.clicked.connect(self.cancel_remove_multiple) self.deleteBoxG.addWidget(self.CancelDeleteMultipleROIButton, 2, 2, 1, 2) self.CancelDeleteMultipleROIButton.setFont(self.smallfont) self.CancelDeleteMultipleROIButton.setFixedWidth(35) b += 1 widget_row = 0 self.segBox = QGroupBox("Segmentation") self.segBoxG = QGridLayout() self.segBox.setLayout(self.segBoxG) self.l0.addWidget(self.segBox, b, 0, 1, 9) self.segBox.setFont(self.boldfont) widget_row += 1 # use GPU self.useGPU = QCheckBox("use GPU") self.useGPU.setToolTip( "if you have specially installed the cuda version of torch, then you can activate this" ) self.useGPU.setFont(self.medfont) self.check_gpu() self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3) # compute segmentation with general models self.net_text = ["run CPSAM"] nett = ["cellpose super-generalist model"] self.StyleButtons = [] jj = 4 for j in range(len(self.net_text)): self.StyleButtons.append( guiparts.ModelButton(self, self.net_text[j], self.net_text[j])) w = 5 self.segBoxG.addWidget(self.StyleButtons[-1], widget_row, jj, 1, w) jj += w self.StyleButtons[-1].setToolTip(nett[j]) widget_row += 1 self.ncells = guiparts.ObservableVariable(0) self.roi_count = QLabel() self.roi_count.setFont(self.boldfont) self.roi_count.setAlignment(QtCore.Qt.AlignLeft) self.ncells.valueChanged.connect( lambda n: self.roi_count.setText(f'{str(n)} ROIs') ) self.segBoxG.addWidget(self.roi_count, widget_row, 0, 1, 4) self.progress = QProgressBar(self) self.segBoxG.addWidget(self.progress, widget_row, 4, 1, 5) widget_row += 1 ############################### Segmentation settings ############################### self.additional_seg_settings_qcollapsible = QCollapsible("additional settings") self.additional_seg_settings_qcollapsible.setFont(self.medfont) self.additional_seg_settings_qcollapsible._toggle_btn.setFont(self.medfont) self.segmentation_settings = guiparts.SegmentationSettings(self.medfont) self.additional_seg_settings_qcollapsible.setContent(self.segmentation_settings) self.segBoxG.addWidget(self.additional_seg_settings_qcollapsible, widget_row, 0, 1, 9) # connect edits to image processing steps: self.segmentation_settings.diameter_box.editingFinished.connect(self.update_scale) self.segmentation_settings.flow_threshold_box.returnPressed.connect(self.compute_cprob) self.segmentation_settings.cellprob_threshold_box.returnPressed.connect(self.compute_cprob) self.segmentation_settings.niter_box.returnPressed.connect(self.compute_cprob) # Needed to do this for the drop down to not be open on startup self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(True) self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(False) b += 1 self.modelBox = QGroupBox("user-trained models") self.modelBoxG = QGridLayout() self.modelBox.setLayout(self.modelBoxG) self.l0.addWidget(self.modelBox, b, 0, 1, 9) self.modelBox.setFont(self.boldfont) # choose models self.ModelChooseC = QComboBox() self.ModelChooseC.setFont(self.medfont) current_index = 0 self.ModelChooseC.addItems(["custom models"]) if len(self.model_strings) > 0: self.ModelChooseC.addItems(self.model_strings) self.ModelChooseC.setFixedWidth(175) self.ModelChooseC.setCurrentIndex(current_index) tipstr = 'add or train your own models in the "Models" file menu and choose model here' self.ModelChooseC.setToolTip(tipstr) self.ModelChooseC.activated.connect(lambda: self.model_choose(custom=True)) self.modelBoxG.addWidget(self.ModelChooseC, widget_row, 0, 1, 8) # compute segmentation w/ custom model self.ModelButtonC = QPushButton(u"run") self.ModelButtonC.setFont(self.medfont) self.ModelButtonC.setFixedWidth(35) self.ModelButtonC.clicked.connect( lambda: self.compute_segmentation(custom=True)) self.modelBoxG.addWidget(self.ModelButtonC, widget_row, 8, 1, 1) self.ModelButtonC.setEnabled(False) b += 1 self.filterBox = QGroupBox("Image filtering") self.filterBox.setFont(self.boldfont) self.filterBox_grid_layout = QGridLayout() self.filterBox.setLayout(self.filterBox_grid_layout) self.l0.addWidget(self.filterBox, b, 0, 1, 9) widget_row = 0 # Filtering self.FilterButtons = [] nett = [ "clear restore/filter", "filter image (settings below)", ] self.filter_text = ["none", "filter", ] self.restore = None self.ratio = 1. jj = 0 w = 3 for j in range(len(self.filter_text)): self.FilterButtons.append( guiparts.FilterButton(self, self.filter_text[j])) self.filterBox_grid_layout.addWidget(self.FilterButtons[-1], widget_row, jj, 1, w) self.FilterButtons[-1].setFixedWidth(75) self.FilterButtons[-1].setToolTip(nett[j]) self.FilterButtons[-1].setFont(self.medfont) widget_row += 1 if j%2==1 else 0 jj = 0 if j%2==1 else jj + w self.save_norm = QCheckBox("save restored/filtered image") self.save_norm.setFont(self.medfont) self.save_norm.setToolTip("save restored/filtered image in _seg.npy file") self.save_norm.setChecked(True) widget_row += 2 self.filtBox = QCollapsible("custom filter settings") self.filtBox._toggle_btn.setFont(self.medfont) self.filtBoxG = QGridLayout() _content = QWidget() _content.setLayout(self.filtBoxG) _content.setMaximumHeight(0) _content.setMinimumHeight(0) self.filtBox.setContent(_content) self.filterBox_grid_layout.addWidget(self.filtBox, widget_row, 0, 1, 9) self.filt_vals = [0., 0., 0., 0.] self.filt_edits = [] labels = [ "sharpen\nradius", "smooth\nradius", "tile_norm\nblocksize", "tile_norm\nsmooth3D" ] tooltips = [ "set size of surround-subtraction filter for sharpening image", "set size of gaussian filter for smoothing image", "set size of tiles to use to normalize image", "set amount of smoothing of normalization values across planes" ] for p in range(4): label = QLabel(f"{labels[p]}:") label.setToolTip(tooltips[p]) label.setFont(self.medfont) self.filtBoxG.addWidget(label, widget_row + p // 2, 4 * (p % 2), 1, 2) self.filt_edits.append(QLineEdit()) self.filt_edits[p].setText(str(self.filt_vals[p])) self.filt_edits[p].setFixedWidth(40) self.filt_edits[p].setFont(self.medfont) self.filtBoxG.addWidget(self.filt_edits[p], widget_row + p // 2, 4 * (p % 2) + 2, 1, 2) self.filt_edits[p].setToolTip(tooltips[p]) widget_row += 3 self.norm3D_cb = QCheckBox("norm3D") self.norm3D_cb.setFont(self.medfont) self.norm3D_cb.setChecked(True) self.norm3D_cb.setToolTip("run same normalization across planes") self.filtBoxG.addWidget(self.norm3D_cb, widget_row, 0, 1, 3) return b def level_change(self, r): r = ["red", "green", "blue"].index(r) if self.loaded: sval = self.sliders[r].value() self.saturation[r][self.currentZ] = sval if not self.autobtn.isChecked(): for r in range(3): for i in range(len(self.saturation[r])): self.saturation[r][i] = self.saturation[r][self.currentZ] self.update_plot() def keyPressEvent(self, event): if self.loaded: if not (event.modifiers() & (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier) or self.in_stroke): updated = False if len(self.current_point_set) > 0: if event.key() == QtCore.Qt.Key_Return: self.add_set() else: nviews = self.ViewDropDown.count() - 1 nviews += int( self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).isEnabled()) if event.key() == QtCore.Qt.Key_X: self.MCheckBox.toggle() if event.key() == QtCore.Qt.Key_Z: self.OCheckBox.toggle() if event.key() == QtCore.Qt.Key_Left or event.key( ) == QtCore.Qt.Key_A: self.get_prev_image() elif event.key() == QtCore.Qt.Key_Right or event.key( ) == QtCore.Qt.Key_D: self.get_next_image() elif event.key() == QtCore.Qt.Key_PageDown: self.view = (self.view + 1) % (nviews) self.ViewDropDown.setCurrentIndex(self.view) elif event.key() == QtCore.Qt.Key_PageUp: self.view = (self.view - 1) % (nviews) self.ViewDropDown.setCurrentIndex(self.view) # can change background or stroke size if cell not finished if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W: self.color = (self.color - 1) % (6) self.RGBDropDown.setCurrentIndex(self.color) elif event.key() == QtCore.Qt.Key_Down or event.key( ) == QtCore.Qt.Key_S: self.color = (self.color + 1) % (6) self.RGBDropDown.setCurrentIndex(self.color) elif event.key() == QtCore.Qt.Key_R: if self.color != 1: self.color = 1 else: self.color = 0 self.RGBDropDown.setCurrentIndex(self.color) elif event.key() == QtCore.Qt.Key_G: if self.color != 2: self.color = 2 else: self.color = 0 self.RGBDropDown.setCurrentIndex(self.color) elif event.key() == QtCore.Qt.Key_B: if self.color != 3: self.color = 3 else: self.color = 0 self.RGBDropDown.setCurrentIndex(self.color) elif (event.key() == QtCore.Qt.Key_Comma or event.key() == QtCore.Qt.Key_Period): count = self.BrushChoose.count() gci = self.BrushChoose.currentIndex() if event.key() == QtCore.Qt.Key_Comma: gci = max(0, gci - 1) else: gci = min(count - 1, gci + 1) self.BrushChoose.setCurrentIndex(gci) self.brush_choose() if not updated: self.update_plot() if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal: self.p0.keyPressEvent(event) def autosave_on(self): if self.SCheckBox.isChecked(): self.autosave = True else: self.autosave = False def check_gpu(self, torch=True): # also decide whether or not to use torch self.useGPU.setChecked(False) self.useGPU.setEnabled(False) if core.use_gpu(use_torch=True): self.useGPU.setEnabled(True) self.useGPU.setChecked(True) else: self.useGPU.setStyleSheet("color: rgb(80,80,80);") def model_choose(self, custom=False): index = self.ModelChooseC.currentIndex( ) if custom else self.ModelChooseB.currentIndex() if index > 0: if custom: model_name = self.ModelChooseC.currentText() else: model_name = self.net_names[index - 1] print(f"GUI_INFO: selected model {model_name}, loading now") self.initialize_model(model_name=model_name, custom=custom) def toggle_scale(self): if self.scale_on: self.p0.removeItem(self.scale) self.scale_on = False else: self.p0.addItem(self.scale) self.scale_on = True def enable_buttons(self): if len(self.model_strings) > 0: self.ModelButtonC.setEnabled(True) for i in range(len(self.StyleButtons)): self.StyleButtons[i].setEnabled(True) for i in range(len(self.FilterButtons)): self.FilterButtons[i].setEnabled(True) if self.load_3D: self.FilterButtons[-2].setEnabled(False) self.newmodel.setEnabled(True) self.loadMasks.setEnabled(True) for n in range(self.nchan): self.sliders[n].setEnabled(True) for n in range(self.nchan, 3): self.sliders[n].setEnabled(True) self.toggle_mask_ops() self.update_plot() self.setWindowTitle(self.filename) def disable_buttons_removeROIs(self): if len(self.model_strings) > 0: self.ModelButtonC.setEnabled(False) for i in range(len(self.StyleButtons)): self.StyleButtons[i].setEnabled(False) self.newmodel.setEnabled(False) self.loadMasks.setEnabled(False) self.saveSet.setEnabled(False) self.savePNG.setEnabled(False) self.saveFlows.setEnabled(False) self.saveOutlines.setEnabled(False) self.saveROIs.setEnabled(False) self.MakeDeletionRegionButton.setEnabled(False) self.DeleteMultipleROIButton.setEnabled(False) self.DoneDeleteMultipleROIButton.setEnabled(True) self.CancelDeleteMultipleROIButton.setEnabled(True) def toggle_mask_ops(self): self.update_layer() self.toggle_saving() self.toggle_removals() def toggle_saving(self): if self.ncells > 0: self.saveSet.setEnabled(True) self.savePNG.setEnabled(True) self.saveFlows.setEnabled(True) self.saveOutlines.setEnabled(True) self.saveROIs.setEnabled(True) else: self.saveSet.setEnabled(False) self.savePNG.setEnabled(False) self.saveFlows.setEnabled(False) self.saveOutlines.setEnabled(False) self.saveROIs.setEnabled(False) def toggle_removals(self): if self.ncells > 0: self.ClearButton.setEnabled(True) self.remcell.setEnabled(True) self.undo.setEnabled(True) self.MakeDeletionRegionButton.setEnabled(True) self.DeleteMultipleROIButton.setEnabled(True) self.DoneDeleteMultipleROIButton.setEnabled(False) self.CancelDeleteMultipleROIButton.setEnabled(False) else: self.ClearButton.setEnabled(False) self.remcell.setEnabled(False) self.undo.setEnabled(False) self.MakeDeletionRegionButton.setEnabled(False) self.DeleteMultipleROIButton.setEnabled(False) self.DoneDeleteMultipleROIButton.setEnabled(False) self.CancelDeleteMultipleROIButton.setEnabled(False) def remove_action(self): if self.selected > 0: self.remove_cell(self.selected) def undo_action(self): if (len(self.strokes) > 0 and self.strokes[-1][0][0] == self.currentZ): self.remove_stroke() else: # remove previous cell if self.ncells > 0: self.remove_cell(self.ncells.get()) def undo_remove_action(self): self.undo_remove_cell() def get_files(self): folder = os.path.dirname(self.filename) mask_filter = "_masks" images = get_image_files(folder, mask_filter) fnames = [os.path.split(images[k])[-1] for k in range(len(images))] f0 = os.path.split(self.filename)[-1] idx = np.nonzero(np.array(fnames) == f0)[0][0] return images, idx def get_prev_image(self): images, idx = self.get_files() idx = (idx - 1) % len(images) io._load_image(self, filename=images[idx]) def get_next_image(self, load_seg=True): images, idx = self.get_files() idx = (idx + 1) % len(images) io._load_image(self, filename=images[idx], load_seg=load_seg) def dragEnterEvent(self, event): if event.mimeData().hasUrls(): event.accept() else: event.ignore() def dropEvent(self, event): files = [u.toLocalFile() for u in event.mimeData().urls()] if os.path.splitext(files[0])[-1] == ".npy": io._load_seg(self, filename=files[0], load_3D=self.load_3D) else: io._load_image(self, filename=files[0], load_seg=True, load_3D=self.load_3D) def toggle_masks(self): if self.MCheckBox.isChecked(): self.masksOn = True else: self.masksOn = False if self.OCheckBox.isChecked(): self.outlinesOn = True else: self.outlinesOn = False if not self.masksOn and not self.outlinesOn: self.p0.removeItem(self.layer) self.layer_off = True else: if self.layer_off: self.p0.addItem(self.layer) self.draw_layer() self.update_layer() if self.loaded: self.update_plot() self.update_layer() def make_viewbox(self): self.p0 = guiparts.ViewBoxNoRightDrag(parent=self, lockAspect=True, name="plot1", border=[100, 100, 100], invertY=True) self.p0.setCursor(QtCore.Qt.CrossCursor) self.brush_size = 3 self.win.addItem(self.p0, 0, 0, rowspan=1, colspan=1) self.p0.setMenuEnabled(False) self.p0.setMouseEnabled(x=True, y=True) self.img = pg.ImageItem(viewbox=self.p0, parent=self) self.img.autoDownsample = False self.layer = guiparts.ImageDraw(viewbox=self.p0, parent=self) self.layer.setLevels([0, 255]) self.scale = pg.ImageItem(viewbox=self.p0, parent=self) self.scale.setLevels([0, 255]) self.p0.scene().contextMenuItem = self.p0 self.Ly, self.Lx = 512, 512 self.p0.addItem(self.img) self.p0.addItem(self.layer) self.p0.addItem(self.scale) def reset(self): # ---- start sets of points ---- # self.selected = 0 self.nchan = 3 self.loaded = False self.channel = [0, 1] self.current_point_set = [] self.in_stroke = False self.strokes = [] self.stroke_appended = True self.resize = False self.ncells.reset() self.zdraw = [] self.removed_cell = [] self.cellcolors = np.array([255, 255, 255])[np.newaxis, :] # -- zero out image stack -- # self.opacity = 128 # how opaque masks should be self.outcolor = [200, 200, 255, 200] self.NZ, self.Ly, self.Lx = 1, 256, 256 self.saturation = self.saturation if hasattr(self, 'saturation') else [] # only adjust the saturation if auto-adjust is on: if self.autobtn.isChecked(): for r in range(3): self.saturation.append([[0, 255] for n in range(self.NZ)]) self.sliders[r].setValue([0, 255]) self.sliders[r].setEnabled(False) self.sliders[r].show() self.currentZ = 0 self.flows = [[], [], [], [], [[]]] # masks matrix # image matrix with a scale disk self.stack = np.zeros((1, self.Ly, self.Lx, 3)) self.Lyr, self.Lxr = self.Ly, self.Lx self.Ly0, self.Lx0 = self.Ly, self.Lx self.radii = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) self.cellpix = np.zeros((1, self.Ly, self.Lx), np.uint16) self.outpix = np.zeros((1, self.Ly, self.Lx), np.uint16) self.ismanual = np.zeros(0, "bool") # -- set menus to default -- # self.color = 0 self.RGBDropDown.setCurrentIndex(self.color) self.view = 0 self.ViewDropDown.setCurrentIndex(0) self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False) self.delete_restore() self.clear_all() self.filename = [] self.loaded = False self.recompute_masks = False self.deleting_multiple = False self.removing_cells_list = [] self.removing_region = False self.remove_roi_obj = None def delete_restore(self): """ delete restored imgs but don't reset settings """ if hasattr(self, "stack_filtered"): del self.stack_filtered if hasattr(self, "cellpix_orig"): self.cellpix = self.cellpix_orig.copy() self.outpix = self.outpix_orig.copy() del self.outpix_orig, self.outpix_resize del self.cellpix_orig, self.cellpix_resize def clear_restore(self): """ delete restored imgs and reset settings """ print("GUI_INFO: clearing restored image") self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False) if self.ViewDropDown.currentIndex() == self.ViewDropDown.count() - 1: self.ViewDropDown.setCurrentIndex(0) self.delete_restore() self.restore = None self.ratio = 1. self.set_normalize_params(self.get_normalize_params()) def brush_choose(self): self.brush_size = self.BrushChoose.currentIndex() * 2 + 1 if self.loaded: self.layer.setDrawKernel(kernel_size=self.brush_size) self.update_layer() def clear_all(self): self.prev_selected = 0 self.selected = 0 if self.restore and "upsample" in self.restore: self.layerz = 0 * np.ones((self.Lyr, self.Lxr, 4), np.uint8) self.cellpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16) self.outpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16) self.cellpix_resize = self.cellpix.copy() self.outpix_resize = self.outpix.copy() self.cellpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16) self.outpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16) else: self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) self.cellpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16) self.outpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16) self.cellcolors = np.array([255, 255, 255])[np.newaxis, :] self.ncells.reset() self.toggle_removals() self.update_scale() self.update_layer() def select_cell(self, idx): self.prev_selected = self.selected self.selected = idx if self.selected > 0: z = self.currentZ self.layerz[self.cellpix[z] == idx] = np.array( [255, 255, 255, self.opacity]) self.update_layer() def select_cell_multi(self, idx): if idx > 0: z = self.currentZ self.layerz[self.cellpix[z] == idx] = np.array( [255, 255, 255, self.opacity]) self.update_layer() def unselect_cell(self): if self.selected > 0: idx = self.selected if idx < (self.ncells.get() + 1): z = self.currentZ self.layerz[self.cellpix[z] == idx] = np.append( self.cellcolors[idx], self.opacity) if self.outlinesOn: self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype( np.uint8) #[0,0,0,self.opacity]) self.update_layer() self.selected = 0 def unselect_cell_multi(self, idx): z = self.currentZ self.layerz[self.cellpix[z] == idx] = np.append(self.cellcolors[idx], self.opacity) if self.outlinesOn: self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype( np.uint8) # [0,0,0,self.opacity]) self.update_layer() def remove_cell(self, idx): if isinstance(idx, (int, np.integer)): idx = [idx] # because the function remove_single_cell updates the state of the cellpix and outpix arrays # by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order # so that the indices are correct idx.sort(reverse=True) for i in idx: self.remove_single_cell(i) self.ncells -= len(idx) # _save_sets uses ncells self.update_layer() if self.ncells == 0: self.ClearButton.setEnabled(False) if self.NZ == 1: io._save_sets_with_check(self) def remove_single_cell(self, idx): # remove from manual array self.selected = 0 if self.NZ > 1: zextent = ((self.cellpix == idx).sum(axis=(1, 2)) > 0).nonzero()[0] else: zextent = [0] for z in zextent: cp = self.cellpix[z] == idx op = self.outpix[z] == idx # remove from self.cellpix and self.outpix self.cellpix[z, cp] = 0 self.outpix[z, op] = 0 if z == self.currentZ: # remove from mask layer self.layerz[cp] = np.array([0, 0, 0, 0]) # reduce other pixels by -1 self.cellpix[self.cellpix > idx] -= 1 self.outpix[self.outpix > idx] -= 1 if self.NZ == 1: self.removed_cell = [ self.ismanual[idx - 1], self.cellcolors[idx], np.nonzero(cp), np.nonzero(op) ] self.redo.setEnabled(True) ar, ac = self.removed_cell[2] d = datetime.datetime.now() self.track_changes.append( [d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]]) # remove cell from lists self.ismanual = np.delete(self.ismanual, idx - 1) self.cellcolors = np.delete(self.cellcolors, [idx], axis=0) del self.zdraw[idx - 1] print("GUI_INFO: removed cell %d" % (idx - 1)) def remove_region_cells(self): if self.removing_cells_list: for idx in self.removing_cells_list: self.unselect_cell_multi(idx) self.removing_cells_list.clear() self.disable_buttons_removeROIs() self.removing_region = True self.clear_multi_selected_cells() # make roi region here in center of view, making ROI half the size of the view roi_width = self.p0.viewRect().width() / 2 x_loc = self.p0.viewRect().x() + (roi_width / 2) roi_height = self.p0.viewRect().height() / 2 y_loc = self.p0.viewRect().y() + (roi_height / 2) pos = [x_loc, y_loc] roi = pg.RectROI(pos, [roi_width, roi_height], pen=pg.mkPen("y", width=2), removable=True) roi.sigRemoveRequested.connect(self.remove_roi) roi.sigRegionChangeFinished.connect(self.roi_changed) self.p0.addItem(roi) self.remove_roi_obj = roi self.roi_changed(roi) def delete_multiple_cells(self): self.unselect_cell() self.disable_buttons_removeROIs() self.DoneDeleteMultipleROIButton.setEnabled(True) self.MakeDeletionRegionButton.setEnabled(True) self.CancelDeleteMultipleROIButton.setEnabled(True) self.deleting_multiple = True def done_remove_multiple_cells(self): self.deleting_multiple = False self.removing_region = False self.DoneDeleteMultipleROIButton.setEnabled(False) self.MakeDeletionRegionButton.setEnabled(False) self.CancelDeleteMultipleROIButton.setEnabled(False) if self.removing_cells_list: self.removing_cells_list = list(set(self.removing_cells_list)) display_remove_list = [i - 1 for i in self.removing_cells_list] print(f"GUI_INFO: removing cells: {display_remove_list}") self.remove_cell(self.removing_cells_list) self.removing_cells_list.clear() self.unselect_cell() self.enable_buttons() if self.remove_roi_obj is not None: self.remove_roi(self.remove_roi_obj) def merge_cells(self, idx): self.prev_selected = self.selected self.selected = idx if self.selected != self.prev_selected: for z in range(self.NZ): ar0, ac0 = np.nonzero(self.cellpix[z] == self.prev_selected) ar1, ac1 = np.nonzero(self.cellpix[z] == self.selected) touching = np.logical_and((ar0[:, np.newaxis] - ar1) < 3, (ac0[:, np.newaxis] - ac1) < 3).sum() ar = np.hstack((ar0, ar1)) ac = np.hstack((ac0, ac1)) vr0, vc0 = np.nonzero(self.outpix[z] == self.prev_selected) vr1, vc1 = np.nonzero(self.outpix[z] == self.selected) self.outpix[z, vr0, vc0] = 0 self.outpix[z, vr1, vc1] = 0 if touching > 0: mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8) mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1 contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) pvc, pvr = contours[-2][0].squeeze().T vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2 else: vr = np.hstack((vr0, vr1)) vc = np.hstack((vc0, vc1)) color = self.cellcolors[self.prev_selected] self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected) self.remove_cell(self.selected) print("GUI_INFO: merged two cells") self.update_layer() io._save_sets_with_check(self) self.undo.setEnabled(False) self.redo.setEnabled(False) def undo_remove_cell(self): if len(self.removed_cell) > 0: z = 0 ar, ac = self.removed_cell[2] vr, vc = self.removed_cell[3] color = self.removed_cell[1] self.draw_mask(z, ar, ac, vr, vc, color) self.toggle_mask_ops() self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0) self.ncells += 1 self.ismanual = np.append(self.ismanual, self.removed_cell[0]) self.zdraw.append([]) print(">>> added back removed cell") self.update_layer() io._save_sets_with_check(self) self.removed_cell = [] self.redo.setEnabled(False) def remove_stroke(self, delete_points=True, stroke_ind=-1): stroke = np.array(self.strokes[stroke_ind]) cZ = self.currentZ inZ = stroke[0, 0] == cZ if inZ: outpix = self.outpix[cZ, stroke[:, 1], stroke[:, 2]] > 0 self.layerz[stroke[~outpix, 1], stroke[~outpix, 2]] = np.array([0, 0, 0, 0]) cellpix = self.cellpix[cZ, stroke[:, 1], stroke[:, 2]] ccol = self.cellcolors.copy() if self.selected > 0: ccol[self.selected] = np.array([255, 255, 255]) col2mask = ccol[cellpix] if self.masksOn: col2mask = np.concatenate( (col2mask, self.opacity * (cellpix[:, np.newaxis] > 0)), axis=-1) else: col2mask = np.concatenate((col2mask, 0 * (cellpix[:, np.newaxis] > 0)), axis=-1) self.layerz[stroke[:, 1], stroke[:, 2], :] = col2mask if self.outlinesOn: self.layerz[stroke[outpix, 1], stroke[outpix, 2]] = np.array(self.outcolor) if delete_points: del self.current_point_set[stroke_ind] self.update_layer() del self.strokes[stroke_ind] def plot_clicked(self, event): if event.button()==QtCore.Qt.LeftButton \ and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\ and not self.removing_region: if event.double(): try: self.p0.setYRange(0, self.Ly + self.pr) except: self.p0.setYRange(0, self.Ly) self.p0.setXRange(0, self.Lx) def cancel_remove_multiple(self): self.clear_multi_selected_cells() self.done_remove_multiple_cells() def clear_multi_selected_cells(self): # unselect all previously selected cells: for idx in self.removing_cells_list: self.unselect_cell_multi(idx) self.removing_cells_list.clear() def add_roi(self, roi): self.p0.addItem(roi) self.remove_roi_obj = roi def remove_roi(self, roi): self.clear_multi_selected_cells() assert roi == self.remove_roi_obj self.remove_roi_obj = None self.p0.removeItem(roi) self.removing_region = False def roi_changed(self, roi): # find the overlapping cells and make them selected pos = roi.pos() size = roi.size() x0 = int(pos.x()) y0 = int(pos.y()) x1 = int(pos.x() + size.x()) y1 = int(pos.y() + size.y()) if x0 < 0: x0 = 0 if y0 < 0: y0 = 0 if x1 > self.Lx: x1 = self.Lx if y1 > self.Ly: y1 = self.Ly # find cells in that region cell_idxs = np.unique(self.cellpix[self.currentZ, y0:y1, x0:x1]) cell_idxs = np.trim_zeros(cell_idxs) # deselect cells not in region by deselecting all and then selecting the ones in the region self.clear_multi_selected_cells() for idx in cell_idxs: self.select_cell_multi(idx) self.removing_cells_list.append(idx) self.update_layer() def mouse_moved(self, pos): items = self.win.scene().items(pos) def color_choose(self): self.color = self.RGBDropDown.currentIndex() self.view = 0 self.ViewDropDown.setCurrentIndex(self.view) self.update_plot() def update_plot(self): self.view = self.ViewDropDown.currentIndex() self.Ly, self.Lx, _ = self.stack[self.currentZ].shape if self.view == 0 or self.view == self.ViewDropDown.count() - 1: image = self.stack[ self.currentZ] if self.view == 0 else self.stack_filtered[self.currentZ] if self.color == 0: self.img.setImage(image, autoLevels=False, lut=None) if self.nchan > 1: levels = np.array([ self.saturation[0][self.currentZ], self.saturation[1][self.currentZ], self.saturation[2][self.currentZ] ]) self.img.setLevels(levels) else: self.img.setLevels(self.saturation[0][self.currentZ]) elif self.color > 0 and self.color < 4: if self.nchan > 1: image = image[:, :, self.color - 1] self.img.setImage(image, autoLevels=False, lut=self.cmap[self.color]) if self.nchan > 1: self.img.setLevels(self.saturation[self.color - 1][self.currentZ]) else: self.img.setLevels(self.saturation[0][self.currentZ]) elif self.color == 4: if self.nchan > 1: image = image.mean(axis=-1) self.img.setImage(image, autoLevels=False, lut=None) self.img.setLevels(self.saturation[0][self.currentZ]) elif self.color == 5: if self.nchan > 1: image = image.mean(axis=-1) self.img.setImage(image, autoLevels=False, lut=self.cmap[0]) self.img.setLevels(self.saturation[0][self.currentZ]) else: image = np.zeros((self.Ly, self.Lx), np.uint8) if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0: image = self.flows[self.view - 1][self.currentZ] if self.view > 1: self.img.setImage(image, autoLevels=False, lut=self.bwr) else: self.img.setImage(image, autoLevels=False, lut=None) self.img.setLevels([0.0, 255.0]) for r in range(3): self.sliders[r].setValue([ self.saturation[r][self.currentZ][0], self.saturation[r][self.currentZ][1] ]) self.win.show() self.show() def update_layer(self): if self.masksOn or self.outlinesOn: self.layer.setImage(self.layerz, autoLevels=False) self.win.show() self.show() def add_set(self): if len(self.current_point_set) > 0: while len(self.strokes) > 0: self.remove_stroke(delete_points=False) if len(self.current_point_set[0]) > 8: color = self.colormap[self.ncells.get(), :3] median = self.add_mask(points=self.current_point_set, color=color) if median is not None: self.removed_cell = [] self.toggle_mask_ops() self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0) self.ncells += 1 self.ismanual = np.append(self.ismanual, True) if self.NZ == 1: # only save after each cell if single image io._save_sets_with_check(self) else: print("GUI_ERROR: cell too small, not drawn") self.current_stroke = [] self.strokes = [] self.current_point_set = [] self.update_layer() def add_mask(self, points=None, color=(100, 200, 50), dense=True): # points is list of strokes points_all = np.concatenate(points, axis=0) # loop over z values median = [] zdraw = np.unique(points_all[:, 0]) z = 0 ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros( 0, "int"), np.zeros(0, "int") for stroke in points: stroke = np.concatenate(stroke, axis=0).reshape(-1, 4) vr = stroke[:, 1] vc = stroke[:, 2] # get points inside drawn points mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8) pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2), axis=-1)[:, np.newaxis, :] mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) ar, ac = np.nonzero(mask) ar, ac = ar + vr.min() - 2, ac + vc.min() - 2 # get dense outline contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) pvc, pvr = contours[-2][0][:,0].T vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2 # concatenate all points ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac)))) # if these pixels are overlapping with another cell, reassign them ioverlap = self.cellpix[z][ar, ac] > 0 if (~ioverlap).sum() < 10: print("GUI_ERROR: cell < 10 pixels without overlaps, not drawn") return None elif ioverlap.sum() > 0: ar, ac = ar[~ioverlap], ac[~ioverlap] # compute outline of new mask mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8) mask[ar - vr.min() + 2, ac - vc.min() + 2] = 1 contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) pvc, pvr = contours[-2][0][:,0].T vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2 ars = np.concatenate((ars, ar), axis=0) acs = np.concatenate((acs, ac), axis=0) vrs = np.concatenate((vrs, vr), axis=0) vcs = np.concatenate((vcs, vc), axis=0) self.draw_mask(z, ars, acs, vrs, vcs, color) median.append(np.array([np.median(ars), np.median(acs)])) self.zdraw.append(zdraw) d = datetime.datetime.now() self.track_changes.append( [d.strftime("%m/%d/%Y, %H:%M:%S"), "added mask", [ar, ac]]) return median def draw_mask(self, z, ar, ac, vr, vc, color, idx=None): """ draw single mask using outlines and area """ if idx is None: idx = self.ncells + 1 self.cellpix[z, vr, vc] = idx self.cellpix[z, ar, ac] = idx self.outpix[z, vr, vc] = idx if self.restore and "upsample" in self.restore: if self.resize: self.cellpix_resize[z, vr, vc] = idx self.cellpix_resize[z, ar, ac] = idx self.outpix_resize[z, vr, vc] = idx self.cellpix_orig[z, (vr / self.ratio).astype(int), (vc / self.ratio).astype(int)] = idx self.cellpix_orig[z, (ar / self.ratio).astype(int), (ac / self.ratio).astype(int)] = idx self.outpix_orig[z, (vr / self.ratio).astype(int), (vc / self.ratio).astype(int)] = idx else: self.cellpix_orig[z, vr, vc] = idx self.cellpix_orig[z, ar, ac] = idx self.outpix_orig[z, vr, vc] = idx # get upsampled mask vrr = (vr.copy() * self.ratio).astype(int) vcr = (vc.copy() * self.ratio).astype(int) mask = np.zeros((np.ptp(vrr) + 4, np.ptp(vcr) + 4), np.uint8) pts = np.stack((vcr - vcr.min() + 2, vrr - vrr.min() + 2), axis=-1)[:, np.newaxis, :] mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) arr, acr = np.nonzero(mask) arr, acr = arr + vrr.min() - 2, acr + vcr.min() - 2 # get dense outline contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) pvc, pvr = contours[-2][0].squeeze().T vrr, vcr = pvr + vrr.min() - 2, pvc + vcr.min() - 2 # concatenate all points arr, acr = np.hstack((np.vstack((vrr, vcr)), np.vstack((arr, acr)))) self.cellpix_resize[z, vrr, vcr] = idx self.cellpix_resize[z, arr, acr] = idx self.outpix_resize[z, vrr, vcr] = idx if z == self.currentZ: self.layerz[ar, ac, :3] = color if self.masksOn: self.layerz[ar, ac, -1] = self.opacity if self.outlinesOn: self.layerz[vr, vc] = np.array(self.outcolor) def compute_scale(self): # get diameter from gui diameter = self.segmentation_settings.diameter if not diameter: diameter = 30 self.pr = int(diameter) self.radii_padding = int(self.pr * 1.25) self.radii = np.zeros((self.Ly + self.radii_padding, self.Lx, 4), np.uint8) yy, xx = disk([self.Ly + self.radii_padding / 2 - 1, self.pr / 2 + 1], self.pr / 2, self.Ly + self.radii_padding, self.Lx) # rgb(150,50,150) self.radii[yy, xx, 0] = 150 self.radii[yy, xx, 1] = 50 self.radii[yy, xx, 2] = 150 self.radii[yy, xx, 3] = 255 self.p0.setYRange(0, self.Ly + self.radii_padding) self.p0.setXRange(0, self.Lx) def update_scale(self): self.compute_scale() self.scale.setImage(self.radii, autoLevels=False) self.scale.setLevels([0.0, 255.0]) self.win.show() self.show() def draw_layer(self): if self.resize: self.Ly, self.Lx = self.Lyr, self.Lxr else: self.Ly, self.Lx = self.Ly0, self.Lx0 if self.masksOn or self.outlinesOn: if self.restore and "upsample" in self.restore: if self.resize: self.cellpix = self.cellpix_resize.copy() self.outpix = self.outpix_resize.copy() else: self.cellpix = self.cellpix_orig.copy() self.outpix = self.outpix_orig.copy() self.layerz = np.zeros((self.Ly, self.Lx, 4), np.uint8) if self.masksOn: self.layerz[..., :3] = self.cellcolors[self.cellpix[self.currentZ], :] self.layerz[..., 3] = self.opacity * (self.cellpix[self.currentZ] > 0).astype(np.uint8) if self.selected > 0: self.layerz[self.cellpix[self.currentZ] == self.selected] = np.array( [255, 255, 255, self.opacity]) cZ = self.currentZ stroke_z = np.array([s[0][0] for s in self.strokes]) inZ = np.nonzero(stroke_z == cZ)[0] if len(inZ) > 0: for i in inZ: stroke = np.array(self.strokes[i]) self.layerz[stroke[:, 1], stroke[:, 2]] = np.array([255, 0, 255, 100]) else: self.layerz[..., 3] = 0 if self.outlinesOn: self.layerz[self.outpix[self.currentZ] > 0] = np.array( self.outcolor).astype(np.uint8) def set_normalize_params(self, normalize_params): from cellpose.models import normalize_default if self.restore != "filter": keys = list(normalize_params.keys()).copy() for key in keys: if key != "percentile": normalize_params[key] = normalize_default[key] normalize_params = {**normalize_default, **normalize_params} out = self.check_filter_params(normalize_params["sharpen_radius"], normalize_params["smooth_radius"], normalize_params["tile_norm_blocksize"], normalize_params["tile_norm_smooth3D"], normalize_params["norm3D"], normalize_params["invert"]) def check_filter_params(self, sharpen, smooth, tile_norm, smooth3D, norm3D, invert): tile_norm = 0 if tile_norm < 0 else tile_norm sharpen = 0 if sharpen < 0 else sharpen smooth = 0 if smooth < 0 else smooth smooth3D = 0 if smooth3D < 0 else smooth3D norm3D = bool(norm3D) invert = bool(invert) if tile_norm > self.Ly and tile_norm > self.Lx: print( "GUI_ERROR: tile size (tile_norm) bigger than both image dimensions, disabling" ) tile_norm = 0 self.filt_edits[0].setText(str(sharpen)) self.filt_edits[1].setText(str(smooth)) self.filt_edits[2].setText(str(tile_norm)) self.filt_edits[3].setText(str(smooth3D)) self.norm3D_cb.setChecked(norm3D) return sharpen, smooth, tile_norm, smooth3D, norm3D, invert def get_normalize_params(self): percentile = [ self.segmentation_settings.low_percentile, self.segmentation_settings.high_percentile, ] normalize_params = {"percentile": percentile} norm3D = self.norm3D_cb.isChecked() normalize_params["norm3D"] = norm3D sharpen = float(self.filt_edits[0].text()) smooth = float(self.filt_edits[1].text()) tile_norm = float(self.filt_edits[2].text()) smooth3D = float(self.filt_edits[3].text()) invert = False out = self.check_filter_params(sharpen, smooth, tile_norm, smooth3D, norm3D, invert) sharpen, smooth, tile_norm, smooth3D, norm3D, invert = out normalize_params["sharpen_radius"] = sharpen normalize_params["smooth_radius"] = smooth normalize_params["tile_norm_blocksize"] = tile_norm normalize_params["tile_norm_smooth3D"] = smooth3D normalize_params["invert"] = invert from cellpose.models import normalize_default normalize_params = {**normalize_default, **normalize_params} return normalize_params def compute_saturation_if_checked(self): if self.autobtn.isChecked(): self.compute_saturation() def compute_saturation(self, return_img=False): norm = self.get_normalize_params() print(norm) sharpen, smooth = norm["sharpen_radius"], norm["smooth_radius"] percentile = norm["percentile"] tile_norm = norm["tile_norm_blocksize"] invert = norm["invert"] norm3D = norm["norm3D"] smooth3D = norm["tile_norm_smooth3D"] tile_norm = norm["tile_norm_blocksize"] if sharpen > 0 or smooth > 0 or tile_norm > 0: img_norm = self.stack.copy() else: img_norm = self.stack if sharpen > 0 or smooth > 0 or tile_norm > 0: self.restore = "filter" print( "GUI_INFO: computing filtered image because sharpen > 0 or tile_norm > 0" ) print( "GUI_WARNING: will use memory to create filtered image -- make sure to have RAM for this" ) img_norm = self.stack.copy() if sharpen > 0 or smooth > 0: img_norm = smooth_sharpen_img(self.stack, sharpen_radius=sharpen, smooth_radius=smooth) if tile_norm > 0: img_norm = normalize99_tile(img_norm, blocksize=tile_norm, lower=percentile[0], upper=percentile[1], smooth3D=smooth3D, norm3D=norm3D) # convert to 0->255 img_norm_min = img_norm.min() img_norm_max = img_norm.max() for c in range(img_norm.shape[-1]): if np.ptp(img_norm[..., c]) > 1e-3: img_norm[..., c] -= img_norm_min img_norm[..., c] /= (img_norm_max - img_norm_min) img_norm *= 255 self.stack_filtered = img_norm self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(True) self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1) else: img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered if self.autobtn.isChecked(): self.saturation = [] for c in range(img_norm.shape[-1]): self.saturation.append([]) if np.ptp(img_norm[..., c]) > 1e-3: if norm3D: x01 = np.percentile(img_norm[..., c], percentile[0]) x99 = np.percentile(img_norm[..., c], percentile[1]) if invert: x01i = 255. - x99 x99i = 255. - x01 x01, x99 = x01i, x99i for n in range(self.NZ): self.saturation[-1].append([x01, x99]) else: for z in range(self.NZ): if self.NZ > 1: x01 = np.percentile(img_norm[z, :, :, c], percentile[0]) x99 = np.percentile(img_norm[z, :, :, c], percentile[1]) else: x01 = np.percentile(img_norm[..., c], percentile[0]) x99 = np.percentile(img_norm[..., c], percentile[1]) if invert: x01i = 255. - x99 x99i = 255. - x01 x01, x99 = x01i, x99i self.saturation[-1].append([x01, x99]) else: for n in range(self.NZ): self.saturation[-1].append([0, 255.]) print(self.saturation[2][self.currentZ]) if img_norm.shape[-1] == 1: self.saturation.append(self.saturation[0]) self.saturation.append(self.saturation[0]) # self.autobtn.setChecked(True) self.update_plot() def get_model_path(self, custom=False): if custom: self.current_model = self.ModelChooseC.currentText() self.current_model_path = os.fspath( models.MODEL_DIR.joinpath(self.current_model)) else: self.current_model = "cpsam" self.current_model_path = models.model_path(self.current_model) def initialize_model(self, model_name=None, custom=False): if model_name is None or custom: self.get_model_path(custom=custom) if not os.path.exists(self.current_model_path): raise ValueError("need to specify model (use dropdown)") if model_name is None or not isinstance(model_name, str): self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), pretrained_model=self.current_model_path) else: self.current_model = model_name self.current_model_path = os.fspath( models.MODEL_DIR.joinpath(self.current_model)) self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), pretrained_model=self.current_model) def add_model(self): io._add_model(self) return def remove_model(self): io._remove_model(self) return def new_model(self): if self.NZ != 1: print("ERROR: cannot train model on 3D data") return # train model image_names = self.get_files()[0] self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set( image_names) TW = guiparts.TrainWindow(self, models.MODEL_NAMES) train = TW.exec_() if train: self.logger.info( f"training with {[os.path.split(f)[1] for f in self.train_files]}") self.train_model(restore=restore, normalize_params=normalize_params) else: print("GUI_INFO: training cancelled") def train_model(self, restore=None, normalize_params=None): from cellpose.models import normalize_default if normalize_params is None: normalize_params = copy.deepcopy(normalize_default) model_type = models.MODEL_NAMES[self.training_params["model_index"]] self.logger.info(f"training new model starting at model {model_type}") self.current_model = model_type self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), model_type=model_type) save_path = os.path.dirname(self.filename) print("GUI_INFO: name of new model: " + self.training_params["model_name"]) self.new_model_path, train_losses = train.train_seg( self.model.net, train_data=self.train_data, train_labels=self.train_labels, normalize=normalize_params, min_train_masks=0, save_path=save_path, nimg_per_epoch=max(2, len(self.train_data)), learning_rate=self.training_params["learning_rate"], weight_decay=self.training_params["weight_decay"], n_epochs=self.training_params["n_epochs"], model_name=self.training_params["model_name"])[:2] # save train losses np.save(str(self.new_model_path) + "_train_losses.npy", train_losses) # run model on next image io._add_model(self, self.new_model_path) diam_labels = self.model.net.diam_labels.item() #.copy() self.new_model_ind = len(self.model_strings) self.autorun = True self.clear_all() self.restore = restore self.set_normalize_params(normalize_params) self.get_next_image(load_seg=False) self.compute_segmentation(custom=True) self.logger.info( f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!" ) def compute_cprob(self): if self.recompute_masks: flow_threshold = self.segmentation_settings.flow_threshold cellprob_threshold = self.segmentation_settings.cellprob_threshold niter = self.segmentation_settings.niter min_size = int(self.min_size.text()) if not isinstance( self.min_size, int) else self.min_size self.logger.info( "computing masks with cell prob=%0.3f, flow error threshold=%0.3f" % (cellprob_threshold, flow_threshold)) try: dP = self.flows[2].squeeze() cellprob = self.flows[3].squeeze() except IndexError: self.logger.error("Flows don't exist, try running model again.") return maski = dynamics.resize_and_compute_masks( dP=dP, cellprob=cellprob, niter=niter, do_3D=self.load_3D, min_size=min_size, # max_size_fraction=min_size_fraction, # Leave as default cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold) self.masksOn = True if not self.OCheckBox.isChecked(): self.MCheckBox.setChecked(True) if maski.ndim < 3: maski = maski[np.newaxis, ...] self.logger.info("%d cells found" % (len(np.unique(maski)[1:]))) io._masks_to_gui(self, maski, outlines=None) self.show() def compute_segmentation(self, custom=False, model_name=None, load_model=True): self.progress.setValue(0) try: tic = time.time() self.clear_all() self.flows = [[], [], []] if load_model: self.initialize_model(model_name=model_name, custom=custom) self.progress.setValue(10) do_3D = self.load_3D stitch_threshold = float(self.stitch_threshold.text()) if not isinstance( self.stitch_threshold, float) else self.stitch_threshold anisotropy = float(self.anisotropy.text()) if not isinstance( self.anisotropy, float) else self.anisotropy flow3D_smooth = float(self.flow3D_smooth.text()) if not isinstance( self.flow3D_smooth, float) else self.flow3D_smooth min_size = int(self.min_size.text()) if not isinstance( self.min_size, int) else self.min_size do_3D = False if stitch_threshold > 0. else do_3D if self.restore == "filter": data = self.stack_filtered.copy().squeeze() else: data = self.stack.copy().squeeze() flow_threshold = self.segmentation_settings.flow_threshold cellprob_threshold = self.segmentation_settings.cellprob_threshold diameter = self.segmentation_settings.diameter niter = self.segmentation_settings.niter normalize_params = self.get_normalize_params() print(normalize_params) try: masks, flows = self.model.eval( data, diameter=diameter, cellprob_threshold=cellprob_threshold, flow_threshold=flow_threshold, do_3D=do_3D, niter=niter, normalize=normalize_params, stitch_threshold=stitch_threshold, anisotropy=anisotropy, flow3D_smooth=flow3D_smooth, min_size=min_size, channel_axis=-1, progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2] except Exception as e: print("NET ERROR: %s" % e) self.progress.setValue(0) return self.progress.setValue(75) # convert flows to uint8 and resize to original image size flows_new = [] flows_new.append(flows[0].copy()) # RGB flow flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) * 255).astype("uint8")) # cellprob flows_new.append(flows[1].copy()) # XY flows flows_new.append(flows[2].copy()) # original cellprob if self.load_3D: if stitch_threshold == 0.: flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8")) else: flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8")) if not self.load_3D: if self.restore and "upsample" in self.restore: self.Ly, self.Lx = self.Lyr, self.Lxr if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx): self.flows = [] for j in range(len(flows_new)): self.flows.append( resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx, interpolation=cv2.INTER_NEAREST)) else: self.flows = flows_new else: self.flows = [] Lz, Ly, Lx = self.NZ, self.Ly, self.Lx Lz0, Ly0, Lx0 = flows_new[0].shape[:3] print("GUI_INFO: resizing flows to original image size") for j in range(len(flows_new)): flow0 = flows_new[j] if Ly0 != Ly: flow0 = resize_image(flow0, Ly=Ly, Lx=Lx, no_channels=flow0.ndim==3, interpolation=cv2.INTER_NEAREST) if Lz0 != Lz: flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1), Ly=Lz, Lx=Lx, no_channels=flow0.ndim==3, interpolation=cv2.INTER_NEAREST), 0, 1) self.flows.append(flow0) # add first axis if self.NZ == 1: masks = masks[np.newaxis, ...] self.flows = [ self.flows[n][np.newaxis, ...] for n in range(len(self.flows)) ] self.logger.info("%d cells found with model in %0.3f sec" % (len(np.unique(masks)[1:]), time.time() - tic)) self.progress.setValue(80) z = 0 io._masks_to_gui(self, masks, outlines=None) self.masksOn = True self.MCheckBox.setChecked(True) self.progress.setValue(100) if self.restore != "filter" and self.restore is not None and self.autobtn.isChecked(): self.compute_saturation() if not do_3D and not stitch_threshold > 0: self.recompute_masks = True else: self.recompute_masks = False except Exception as e: print("ERROR: %s" % e)