Spaces:
No application file
No application file
| # -*- coding: utf-8 -*- | |
| import sys | |
| import time | |
| from PyQt5.QtGui import ( | |
| QBrush, | |
| QPainter, | |
| QPen, | |
| QPixmap, | |
| QKeySequence, | |
| QPen, | |
| QBrush, | |
| QColor, | |
| QImage, | |
| ) | |
| from PyQt5.QtWidgets import ( | |
| QFileDialog, | |
| QApplication, | |
| QGraphicsEllipseItem, | |
| QGraphicsItem, | |
| QGraphicsRectItem, | |
| QGraphicsScene, | |
| QGraphicsView, | |
| QGraphicsPixmapItem, | |
| QHBoxLayout, | |
| QPushButton, | |
| QSlider, | |
| QVBoxLayout, | |
| QWidget, | |
| QShortcut, | |
| ) | |
| import numpy as np | |
| from skimage import transform, io | |
| import torch | |
| import torch.nn as nn | |
| from torch.nn import functional as F | |
| from PIL import Image | |
| from segment_anything import sam_model_registry | |
| # freeze seeds | |
| torch.manual_seed(2023) | |
| torch.cuda.empty_cache() | |
| torch.cuda.manual_seed(2023) | |
| np.random.seed(2023) | |
| SAM_MODEL_TYPE = "vit_b" | |
| MedSAM_CKPT_PATH = "work_dir/MedSAM/medsam_vit_b.pth" | |
| MEDSAM_IMG_INPUT_SIZE = 1024 | |
| if torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| else: | |
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| def medsam_inference(medsam_model, img_embed, box_1024, height, width): | |
| box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device) | |
| if len(box_torch.shape) == 2: | |
| box_torch = box_torch[:, None, :] # (B, 1, 4) | |
| sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder( | |
| points=None, | |
| boxes=box_torch, | |
| masks=None, | |
| ) | |
| low_res_logits, _ = medsam_model.mask_decoder( | |
| image_embeddings=img_embed, # (B, 256, 64, 64) | |
| image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64) | |
| sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256) | |
| dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64) | |
| multimask_output=False, | |
| ) | |
| low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256) | |
| low_res_pred = F.interpolate( | |
| low_res_pred, | |
| size=(height, width), | |
| mode="bilinear", | |
| align_corners=False, | |
| ) # (1, 1, gt.shape) | |
| low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256) | |
| medsam_seg = (low_res_pred > 0.5).astype(np.uint8) | |
| return medsam_seg | |
| print("Loading MedSAM model, a sec.") | |
| tic = time.perf_counter() | |
| # set up model | |
| medsam_model = sam_model_registry["vit_b"](checkpoint=MedSAM_CKPT_PATH).to(device) | |
| medsam_model.eval() | |
| print(f"Done, took {time.perf_counter() - tic}") | |
| def np2pixmap(np_img): | |
| height, width, channel = np_img.shape | |
| bytesPerLine = 3 * width | |
| qImg = QImage(np_img.data, width, height, bytesPerLine, QImage.Format_RGB888) | |
| return QPixmap.fromImage(qImg) | |
| colors = [ | |
| (255, 0, 0), | |
| (0, 255, 0), | |
| (0, 0, 255), | |
| (255, 255, 0), | |
| (255, 0, 255), | |
| (0, 255, 255), | |
| (128, 0, 0), | |
| (0, 128, 0), | |
| (0, 0, 128), | |
| (128, 128, 0), | |
| (128, 0, 128), | |
| (0, 128, 128), | |
| (255, 255, 255), | |
| (192, 192, 192), | |
| (64, 64, 64), | |
| (255, 0, 255), | |
| (0, 255, 255), | |
| (255, 255, 0), | |
| (0, 0, 127), | |
| (192, 0, 192), | |
| ] | |
| class Window(QWidget): | |
| def __init__(self): | |
| super().__init__() | |
| # configs | |
| self.half_point_size = 5 # radius of bbox starting and ending points | |
| # app stats | |
| self.image_path = None | |
| self.color_idx = 0 | |
| self.bg_img = None | |
| self.is_mouse_down = False | |
| self.rect = None | |
| self.point_size = self.half_point_size * 2 | |
| self.start_point = None | |
| self.end_point = None | |
| self.start_pos = (None, None) | |
| self.embedding = None | |
| self.prev_mask = None | |
| self.view = QGraphicsView() | |
| self.view.setRenderHint(QPainter.Antialiasing) | |
| pixmap = self.load_image() | |
| vbox = QVBoxLayout(self) | |
| vbox.addWidget(self.view) | |
| load_button = QPushButton("Load Image") | |
| save_button = QPushButton("Save Mask") | |
| hbox = QHBoxLayout(self) | |
| hbox.addWidget(load_button) | |
| hbox.addWidget(save_button) | |
| vbox.addLayout(hbox) | |
| self.setLayout(vbox) | |
| # keyboard shortcuts | |
| self.quit_shortcut = QShortcut(QKeySequence("Ctrl+Q"), self) | |
| self.quit_shortcut.activated.connect(lambda: quit()) | |
| self.undo_shortcut = QShortcut(QKeySequence("Ctrl+Z"), self) | |
| self.undo_shortcut.activated.connect(self.undo) | |
| load_button.clicked.connect(self.load_image) | |
| save_button.clicked.connect(self.save_mask) | |
| def undo(self): | |
| if self.prev_mask is None: | |
| print("No previous mask record") | |
| return | |
| self.color_idx -= 1 | |
| bg = Image.fromarray(self.img_3c.astype("uint8"), "RGB") | |
| mask = Image.fromarray(self.prev_mask.astype("uint8"), "RGB") | |
| img = Image.blend(bg, mask, 0.2) | |
| self.scene.removeItem(self.bg_img) | |
| self.bg_img = self.scene.addPixmap(np2pixmap(np.array(img))) | |
| self.mask_c = self.prev_mask | |
| self.prev_mask = None | |
| def load_image(self): | |
| file_path, file_type = QFileDialog.getOpenFileName( | |
| self, "Choose Image to Segment", ".", "Image Files (*.png *.jpg *.bmp)" | |
| ) | |
| if file_path is None or len(file_path) == 0: | |
| print("No image path specified, plz select an image") | |
| exit() | |
| img_np = io.imread(file_path) | |
| if len(img_np.shape) == 2: | |
| img_3c = np.repeat(img_np[:, :, None], 3, axis=-1) | |
| else: | |
| img_3c = img_np | |
| self.img_3c = img_3c | |
| self.image_path = file_path | |
| self.get_embeddings() | |
| pixmap = np2pixmap(self.img_3c) | |
| H, W, _ = self.img_3c.shape | |
| self.scene = QGraphicsScene(0, 0, W, H) | |
| self.end_point = None | |
| self.rect = None | |
| self.bg_img = self.scene.addPixmap(pixmap) | |
| self.bg_img.setPos(0, 0) | |
| self.mask_c = np.zeros((*self.img_3c.shape[:2], 3), dtype="uint8") | |
| self.view.setScene(self.scene) | |
| # events | |
| self.scene.mousePressEvent = self.mouse_press | |
| self.scene.mouseMoveEvent = self.mouse_move | |
| self.scene.mouseReleaseEvent = self.mouse_release | |
| def mouse_press(self, ev): | |
| x, y = ev.scenePos().x(), ev.scenePos().y() | |
| self.is_mouse_down = True | |
| self.start_pos = ev.scenePos().x(), ev.scenePos().y() | |
| self.start_point = self.scene.addEllipse( | |
| x - self.half_point_size, | |
| y - self.half_point_size, | |
| self.point_size, | |
| self.point_size, | |
| pen=QPen(QColor("red")), | |
| brush=QBrush(QColor("red")), | |
| ) | |
| def mouse_move(self, ev): | |
| if not self.is_mouse_down: | |
| return | |
| x, y = ev.scenePos().x(), ev.scenePos().y() | |
| if self.end_point is not None: | |
| self.scene.removeItem(self.end_point) | |
| self.end_point = self.scene.addEllipse( | |
| x - self.half_point_size, | |
| y - self.half_point_size, | |
| self.point_size, | |
| self.point_size, | |
| pen=QPen(QColor("red")), | |
| brush=QBrush(QColor("red")), | |
| ) | |
| if self.rect is not None: | |
| self.scene.removeItem(self.rect) | |
| sx, sy = self.start_pos | |
| xmin = min(x, sx) | |
| xmax = max(x, sx) | |
| ymin = min(y, sy) | |
| ymax = max(y, sy) | |
| self.rect = self.scene.addRect( | |
| xmin, ymin, xmax - xmin, ymax - ymin, pen=QPen(QColor("red")) | |
| ) | |
| def mouse_release(self, ev): | |
| x, y = ev.scenePos().x(), ev.scenePos().y() | |
| sx, sy = self.start_pos | |
| xmin = min(x, sx) | |
| xmax = max(x, sx) | |
| ymin = min(y, sy) | |
| ymax = max(y, sy) | |
| self.is_mouse_down = False | |
| H, W, _ = self.img_3c.shape | |
| box_np = np.array([[xmin, ymin, xmax, ymax]]) | |
| # print("bounding box:", box_np) | |
| box_1024 = box_np / np.array([W, H, W, H]) * 1024 | |
| sam_mask = medsam_inference(medsam_model, self.embedding, box_1024, H, W) | |
| self.prev_mask = self.mask_c.copy() | |
| self.mask_c[sam_mask != 0] = colors[self.color_idx % len(colors)] | |
| self.color_idx += 1 | |
| bg = Image.fromarray(self.img_3c.astype("uint8"), "RGB") | |
| mask = Image.fromarray(self.mask_c.astype("uint8"), "RGB") | |
| img = Image.blend(bg, mask, 0.2) | |
| self.scene.removeItem(self.bg_img) | |
| self.bg_img = self.scene.addPixmap(np2pixmap(np.array(img))) | |
| def save_mask(self): | |
| out_path = f"{self.image_path.split('.')[0]}_mask.png" | |
| io.imsave(out_path, self.mask_c) | |
| def get_embeddings(self): | |
| print("Calculating embedding, gui may be unresponsive.") | |
| img_1024 = transform.resize( | |
| self.img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True | |
| ).astype(np.uint8) | |
| img_1024 = (img_1024 - img_1024.min()) / np.clip( | |
| img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None | |
| ) # normalize to [0, 1], (H, W, 3) | |
| # convert the shape to (3, H, W) | |
| img_1024_tensor = ( | |
| torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device) | |
| ) | |
| # if self.embedding is None: | |
| with torch.no_grad(): | |
| self.embedding = medsam_model.image_encoder( | |
| img_1024_tensor | |
| ) # (1, 256, 64, 64) | |
| print("Done.") | |
| app = QApplication(sys.argv) | |
| w = Window() | |
| w.show() | |
| app.exec() | |