""" Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. """ from qtpy import QtGui, QtCore from qtpy.QtGui import QPixmap, QDoubleValidator from qtpy.QtWidgets import QWidget, QDialog, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox, QVBoxLayout import pyqtgraph as pg import numpy as np import pathlib, os def stylesheet(): return """ QToolTip { background-color: black; color: white; border: black solid 1px } QComboBox {color: white; background-color: rgb(40,40,40);} QComboBox::item:enabled { color: white; background-color: rgb(40,40,40); selection-color: white; selection-background-color: rgb(50,100,50);} QComboBox::item:!enabled { background-color: rgb(40,40,40); color: rgb(100,100,100); } QScrollArea > QWidget > QWidget { background: transparent; border: none; margin: 0px 0px 0px 0px; } QGroupBox { border: 1px solid white; color: rgb(255,255,255); border-radius: 6px; margin-top: 8px; padding: 0px 0px;} QPushButton:pressed {Text-align: center; background-color: rgb(150,50,150); border-color: white; color:white;} QToolTip { background-color: black; color: white; border: black solid 1px } QPushButton:!pressed {Text-align: center; background-color: rgb(50,50,50); border-color: white; color:white;} QToolTip { background-color: black; color: white; border: black solid 1px } QPushButton:disabled {Text-align: center; background-color: rgb(30,30,30); border-color: white; color:rgb(80,80,80);} QToolTip { background-color: black; color: white; border: black solid 1px } """ class DarkPalette(QtGui.QPalette): """Class that inherits from pyqtgraph.QtGui.QPalette and renders dark colours for the application. (from pykilosort/kilosort4) """ def __init__(self): QtGui.QPalette.__init__(self) self.setup() def setup(self): self.setColor(QtGui.QPalette.Window, QtGui.QColor(40, 40, 40)) self.setColor(QtGui.QPalette.WindowText, QtGui.QColor(255, 255, 255)) self.setColor(QtGui.QPalette.Base, QtGui.QColor(34, 27, 24)) self.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(53, 50, 47)) self.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(255, 255, 255)) self.setColor(QtGui.QPalette.ToolTipText, QtGui.QColor(255, 255, 255)) self.setColor(QtGui.QPalette.Text, QtGui.QColor(255, 255, 255)) self.setColor(QtGui.QPalette.Button, QtGui.QColor(53, 50, 47)) self.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(255, 255, 255)) self.setColor(QtGui.QPalette.BrightText, QtGui.QColor(255, 0, 0)) self.setColor(QtGui.QPalette.Link, QtGui.QColor(42, 130, 218)) self.setColor(QtGui.QPalette.Highlight, QtGui.QColor(42, 130, 218)) self.setColor(QtGui.QPalette.HighlightedText, QtGui.QColor(0, 0, 0)) self.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Text, QtGui.QColor(128, 128, 128)) self.setColor( QtGui.QPalette.Disabled, QtGui.QPalette.ButtonText, QtGui.QColor(128, 128, 128), ) self.setColor( QtGui.QPalette.Disabled, QtGui.QPalette.WindowText, QtGui.QColor(128, 128, 128), ) # def create_channel_choose(): # # choose channel # ChannelChoose = [QComboBox(), QComboBox()] # ChannelLabels = [] # ChannelChoose[0].addItems(["gray", "red", "green", "blue"]) # ChannelChoose[1].addItems(["none", "red", "green", "blue"]) # cstr = ["chan to segment:", "chan2 (optional): "] # for i in range(2): # ChannelLabels.append(QLabel(cstr[i])) # if i == 0: # ChannelLabels[i].setToolTip( # "this is the channel in which the cytoplasm or nuclei exist \ # that you want to segment") # ChannelChoose[i].setToolTip( # "this is the channel in which the cytoplasm or nuclei exist \ # that you want to segment") # else: # ChannelLabels[i].setToolTip( # "if cytoplasm model is chosen, and you also have a \ # nuclear channel, then choose the nuclear channel for this option") # ChannelChoose[i].setToolTip( # "if cytoplasm model is chosen, and you also have a \ # nuclear channel, then choose the nuclear channel for this option") # return ChannelChoose, ChannelLabels class ModelButton(QPushButton): def __init__(self, parent, model_name, text): super().__init__() self.setEnabled(False) self.setText(text) self.setFont(parent.boldfont) self.clicked.connect(lambda: self.press(parent)) self.model_name = "cpsam" def press(self, parent): parent.compute_segmentation(model_name="cpsam") class FilterButton(QPushButton): def __init__(self, parent, text): super().__init__() self.setEnabled(False) self.model_type = text self.setText(text) self.setFont(parent.medfont) self.clicked.connect(lambda: self.press(parent)) def press(self, parent): if self.model_type == "filter": parent.restore = "filter" normalize_params = parent.get_normalize_params() if (normalize_params["sharpen_radius"] == 0 and normalize_params["smooth_radius"] == 0 and normalize_params["tile_norm_blocksize"] == 0): print( "GUI_ERROR: no filtering settings on (use custom filter settings)") parent.restore = None return parent.restore = self.model_type parent.compute_saturation() # elif self.model_type != "none": # parent.compute_denoise_model(model_type=self.model_type) else: parent.clear_restore() # parent.set_restore_button() class ObservableVariable(QtCore.QObject): valueChanged = QtCore.Signal(object) def __init__(self, initial=None): super().__init__() self._value = initial def set(self, new_value): """ Use this method to get emit the value changing and update the ROI count""" if new_value != self._value: self._value = new_value self.valueChanged.emit(new_value) def get(self): return self._value def __call__(self): return self._value def reset(self): self.set(0) def __iadd__(self, amount): if not isinstance(amount, (int, float)): raise TypeError("Value must be numeric.") self.set(self._value + amount) return self def __radd__(self, other): return other + self._value def __add__(self, other): return other + self._value def __isub__(self, amount): if not isinstance(amount, (int, float)): raise TypeError("Value must be numeric.") self.set(self._value - amount) return self def __str__(self): return str(self._value) def __lt__(self, x): return self._value < x def __gt__(self, x): return self._value > x def __eq__(self, x): return self._value == x class NormalizationSettings(QWidget): # TODO pass class SegmentationSettings(QWidget): """ Container for gui settings. Validation is done automatically so any attributes can be acessed without concern. """ def __init__(self, font): super().__init__() # Put everything in a grid layout: grid_layout = QGridLayout() widget_container = QWidget() widget_container.setLayout(grid_layout) row = 0 ########################### Diameter ########################### # TODO: Validate inputs diam_qlabel = QLabel("diameter:") diam_qlabel.setToolTip("diameter of cells in pixels. If not 30, image will be resized to this") diam_qlabel.setFont(font) grid_layout.addWidget(diam_qlabel, row, 0, 1, 2) self.diameter_box = QLineEdit() self.diameter_box.setToolTip("diameter of cells in pixels. If not blank, image will be resized relative to 30 pixel cell diameters") self.diameter_box.setFont(font) self.diameter_box.setFixedWidth(40) self.diameter_box.setText(' ') grid_layout.addWidget(self.diameter_box, row, 2, 1, 2) row += 1 ########################### Flow threshold ########################### # TODO: Validate inputs flow_threshold_qlabel = QLabel("flow\nthreshold:") flow_threshold_qlabel.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run") flow_threshold_qlabel.setFont(font) grid_layout.addWidget(flow_threshold_qlabel, row, 0, 1, 2) self.flow_threshold_box = QLineEdit() self.flow_threshold_box.setText("0.4") self.flow_threshold_box.setFixedWidth(40) self.flow_threshold_box.setFont(font) grid_layout.addWidget(self.flow_threshold_box, row, 2, 1, 2) self.flow_threshold_box.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run") ########################### Cellprob threshold ########################### # TODO: Validate inputs cellprob_qlabel = QLabel("cellprob\nthreshold:") cellprob_qlabel.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run") cellprob_qlabel.setFont(font) grid_layout.addWidget(cellprob_qlabel, row, 4, 1, 2) self.cellprob_threshold_box = QLineEdit() self.cellprob_threshold_box.setText("0.0") self.cellprob_threshold_box.setFixedWidth(40) self.cellprob_threshold_box.setFont(font) self.cellprob_threshold_box.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run") grid_layout.addWidget(self.cellprob_threshold_box, row, 6, 1, 2) row += 1 ########################### Norm percentiles ########################### norm_percentiles_qlabel = QLabel("norm percentiles:") norm_percentiles_qlabel.setToolTip("sets normalization percentiles for segmentation and denoising\n(pixels at lower percentile set to 0.0 and at upper set to 1.0 for network)") norm_percentiles_qlabel.setFont(font) grid_layout.addWidget(norm_percentiles_qlabel, row, 0, 1, 8) row += 1 validator = QDoubleValidator(0.0, 100.0, 2) validator.setNotation(QDoubleValidator.StandardNotation) low_norm_qlabel = QLabel('lower:') low_norm_qlabel.setToolTip("pixels at this percentile set to 0 (default 1.0)") low_norm_qlabel.setFont(font) grid_layout.addWidget(low_norm_qlabel, row, 0, 1, 2) self.norm_percentile_low_box = QLineEdit() self.norm_percentile_low_box.setText("1.0") self.norm_percentile_low_box.setFont(font) self.norm_percentile_low_box.setFixedWidth(40) self.norm_percentile_low_box.setToolTip("pixels at this percentile set to 0 (default 1.0)") self.norm_percentile_low_box.setValidator(validator) self.norm_percentile_low_box.editingFinished.connect(self.validate_normalization_range) grid_layout.addWidget(self.norm_percentile_low_box, row, 2, 1, 1) high_norm_qlabel = QLabel('upper:') high_norm_qlabel.setToolTip("pixels at this percentile set to 1 (default 99.0)") high_norm_qlabel.setFont(font) grid_layout.addWidget(high_norm_qlabel, row, 4, 1, 2) self.norm_percentile_high_box = QLineEdit() self.norm_percentile_high_box.setText("99.0") self.norm_percentile_high_box.setFont(font) self.norm_percentile_high_box.setFixedWidth(40) self.norm_percentile_high_box.setToolTip("pixels at this percentile set to 1 (default 99.0)") self.norm_percentile_high_box.setValidator(validator) self.norm_percentile_high_box.editingFinished.connect(self.validate_normalization_range) grid_layout.addWidget(self.norm_percentile_high_box, row, 6, 1, 2) row += 1 ########################### niter ########################### # TODO: change this to follow the same default logic as 'diameter' above # TODO: input validation niter_qlabel = QLabel("niter dynamics:") niter_qlabel.setFont(font) niter_qlabel.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria") grid_layout.addWidget(niter_qlabel, row, 0, 1, 4) self.niter_box = QLineEdit() self.niter_box.setText("0") self.niter_box.setFixedWidth(40) self.niter_box.setFont(font) self.niter_box.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria") grid_layout.addWidget(self.niter_box, row, 4, 1, 2) self.setLayout(grid_layout) def validate_normalization_range(self): low_text = self.norm_percentile_low_box.text() high_text = self.norm_percentile_high_box.text() if not low_text or low_text.isspace(): self.norm_percentile_low_box.setText('1.0') low_text = '1.0' elif not high_text or high_text.isspace(): self.norm_percentile_high_box.setText('1.0') high_text = '99.0' low = float(low_text) high = float(high_text) if low >= high: # Invalid: show error and mark fields self.norm_percentile_low_box.setStyleSheet("border: 1px solid red;") self.norm_percentile_high_box.setStyleSheet("border: 1px solid red;") else: # Valid: clear style self.norm_percentile_low_box.setStyleSheet("") self.norm_percentile_high_box.setStyleSheet("") @property def low_percentile(self): """ Also validate the low input by returning 1.0 if text doesn't work """ low_text = self.norm_percentile_low_box.text() if not low_text or low_text.isspace(): self.norm_percentile_low_box.setText('1.0') low_text = '1.0' return float(self.norm_percentile_low_box.text()) @property def high_percentile(self): """ Also validate the high input by returning 99.0 if text doesn't work """ high_text = self.norm_percentile_high_box.text() if not high_text or high_text.isspace(): self.norm_percentile_high_box.setText('99.0') high_text = '99.0' return float(self.norm_percentile_high_box.text()) @property def diameter(self): """ Get the diameter from the diameter box, if box isn't a number return None""" try: d = float(self.diameter_box.text()) except ValueError: d = None return d @property def flow_threshold(self): return float(self.flow_threshold_box.text()) @property def cellprob_threshold(self): return float(self.cellprob_threshold_box.text()) @property def niter(self): num = int(self.niter_box.text()) if num < 1: self.niter_box.setText('200') return 200 else: return num class TrainWindow(QDialog): def __init__(self, parent, model_strings): super().__init__(parent) self.setGeometry(100, 100, 900, 550) self.setWindowTitle("train settings") self.win = QWidget(self) self.l0 = QGridLayout() self.win.setLayout(self.l0) yoff = 0 qlabel = QLabel("train model w/ images + _seg.npy in current folder >>") qlabel.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold)) qlabel.setAlignment(QtCore.Qt.AlignVCenter) self.l0.addWidget(qlabel, yoff, 0, 1, 2) # choose initial model yoff += 1 self.ModelChoose = QComboBox() self.ModelChoose.addItems(model_strings) self.ModelChoose.setFixedWidth(150) self.ModelChoose.setCurrentIndex(parent.training_params["model_index"]) self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1) qlabel = QLabel("initial model: ") qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) self.l0.addWidget(qlabel, yoff, 0, 1, 1) # choose parameters labels = ["learning_rate", "weight_decay", "n_epochs", "model_name"] self.edits = [] yoff += 1 for i, label in enumerate(labels): qlabel = QLabel(label) qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) self.l0.addWidget(qlabel, i + yoff, 0, 1, 1) self.edits.append(QLineEdit()) self.edits[-1].setText(str(parent.training_params[label])) self.edits[-1].setFixedWidth(200) self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1) yoff += len(labels) yoff += 1 self.use_norm = QCheckBox(f"use restored/filtered image") self.use_norm.setChecked(True) yoff += 2 qlabel = QLabel( "(to remove files, click cancel then remove \nfrom folder and reopen train window)" ) self.l0.addWidget(qlabel, yoff, 0, 2, 4) # click button yoff += 3 QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel self.buttonBox = QDialogButtonBox(QBtn) self.buttonBox.accepted.connect(lambda: self.accept(parent)) self.buttonBox.rejected.connect(self.reject) self.l0.addWidget(self.buttonBox, yoff, 0, 1, 4) # list files in folder qlabel = QLabel("filenames") qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) self.l0.addWidget(qlabel, 0, 4, 1, 1) qlabel = QLabel("# of masks") qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) self.l0.addWidget(qlabel, 0, 5, 1, 1) for i in range(10): if i > len(parent.train_files) - 1: break elif i == 9 and len(parent.train_files) > 10: label = "..." nmasks = "..." else: label = os.path.split(parent.train_files[i])[-1] nmasks = str(parent.train_labels[i].max()) qlabel = QLabel(label) self.l0.addWidget(qlabel, i + 1, 4, 1, 1) qlabel = QLabel(nmasks) qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) self.l0.addWidget(qlabel, i + 1, 5, 1, 1) def accept(self, parent): # set training params parent.training_params = { "model_index": self.ModelChoose.currentIndex(), "learning_rate": float(self.edits[0].text()), "weight_decay": float(self.edits[1].text()), "n_epochs": int(self.edits[2].text()), "model_name": self.edits[3].text(), #"use_norm": True if self.use_norm.isChecked() else False, } self.done(1) class ExampleGUI(QDialog): def __init__(self, parent=None): super(ExampleGUI, self).__init__(parent) self.setGeometry(100, 100, 1300, 900) self.setWindowTitle("GUI layout") self.win = QWidget(self) layout = QGridLayout() self.win.setLayout(layout) guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png") guip_path = str(guip_path.resolve()) pixmap = QPixmap(guip_path) label = QLabel(self) label.setPixmap(pixmap) pixmap.scaled layout.addWidget(label, 0, 0, 1, 1) class HelpWindow(QDialog): def __init__(self, parent=None): super(HelpWindow, self).__init__(parent) self.setGeometry(100, 50, 700, 1000) self.setWindowTitle("cellpose help") self.win = QWidget(self) layout = QGridLayout() self.win.setLayout(layout) text_file = pathlib.Path(__file__).parent.joinpath("guihelpwindowtext.html") with open(str(text_file.resolve()), "r") as f: text = f.read() label = QLabel(text) label.setFont(QtGui.QFont("Arial", 8)) label.setWordWrap(True) layout.addWidget(label, 0, 0, 1, 1) self.show() class TrainHelpWindow(QDialog): def __init__(self, parent=None): super(TrainHelpWindow, self).__init__(parent) self.setGeometry(100, 50, 700, 300) self.setWindowTitle("training instructions") self.win = QWidget(self) layout = QGridLayout() self.win.setLayout(layout) text_file = pathlib.Path(__file__).parent.joinpath( "guitrainhelpwindowtext.html") with open(str(text_file.resolve()), "r") as f: text = f.read() label = QLabel(text) label.setFont(QtGui.QFont("Arial", 8)) label.setWordWrap(True) layout.addWidget(label, 0, 0, 1, 1) self.show() class ViewBoxNoRightDrag(pg.ViewBox): def __init__(self, parent=None, border=None, lockAspect=False, enableMouse=True, invertY=False, enableMenu=True, name=None, invertX=False): pg.ViewBox.__init__(self, None, border, lockAspect, enableMouse, invertY, enableMenu, name, invertX) self.parent = parent self.axHistoryPointer = -1 def keyPressEvent(self, ev): """ This routine should capture key presses in the current view box. The following events are implemented: +/= : moves forward in the zooming stack (if it exists) - : moves backward in the zooming stack (if it exists) """ ev.accept() if ev.text() == "-": self.scaleBy([1.1, 1.1]) elif ev.text() in ["+", "="]: self.scaleBy([0.9, 0.9]) else: ev.ignore() class ImageDraw(pg.ImageItem): """ **Bases:** :class:`GraphicsObject ` GraphicsObject displaying an image. Optimized for rapid update (ie video display). This item displays either a 2D numpy array (height, width) or a 3D array (height, width, RGBa). This array is optionally scaled (see :func:`setLevels `) and/or colored with a lookup table (see :func:`setLookupTable `) before being displayed. ImageItem is frequently used in conjunction with :class:`HistogramLUTItem ` or :class:`HistogramLUTWidget ` to provide a GUI for controlling the levels and lookup table used to display the image. """ sigImageChanged = QtCore.Signal() def __init__(self, image=None, viewbox=None, parent=None, **kargs): super(ImageDraw, self).__init__() self.levels = np.array([0, 255]) self.lut = None self.autoDownsample = False self.axisOrder = "row-major" self.removable = False self.parent = parent self.setDrawKernel(kernel_size=self.parent.brush_size) self.parent.current_stroke = [] self.parent.in_stroke = False def mouseClickEvent(self, ev): if (self.parent.masksOn or self.parent.outlinesOn) and not self.parent.removing_region: is_right_click = ev.button() == QtCore.Qt.RightButton if self.parent.loaded \ and (is_right_click or ev.modifiers() & QtCore.Qt.ShiftModifier and not ev.double())\ and not self.parent.deleting_multiple: if not self.parent.in_stroke: ev.accept() self.create_start(ev.pos()) self.parent.stroke_appended = False self.parent.in_stroke = True self.drawAt(ev.pos(), ev) else: ev.accept() self.end_stroke() self.parent.in_stroke = False elif not self.parent.in_stroke: y, x = int(ev.pos().y()), int(ev.pos().x()) if y >= 0 and y < self.parent.Ly and x >= 0 and x < self.parent.Lx: if ev.button() == QtCore.Qt.LeftButton and not ev.double(): idx = self.parent.cellpix[self.parent.currentZ][y, x] if idx > 0: if ev.modifiers() & QtCore.Qt.ControlModifier: # delete mask selected self.parent.remove_cell(idx) elif ev.modifiers() & QtCore.Qt.AltModifier: self.parent.merge_cells(idx) elif self.parent.masksOn and not self.parent.deleting_multiple: self.parent.unselect_cell() self.parent.select_cell(idx) elif self.parent.deleting_multiple: if idx in self.parent.removing_cells_list: self.parent.unselect_cell_multi(idx) self.parent.removing_cells_list.remove(idx) else: self.parent.select_cell_multi(idx) self.parent.removing_cells_list.append(idx) elif self.parent.masksOn and not self.parent.deleting_multiple: self.parent.unselect_cell() def mouseDragEvent(self, ev): ev.ignore() return def hoverEvent(self, ev): if self.parent.in_stroke: if self.parent.in_stroke: # continue stroke if not at start self.drawAt(ev.pos()) if self.is_at_start(ev.pos()): self.end_stroke() else: ev.acceptClicks(QtCore.Qt.RightButton) def create_start(self, pos): self.scatter = pg.ScatterPlotItem([pos.x()], [pos.y()], pxMode=False, pen=pg.mkPen(color=(255, 0, 0), width=self.parent.brush_size), size=max(3 * 2, self.parent.brush_size * 1.8 * 2), brush=None) self.parent.p0.addItem(self.scatter) def is_at_start(self, pos): thresh_out = max(6, self.parent.brush_size * 3) thresh_in = max(3, self.parent.brush_size * 1.8) # first check if you ever left the start if len(self.parent.current_stroke) > 3: stroke = np.array(self.parent.current_stroke) dist = (((stroke[1:, 1:] - stroke[:1, 1:][np.newaxis, :, :])**2).sum(axis=-1))**0.5 dist = dist.flatten() has_left = (dist > thresh_out).nonzero()[0] if len(has_left) > 0: first_left = np.sort(has_left)[0] has_returned = (dist[max(4, first_left + 1):] < thresh_in).sum() if has_returned > 0: return True else: return False else: return False def end_stroke(self): self.parent.p0.removeItem(self.scatter) if not self.parent.stroke_appended: self.parent.strokes.append(self.parent.current_stroke) self.parent.stroke_appended = True self.parent.current_stroke = np.array(self.parent.current_stroke) ioutline = self.parent.current_stroke[:, 3] == 1 self.parent.current_point_set.append( list(self.parent.current_stroke[ioutline])) self.parent.current_stroke = [] if self.parent.autosave: self.parent.add_set() if len(self.parent.current_point_set) and len( self.parent.current_point_set[0]) > 0 and self.parent.autosave: self.parent.add_set() self.parent.in_stroke = False def tabletEvent(self, ev): pass def drawAt(self, pos, ev=None): mask = self.strokemask stroke = self.parent.current_stroke pos = [int(pos.y()), int(pos.x())] dk = self.drawKernel kc = self.drawKernelCenter sx = [0, dk.shape[0]] sy = [0, dk.shape[1]] tx = [pos[0] - kc[0], pos[0] - kc[0] + dk.shape[0]] ty = [pos[1] - kc[1], pos[1] - kc[1] + dk.shape[1]] kcent = kc.copy() if tx[0] <= 0: sx[0] = 0 sx[1] = kc[0] + 1 tx = sx kcent[0] = 0 if ty[0] <= 0: sy[0] = 0 sy[1] = kc[1] + 1 ty = sy kcent[1] = 0 if tx[1] >= self.parent.Ly - 1: sx[0] = dk.shape[0] - kc[0] - 1 sx[1] = dk.shape[0] tx[0] = self.parent.Ly - kc[0] - 1 tx[1] = self.parent.Ly kcent[0] = tx[1] - tx[0] - 1 if ty[1] >= self.parent.Lx - 1: sy[0] = dk.shape[1] - kc[1] - 1 sy[1] = dk.shape[1] ty[0] = self.parent.Lx - kc[1] - 1 ty[1] = self.parent.Lx kcent[1] = ty[1] - ty[0] - 1 ts = (slice(tx[0], tx[1]), slice(ty[0], ty[1])) ss = (slice(sx[0], sx[1]), slice(sy[0], sy[1])) self.image[ts] = mask[ss] for ky, y in enumerate(np.arange(ty[0], ty[1], 1, int)): for kx, x in enumerate(np.arange(tx[0], tx[1], 1, int)): iscent = np.logical_and(kx == kcent[0], ky == kcent[1]) stroke.append([self.parent.currentZ, x, y, iscent]) self.updateImage() def setDrawKernel(self, kernel_size=3): bs = kernel_size kernel = np.ones((bs, bs), np.uint8) self.drawKernel = kernel self.drawKernelCenter = [ int(np.floor(kernel.shape[0] / 2)), int(np.floor(kernel.shape[1] / 2)) ] onmask = 255 * kernel[:, :, np.newaxis] offmask = np.zeros((bs, bs, 1)) opamask = 100 * kernel[:, :, np.newaxis] self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1) self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1)