YPan0's picture
Upload folder using huggingface_hub
b6deff2 verified
# -*- coding: utf-8 -*-
import sys
import time
from PyQt5.QtGui import (
QBrush,
QPainter,
QPen,
QPixmap,
QKeySequence,
QPen,
QBrush,
QColor,
QImage,
)
from PyQt5.QtWidgets import (
QFileDialog,
QApplication,
QGraphicsEllipseItem,
QGraphicsItem,
QGraphicsRectItem,
QGraphicsScene,
QGraphicsView,
QGraphicsPixmapItem,
QHBoxLayout,
QPushButton,
QSlider,
QVBoxLayout,
QWidget,
QShortcut,
)
import numpy as np
from skimage import transform, io
import torch
import torch.nn as nn
from torch.nn import functional as F
from PIL import Image
from segment_anything import sam_model_registry
# freeze seeds
torch.manual_seed(2023)
torch.cuda.empty_cache()
torch.cuda.manual_seed(2023)
np.random.seed(2023)
SAM_MODEL_TYPE = "vit_b"
MedSAM_CKPT_PATH = "work_dir/MedSAM/medsam_vit_b.pth"
MEDSAM_IMG_INPUT_SIZE = 1024
if torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, height, width):
box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
if len(box_torch.shape) == 2:
box_torch = box_torch[:, None, :] # (B, 1, 4)
sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
points=None,
boxes=box_torch,
masks=None,
)
low_res_logits, _ = medsam_model.mask_decoder(
image_embeddings=img_embed, # (B, 256, 64, 64)
image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
multimask_output=False,
)
low_res_pred = torch.sigmoid(low_res_logits) # (1, 1, 256, 256)
low_res_pred = F.interpolate(
low_res_pred,
size=(height, width),
mode="bilinear",
align_corners=False,
) # (1, 1, gt.shape)
low_res_pred = low_res_pred.squeeze().cpu().numpy() # (256, 256)
medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
return medsam_seg
print("Loading MedSAM model, a sec.")
tic = time.perf_counter()
# set up model
medsam_model = sam_model_registry["vit_b"](checkpoint=MedSAM_CKPT_PATH).to(device)
medsam_model.eval()
print(f"Done, took {time.perf_counter() - tic}")
def np2pixmap(np_img):
height, width, channel = np_img.shape
bytesPerLine = 3 * width
qImg = QImage(np_img.data, width, height, bytesPerLine, QImage.Format_RGB888)
return QPixmap.fromImage(qImg)
colors = [
(255, 0, 0),
(0, 255, 0),
(0, 0, 255),
(255, 255, 0),
(255, 0, 255),
(0, 255, 255),
(128, 0, 0),
(0, 128, 0),
(0, 0, 128),
(128, 128, 0),
(128, 0, 128),
(0, 128, 128),
(255, 255, 255),
(192, 192, 192),
(64, 64, 64),
(255, 0, 255),
(0, 255, 255),
(255, 255, 0),
(0, 0, 127),
(192, 0, 192),
]
class Window(QWidget):
def __init__(self):
super().__init__()
# configs
self.half_point_size = 5 # radius of bbox starting and ending points
# app stats
self.image_path = None
self.color_idx = 0
self.bg_img = None
self.is_mouse_down = False
self.rect = None
self.point_size = self.half_point_size * 2
self.start_point = None
self.end_point = None
self.start_pos = (None, None)
self.embedding = None
self.prev_mask = None
self.view = QGraphicsView()
self.view.setRenderHint(QPainter.Antialiasing)
pixmap = self.load_image()
vbox = QVBoxLayout(self)
vbox.addWidget(self.view)
load_button = QPushButton("Load Image")
save_button = QPushButton("Save Mask")
hbox = QHBoxLayout(self)
hbox.addWidget(load_button)
hbox.addWidget(save_button)
vbox.addLayout(hbox)
self.setLayout(vbox)
# keyboard shortcuts
self.quit_shortcut = QShortcut(QKeySequence("Ctrl+Q"), self)
self.quit_shortcut.activated.connect(lambda: quit())
self.undo_shortcut = QShortcut(QKeySequence("Ctrl+Z"), self)
self.undo_shortcut.activated.connect(self.undo)
load_button.clicked.connect(self.load_image)
save_button.clicked.connect(self.save_mask)
def undo(self):
if self.prev_mask is None:
print("No previous mask record")
return
self.color_idx -= 1
bg = Image.fromarray(self.img_3c.astype("uint8"), "RGB")
mask = Image.fromarray(self.prev_mask.astype("uint8"), "RGB")
img = Image.blend(bg, mask, 0.2)
self.scene.removeItem(self.bg_img)
self.bg_img = self.scene.addPixmap(np2pixmap(np.array(img)))
self.mask_c = self.prev_mask
self.prev_mask = None
def load_image(self):
file_path, file_type = QFileDialog.getOpenFileName(
self, "Choose Image to Segment", ".", "Image Files (*.png *.jpg *.bmp)"
)
if file_path is None or len(file_path) == 0:
print("No image path specified, plz select an image")
exit()
img_np = io.imread(file_path)
if len(img_np.shape) == 2:
img_3c = np.repeat(img_np[:, :, None], 3, axis=-1)
else:
img_3c = img_np
self.img_3c = img_3c
self.image_path = file_path
self.get_embeddings()
pixmap = np2pixmap(self.img_3c)
H, W, _ = self.img_3c.shape
self.scene = QGraphicsScene(0, 0, W, H)
self.end_point = None
self.rect = None
self.bg_img = self.scene.addPixmap(pixmap)
self.bg_img.setPos(0, 0)
self.mask_c = np.zeros((*self.img_3c.shape[:2], 3), dtype="uint8")
self.view.setScene(self.scene)
# events
self.scene.mousePressEvent = self.mouse_press
self.scene.mouseMoveEvent = self.mouse_move
self.scene.mouseReleaseEvent = self.mouse_release
def mouse_press(self, ev):
x, y = ev.scenePos().x(), ev.scenePos().y()
self.is_mouse_down = True
self.start_pos = ev.scenePos().x(), ev.scenePos().y()
self.start_point = self.scene.addEllipse(
x - self.half_point_size,
y - self.half_point_size,
self.point_size,
self.point_size,
pen=QPen(QColor("red")),
brush=QBrush(QColor("red")),
)
def mouse_move(self, ev):
if not self.is_mouse_down:
return
x, y = ev.scenePos().x(), ev.scenePos().y()
if self.end_point is not None:
self.scene.removeItem(self.end_point)
self.end_point = self.scene.addEllipse(
x - self.half_point_size,
y - self.half_point_size,
self.point_size,
self.point_size,
pen=QPen(QColor("red")),
brush=QBrush(QColor("red")),
)
if self.rect is not None:
self.scene.removeItem(self.rect)
sx, sy = self.start_pos
xmin = min(x, sx)
xmax = max(x, sx)
ymin = min(y, sy)
ymax = max(y, sy)
self.rect = self.scene.addRect(
xmin, ymin, xmax - xmin, ymax - ymin, pen=QPen(QColor("red"))
)
def mouse_release(self, ev):
x, y = ev.scenePos().x(), ev.scenePos().y()
sx, sy = self.start_pos
xmin = min(x, sx)
xmax = max(x, sx)
ymin = min(y, sy)
ymax = max(y, sy)
self.is_mouse_down = False
H, W, _ = self.img_3c.shape
box_np = np.array([[xmin, ymin, xmax, ymax]])
# print("bounding box:", box_np)
box_1024 = box_np / np.array([W, H, W, H]) * 1024
sam_mask = medsam_inference(medsam_model, self.embedding, box_1024, H, W)
self.prev_mask = self.mask_c.copy()
self.mask_c[sam_mask != 0] = colors[self.color_idx % len(colors)]
self.color_idx += 1
bg = Image.fromarray(self.img_3c.astype("uint8"), "RGB")
mask = Image.fromarray(self.mask_c.astype("uint8"), "RGB")
img = Image.blend(bg, mask, 0.2)
self.scene.removeItem(self.bg_img)
self.bg_img = self.scene.addPixmap(np2pixmap(np.array(img)))
def save_mask(self):
out_path = f"{self.image_path.split('.')[0]}_mask.png"
io.imsave(out_path, self.mask_c)
@torch.no_grad()
def get_embeddings(self):
print("Calculating embedding, gui may be unresponsive.")
img_1024 = transform.resize(
self.img_3c, (1024, 1024), order=3, preserve_range=True, anti_aliasing=True
).astype(np.uint8)
img_1024 = (img_1024 - img_1024.min()) / np.clip(
img_1024.max() - img_1024.min(), a_min=1e-8, a_max=None
) # normalize to [0, 1], (H, W, 3)
# convert the shape to (3, H, W)
img_1024_tensor = (
torch.tensor(img_1024).float().permute(2, 0, 1).unsqueeze(0).to(device)
)
# if self.embedding is None:
with torch.no_grad():
self.embedding = medsam_model.image_encoder(
img_1024_tensor
) # (1, 256, 64, 64)
print("Done.")
app = QApplication(sys.argv)
w = Window()
w.show()
app.exec()