Spaces:
Sleeping
Sleeping
| """ | |
| Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu. | |
| """ | |
| import sys, os, pathlib, warnings, datetime, time, copy | |
| from qtpy import QtGui, QtCore | |
| from superqt import QRangeSlider, QCollapsible | |
| from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, \ | |
| QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, \ | |
| QLineEdit, QMessageBox, QGroupBox, QMenu, QAction | |
| import pyqtgraph as pg | |
| import numpy as np | |
| from scipy.stats import mode | |
| import cv2 | |
| from . import guiparts, menus, io | |
| from .. import models, core, dynamics, version, train | |
| from ..utils import download_url_to_file, masks_to_outlines, diameters | |
| from ..io import get_image_files, imsave, imread | |
| from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img | |
| from ..models import normalize_default | |
| from ..plot import disk | |
| try: | |
| import matplotlib.pyplot as plt | |
| MATPLOTLIB = True | |
| except: | |
| MATPLOTLIB = False | |
| Horizontal = QtCore.Qt.Orientation.Horizontal | |
| class Slider(QRangeSlider): | |
| def __init__(self, parent, name, color): | |
| super().__init__(Horizontal) | |
| self.setEnabled(False) | |
| self.valueChanged.connect(lambda: self.levelChanged(parent)) | |
| self.name = name | |
| self.setStyleSheet(""" QSlider{ | |
| background-color: transparent; | |
| } | |
| """) | |
| self.show() | |
| def levelChanged(self, parent): | |
| parent.level_change(self.name) | |
| class QHLine(QFrame): | |
| def __init__(self): | |
| super(QHLine, self).__init__() | |
| self.setFrameShape(QFrame.HLine) | |
| self.setLineWidth(8) | |
| def make_bwr(): | |
| # make a bwr colormap | |
| b = np.append(255 * np.ones(128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis] | |
| r = np.append(np.linspace(0, 255, 128), 255 * np.ones(128))[:, np.newaxis] | |
| g = np.append(np.linspace(0, 255, 128), | |
| np.linspace(0, 255, 128)[::-1])[:, np.newaxis] | |
| color = np.concatenate((r, g, b), axis=-1).astype(np.uint8) | |
| bwr = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) | |
| return bwr | |
| def make_spectral(): | |
| # make spectral colormap | |
| r = np.array([ | |
| 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, | |
| 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 128, 128, 128, 128, 128, | |
| 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 120, 112, 104, 96, 88, | |
| 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 7, 11, 15, 19, 23, | |
| 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103, | |
| 107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167, | |
| 171, 175, 179, 183, 187, 191, 195, 199, 203, 207, 211, 215, 219, 223, 227, 231, | |
| 235, 239, 243, 247, 251, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, | |
| 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, | |
| 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, | |
| 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, | |
| 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, | |
| 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, | |
| 255, 255, 255, 255, 255 | |
| ]) | |
| g = np.array([ | |
| 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 5, 4, 4, 3, 3, | |
| 2, 2, 1, 1, 0, 0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, | |
| 119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, | |
| 247, 255, 247, 239, 231, 223, 215, 207, 199, 191, 183, 175, 167, 159, 151, 143, | |
| 135, 128, 129, 131, 132, 134, 135, 137, 139, 140, 142, 143, 145, 147, 148, 150, | |
| 151, 153, 154, 156, 158, 159, 161, 162, 164, 166, 167, 169, 170, 172, 174, 175, | |
| 177, 178, 180, 181, 183, 185, 186, 188, 189, 191, 193, 194, 196, 197, 199, 201, | |
| 202, 204, 205, 207, 208, 210, 212, 213, 215, 216, 218, 220, 221, 223, 224, 226, | |
| 228, 229, 231, 232, 234, 235, 237, 239, 240, 242, 243, 245, 247, 248, 250, 251, | |
| 253, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, | |
| 195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, | |
| 131, 127, 123, 119, 115, 111, 107, 103, 99, 95, 91, 87, 83, 79, 75, 71, 67, 63, | |
| 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 0, 8, 16, 24, 32, 41, | |
| 49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, | |
| 189, 197, 205, 213, 222, 230, 238, 246, 254 | |
| ]) | |
| b = np.array([ | |
| 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143, | |
| 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 255, 255, | |
| 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, | |
| 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 251, 247, | |
| 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183, | |
| 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 128, 126, 124, | |
| 122, 120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100, 98, 96, 94, 92, 90, | |
| 88, 86, 84, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, | |
| 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, | |
| 8, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, | |
| 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74, | |
| 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205, | |
| 213, 222, 230, 238, 246, 254 | |
| ]) | |
| color = (np.vstack((r, g, b)).T).astype(np.uint8) | |
| spectral = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) | |
| return spectral | |
| def make_cmap(cm=0): | |
| # make a single channel colormap | |
| r = np.arange(0, 256) | |
| color = np.zeros((256, 3)) | |
| color[:, cm] = r | |
| color = color.astype(np.uint8) | |
| cmap = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) | |
| return cmap | |
| def run(image=None): | |
| from ..io import logger_setup | |
| logger, log_file = logger_setup() | |
| # Always start by initializing Qt (only once per application) | |
| warnings.filterwarnings("ignore") | |
| app = QApplication(sys.argv) | |
| icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png") | |
| guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png") | |
| if not icon_path.is_file(): | |
| cp_dir = pathlib.Path.home().joinpath(".cellpose") | |
| cp_dir.mkdir(exist_ok=True) | |
| print("downloading logo") | |
| download_url_to_file( | |
| "https://www.cellpose.org/static/images/cellpose_transparent.png", | |
| icon_path, progress=True) | |
| if not guip_path.is_file(): | |
| print("downloading help window image") | |
| download_url_to_file("https://www.cellpose.org/static/images/cellposeSAM_gui.png", | |
| guip_path, progress=True) | |
| icon_path = str(icon_path.resolve()) | |
| app_icon = QtGui.QIcon() | |
| app_icon.addFile(icon_path, QtCore.QSize(16, 16)) | |
| app_icon.addFile(icon_path, QtCore.QSize(24, 24)) | |
| app_icon.addFile(icon_path, QtCore.QSize(32, 32)) | |
| app_icon.addFile(icon_path, QtCore.QSize(48, 48)) | |
| app_icon.addFile(icon_path, QtCore.QSize(64, 64)) | |
| app_icon.addFile(icon_path, QtCore.QSize(256, 256)) | |
| app.setWindowIcon(app_icon) | |
| app.setStyle("Fusion") | |
| app.setPalette(guiparts.DarkPalette()) | |
| MainW(image=image, logger=logger) | |
| ret = app.exec_() | |
| sys.exit(ret) | |
| class MainW(QMainWindow): | |
| def __init__(self, image=None, logger=None): | |
| super(MainW, self).__init__() | |
| self.logger = logger | |
| pg.setConfigOptions(imageAxisOrder="row-major") | |
| self.setGeometry(50, 50, 1200, 1000) | |
| self.setWindowTitle(f"cellpose v{version}") | |
| self.cp_path = os.path.dirname(os.path.realpath(__file__)) | |
| app_icon = QtGui.QIcon() | |
| icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png") | |
| icon_path = str(icon_path.resolve()) | |
| app_icon.addFile(icon_path, QtCore.QSize(16, 16)) | |
| app_icon.addFile(icon_path, QtCore.QSize(24, 24)) | |
| app_icon.addFile(icon_path, QtCore.QSize(32, 32)) | |
| app_icon.addFile(icon_path, QtCore.QSize(48, 48)) | |
| app_icon.addFile(icon_path, QtCore.QSize(64, 64)) | |
| app_icon.addFile(icon_path, QtCore.QSize(256, 256)) | |
| self.setWindowIcon(app_icon) | |
| # rgb(150,255,150) | |
| self.setStyleSheet(guiparts.stylesheet()) | |
| menus.mainmenu(self) | |
| menus.editmenu(self) | |
| menus.modelmenu(self) | |
| menus.helpmenu(self) | |
| self.stylePressed = """QPushButton {Text-align: center; | |
| background-color: rgb(150,50,150); | |
| border-color: white; | |
| color:white;} | |
| QToolTip { | |
| background-color: black; | |
| color: white; | |
| border: black solid 1px | |
| }""" | |
| self.styleUnpressed = """QPushButton {Text-align: center; | |
| background-color: rgb(50,50,50); | |
| border-color: white; | |
| color:white;} | |
| QToolTip { | |
| background-color: black; | |
| color: white; | |
| border: black solid 1px | |
| }""" | |
| self.loaded = False | |
| # ---- MAIN WIDGET LAYOUT ---- # | |
| self.cwidget = QWidget(self) | |
| self.lmain = QGridLayout() | |
| self.cwidget.setLayout(self.lmain) | |
| self.setCentralWidget(self.cwidget) | |
| self.lmain.setVerticalSpacing(0) | |
| self.lmain.setContentsMargins(0, 0, 0, 10) | |
| self.imask = 0 | |
| self.scrollarea = QScrollArea() | |
| self.scrollarea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn) | |
| self.scrollarea.setStyleSheet("""QScrollArea { border: none }""") | |
| self.scrollarea.setWidgetResizable(True) | |
| self.swidget = QWidget(self) | |
| self.scrollarea.setWidget(self.swidget) | |
| self.l0 = QGridLayout() | |
| self.swidget.setLayout(self.l0) | |
| b = self.make_buttons() | |
| self.lmain.addWidget(self.scrollarea, 0, 0, 39, 9) | |
| # ---- drawing area ---- # | |
| self.win = pg.GraphicsLayoutWidget() | |
| self.lmain.addWidget(self.win, 0, 9, 40, 30) | |
| self.win.scene().sigMouseClicked.connect(self.plot_clicked) | |
| self.win.scene().sigMouseMoved.connect(self.mouse_moved) | |
| self.make_viewbox() | |
| self.lmain.setColumnStretch(10, 1) | |
| bwrmap = make_bwr() | |
| self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False) | |
| self.cmap = [] | |
| # spectral colormap | |
| self.cmap.append(make_spectral().getLookupTable(start=0.0, stop=255.0, | |
| alpha=False)) | |
| # single channel colormaps | |
| for i in range(3): | |
| self.cmap.append( | |
| make_cmap(i).getLookupTable(start=0.0, stop=255.0, alpha=False)) | |
| if MATPLOTLIB: | |
| self.colormap = (plt.get_cmap("gist_ncar")(np.linspace(0.0, .9, 1000000)) * | |
| 255).astype(np.uint8) | |
| np.random.seed(42) # make colors stable | |
| self.colormap = self.colormap[np.random.permutation(1000000)] | |
| else: | |
| np.random.seed(42) # make colors stable | |
| self.colormap = ((np.random.rand(1000000, 3) * 0.8 + 0.1) * 255).astype( | |
| np.uint8) | |
| self.NZ = 1 | |
| self.restore = None | |
| self.ratio = 1. | |
| self.reset() | |
| # This needs to go after .reset() is called to get state fully set up: | |
| self.autobtn.checkStateChanged.connect(self.compute_saturation_if_checked) | |
| self.load_3D = False | |
| # if called with image, load it | |
| if image is not None: | |
| self.filename = image | |
| io._load_image(self, self.filename) | |
| # training settings | |
| d = datetime.datetime.now() | |
| self.training_params = { | |
| "model_index": 0, | |
| "learning_rate": 1e-5, | |
| "weight_decay": 0.1, | |
| "n_epochs": 100, | |
| "model_name": "cpsam" + d.strftime("_%Y%m%d_%H%M%S"), | |
| } | |
| self.stitch_threshold = 0. | |
| self.flow3D_smooth = 0. | |
| self.anisotropy = 1. | |
| self.min_size = 15 | |
| self.setAcceptDrops(True) | |
| self.win.show() | |
| self.show() | |
| def help_window(self): | |
| HW = guiparts.HelpWindow(self) | |
| HW.show() | |
| def train_help_window(self): | |
| THW = guiparts.TrainHelpWindow(self) | |
| THW.show() | |
| def gui_window(self): | |
| EG = guiparts.ExampleGUI(self) | |
| EG.show() | |
| def make_buttons(self): | |
| self.boldfont = QtGui.QFont("Arial", 11, QtGui.QFont.Bold) | |
| self.boldmedfont = QtGui.QFont("Arial", 9, QtGui.QFont.Bold) | |
| self.medfont = QtGui.QFont("Arial", 9) | |
| self.smallfont = QtGui.QFont("Arial", 8) | |
| b = 0 | |
| self.satBox = QGroupBox("Views") | |
| self.satBox.setFont(self.boldfont) | |
| self.satBoxG = QGridLayout() | |
| self.satBox.setLayout(self.satBoxG) | |
| self.l0.addWidget(self.satBox, b, 0, 1, 9) | |
| widget_row = 0 | |
| self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob | |
| self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B | |
| self.RGBDropDown = QComboBox() | |
| self.RGBDropDown.addItems( | |
| ["RGB", "red=R", "green=G", "blue=B", "gray", "spectral"]) | |
| self.RGBDropDown.setFont(self.medfont) | |
| self.RGBDropDown.currentIndexChanged.connect(self.color_choose) | |
| self.satBoxG.addWidget(self.RGBDropDown, widget_row, 0, 1, 3) | |
| label = QLabel("<p>[↑ / ↓ or W/S]</p>") | |
| label.setFont(self.smallfont) | |
| self.satBoxG.addWidget(label, widget_row, 3, 1, 3) | |
| label = QLabel("[R / G / B \n toggles color ]") | |
| label.setFont(self.smallfont) | |
| self.satBoxG.addWidget(label, widget_row, 6, 1, 3) | |
| widget_row += 1 | |
| self.ViewDropDown = QComboBox() | |
| self.ViewDropDown.addItems(["image", "gradXY", "cellprob", "restored"]) | |
| self.ViewDropDown.setFont(self.medfont) | |
| self.ViewDropDown.model().item(3).setEnabled(False) | |
| self.ViewDropDown.currentIndexChanged.connect(self.update_plot) | |
| self.satBoxG.addWidget(self.ViewDropDown, widget_row, 0, 2, 3) | |
| label = QLabel("[pageup / pagedown]") | |
| label.setFont(self.smallfont) | |
| self.satBoxG.addWidget(label, widget_row, 3, 1, 5) | |
| widget_row += 2 | |
| label = QLabel("") | |
| label.setToolTip( | |
| "NOTE: manually changing the saturation bars does not affect normalization in segmentation" | |
| ) | |
| self.satBoxG.addWidget(label, widget_row, 0, 1, 5) | |
| self.autobtn = QCheckBox("auto-adjust saturation") | |
| self.autobtn.setToolTip("sets scale-bars as normalized for segmentation") | |
| self.autobtn.setFont(self.medfont) | |
| self.autobtn.setChecked(True) | |
| self.satBoxG.addWidget(self.autobtn, widget_row, 1, 1, 8) | |
| widget_row += 1 | |
| self.sliders = [] | |
| colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [100, 100, 100]] | |
| colornames = ["red", "Chartreuse", "DodgerBlue"] | |
| names = ["red", "green", "blue"] | |
| for r in range(3): | |
| widget_row += 1 | |
| if r == 0: | |
| label = QLabel('<font color="gray">gray/</font><br>red') | |
| else: | |
| label = QLabel(names[r] + ":") | |
| label.setStyleSheet(f"color: {colornames[r]}") | |
| label.setFont(self.boldmedfont) | |
| self.satBoxG.addWidget(label, widget_row, 0, 1, 2) | |
| self.sliders.append(Slider(self, names[r], colors[r])) | |
| self.sliders[-1].setMinimum(-.1) | |
| self.sliders[-1].setMaximum(255.1) | |
| self.sliders[-1].setValue([0, 255]) | |
| self.sliders[-1].setToolTip( | |
| "NOTE: manually changing the saturation bars does not affect normalization in segmentation" | |
| ) | |
| self.satBoxG.addWidget(self.sliders[-1], widget_row, 2, 1, 7) | |
| b += 1 | |
| self.drawBox = QGroupBox("Drawing") | |
| self.drawBox.setFont(self.boldfont) | |
| self.drawBoxG = QGridLayout() | |
| self.drawBox.setLayout(self.drawBoxG) | |
| self.l0.addWidget(self.drawBox, b, 0, 1, 9) | |
| self.autosave = True | |
| widget_row = 0 | |
| self.brush_size = 3 | |
| self.BrushChoose = QComboBox() | |
| self.BrushChoose.addItems(["1", "3", "5", "7", "9"]) | |
| self.BrushChoose.currentIndexChanged.connect(self.brush_choose) | |
| self.BrushChoose.setFixedWidth(40) | |
| self.BrushChoose.setFont(self.medfont) | |
| self.drawBoxG.addWidget(self.BrushChoose, widget_row, 3, 1, 2) | |
| label = QLabel("brush size:") | |
| label.setFont(self.medfont) | |
| self.drawBoxG.addWidget(label, widget_row, 0, 1, 3) | |
| widget_row += 1 | |
| # turn off masks | |
| self.layer_off = False | |
| self.masksOn = True | |
| self.MCheckBox = QCheckBox("MASKS ON [X]") | |
| self.MCheckBox.setFont(self.medfont) | |
| self.MCheckBox.setChecked(True) | |
| self.MCheckBox.toggled.connect(self.toggle_masks) | |
| self.drawBoxG.addWidget(self.MCheckBox, widget_row, 0, 1, 5) | |
| widget_row += 1 | |
| # turn off outlines | |
| self.outlinesOn = False # turn off by default | |
| self.OCheckBox = QCheckBox("outlines on [Z]") | |
| self.OCheckBox.setFont(self.medfont) | |
| self.drawBoxG.addWidget(self.OCheckBox, widget_row, 0, 1, 5) | |
| self.OCheckBox.setChecked(False) | |
| self.OCheckBox.toggled.connect(self.toggle_masks) | |
| widget_row += 1 | |
| self.SCheckBox = QCheckBox("single stroke") | |
| self.SCheckBox.setFont(self.medfont) | |
| self.SCheckBox.setChecked(True) | |
| self.SCheckBox.toggled.connect(self.autosave_on) | |
| self.SCheckBox.setEnabled(True) | |
| self.drawBoxG.addWidget(self.SCheckBox, widget_row, 0, 1, 5) | |
| # buttons for deleting multiple cells | |
| self.deleteBox = QGroupBox("delete multiple ROIs") | |
| self.deleteBox.setStyleSheet("color: rgb(200, 200, 200)") | |
| self.deleteBox.setFont(self.medfont) | |
| self.deleteBoxG = QGridLayout() | |
| self.deleteBox.setLayout(self.deleteBoxG) | |
| self.drawBoxG.addWidget(self.deleteBox, 0, 5, 4, 4) | |
| self.MakeDeletionRegionButton = QPushButton("region-select") | |
| self.MakeDeletionRegionButton.clicked.connect(self.remove_region_cells) | |
| self.deleteBoxG.addWidget(self.MakeDeletionRegionButton, 0, 0, 1, 4) | |
| self.MakeDeletionRegionButton.setFont(self.smallfont) | |
| self.MakeDeletionRegionButton.setFixedWidth(70) | |
| self.DeleteMultipleROIButton = QPushButton("click-select") | |
| self.DeleteMultipleROIButton.clicked.connect(self.delete_multiple_cells) | |
| self.deleteBoxG.addWidget(self.DeleteMultipleROIButton, 1, 0, 1, 4) | |
| self.DeleteMultipleROIButton.setFont(self.smallfont) | |
| self.DeleteMultipleROIButton.setFixedWidth(70) | |
| self.DoneDeleteMultipleROIButton = QPushButton("done") | |
| self.DoneDeleteMultipleROIButton.clicked.connect( | |
| self.done_remove_multiple_cells) | |
| self.deleteBoxG.addWidget(self.DoneDeleteMultipleROIButton, 2, 0, 1, 2) | |
| self.DoneDeleteMultipleROIButton.setFont(self.smallfont) | |
| self.DoneDeleteMultipleROIButton.setFixedWidth(35) | |
| self.CancelDeleteMultipleROIButton = QPushButton("cancel") | |
| self.CancelDeleteMultipleROIButton.clicked.connect(self.cancel_remove_multiple) | |
| self.deleteBoxG.addWidget(self.CancelDeleteMultipleROIButton, 2, 2, 1, 2) | |
| self.CancelDeleteMultipleROIButton.setFont(self.smallfont) | |
| self.CancelDeleteMultipleROIButton.setFixedWidth(35) | |
| b += 1 | |
| widget_row = 0 | |
| self.segBox = QGroupBox("Segmentation") | |
| self.segBoxG = QGridLayout() | |
| self.segBox.setLayout(self.segBoxG) | |
| self.l0.addWidget(self.segBox, b, 0, 1, 9) | |
| self.segBox.setFont(self.boldfont) | |
| widget_row += 1 | |
| # use GPU | |
| self.useGPU = QCheckBox("use GPU") | |
| self.useGPU.setToolTip( | |
| "if you have specially installed the <i>cuda</i> version of torch, then you can activate this" | |
| ) | |
| self.useGPU.setFont(self.medfont) | |
| self.check_gpu() | |
| self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3) | |
| # compute segmentation with general models | |
| self.net_text = ["run CPSAM"] | |
| nett = ["cellpose super-generalist model"] | |
| self.StyleButtons = [] | |
| jj = 4 | |
| for j in range(len(self.net_text)): | |
| self.StyleButtons.append( | |
| guiparts.ModelButton(self, self.net_text[j], self.net_text[j])) | |
| w = 5 | |
| self.segBoxG.addWidget(self.StyleButtons[-1], widget_row, jj, 1, w) | |
| jj += w | |
| self.StyleButtons[-1].setToolTip(nett[j]) | |
| widget_row += 1 | |
| self.ncells = guiparts.ObservableVariable(0) | |
| self.roi_count = QLabel() | |
| self.roi_count.setFont(self.boldfont) | |
| self.roi_count.setAlignment(QtCore.Qt.AlignLeft) | |
| self.ncells.valueChanged.connect( | |
| lambda n: self.roi_count.setText(f'{str(n)} ROIs') | |
| ) | |
| self.segBoxG.addWidget(self.roi_count, widget_row, 0, 1, 4) | |
| self.progress = QProgressBar(self) | |
| self.segBoxG.addWidget(self.progress, widget_row, 4, 1, 5) | |
| widget_row += 1 | |
| ############################### Segmentation settings ############################### | |
| self.additional_seg_settings_qcollapsible = QCollapsible("additional settings") | |
| self.additional_seg_settings_qcollapsible.setFont(self.medfont) | |
| self.additional_seg_settings_qcollapsible._toggle_btn.setFont(self.medfont) | |
| self.segmentation_settings = guiparts.SegmentationSettings(self.medfont) | |
| self.additional_seg_settings_qcollapsible.setContent(self.segmentation_settings) | |
| self.segBoxG.addWidget(self.additional_seg_settings_qcollapsible, widget_row, 0, 1, 9) | |
| # connect edits to image processing steps: | |
| self.segmentation_settings.diameter_box.editingFinished.connect(self.update_scale) | |
| self.segmentation_settings.flow_threshold_box.returnPressed.connect(self.compute_cprob) | |
| self.segmentation_settings.cellprob_threshold_box.returnPressed.connect(self.compute_cprob) | |
| self.segmentation_settings.niter_box.returnPressed.connect(self.compute_cprob) | |
| # Needed to do this for the drop down to not be open on startup | |
| self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(True) | |
| self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(False) | |
| b += 1 | |
| self.modelBox = QGroupBox("user-trained models") | |
| self.modelBoxG = QGridLayout() | |
| self.modelBox.setLayout(self.modelBoxG) | |
| self.l0.addWidget(self.modelBox, b, 0, 1, 9) | |
| self.modelBox.setFont(self.boldfont) | |
| # choose models | |
| self.ModelChooseC = QComboBox() | |
| self.ModelChooseC.setFont(self.medfont) | |
| current_index = 0 | |
| self.ModelChooseC.addItems(["custom models"]) | |
| if len(self.model_strings) > 0: | |
| self.ModelChooseC.addItems(self.model_strings) | |
| self.ModelChooseC.setFixedWidth(175) | |
| self.ModelChooseC.setCurrentIndex(current_index) | |
| tipstr = 'add or train your own models in the "Models" file menu and choose model here' | |
| self.ModelChooseC.setToolTip(tipstr) | |
| self.ModelChooseC.activated.connect(lambda: self.model_choose(custom=True)) | |
| self.modelBoxG.addWidget(self.ModelChooseC, widget_row, 0, 1, 8) | |
| # compute segmentation w/ custom model | |
| self.ModelButtonC = QPushButton(u"run") | |
| self.ModelButtonC.setFont(self.medfont) | |
| self.ModelButtonC.setFixedWidth(35) | |
| self.ModelButtonC.clicked.connect( | |
| lambda: self.compute_segmentation(custom=True)) | |
| self.modelBoxG.addWidget(self.ModelButtonC, widget_row, 8, 1, 1) | |
| self.ModelButtonC.setEnabled(False) | |
| b += 1 | |
| self.filterBox = QGroupBox("Image filtering") | |
| self.filterBox.setFont(self.boldfont) | |
| self.filterBox_grid_layout = QGridLayout() | |
| self.filterBox.setLayout(self.filterBox_grid_layout) | |
| self.l0.addWidget(self.filterBox, b, 0, 1, 9) | |
| widget_row = 0 | |
| # Filtering | |
| self.FilterButtons = [] | |
| nett = [ | |
| "clear restore/filter", | |
| "filter image (settings below)", | |
| ] | |
| self.filter_text = ["none", | |
| "filter", | |
| ] | |
| self.restore = None | |
| self.ratio = 1. | |
| jj = 0 | |
| w = 3 | |
| for j in range(len(self.filter_text)): | |
| self.FilterButtons.append( | |
| guiparts.FilterButton(self, self.filter_text[j])) | |
| self.filterBox_grid_layout.addWidget(self.FilterButtons[-1], widget_row, jj, 1, w) | |
| self.FilterButtons[-1].setFixedWidth(75) | |
| self.FilterButtons[-1].setToolTip(nett[j]) | |
| self.FilterButtons[-1].setFont(self.medfont) | |
| widget_row += 1 if j%2==1 else 0 | |
| jj = 0 if j%2==1 else jj + w | |
| self.save_norm = QCheckBox("save restored/filtered image") | |
| self.save_norm.setFont(self.medfont) | |
| self.save_norm.setToolTip("save restored/filtered image in _seg.npy file") | |
| self.save_norm.setChecked(True) | |
| widget_row += 2 | |
| self.filtBox = QCollapsible("custom filter settings") | |
| self.filtBox._toggle_btn.setFont(self.medfont) | |
| self.filtBoxG = QGridLayout() | |
| _content = QWidget() | |
| _content.setLayout(self.filtBoxG) | |
| _content.setMaximumHeight(0) | |
| _content.setMinimumHeight(0) | |
| self.filtBox.setContent(_content) | |
| self.filterBox_grid_layout.addWidget(self.filtBox, widget_row, 0, 1, 9) | |
| self.filt_vals = [0., 0., 0., 0.] | |
| self.filt_edits = [] | |
| labels = [ | |
| "sharpen\nradius", "smooth\nradius", "tile_norm\nblocksize", | |
| "tile_norm\nsmooth3D" | |
| ] | |
| tooltips = [ | |
| "set size of surround-subtraction filter for sharpening image", | |
| "set size of gaussian filter for smoothing image", | |
| "set size of tiles to use to normalize image", | |
| "set amount of smoothing of normalization values across planes" | |
| ] | |
| for p in range(4): | |
| label = QLabel(f"{labels[p]}:") | |
| label.setToolTip(tooltips[p]) | |
| label.setFont(self.medfont) | |
| self.filtBoxG.addWidget(label, widget_row + p // 2, 4 * (p % 2), 1, 2) | |
| self.filt_edits.append(QLineEdit()) | |
| self.filt_edits[p].setText(str(self.filt_vals[p])) | |
| self.filt_edits[p].setFixedWidth(40) | |
| self.filt_edits[p].setFont(self.medfont) | |
| self.filtBoxG.addWidget(self.filt_edits[p], widget_row + p // 2, 4 * (p % 2) + 2, 1, | |
| 2) | |
| self.filt_edits[p].setToolTip(tooltips[p]) | |
| widget_row += 3 | |
| self.norm3D_cb = QCheckBox("norm3D") | |
| self.norm3D_cb.setFont(self.medfont) | |
| self.norm3D_cb.setChecked(True) | |
| self.norm3D_cb.setToolTip("run same normalization across planes") | |
| self.filtBoxG.addWidget(self.norm3D_cb, widget_row, 0, 1, 3) | |
| return b | |
| def level_change(self, r): | |
| r = ["red", "green", "blue"].index(r) | |
| if self.loaded: | |
| sval = self.sliders[r].value() | |
| self.saturation[r][self.currentZ] = sval | |
| if not self.autobtn.isChecked(): | |
| for r in range(3): | |
| for i in range(len(self.saturation[r])): | |
| self.saturation[r][i] = self.saturation[r][self.currentZ] | |
| self.update_plot() | |
| def keyPressEvent(self, event): | |
| if self.loaded: | |
| if not (event.modifiers() & | |
| (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier | | |
| QtCore.Qt.AltModifier) or self.in_stroke): | |
| updated = False | |
| if len(self.current_point_set) > 0: | |
| if event.key() == QtCore.Qt.Key_Return: | |
| self.add_set() | |
| else: | |
| nviews = self.ViewDropDown.count() - 1 | |
| nviews += int( | |
| self.ViewDropDown.model().item(self.ViewDropDown.count() - | |
| 1).isEnabled()) | |
| if event.key() == QtCore.Qt.Key_X: | |
| self.MCheckBox.toggle() | |
| if event.key() == QtCore.Qt.Key_Z: | |
| self.OCheckBox.toggle() | |
| if event.key() == QtCore.Qt.Key_Left or event.key( | |
| ) == QtCore.Qt.Key_A: | |
| self.get_prev_image() | |
| elif event.key() == QtCore.Qt.Key_Right or event.key( | |
| ) == QtCore.Qt.Key_D: | |
| self.get_next_image() | |
| elif event.key() == QtCore.Qt.Key_PageDown: | |
| self.view = (self.view + 1) % (nviews) | |
| self.ViewDropDown.setCurrentIndex(self.view) | |
| elif event.key() == QtCore.Qt.Key_PageUp: | |
| self.view = (self.view - 1) % (nviews) | |
| self.ViewDropDown.setCurrentIndex(self.view) | |
| # can change background or stroke size if cell not finished | |
| if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W: | |
| self.color = (self.color - 1) % (6) | |
| self.RGBDropDown.setCurrentIndex(self.color) | |
| elif event.key() == QtCore.Qt.Key_Down or event.key( | |
| ) == QtCore.Qt.Key_S: | |
| self.color = (self.color + 1) % (6) | |
| self.RGBDropDown.setCurrentIndex(self.color) | |
| elif event.key() == QtCore.Qt.Key_R: | |
| if self.color != 1: | |
| self.color = 1 | |
| else: | |
| self.color = 0 | |
| self.RGBDropDown.setCurrentIndex(self.color) | |
| elif event.key() == QtCore.Qt.Key_G: | |
| if self.color != 2: | |
| self.color = 2 | |
| else: | |
| self.color = 0 | |
| self.RGBDropDown.setCurrentIndex(self.color) | |
| elif event.key() == QtCore.Qt.Key_B: | |
| if self.color != 3: | |
| self.color = 3 | |
| else: | |
| self.color = 0 | |
| self.RGBDropDown.setCurrentIndex(self.color) | |
| elif (event.key() == QtCore.Qt.Key_Comma or | |
| event.key() == QtCore.Qt.Key_Period): | |
| count = self.BrushChoose.count() | |
| gci = self.BrushChoose.currentIndex() | |
| if event.key() == QtCore.Qt.Key_Comma: | |
| gci = max(0, gci - 1) | |
| else: | |
| gci = min(count - 1, gci + 1) | |
| self.BrushChoose.setCurrentIndex(gci) | |
| self.brush_choose() | |
| if not updated: | |
| self.update_plot() | |
| if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal: | |
| self.p0.keyPressEvent(event) | |
| def autosave_on(self): | |
| if self.SCheckBox.isChecked(): | |
| self.autosave = True | |
| else: | |
| self.autosave = False | |
| def check_gpu(self, torch=True): | |
| # also decide whether or not to use torch | |
| self.useGPU.setChecked(False) | |
| self.useGPU.setEnabled(False) | |
| if core.use_gpu(use_torch=True): | |
| self.useGPU.setEnabled(True) | |
| self.useGPU.setChecked(True) | |
| else: | |
| self.useGPU.setStyleSheet("color: rgb(80,80,80);") | |
| def model_choose(self, custom=False): | |
| index = self.ModelChooseC.currentIndex( | |
| ) if custom else self.ModelChooseB.currentIndex() | |
| if index > 0: | |
| if custom: | |
| model_name = self.ModelChooseC.currentText() | |
| else: | |
| model_name = self.net_names[index - 1] | |
| print(f"GUI_INFO: selected model {model_name}, loading now") | |
| self.initialize_model(model_name=model_name, custom=custom) | |
| def toggle_scale(self): | |
| if self.scale_on: | |
| self.p0.removeItem(self.scale) | |
| self.scale_on = False | |
| else: | |
| self.p0.addItem(self.scale) | |
| self.scale_on = True | |
| def enable_buttons(self): | |
| if len(self.model_strings) > 0: | |
| self.ModelButtonC.setEnabled(True) | |
| for i in range(len(self.StyleButtons)): | |
| self.StyleButtons[i].setEnabled(True) | |
| for i in range(len(self.FilterButtons)): | |
| self.FilterButtons[i].setEnabled(True) | |
| if self.load_3D: | |
| self.FilterButtons[-2].setEnabled(False) | |
| self.newmodel.setEnabled(True) | |
| self.loadMasks.setEnabled(True) | |
| for n in range(self.nchan): | |
| self.sliders[n].setEnabled(True) | |
| for n in range(self.nchan, 3): | |
| self.sliders[n].setEnabled(True) | |
| self.toggle_mask_ops() | |
| self.update_plot() | |
| self.setWindowTitle(self.filename) | |
| def disable_buttons_removeROIs(self): | |
| if len(self.model_strings) > 0: | |
| self.ModelButtonC.setEnabled(False) | |
| for i in range(len(self.StyleButtons)): | |
| self.StyleButtons[i].setEnabled(False) | |
| self.newmodel.setEnabled(False) | |
| self.loadMasks.setEnabled(False) | |
| self.saveSet.setEnabled(False) | |
| self.savePNG.setEnabled(False) | |
| self.saveFlows.setEnabled(False) | |
| self.saveOutlines.setEnabled(False) | |
| self.saveROIs.setEnabled(False) | |
| self.MakeDeletionRegionButton.setEnabled(False) | |
| self.DeleteMultipleROIButton.setEnabled(False) | |
| self.DoneDeleteMultipleROIButton.setEnabled(True) | |
| self.CancelDeleteMultipleROIButton.setEnabled(True) | |
| def toggle_mask_ops(self): | |
| self.update_layer() | |
| self.toggle_saving() | |
| self.toggle_removals() | |
| def toggle_saving(self): | |
| if self.ncells > 0: | |
| self.saveSet.setEnabled(True) | |
| self.savePNG.setEnabled(True) | |
| self.saveFlows.setEnabled(True) | |
| self.saveOutlines.setEnabled(True) | |
| self.saveROIs.setEnabled(True) | |
| else: | |
| self.saveSet.setEnabled(False) | |
| self.savePNG.setEnabled(False) | |
| self.saveFlows.setEnabled(False) | |
| self.saveOutlines.setEnabled(False) | |
| self.saveROIs.setEnabled(False) | |
| def toggle_removals(self): | |
| if self.ncells > 0: | |
| self.ClearButton.setEnabled(True) | |
| self.remcell.setEnabled(True) | |
| self.undo.setEnabled(True) | |
| self.MakeDeletionRegionButton.setEnabled(True) | |
| self.DeleteMultipleROIButton.setEnabled(True) | |
| self.DoneDeleteMultipleROIButton.setEnabled(False) | |
| self.CancelDeleteMultipleROIButton.setEnabled(False) | |
| else: | |
| self.ClearButton.setEnabled(False) | |
| self.remcell.setEnabled(False) | |
| self.undo.setEnabled(False) | |
| self.MakeDeletionRegionButton.setEnabled(False) | |
| self.DeleteMultipleROIButton.setEnabled(False) | |
| self.DoneDeleteMultipleROIButton.setEnabled(False) | |
| self.CancelDeleteMultipleROIButton.setEnabled(False) | |
| def remove_action(self): | |
| if self.selected > 0: | |
| self.remove_cell(self.selected) | |
| def undo_action(self): | |
| if (len(self.strokes) > 0 and self.strokes[-1][0][0] == self.currentZ): | |
| self.remove_stroke() | |
| else: | |
| # remove previous cell | |
| if self.ncells > 0: | |
| self.remove_cell(self.ncells.get()) | |
| def undo_remove_action(self): | |
| self.undo_remove_cell() | |
| def get_files(self): | |
| folder = os.path.dirname(self.filename) | |
| mask_filter = "_masks" | |
| images = get_image_files(folder, mask_filter) | |
| fnames = [os.path.split(images[k])[-1] for k in range(len(images))] | |
| f0 = os.path.split(self.filename)[-1] | |
| idx = np.nonzero(np.array(fnames) == f0)[0][0] | |
| return images, idx | |
| def get_prev_image(self): | |
| images, idx = self.get_files() | |
| idx = (idx - 1) % len(images) | |
| io._load_image(self, filename=images[idx]) | |
| def get_next_image(self, load_seg=True): | |
| images, idx = self.get_files() | |
| idx = (idx + 1) % len(images) | |
| io._load_image(self, filename=images[idx], load_seg=load_seg) | |
| def dragEnterEvent(self, event): | |
| if event.mimeData().hasUrls(): | |
| event.accept() | |
| else: | |
| event.ignore() | |
| def dropEvent(self, event): | |
| files = [u.toLocalFile() for u in event.mimeData().urls()] | |
| if os.path.splitext(files[0])[-1] == ".npy": | |
| io._load_seg(self, filename=files[0], load_3D=self.load_3D) | |
| else: | |
| io._load_image(self, filename=files[0], load_seg=True, load_3D=self.load_3D) | |
| def toggle_masks(self): | |
| if self.MCheckBox.isChecked(): | |
| self.masksOn = True | |
| else: | |
| self.masksOn = False | |
| if self.OCheckBox.isChecked(): | |
| self.outlinesOn = True | |
| else: | |
| self.outlinesOn = False | |
| if not self.masksOn and not self.outlinesOn: | |
| self.p0.removeItem(self.layer) | |
| self.layer_off = True | |
| else: | |
| if self.layer_off: | |
| self.p0.addItem(self.layer) | |
| self.draw_layer() | |
| self.update_layer() | |
| if self.loaded: | |
| self.update_plot() | |
| self.update_layer() | |
| def make_viewbox(self): | |
| self.p0 = guiparts.ViewBoxNoRightDrag(parent=self, lockAspect=True, | |
| name="plot1", border=[100, 100, | |
| 100], invertY=True) | |
| self.p0.setCursor(QtCore.Qt.CrossCursor) | |
| self.brush_size = 3 | |
| self.win.addItem(self.p0, 0, 0, rowspan=1, colspan=1) | |
| self.p0.setMenuEnabled(False) | |
| self.p0.setMouseEnabled(x=True, y=True) | |
| self.img = pg.ImageItem(viewbox=self.p0, parent=self) | |
| self.img.autoDownsample = False | |
| self.layer = guiparts.ImageDraw(viewbox=self.p0, parent=self) | |
| self.layer.setLevels([0, 255]) | |
| self.scale = pg.ImageItem(viewbox=self.p0, parent=self) | |
| self.scale.setLevels([0, 255]) | |
| self.p0.scene().contextMenuItem = self.p0 | |
| self.Ly, self.Lx = 512, 512 | |
| self.p0.addItem(self.img) | |
| self.p0.addItem(self.layer) | |
| self.p0.addItem(self.scale) | |
| def reset(self): | |
| # ---- start sets of points ---- # | |
| self.selected = 0 | |
| self.nchan = 3 | |
| self.loaded = False | |
| self.channel = [0, 1] | |
| self.current_point_set = [] | |
| self.in_stroke = False | |
| self.strokes = [] | |
| self.stroke_appended = True | |
| self.resize = False | |
| self.ncells.reset() | |
| self.zdraw = [] | |
| self.removed_cell = [] | |
| self.cellcolors = np.array([255, 255, 255])[np.newaxis, :] | |
| # -- zero out image stack -- # | |
| self.opacity = 128 # how opaque masks should be | |
| self.outcolor = [200, 200, 255, 200] | |
| self.NZ, self.Ly, self.Lx = 1, 256, 256 | |
| self.saturation = self.saturation if hasattr(self, 'saturation') else [] | |
| # only adjust the saturation if auto-adjust is on: | |
| if self.autobtn.isChecked(): | |
| for r in range(3): | |
| self.saturation.append([[0, 255] for n in range(self.NZ)]) | |
| self.sliders[r].setValue([0, 255]) | |
| self.sliders[r].setEnabled(False) | |
| self.sliders[r].show() | |
| self.currentZ = 0 | |
| self.flows = [[], [], [], [], [[]]] | |
| # masks matrix | |
| # image matrix with a scale disk | |
| self.stack = np.zeros((1, self.Ly, self.Lx, 3)) | |
| self.Lyr, self.Lxr = self.Ly, self.Lx | |
| self.Ly0, self.Lx0 = self.Ly, self.Lx | |
| self.radii = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) | |
| self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) | |
| self.cellpix = np.zeros((1, self.Ly, self.Lx), np.uint16) | |
| self.outpix = np.zeros((1, self.Ly, self.Lx), np.uint16) | |
| self.ismanual = np.zeros(0, "bool") | |
| # -- set menus to default -- # | |
| self.color = 0 | |
| self.RGBDropDown.setCurrentIndex(self.color) | |
| self.view = 0 | |
| self.ViewDropDown.setCurrentIndex(0) | |
| self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False) | |
| self.delete_restore() | |
| self.clear_all() | |
| self.filename = [] | |
| self.loaded = False | |
| self.recompute_masks = False | |
| self.deleting_multiple = False | |
| self.removing_cells_list = [] | |
| self.removing_region = False | |
| self.remove_roi_obj = None | |
| def delete_restore(self): | |
| """ delete restored imgs but don't reset settings """ | |
| if hasattr(self, "stack_filtered"): | |
| del self.stack_filtered | |
| if hasattr(self, "cellpix_orig"): | |
| self.cellpix = self.cellpix_orig.copy() | |
| self.outpix = self.outpix_orig.copy() | |
| del self.outpix_orig, self.outpix_resize | |
| del self.cellpix_orig, self.cellpix_resize | |
| def clear_restore(self): | |
| """ delete restored imgs and reset settings """ | |
| print("GUI_INFO: clearing restored image") | |
| self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False) | |
| if self.ViewDropDown.currentIndex() == self.ViewDropDown.count() - 1: | |
| self.ViewDropDown.setCurrentIndex(0) | |
| self.delete_restore() | |
| self.restore = None | |
| self.ratio = 1. | |
| self.set_normalize_params(self.get_normalize_params()) | |
| def brush_choose(self): | |
| self.brush_size = self.BrushChoose.currentIndex() * 2 + 1 | |
| if self.loaded: | |
| self.layer.setDrawKernel(kernel_size=self.brush_size) | |
| self.update_layer() | |
| def clear_all(self): | |
| self.prev_selected = 0 | |
| self.selected = 0 | |
| if self.restore and "upsample" in self.restore: | |
| self.layerz = 0 * np.ones((self.Lyr, self.Lxr, 4), np.uint8) | |
| self.cellpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16) | |
| self.outpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16) | |
| self.cellpix_resize = self.cellpix.copy() | |
| self.outpix_resize = self.outpix.copy() | |
| self.cellpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16) | |
| self.outpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16) | |
| else: | |
| self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) | |
| self.cellpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16) | |
| self.outpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16) | |
| self.cellcolors = np.array([255, 255, 255])[np.newaxis, :] | |
| self.ncells.reset() | |
| self.toggle_removals() | |
| self.update_scale() | |
| self.update_layer() | |
| def select_cell(self, idx): | |
| self.prev_selected = self.selected | |
| self.selected = idx | |
| if self.selected > 0: | |
| z = self.currentZ | |
| self.layerz[self.cellpix[z] == idx] = np.array( | |
| [255, 255, 255, self.opacity]) | |
| self.update_layer() | |
| def select_cell_multi(self, idx): | |
| if idx > 0: | |
| z = self.currentZ | |
| self.layerz[self.cellpix[z] == idx] = np.array( | |
| [255, 255, 255, self.opacity]) | |
| self.update_layer() | |
| def unselect_cell(self): | |
| if self.selected > 0: | |
| idx = self.selected | |
| if idx < (self.ncells.get() + 1): | |
| z = self.currentZ | |
| self.layerz[self.cellpix[z] == idx] = np.append( | |
| self.cellcolors[idx], self.opacity) | |
| if self.outlinesOn: | |
| self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype( | |
| np.uint8) | |
| #[0,0,0,self.opacity]) | |
| self.update_layer() | |
| self.selected = 0 | |
| def unselect_cell_multi(self, idx): | |
| z = self.currentZ | |
| self.layerz[self.cellpix[z] == idx] = np.append(self.cellcolors[idx], | |
| self.opacity) | |
| if self.outlinesOn: | |
| self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype( | |
| np.uint8) | |
| # [0,0,0,self.opacity]) | |
| self.update_layer() | |
| def remove_cell(self, idx): | |
| if isinstance(idx, (int, np.integer)): | |
| idx = [idx] | |
| # because the function remove_single_cell updates the state of the cellpix and outpix arrays | |
| # by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order | |
| # so that the indices are correct | |
| idx.sort(reverse=True) | |
| for i in idx: | |
| self.remove_single_cell(i) | |
| self.ncells -= len(idx) # _save_sets uses ncells | |
| self.update_layer() | |
| if self.ncells == 0: | |
| self.ClearButton.setEnabled(False) | |
| if self.NZ == 1: | |
| io._save_sets_with_check(self) | |
| def remove_single_cell(self, idx): | |
| # remove from manual array | |
| self.selected = 0 | |
| if self.NZ > 1: | |
| zextent = ((self.cellpix == idx).sum(axis=(1, 2)) > 0).nonzero()[0] | |
| else: | |
| zextent = [0] | |
| for z in zextent: | |
| cp = self.cellpix[z] == idx | |
| op = self.outpix[z] == idx | |
| # remove from self.cellpix and self.outpix | |
| self.cellpix[z, cp] = 0 | |
| self.outpix[z, op] = 0 | |
| if z == self.currentZ: | |
| # remove from mask layer | |
| self.layerz[cp] = np.array([0, 0, 0, 0]) | |
| # reduce other pixels by -1 | |
| self.cellpix[self.cellpix > idx] -= 1 | |
| self.outpix[self.outpix > idx] -= 1 | |
| if self.NZ == 1: | |
| self.removed_cell = [ | |
| self.ismanual[idx - 1], self.cellcolors[idx], | |
| np.nonzero(cp), | |
| np.nonzero(op) | |
| ] | |
| self.redo.setEnabled(True) | |
| ar, ac = self.removed_cell[2] | |
| d = datetime.datetime.now() | |
| self.track_changes.append( | |
| [d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]]) | |
| # remove cell from lists | |
| self.ismanual = np.delete(self.ismanual, idx - 1) | |
| self.cellcolors = np.delete(self.cellcolors, [idx], axis=0) | |
| del self.zdraw[idx - 1] | |
| print("GUI_INFO: removed cell %d" % (idx - 1)) | |
| def remove_region_cells(self): | |
| if self.removing_cells_list: | |
| for idx in self.removing_cells_list: | |
| self.unselect_cell_multi(idx) | |
| self.removing_cells_list.clear() | |
| self.disable_buttons_removeROIs() | |
| self.removing_region = True | |
| self.clear_multi_selected_cells() | |
| # make roi region here in center of view, making ROI half the size of the view | |
| roi_width = self.p0.viewRect().width() / 2 | |
| x_loc = self.p0.viewRect().x() + (roi_width / 2) | |
| roi_height = self.p0.viewRect().height() / 2 | |
| y_loc = self.p0.viewRect().y() + (roi_height / 2) | |
| pos = [x_loc, y_loc] | |
| roi = pg.RectROI(pos, [roi_width, roi_height], pen=pg.mkPen("y", width=2), | |
| removable=True) | |
| roi.sigRemoveRequested.connect(self.remove_roi) | |
| roi.sigRegionChangeFinished.connect(self.roi_changed) | |
| self.p0.addItem(roi) | |
| self.remove_roi_obj = roi | |
| self.roi_changed(roi) | |
| def delete_multiple_cells(self): | |
| self.unselect_cell() | |
| self.disable_buttons_removeROIs() | |
| self.DoneDeleteMultipleROIButton.setEnabled(True) | |
| self.MakeDeletionRegionButton.setEnabled(True) | |
| self.CancelDeleteMultipleROIButton.setEnabled(True) | |
| self.deleting_multiple = True | |
| def done_remove_multiple_cells(self): | |
| self.deleting_multiple = False | |
| self.removing_region = False | |
| self.DoneDeleteMultipleROIButton.setEnabled(False) | |
| self.MakeDeletionRegionButton.setEnabled(False) | |
| self.CancelDeleteMultipleROIButton.setEnabled(False) | |
| if self.removing_cells_list: | |
| self.removing_cells_list = list(set(self.removing_cells_list)) | |
| display_remove_list = [i - 1 for i in self.removing_cells_list] | |
| print(f"GUI_INFO: removing cells: {display_remove_list}") | |
| self.remove_cell(self.removing_cells_list) | |
| self.removing_cells_list.clear() | |
| self.unselect_cell() | |
| self.enable_buttons() | |
| if self.remove_roi_obj is not None: | |
| self.remove_roi(self.remove_roi_obj) | |
| def merge_cells(self, idx): | |
| self.prev_selected = self.selected | |
| self.selected = idx | |
| if self.selected != self.prev_selected: | |
| for z in range(self.NZ): | |
| ar0, ac0 = np.nonzero(self.cellpix[z] == self.prev_selected) | |
| ar1, ac1 = np.nonzero(self.cellpix[z] == self.selected) | |
| touching = np.logical_and((ar0[:, np.newaxis] - ar1) < 3, | |
| (ac0[:, np.newaxis] - ac1) < 3).sum() | |
| ar = np.hstack((ar0, ar1)) | |
| ac = np.hstack((ac0, ac1)) | |
| vr0, vc0 = np.nonzero(self.outpix[z] == self.prev_selected) | |
| vr1, vc1 = np.nonzero(self.outpix[z] == self.selected) | |
| self.outpix[z, vr0, vc0] = 0 | |
| self.outpix[z, vr1, vc1] = 0 | |
| if touching > 0: | |
| mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8) | |
| mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1 | |
| contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, | |
| cv2.CHAIN_APPROX_NONE) | |
| pvc, pvr = contours[-2][0].squeeze().T | |
| vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2 | |
| else: | |
| vr = np.hstack((vr0, vr1)) | |
| vc = np.hstack((vc0, vc1)) | |
| color = self.cellcolors[self.prev_selected] | |
| self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected) | |
| self.remove_cell(self.selected) | |
| print("GUI_INFO: merged two cells") | |
| self.update_layer() | |
| io._save_sets_with_check(self) | |
| self.undo.setEnabled(False) | |
| self.redo.setEnabled(False) | |
| def undo_remove_cell(self): | |
| if len(self.removed_cell) > 0: | |
| z = 0 | |
| ar, ac = self.removed_cell[2] | |
| vr, vc = self.removed_cell[3] | |
| color = self.removed_cell[1] | |
| self.draw_mask(z, ar, ac, vr, vc, color) | |
| self.toggle_mask_ops() | |
| self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0) | |
| self.ncells += 1 | |
| self.ismanual = np.append(self.ismanual, self.removed_cell[0]) | |
| self.zdraw.append([]) | |
| print(">>> added back removed cell") | |
| self.update_layer() | |
| io._save_sets_with_check(self) | |
| self.removed_cell = [] | |
| self.redo.setEnabled(False) | |
| def remove_stroke(self, delete_points=True, stroke_ind=-1): | |
| stroke = np.array(self.strokes[stroke_ind]) | |
| cZ = self.currentZ | |
| inZ = stroke[0, 0] == cZ | |
| if inZ: | |
| outpix = self.outpix[cZ, stroke[:, 1], stroke[:, 2]] > 0 | |
| self.layerz[stroke[~outpix, 1], stroke[~outpix, 2]] = np.array([0, 0, 0, 0]) | |
| cellpix = self.cellpix[cZ, stroke[:, 1], stroke[:, 2]] | |
| ccol = self.cellcolors.copy() | |
| if self.selected > 0: | |
| ccol[self.selected] = np.array([255, 255, 255]) | |
| col2mask = ccol[cellpix] | |
| if self.masksOn: | |
| col2mask = np.concatenate( | |
| (col2mask, self.opacity * (cellpix[:, np.newaxis] > 0)), axis=-1) | |
| else: | |
| col2mask = np.concatenate((col2mask, 0 * (cellpix[:, np.newaxis] > 0)), | |
| axis=-1) | |
| self.layerz[stroke[:, 1], stroke[:, 2], :] = col2mask | |
| if self.outlinesOn: | |
| self.layerz[stroke[outpix, 1], stroke[outpix, | |
| 2]] = np.array(self.outcolor) | |
| if delete_points: | |
| del self.current_point_set[stroke_ind] | |
| self.update_layer() | |
| del self.strokes[stroke_ind] | |
| def plot_clicked(self, event): | |
| if event.button()==QtCore.Qt.LeftButton \ | |
| and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\ | |
| and not self.removing_region: | |
| if event.double(): | |
| try: | |
| self.p0.setYRange(0, self.Ly + self.pr) | |
| except: | |
| self.p0.setYRange(0, self.Ly) | |
| self.p0.setXRange(0, self.Lx) | |
| def cancel_remove_multiple(self): | |
| self.clear_multi_selected_cells() | |
| self.done_remove_multiple_cells() | |
| def clear_multi_selected_cells(self): | |
| # unselect all previously selected cells: | |
| for idx in self.removing_cells_list: | |
| self.unselect_cell_multi(idx) | |
| self.removing_cells_list.clear() | |
| def add_roi(self, roi): | |
| self.p0.addItem(roi) | |
| self.remove_roi_obj = roi | |
| def remove_roi(self, roi): | |
| self.clear_multi_selected_cells() | |
| assert roi == self.remove_roi_obj | |
| self.remove_roi_obj = None | |
| self.p0.removeItem(roi) | |
| self.removing_region = False | |
| def roi_changed(self, roi): | |
| # find the overlapping cells and make them selected | |
| pos = roi.pos() | |
| size = roi.size() | |
| x0 = int(pos.x()) | |
| y0 = int(pos.y()) | |
| x1 = int(pos.x() + size.x()) | |
| y1 = int(pos.y() + size.y()) | |
| if x0 < 0: | |
| x0 = 0 | |
| if y0 < 0: | |
| y0 = 0 | |
| if x1 > self.Lx: | |
| x1 = self.Lx | |
| if y1 > self.Ly: | |
| y1 = self.Ly | |
| # find cells in that region | |
| cell_idxs = np.unique(self.cellpix[self.currentZ, y0:y1, x0:x1]) | |
| cell_idxs = np.trim_zeros(cell_idxs) | |
| # deselect cells not in region by deselecting all and then selecting the ones in the region | |
| self.clear_multi_selected_cells() | |
| for idx in cell_idxs: | |
| self.select_cell_multi(idx) | |
| self.removing_cells_list.append(idx) | |
| self.update_layer() | |
| def mouse_moved(self, pos): | |
| items = self.win.scene().items(pos) | |
| def color_choose(self): | |
| self.color = self.RGBDropDown.currentIndex() | |
| self.view = 0 | |
| self.ViewDropDown.setCurrentIndex(self.view) | |
| self.update_plot() | |
| def update_plot(self): | |
| self.view = self.ViewDropDown.currentIndex() | |
| self.Ly, self.Lx, _ = self.stack[self.currentZ].shape | |
| if self.view == 0 or self.view == self.ViewDropDown.count() - 1: | |
| image = self.stack[ | |
| self.currentZ] if self.view == 0 else self.stack_filtered[self.currentZ] | |
| if self.color == 0: | |
| self.img.setImage(image, autoLevels=False, lut=None) | |
| if self.nchan > 1: | |
| levels = np.array([ | |
| self.saturation[0][self.currentZ], | |
| self.saturation[1][self.currentZ], | |
| self.saturation[2][self.currentZ] | |
| ]) | |
| self.img.setLevels(levels) | |
| else: | |
| self.img.setLevels(self.saturation[0][self.currentZ]) | |
| elif self.color > 0 and self.color < 4: | |
| if self.nchan > 1: | |
| image = image[:, :, self.color - 1] | |
| self.img.setImage(image, autoLevels=False, lut=self.cmap[self.color]) | |
| if self.nchan > 1: | |
| self.img.setLevels(self.saturation[self.color - 1][self.currentZ]) | |
| else: | |
| self.img.setLevels(self.saturation[0][self.currentZ]) | |
| elif self.color == 4: | |
| if self.nchan > 1: | |
| image = image.mean(axis=-1) | |
| self.img.setImage(image, autoLevels=False, lut=None) | |
| self.img.setLevels(self.saturation[0][self.currentZ]) | |
| elif self.color == 5: | |
| if self.nchan > 1: | |
| image = image.mean(axis=-1) | |
| self.img.setImage(image, autoLevels=False, lut=self.cmap[0]) | |
| self.img.setLevels(self.saturation[0][self.currentZ]) | |
| else: | |
| image = np.zeros((self.Ly, self.Lx), np.uint8) | |
| if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0: | |
| image = self.flows[self.view - 1][self.currentZ] | |
| if self.view > 1: | |
| self.img.setImage(image, autoLevels=False, lut=self.bwr) | |
| else: | |
| self.img.setImage(image, autoLevels=False, lut=None) | |
| self.img.setLevels([0.0, 255.0]) | |
| for r in range(3): | |
| self.sliders[r].setValue([ | |
| self.saturation[r][self.currentZ][0], | |
| self.saturation[r][self.currentZ][1] | |
| ]) | |
| self.win.show() | |
| self.show() | |
| def update_layer(self): | |
| if self.masksOn or self.outlinesOn: | |
| self.layer.setImage(self.layerz, autoLevels=False) | |
| self.win.show() | |
| self.show() | |
| def add_set(self): | |
| if len(self.current_point_set) > 0: | |
| while len(self.strokes) > 0: | |
| self.remove_stroke(delete_points=False) | |
| if len(self.current_point_set[0]) > 8: | |
| color = self.colormap[self.ncells.get(), :3] | |
| median = self.add_mask(points=self.current_point_set, color=color) | |
| if median is not None: | |
| self.removed_cell = [] | |
| self.toggle_mask_ops() | |
| self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], | |
| axis=0) | |
| self.ncells += 1 | |
| self.ismanual = np.append(self.ismanual, True) | |
| if self.NZ == 1: | |
| # only save after each cell if single image | |
| io._save_sets_with_check(self) | |
| else: | |
| print("GUI_ERROR: cell too small, not drawn") | |
| self.current_stroke = [] | |
| self.strokes = [] | |
| self.current_point_set = [] | |
| self.update_layer() | |
| def add_mask(self, points=None, color=(100, 200, 50), dense=True): | |
| # points is list of strokes | |
| points_all = np.concatenate(points, axis=0) | |
| # loop over z values | |
| median = [] | |
| zdraw = np.unique(points_all[:, 0]) | |
| z = 0 | |
| ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros( | |
| 0, "int"), np.zeros(0, "int") | |
| for stroke in points: | |
| stroke = np.concatenate(stroke, axis=0).reshape(-1, 4) | |
| vr = stroke[:, 1] | |
| vc = stroke[:, 2] | |
| # get points inside drawn points | |
| mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8) | |
| pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2), | |
| axis=-1)[:, np.newaxis, :] | |
| mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) | |
| ar, ac = np.nonzero(mask) | |
| ar, ac = ar + vr.min() - 2, ac + vc.min() - 2 | |
| # get dense outline | |
| contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| pvc, pvr = contours[-2][0][:,0].T | |
| vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2 | |
| # concatenate all points | |
| ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac)))) | |
| # if these pixels are overlapping with another cell, reassign them | |
| ioverlap = self.cellpix[z][ar, ac] > 0 | |
| if (~ioverlap).sum() < 10: | |
| print("GUI_ERROR: cell < 10 pixels without overlaps, not drawn") | |
| return None | |
| elif ioverlap.sum() > 0: | |
| ar, ac = ar[~ioverlap], ac[~ioverlap] | |
| # compute outline of new mask | |
| mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8) | |
| mask[ar - vr.min() + 2, ac - vc.min() + 2] = 1 | |
| contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, | |
| cv2.CHAIN_APPROX_NONE) | |
| pvc, pvr = contours[-2][0][:,0].T | |
| vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2 | |
| ars = np.concatenate((ars, ar), axis=0) | |
| acs = np.concatenate((acs, ac), axis=0) | |
| vrs = np.concatenate((vrs, vr), axis=0) | |
| vcs = np.concatenate((vcs, vc), axis=0) | |
| self.draw_mask(z, ars, acs, vrs, vcs, color) | |
| median.append(np.array([np.median(ars), np.median(acs)])) | |
| self.zdraw.append(zdraw) | |
| d = datetime.datetime.now() | |
| self.track_changes.append( | |
| [d.strftime("%m/%d/%Y, %H:%M:%S"), "added mask", [ar, ac]]) | |
| return median | |
| def draw_mask(self, z, ar, ac, vr, vc, color, idx=None): | |
| """ draw single mask using outlines and area """ | |
| if idx is None: | |
| idx = self.ncells + 1 | |
| self.cellpix[z, vr, vc] = idx | |
| self.cellpix[z, ar, ac] = idx | |
| self.outpix[z, vr, vc] = idx | |
| if self.restore and "upsample" in self.restore: | |
| if self.resize: | |
| self.cellpix_resize[z, vr, vc] = idx | |
| self.cellpix_resize[z, ar, ac] = idx | |
| self.outpix_resize[z, vr, vc] = idx | |
| self.cellpix_orig[z, (vr / self.ratio).astype(int), | |
| (vc / self.ratio).astype(int)] = idx | |
| self.cellpix_orig[z, (ar / self.ratio).astype(int), | |
| (ac / self.ratio).astype(int)] = idx | |
| self.outpix_orig[z, (vr / self.ratio).astype(int), | |
| (vc / self.ratio).astype(int)] = idx | |
| else: | |
| self.cellpix_orig[z, vr, vc] = idx | |
| self.cellpix_orig[z, ar, ac] = idx | |
| self.outpix_orig[z, vr, vc] = idx | |
| # get upsampled mask | |
| vrr = (vr.copy() * self.ratio).astype(int) | |
| vcr = (vc.copy() * self.ratio).astype(int) | |
| mask = np.zeros((np.ptp(vrr) + 4, np.ptp(vcr) + 4), np.uint8) | |
| pts = np.stack((vcr - vcr.min() + 2, vrr - vrr.min() + 2), | |
| axis=-1)[:, np.newaxis, :] | |
| mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) | |
| arr, acr = np.nonzero(mask) | |
| arr, acr = arr + vrr.min() - 2, acr + vcr.min() - 2 | |
| # get dense outline | |
| contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, | |
| cv2.CHAIN_APPROX_NONE) | |
| pvc, pvr = contours[-2][0].squeeze().T | |
| vrr, vcr = pvr + vrr.min() - 2, pvc + vcr.min() - 2 | |
| # concatenate all points | |
| arr, acr = np.hstack((np.vstack((vrr, vcr)), np.vstack((arr, acr)))) | |
| self.cellpix_resize[z, vrr, vcr] = idx | |
| self.cellpix_resize[z, arr, acr] = idx | |
| self.outpix_resize[z, vrr, vcr] = idx | |
| if z == self.currentZ: | |
| self.layerz[ar, ac, :3] = color | |
| if self.masksOn: | |
| self.layerz[ar, ac, -1] = self.opacity | |
| if self.outlinesOn: | |
| self.layerz[vr, vc] = np.array(self.outcolor) | |
| def compute_scale(self): | |
| # get diameter from gui | |
| diameter = self.segmentation_settings.diameter | |
| if not diameter: | |
| diameter = 30 | |
| self.pr = int(diameter) | |
| self.radii_padding = int(self.pr * 1.25) | |
| self.radii = np.zeros((self.Ly + self.radii_padding, self.Lx, 4), np.uint8) | |
| yy, xx = disk([self.Ly + self.radii_padding / 2 - 1, self.pr / 2 + 1], | |
| self.pr / 2, self.Ly + self.radii_padding, self.Lx) | |
| # rgb(150,50,150) | |
| self.radii[yy, xx, 0] = 150 | |
| self.radii[yy, xx, 1] = 50 | |
| self.radii[yy, xx, 2] = 150 | |
| self.radii[yy, xx, 3] = 255 | |
| self.p0.setYRange(0, self.Ly + self.radii_padding) | |
| self.p0.setXRange(0, self.Lx) | |
| def update_scale(self): | |
| self.compute_scale() | |
| self.scale.setImage(self.radii, autoLevels=False) | |
| self.scale.setLevels([0.0, 255.0]) | |
| self.win.show() | |
| self.show() | |
| def draw_layer(self): | |
| if self.resize: | |
| self.Ly, self.Lx = self.Lyr, self.Lxr | |
| else: | |
| self.Ly, self.Lx = self.Ly0, self.Lx0 | |
| if self.masksOn or self.outlinesOn: | |
| if self.restore and "upsample" in self.restore: | |
| if self.resize: | |
| self.cellpix = self.cellpix_resize.copy() | |
| self.outpix = self.outpix_resize.copy() | |
| else: | |
| self.cellpix = self.cellpix_orig.copy() | |
| self.outpix = self.outpix_orig.copy() | |
| self.layerz = np.zeros((self.Ly, self.Lx, 4), np.uint8) | |
| if self.masksOn: | |
| self.layerz[..., :3] = self.cellcolors[self.cellpix[self.currentZ], :] | |
| self.layerz[..., 3] = self.opacity * (self.cellpix[self.currentZ] | |
| > 0).astype(np.uint8) | |
| if self.selected > 0: | |
| self.layerz[self.cellpix[self.currentZ] == self.selected] = np.array( | |
| [255, 255, 255, self.opacity]) | |
| cZ = self.currentZ | |
| stroke_z = np.array([s[0][0] for s in self.strokes]) | |
| inZ = np.nonzero(stroke_z == cZ)[0] | |
| if len(inZ) > 0: | |
| for i in inZ: | |
| stroke = np.array(self.strokes[i]) | |
| self.layerz[stroke[:, 1], stroke[:, | |
| 2]] = np.array([255, 0, 255, 100]) | |
| else: | |
| self.layerz[..., 3] = 0 | |
| if self.outlinesOn: | |
| self.layerz[self.outpix[self.currentZ] > 0] = np.array( | |
| self.outcolor).astype(np.uint8) | |
| def set_normalize_params(self, normalize_params): | |
| from cellpose.models import normalize_default | |
| if self.restore != "filter": | |
| keys = list(normalize_params.keys()).copy() | |
| for key in keys: | |
| if key != "percentile": | |
| normalize_params[key] = normalize_default[key] | |
| normalize_params = {**normalize_default, **normalize_params} | |
| out = self.check_filter_params(normalize_params["sharpen_radius"], | |
| normalize_params["smooth_radius"], | |
| normalize_params["tile_norm_blocksize"], | |
| normalize_params["tile_norm_smooth3D"], | |
| normalize_params["norm3D"], | |
| normalize_params["invert"]) | |
| def check_filter_params(self, sharpen, smooth, tile_norm, smooth3D, norm3D, invert): | |
| tile_norm = 0 if tile_norm < 0 else tile_norm | |
| sharpen = 0 if sharpen < 0 else sharpen | |
| smooth = 0 if smooth < 0 else smooth | |
| smooth3D = 0 if smooth3D < 0 else smooth3D | |
| norm3D = bool(norm3D) | |
| invert = bool(invert) | |
| if tile_norm > self.Ly and tile_norm > self.Lx: | |
| print( | |
| "GUI_ERROR: tile size (tile_norm) bigger than both image dimensions, disabling" | |
| ) | |
| tile_norm = 0 | |
| self.filt_edits[0].setText(str(sharpen)) | |
| self.filt_edits[1].setText(str(smooth)) | |
| self.filt_edits[2].setText(str(tile_norm)) | |
| self.filt_edits[3].setText(str(smooth3D)) | |
| self.norm3D_cb.setChecked(norm3D) | |
| return sharpen, smooth, tile_norm, smooth3D, norm3D, invert | |
| def get_normalize_params(self): | |
| percentile = [ | |
| self.segmentation_settings.low_percentile, | |
| self.segmentation_settings.high_percentile, | |
| ] | |
| normalize_params = {"percentile": percentile} | |
| norm3D = self.norm3D_cb.isChecked() | |
| normalize_params["norm3D"] = norm3D | |
| sharpen = float(self.filt_edits[0].text()) | |
| smooth = float(self.filt_edits[1].text()) | |
| tile_norm = float(self.filt_edits[2].text()) | |
| smooth3D = float(self.filt_edits[3].text()) | |
| invert = False | |
| out = self.check_filter_params(sharpen, smooth, tile_norm, smooth3D, norm3D, | |
| invert) | |
| sharpen, smooth, tile_norm, smooth3D, norm3D, invert = out | |
| normalize_params["sharpen_radius"] = sharpen | |
| normalize_params["smooth_radius"] = smooth | |
| normalize_params["tile_norm_blocksize"] = tile_norm | |
| normalize_params["tile_norm_smooth3D"] = smooth3D | |
| normalize_params["invert"] = invert | |
| from cellpose.models import normalize_default | |
| normalize_params = {**normalize_default, **normalize_params} | |
| return normalize_params | |
| def compute_saturation_if_checked(self): | |
| if self.autobtn.isChecked(): | |
| self.compute_saturation() | |
| def compute_saturation(self, return_img=False): | |
| norm = self.get_normalize_params() | |
| print(norm) | |
| sharpen, smooth = norm["sharpen_radius"], norm["smooth_radius"] | |
| percentile = norm["percentile"] | |
| tile_norm = norm["tile_norm_blocksize"] | |
| invert = norm["invert"] | |
| norm3D = norm["norm3D"] | |
| smooth3D = norm["tile_norm_smooth3D"] | |
| tile_norm = norm["tile_norm_blocksize"] | |
| if sharpen > 0 or smooth > 0 or tile_norm > 0: | |
| img_norm = self.stack.copy() | |
| else: | |
| img_norm = self.stack | |
| if sharpen > 0 or smooth > 0 or tile_norm > 0: | |
| self.restore = "filter" | |
| print( | |
| "GUI_INFO: computing filtered image because sharpen > 0 or tile_norm > 0" | |
| ) | |
| print( | |
| "GUI_WARNING: will use memory to create filtered image -- make sure to have RAM for this" | |
| ) | |
| img_norm = self.stack.copy() | |
| if sharpen > 0 or smooth > 0: | |
| img_norm = smooth_sharpen_img(self.stack, sharpen_radius=sharpen, | |
| smooth_radius=smooth) | |
| if tile_norm > 0: | |
| img_norm = normalize99_tile(img_norm, blocksize=tile_norm, | |
| lower=percentile[0], upper=percentile[1], | |
| smooth3D=smooth3D, norm3D=norm3D) | |
| # convert to 0->255 | |
| img_norm_min = img_norm.min() | |
| img_norm_max = img_norm.max() | |
| for c in range(img_norm.shape[-1]): | |
| if np.ptp(img_norm[..., c]) > 1e-3: | |
| img_norm[..., c] -= img_norm_min | |
| img_norm[..., c] /= (img_norm_max - img_norm_min) | |
| img_norm *= 255 | |
| self.stack_filtered = img_norm | |
| self.ViewDropDown.model().item(self.ViewDropDown.count() - | |
| 1).setEnabled(True) | |
| self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1) | |
| else: | |
| img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered | |
| if self.autobtn.isChecked(): | |
| self.saturation = [] | |
| for c in range(img_norm.shape[-1]): | |
| self.saturation.append([]) | |
| if np.ptp(img_norm[..., c]) > 1e-3: | |
| if norm3D: | |
| x01 = np.percentile(img_norm[..., c], percentile[0]) | |
| x99 = np.percentile(img_norm[..., c], percentile[1]) | |
| if invert: | |
| x01i = 255. - x99 | |
| x99i = 255. - x01 | |
| x01, x99 = x01i, x99i | |
| for n in range(self.NZ): | |
| self.saturation[-1].append([x01, x99]) | |
| else: | |
| for z in range(self.NZ): | |
| if self.NZ > 1: | |
| x01 = np.percentile(img_norm[z, :, :, c], percentile[0]) | |
| x99 = np.percentile(img_norm[z, :, :, c], percentile[1]) | |
| else: | |
| x01 = np.percentile(img_norm[..., c], percentile[0]) | |
| x99 = np.percentile(img_norm[..., c], percentile[1]) | |
| if invert: | |
| x01i = 255. - x99 | |
| x99i = 255. - x01 | |
| x01, x99 = x01i, x99i | |
| self.saturation[-1].append([x01, x99]) | |
| else: | |
| for n in range(self.NZ): | |
| self.saturation[-1].append([0, 255.]) | |
| print(self.saturation[2][self.currentZ]) | |
| if img_norm.shape[-1] == 1: | |
| self.saturation.append(self.saturation[0]) | |
| self.saturation.append(self.saturation[0]) | |
| # self.autobtn.setChecked(True) | |
| self.update_plot() | |
| def get_model_path(self, custom=False): | |
| if custom: | |
| self.current_model = self.ModelChooseC.currentText() | |
| self.current_model_path = os.fspath( | |
| models.MODEL_DIR.joinpath(self.current_model)) | |
| else: | |
| self.current_model = "cpsam" | |
| self.current_model_path = models.model_path(self.current_model) | |
| def initialize_model(self, model_name=None, custom=False): | |
| if model_name is None or custom: | |
| self.get_model_path(custom=custom) | |
| if not os.path.exists(self.current_model_path): | |
| raise ValueError("need to specify model (use dropdown)") | |
| if model_name is None or not isinstance(model_name, str): | |
| self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), | |
| pretrained_model=self.current_model_path) | |
| else: | |
| self.current_model = model_name | |
| self.current_model_path = os.fspath( | |
| models.MODEL_DIR.joinpath(self.current_model)) | |
| self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), | |
| pretrained_model=self.current_model) | |
| def add_model(self): | |
| io._add_model(self) | |
| return | |
| def remove_model(self): | |
| io._remove_model(self) | |
| return | |
| def new_model(self): | |
| if self.NZ != 1: | |
| print("ERROR: cannot train model on 3D data") | |
| return | |
| # train model | |
| image_names = self.get_files()[0] | |
| self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set( | |
| image_names) | |
| TW = guiparts.TrainWindow(self, models.MODEL_NAMES) | |
| train = TW.exec_() | |
| if train: | |
| self.logger.info( | |
| f"training with {[os.path.split(f)[1] for f in self.train_files]}") | |
| self.train_model(restore=restore, normalize_params=normalize_params) | |
| else: | |
| print("GUI_INFO: training cancelled") | |
| def train_model(self, restore=None, normalize_params=None): | |
| from cellpose.models import normalize_default | |
| if normalize_params is None: | |
| normalize_params = copy.deepcopy(normalize_default) | |
| model_type = models.MODEL_NAMES[self.training_params["model_index"]] | |
| self.logger.info(f"training new model starting at model {model_type}") | |
| self.current_model = model_type | |
| self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), | |
| model_type=model_type) | |
| save_path = os.path.dirname(self.filename) | |
| print("GUI_INFO: name of new model: " + self.training_params["model_name"]) | |
| self.new_model_path, train_losses = train.train_seg( | |
| self.model.net, train_data=self.train_data, train_labels=self.train_labels, | |
| normalize=normalize_params, min_train_masks=0, | |
| save_path=save_path, nimg_per_epoch=max(2, len(self.train_data)), | |
| learning_rate=self.training_params["learning_rate"], | |
| weight_decay=self.training_params["weight_decay"], | |
| n_epochs=self.training_params["n_epochs"], | |
| model_name=self.training_params["model_name"])[:2] | |
| # save train losses | |
| np.save(str(self.new_model_path) + "_train_losses.npy", train_losses) | |
| # run model on next image | |
| io._add_model(self, self.new_model_path) | |
| diam_labels = self.model.net.diam_labels.item() #.copy() | |
| self.new_model_ind = len(self.model_strings) | |
| self.autorun = True | |
| self.clear_all() | |
| self.restore = restore | |
| self.set_normalize_params(normalize_params) | |
| self.get_next_image(load_seg=False) | |
| self.compute_segmentation(custom=True) | |
| self.logger.info( | |
| f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!" | |
| ) | |
| def compute_cprob(self): | |
| if self.recompute_masks: | |
| flow_threshold = self.segmentation_settings.flow_threshold | |
| cellprob_threshold = self.segmentation_settings.cellprob_threshold | |
| niter = self.segmentation_settings.niter | |
| min_size = int(self.min_size.text()) if not isinstance( | |
| self.min_size, int) else self.min_size | |
| self.logger.info( | |
| "computing masks with cell prob=%0.3f, flow error threshold=%0.3f" % | |
| (cellprob_threshold, flow_threshold)) | |
| try: | |
| dP = self.flows[2].squeeze() | |
| cellprob = self.flows[3].squeeze() | |
| except IndexError: | |
| self.logger.error("Flows don't exist, try running model again.") | |
| return | |
| maski = dynamics.resize_and_compute_masks( | |
| dP=dP, | |
| cellprob=cellprob, | |
| niter=niter, | |
| do_3D=self.load_3D, | |
| min_size=min_size, | |
| # max_size_fraction=min_size_fraction, # Leave as default | |
| cellprob_threshold=cellprob_threshold, | |
| flow_threshold=flow_threshold) | |
| self.masksOn = True | |
| if not self.OCheckBox.isChecked(): | |
| self.MCheckBox.setChecked(True) | |
| if maski.ndim < 3: | |
| maski = maski[np.newaxis, ...] | |
| self.logger.info("%d cells found" % (len(np.unique(maski)[1:]))) | |
| io._masks_to_gui(self, maski, outlines=None) | |
| self.show() | |
| def compute_segmentation(self, custom=False, model_name=None, load_model=True): | |
| self.progress.setValue(0) | |
| try: | |
| tic = time.time() | |
| self.clear_all() | |
| self.flows = [[], [], []] | |
| if load_model: | |
| self.initialize_model(model_name=model_name, custom=custom) | |
| self.progress.setValue(10) | |
| do_3D = self.load_3D | |
| stitch_threshold = float(self.stitch_threshold.text()) if not isinstance( | |
| self.stitch_threshold, float) else self.stitch_threshold | |
| anisotropy = float(self.anisotropy.text()) if not isinstance( | |
| self.anisotropy, float) else self.anisotropy | |
| flow3D_smooth = float(self.flow3D_smooth.text()) if not isinstance( | |
| self.flow3D_smooth, float) else self.flow3D_smooth | |
| min_size = int(self.min_size.text()) if not isinstance( | |
| self.min_size, int) else self.min_size | |
| do_3D = False if stitch_threshold > 0. else do_3D | |
| if self.restore == "filter": | |
| data = self.stack_filtered.copy().squeeze() | |
| else: | |
| data = self.stack.copy().squeeze() | |
| flow_threshold = self.segmentation_settings.flow_threshold | |
| cellprob_threshold = self.segmentation_settings.cellprob_threshold | |
| diameter = self.segmentation_settings.diameter | |
| niter = self.segmentation_settings.niter | |
| normalize_params = self.get_normalize_params() | |
| print(normalize_params) | |
| try: | |
| masks, flows = self.model.eval( | |
| data, | |
| diameter=diameter, | |
| cellprob_threshold=cellprob_threshold, | |
| flow_threshold=flow_threshold, do_3D=do_3D, niter=niter, | |
| normalize=normalize_params, stitch_threshold=stitch_threshold, | |
| anisotropy=anisotropy, flow3D_smooth=flow3D_smooth, | |
| min_size=min_size, channel_axis=-1, | |
| progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2] | |
| except Exception as e: | |
| print("NET ERROR: %s" % e) | |
| self.progress.setValue(0) | |
| return | |
| self.progress.setValue(75) | |
| # convert flows to uint8 and resize to original image size | |
| flows_new = [] | |
| flows_new.append(flows[0].copy()) # RGB flow | |
| flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) * | |
| 255).astype("uint8")) # cellprob | |
| flows_new.append(flows[1].copy()) # XY flows | |
| flows_new.append(flows[2].copy()) # original cellprob | |
| if self.load_3D: | |
| if stitch_threshold == 0.: | |
| flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8")) | |
| else: | |
| flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8")) | |
| if not self.load_3D: | |
| if self.restore and "upsample" in self.restore: | |
| self.Ly, self.Lx = self.Lyr, self.Lxr | |
| if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx): | |
| self.flows = [] | |
| for j in range(len(flows_new)): | |
| self.flows.append( | |
| resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx, | |
| interpolation=cv2.INTER_NEAREST)) | |
| else: | |
| self.flows = flows_new | |
| else: | |
| self.flows = [] | |
| Lz, Ly, Lx = self.NZ, self.Ly, self.Lx | |
| Lz0, Ly0, Lx0 = flows_new[0].shape[:3] | |
| print("GUI_INFO: resizing flows to original image size") | |
| for j in range(len(flows_new)): | |
| flow0 = flows_new[j] | |
| if Ly0 != Ly: | |
| flow0 = resize_image(flow0, Ly=Ly, Lx=Lx, | |
| no_channels=flow0.ndim==3, | |
| interpolation=cv2.INTER_NEAREST) | |
| if Lz0 != Lz: | |
| flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1), | |
| Ly=Lz, Lx=Lx, | |
| no_channels=flow0.ndim==3, | |
| interpolation=cv2.INTER_NEAREST), 0, 1) | |
| self.flows.append(flow0) | |
| # add first axis | |
| if self.NZ == 1: | |
| masks = masks[np.newaxis, ...] | |
| self.flows = [ | |
| self.flows[n][np.newaxis, ...] for n in range(len(self.flows)) | |
| ] | |
| self.logger.info("%d cells found with model in %0.3f sec" % | |
| (len(np.unique(masks)[1:]), time.time() - tic)) | |
| self.progress.setValue(80) | |
| z = 0 | |
| io._masks_to_gui(self, masks, outlines=None) | |
| self.masksOn = True | |
| self.MCheckBox.setChecked(True) | |
| self.progress.setValue(100) | |
| if self.restore != "filter" and self.restore is not None and self.autobtn.isChecked(): | |
| self.compute_saturation() | |
| if not do_3D and not stitch_threshold > 0: | |
| self.recompute_masks = True | |
| else: | |
| self.recompute_masks = False | |
| except Exception as e: | |
| print("ERROR: %s" % e) | |