|
|
import time |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import cv2 |
|
|
from PIL import Image, ImageDraw, ImageOps |
|
|
import numpy as np |
|
|
from typing import Union |
|
|
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator |
|
|
from segment_anything.modeling.image_encoder import window_partition, window_unpartition, get_rel_pos, Block as image_encoder_block |
|
|
import matplotlib.pyplot as plt |
|
|
import PIL |
|
|
from .mask_painter import mask_painter |
|
|
from shared.utils import files_locator as fl |
|
|
|
|
|
|
|
|
_bfloat16_supported = torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False |
|
|
|
|
|
|
|
|
def _patched_forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
"""VRAM-optimized forward pass for SAM image encoder blocks. |
|
|
|
|
|
Optimizations made by DeepBeepMeep |
|
|
""" |
|
|
def split_mlp(mlp, x, divide=4): |
|
|
x_shape = x.shape |
|
|
x = x.view(-1, x.shape[-1]) |
|
|
chunk_size = int(x.shape[0] / divide) |
|
|
x_chunks = torch.split(x, chunk_size) |
|
|
for i, x_chunk in enumerate(x_chunks): |
|
|
mlp_chunk = mlp.lin1(x_chunk) |
|
|
mlp_chunk = mlp.act(mlp_chunk) |
|
|
x_chunk[...] = mlp.lin2(mlp_chunk) |
|
|
return x.reshape(x_shape) |
|
|
|
|
|
def get_decomposed_rel_pos(q, rel_pos_h, rel_pos_w, q_size, k_size) -> torch.Tensor: |
|
|
q_h, q_w = q_size |
|
|
k_h, k_w = k_size |
|
|
Rh = get_rel_pos(q_h, k_h, rel_pos_h) |
|
|
Rw = get_rel_pos(q_w, k_w, rel_pos_w) |
|
|
B, _, dim = q.shape |
|
|
r_q = q.reshape(B, q_h, q_w, dim) |
|
|
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) |
|
|
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) |
|
|
attn = torch.zeros(B, q_h, q_w, k_h, k_w, dtype=q.dtype, device=q.device) |
|
|
attn += rel_h[:, :, :, :, None] |
|
|
attn += rel_w[:, :, :, None, :] |
|
|
return attn.view(B, q_h * q_w, k_h * k_w) |
|
|
|
|
|
def pay_attention(self, x: torch.Tensor, split_heads=1) -> torch.Tensor: |
|
|
B, H, W, _ = x.shape |
|
|
|
|
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) |
|
|
|
|
|
if not _bfloat16_supported: |
|
|
qkv = qkv.to(torch.float16) |
|
|
|
|
|
|
|
|
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) |
|
|
if split_heads == 1: |
|
|
attn_mask = None |
|
|
if self.use_rel_pos: |
|
|
attn_mask = get_decomposed_rel_pos(q, self.rel_pos_h.to(q), self.rel_pos_w.to(q), (H, W), (H, W)) |
|
|
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=self.scale) |
|
|
else: |
|
|
chunk_size = self.num_heads // split_heads |
|
|
x = torch.empty_like(q) |
|
|
q_chunks = torch.split(q, chunk_size) |
|
|
k_chunks = torch.split(k, chunk_size) |
|
|
v_chunks = torch.split(v, chunk_size) |
|
|
x_chunks = torch.split(x, chunk_size) |
|
|
for x_chunk, q_chunk, k_chunk, v_chunk in zip(x_chunks, q_chunks, k_chunks, v_chunks): |
|
|
attn_mask = None |
|
|
if self.use_rel_pos: |
|
|
attn_mask = get_decomposed_rel_pos(q_chunk, self.rel_pos_h.to(q), self.rel_pos_w.to(q), (H, W), (H, W)) |
|
|
x_chunk[...] = F.scaled_dot_product_attention(q_chunk, k_chunk, v_chunk, attn_mask=attn_mask, scale=self.scale) |
|
|
del x_chunk, q_chunk, k_chunk, v_chunk |
|
|
del q, k, v, attn_mask |
|
|
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) |
|
|
if not _bfloat16_supported: |
|
|
x = x.to(torch.bfloat16) |
|
|
|
|
|
return self.proj(x) |
|
|
|
|
|
shortcut = x |
|
|
x = self.norm1(x) |
|
|
|
|
|
if self.window_size > 0: |
|
|
H, W = x.shape[1], x.shape[2] |
|
|
x, pad_hw = window_partition(x, self.window_size) |
|
|
x_shape = x.shape |
|
|
|
|
|
if x_shape[0] > 10: |
|
|
chunk_size = int(x.shape[0] / 4) + 1 |
|
|
x_chunks = torch.split(x, chunk_size) |
|
|
for i, x_chunk in enumerate(x_chunks): |
|
|
x_chunk[...] = pay_attention(self.attn, x_chunk) |
|
|
else: |
|
|
x = pay_attention(self.attn, x, 4) |
|
|
|
|
|
|
|
|
if self.window_size > 0: |
|
|
x = window_unpartition(x, self.window_size, pad_hw, (H, W)) |
|
|
x += shortcut |
|
|
shortcut[...] = self.norm2(x) |
|
|
x += split_mlp(self.mlp, shortcut) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
def set_image_encoder_patch(): |
|
|
"""Apply VRAM optimizations to SAM image encoder blocks.""" |
|
|
if not hasattr(image_encoder_block, "patched"): |
|
|
image_encoder_block.forward = _patched_forward |
|
|
image_encoder_block.patched = True |
|
|
|
|
|
|
|
|
class BaseSegmenter: |
|
|
def __init__(self, SAM_checkpoint, model_type, device='cuda:0'): |
|
|
""" |
|
|
device: model device |
|
|
SAM_checkpoint: path of SAM checkpoint |
|
|
model_type: vit_b, vit_l, vit_h |
|
|
""" |
|
|
print(f"Initializing BaseSegmenter to {device}") |
|
|
assert model_type in ['vit_b', 'vit_l', 'vit_h'], 'model_type must be vit_b, vit_l, or vit_h' |
|
|
|
|
|
|
|
|
set_image_encoder_patch() |
|
|
|
|
|
self.device = device |
|
|
|
|
|
self.torch_dtype = torch.float16 if 'cuda' in device else torch.float32 |
|
|
from accelerate import init_empty_weights |
|
|
|
|
|
|
|
|
with init_empty_weights(): |
|
|
self.model = sam_model_registry[model_type](checkpoint=SAM_checkpoint) |
|
|
from mmgp import offload |
|
|
|
|
|
|
|
|
|
|
|
offload.load_model_data(self.model, fl.locate_file("mask/sam_vit_h_4b8939_fp16.safetensors")) |
|
|
self.model.to(torch.float32) |
|
|
self.model.to(device=self.device) |
|
|
self.predictor = SamPredictor(self.model) |
|
|
self.embedded = False |
|
|
|
|
|
@torch.no_grad() |
|
|
def set_image(self, image: np.ndarray): |
|
|
|
|
|
|
|
|
self.orignal_image = image |
|
|
if self.embedded: |
|
|
print('repeat embedding, please reset_image.') |
|
|
return |
|
|
self.predictor.set_image(image) |
|
|
self.embedded = True |
|
|
return |
|
|
|
|
|
@torch.no_grad() |
|
|
def reset_image(self): |
|
|
|
|
|
self.predictor.reset_image() |
|
|
self.embedded = False |
|
|
|
|
|
def predict(self, prompts, mode, multimask=True): |
|
|
""" |
|
|
image: numpy array, h, w, 3 |
|
|
prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input' |
|
|
prompts['point_coords']: numpy array [N,2] |
|
|
prompts['point_labels']: numpy array [1,N] |
|
|
prompts['mask_input']: numpy array [1,256,256] |
|
|
mode: 'point' (points only), 'mask' (mask only), 'both' (consider both) |
|
|
mask_outputs: True (return 3 masks), False (return 1 mask only) |
|
|
whem mask_outputs=True, mask_input=logits[np.argmax(scores), :, :][None, :, :] |
|
|
""" |
|
|
assert self.embedded, 'prediction is called before set_image (feature embedding).' |
|
|
assert mode in ['point', 'mask', 'both'], 'mode must be point, mask, or both' |
|
|
|
|
|
with torch.autocast(device_type='cuda', dtype=torch.float16): |
|
|
if mode == 'point': |
|
|
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], |
|
|
point_labels=prompts['point_labels'], |
|
|
multimask_output=multimask) |
|
|
elif mode == 'mask': |
|
|
masks, scores, logits = self.predictor.predict(mask_input=prompts['mask_input'], |
|
|
multimask_output=multimask) |
|
|
elif mode == 'both': |
|
|
masks, scores, logits = self.predictor.predict(point_coords=prompts['point_coords'], |
|
|
point_labels=prompts['point_labels'], |
|
|
mask_input=prompts['mask_input'], |
|
|
multimask_output=multimask) |
|
|
else: |
|
|
raise("Not implement now!") |
|
|
|
|
|
return masks, scores, logits |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
image = cv2.imread('/hhd3/gaoshang/truck.jpg') |
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
SAM_checkpoint= '/ssd1/gaomingqi/checkpoints/sam_vit_h_4b8939.pth' |
|
|
model_type = 'vit_h' |
|
|
device = "cuda:4" |
|
|
base_segmenter = BaseSegmenter(SAM_checkpoint=SAM_checkpoint, model_type=model_type, device=device) |
|
|
|
|
|
|
|
|
base_segmenter.set_image(image) |
|
|
|
|
|
|
|
|
|
|
|
mode = 'point' |
|
|
prompts = { |
|
|
'point_coords': np.array([[500, 375], [1125, 625]]), |
|
|
'point_labels': np.array([1, 1]), |
|
|
} |
|
|
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=False) |
|
|
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) |
|
|
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) |
|
|
cv2.imwrite('/hhd3/gaoshang/truck_point.jpg', painted_image) |
|
|
|
|
|
|
|
|
mode = 'both' |
|
|
mask_input = logits[np.argmax(scores), :, :] |
|
|
prompts = {'mask_input': mask_input [None, :, :]} |
|
|
prompts = { |
|
|
'point_coords': np.array([[500, 375], [1125, 625]]), |
|
|
'point_labels': np.array([1, 0]), |
|
|
'mask_input': mask_input[None, :, :] |
|
|
} |
|
|
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) |
|
|
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) |
|
|
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) |
|
|
cv2.imwrite('/hhd3/gaoshang/truck_both.jpg', painted_image) |
|
|
|
|
|
|
|
|
mode = 'mask' |
|
|
mask_input = logits[np.argmax(scores), :, :] |
|
|
|
|
|
prompts = {'mask_input': mask_input[None, :, :]} |
|
|
|
|
|
masks, scores, logits = base_segmenter.predict(prompts, mode, multimask=True) |
|
|
painted_image = mask_painter(image, masks[np.argmax(scores)].astype('uint8'), background_alpha=0.8) |
|
|
painted_image = cv2.cvtColor(painted_image, cv2.COLOR_RGB2BGR) |
|
|
cv2.imwrite('/hhd3/gaoshang/truck_mask.jpg', painted_image) |
|
|
|