| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import sys |
| |
|
| | sys.path.append("./") |
| | import copy |
| | import os |
| | import pdb |
| | import tempfile |
| | import time |
| | from bisect import bisect_left |
| | from dataclasses import dataclass |
| |
|
| | import cv2 |
| | import numpy as np |
| | import PIL |
| | import torch |
| | from pytorch3d.ops import sample_farthest_points |
| | from sam2.build_sam import build_sam2 |
| | from sam2.sam2_image_predictor import SAM2ImagePredictor |
| | from torchvision import transforms |
| |
|
| | from engine.BiRefNet.models.birefnet import BiRefNet |
| | from engine.ouputs import BaseOutput |
| | from engine.SegmentAPI.base import BaseSeg, Bbox |
| | from engine.SegmentAPI.img_utils import load_image_file |
| |
|
| | SAM2_WEIGHT = "pretrained_models/sam2/sam2.1_hiera_large.pt" |
| | BIREFNET_WEIGHT = "pretrained_models/BiRefNet-general-epoch_244.pth" |
| |
|
| |
|
| | def avaliable_device(): |
| | if torch.cuda.is_available(): |
| | current_device_id = torch.cuda.current_device() |
| | device = f"cuda:{current_device_id}" |
| | else: |
| | device = "cpu" |
| |
|
| | return device |
| |
|
| |
|
| | @dataclass |
| | class SegmentOut(BaseOutput): |
| | masks: np.ndarray |
| | processed_img: np.ndarray |
| | alpha_img: np.ndarray |
| |
|
| |
|
| | def distance(p1, p2): |
| | return np.sqrt(np.sum((p1 - p2) ** 2)) |
| |
|
| |
|
| | def FPS(sample, num): |
| | n = sample.shape[0] |
| | center = np.mean(sample, axis=0) |
| | select_p = [] |
| | L = [] |
| | for i in range(n): |
| | L.append(distance(sample[i], center)) |
| | p0 = np.argmax(L) |
| | select_p.append(p0) |
| | L = [] |
| | for i in range(n): |
| | L.append(distance(p0, sample[i])) |
| | select_p.append(np.argmax(L)) |
| | for i in range(num - 2): |
| | for p in range(n): |
| | d = distance(sample[select_p[-1]], sample[p]) |
| | if d <= L[p]: |
| | L[p] = d |
| | select_p.append(np.argmax(L)) |
| | return select_p, sample[select_p] |
| |
|
| |
|
| | def fill_mask(alpha): |
| | |
| | h, w = alpha.shape[:2] |
| |
|
| | mask = np.zeros((h + 2, w + 2), np.uint8) |
| | alpha = (alpha * 255).astype(np.uint8) |
| | im_floodfill = alpha.copy() |
| | retval, image, mask, rect = cv2.floodFill(im_floodfill, mask, (0, 0), 255) |
| | im_floodfill_inv = cv2.bitwise_not(im_floodfill) |
| |
|
| | alpha = alpha | im_floodfill_inv |
| | alpha = alpha.astype(np.float32) / 255.0 |
| |
|
| | |
| | return alpha |
| |
|
| |
|
| | def erode_and_dialted(mask, kernel_size=3, iterations=1): |
| | kernel = np.ones((kernel_size, kernel_size), np.uint8) |
| |
|
| | eroded_mask = cv2.erode(mask, kernel, iterations=iterations) |
| |
|
| | dilated_mask = cv2.dilate(eroded_mask, kernel, iterations=iterations) |
| |
|
| | return dilated_mask |
| |
|
| |
|
| | def eroded(mask, kernel_size=3, iterations=1): |
| | kernel = np.ones((kernel_size, kernel_size), np.uint8) |
| | eroded_mask = cv2.erode(mask, kernel, iterations=iterations) |
| |
|
| | return eroded_mask |
| |
|
| |
|
| | def model_type(model): |
| | print(next(model.parameters()).device) |
| |
|
| |
|
| | class SAM2Seg(BaseSeg): |
| | RATIO_MAP = [[512, 1], [1280, 0.6], [1920, 0.4], [3840, 0.2]] |
| |
|
| | def tocpu(self): |
| | self.box_prior.cpu() |
| | self.image_predictor.model.cpu() |
| | torch.cuda.empty_cache() |
| |
|
| | def tocuda(self): |
| | self.box_prior.cuda() |
| | self.image_predictor.model.cuda() |
| |
|
| | def __init__( |
| | self, |
| | config="sam2.1_hiera_l.yaml", |
| | matting_config="resnet50", |
| | background=(1.0, 1.0, 1.0), |
| | wo_supres=False, |
| | ): |
| | super().__init__() |
| |
|
| | self.device = avaliable_device() |
| |
|
| | try: |
| | sam2_image_model = build_sam2(config, SAM2_WEIGHT) |
| | except: |
| | config = os.path.join("./configs/sam2.1/", config) |
| | sam2_image_model = build_sam2(config, SAM2_WEIGHT) |
| |
|
| | self.image_predictor = SAM2ImagePredictor(sam2_image_model) |
| |
|
| | self.box_prior = None |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | self.background = background |
| | self.wo_supers = wo_supres |
| |
|
| | def clean_up(self): |
| | self.tmp.cleanup() |
| |
|
| | def collect_inputs(self, inputs): |
| | return dict( |
| | img_path=inputs["img_path"], |
| | bbox=inputs["bbox"], |
| | ) |
| |
|
| | def _super_resolution(self, input_path): |
| |
|
| | low = os.path.abspath(input_path) |
| | high = self.tmp.name |
| |
|
| | super_weights = os.path.abspath("./pretrained_models/RealESRGAN_x4plus.pth") |
| | hander = os.path.join(SUPRES_PATH, "inference_realesrgan.py") |
| |
|
| | cmd = f"python {hander} -n RealESRGAN_x4plus -i {low} -o {high} --model_path {super_weights} -s 2" |
| |
|
| | os.system(cmd) |
| |
|
| | return os.path.join(high, os.path.basename(input_path)) |
| |
|
| | def predict_bbox(self, img, scale=1.0): |
| |
|
| | ratio = self.ratio_mapping(img) |
| |
|
| | |
| | |
| | img = np.asarray(img).astype(np.float32) / 255.0 |
| | height, width, _ = img.shape |
| |
|
| | |
| | img_tensor = torch.from_numpy(img).permute(2, 0, 1) |
| |
|
| | bgr = torch.tensor([1.0, 1.0, 1.0]).view(3, 1, 1).cuda() |
| | rec = [None] * 4 |
| |
|
| | |
| | with torch.no_grad(): |
| | img_tensor = img_tensor.unsqueeze(0).to(self.device) |
| | fgr, pha, *rec = self.matting_predictor( |
| | img_tensor.to(self.device), |
| | *rec, |
| | downsample_ratio=ratio, |
| | ) |
| |
|
| | pha[pha < 0.5] = 0.0 |
| | pha[pha >= 0.5] = 1.0 |
| | pha = pha[0].permute(1, 2, 0).detach().cpu().numpy() |
| |
|
| | |
| | _h, _w, _ = np.where(pha == 1) |
| |
|
| | whwh = [ |
| | _w.min().item(), |
| | _h.min().item(), |
| | _w.max().item(), |
| | _h.max().item(), |
| | ] |
| |
|
| | box = Bbox(whwh) |
| |
|
| | |
| | scale_box = box.scale(1.00, width=width, height=height) |
| |
|
| | return scale_box, pha[..., 0] |
| |
|
| | def birefnet_predict_bbox(self, img, scale=1.0): |
| |
|
| | |
| |
|
| | if self.box_prior == None: |
| | from engine.BiRefNet.utils import check_state_dict |
| |
|
| | birefnet = BiRefNet(bb_pretrained=False) |
| | state_dict = torch.load(BIREFNET_WEIGHT, map_location="cpu") |
| | state_dict = check_state_dict(state_dict) |
| | birefnet.load_state_dict(state_dict) |
| | device = avaliable_device() |
| | torch.set_float32_matmul_precision(["high", "highest"][0]) |
| |
|
| | birefnet.to(device) |
| | self.box_prior = birefnet |
| | self.box_prior.eval() |
| | self.box_transform = transforms.Compose( |
| | [ |
| | transforms.Resize((1024, 1024)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| | ] |
| | ) |
| | print("BiRefNet is ready to use.") |
| | else: |
| | device = avaliable_device() |
| | self.box_prior.to(device) |
| |
|
| | height, width, _ = img.shape |
| |
|
| | image = PIL.Image.fromarray(img) |
| |
|
| | input_images = self.box_transform(image).unsqueeze(0).to("cuda") |
| | with torch.no_grad(): |
| | preds = self.box_prior(input_images)[-1].sigmoid().cpu() |
| | pha = (preds[0]).squeeze(0).detach().numpy() |
| |
|
| | pha = cv2.resize(pha, (width, height)) |
| |
|
| | masks = copy.deepcopy(pha[..., None]) |
| |
|
| | masks[masks < 0.3] = 0.0 |
| | masks[masks >= 0.3] = 1.0 |
| |
|
| | |
| | _h, _w, _ = np.where(masks == 1) |
| |
|
| | whwh = [ |
| | _w.min().item(), |
| | _h.min().item(), |
| | _w.max().item(), |
| | _h.max().item(), |
| | ] |
| |
|
| | box = Bbox(whwh) |
| |
|
| | |
| | scale_box = box.scale(scale=scale, width=width, height=height) |
| |
|
| | return scale_box, pha |
| |
|
| | def rembg_predict_bbox(self, img, scale=1.0): |
| |
|
| | height, width, _ = img.shape |
| |
|
| | with torch.no_grad(): |
| | img_rmbg = img[..., ::-1] |
| | img_rmbg = remove(img_rmbg) |
| | img_rmbg = img_rmbg[..., :3] |
| | pha = copy.deepcopy(img_rmbg[..., -1:]) |
| |
|
| | masks = copy.deepcopy(pha) |
| |
|
| | masks[masks < 1.0] = 0.0 |
| | masks[masks >= 1.0] = 1.0 |
| |
|
| | |
| | _h, _w, _ = np.where(masks == 1) |
| |
|
| | whwh = [ |
| | _w.min().item(), |
| | _h.min().item(), |
| | _w.max().item(), |
| | _h.max().item(), |
| | ] |
| |
|
| | box = Bbox(whwh) |
| |
|
| | |
| | scale_box = box.scale(scale=scale, width=width, height=height) |
| |
|
| | return scale_box, pha[..., 0].astype(np.float32) / 255.0 |
| |
|
| | def yolo_predict_bbox(self, img, scale=1.0, threshold=0.2): |
| | if self.prior == None: |
| | from ultralytics import YOLO |
| |
|
| | pdb.set_trace() |
| |
|
| | height, width, _ = img.shape |
| |
|
| | with torch.no_grad(): |
| | results = yolo_seg(img[..., ::-1]) |
| | for result in results: |
| | masks = result.masks.data[result.boxes.cls == 0] |
| | if masks.shape[0] >= 1: |
| | masks[masks >= threshold] = 1 |
| | masks[masks < threshold] = 0 |
| | masks = masks.sum(dim=0) |
| |
|
| | pha = masks.detach().cpu().numpy() |
| | pha = cv2.resize(pha, (width, height), interpolation=cv2.INTER_AREA)[..., None] |
| |
|
| | pha[pha >= 0.5] = 1 |
| | pha[pha < 0.5] = 0 |
| |
|
| | masks = copy.deepcopy(pha) |
| |
|
| | pha = pha * 255.0 |
| | |
| | _h, _w, _ = np.where(masks == 1) |
| |
|
| | whwh = [ |
| | _w.min().item(), |
| | _h.min().item(), |
| | _w.max().item(), |
| | _h.max().item(), |
| | ] |
| |
|
| | box = Bbox(whwh) |
| |
|
| | |
| | scale_box = box.scale(scale=scale, width=width, height=height) |
| |
|
| | return scale_box, pha[..., 0].astype(np.float32) / 255.0 |
| |
|
| | def ratio_mapping(self, img): |
| |
|
| | my_ratio_map = self.RATIO_MAP |
| |
|
| | ratio_landmarks = [v[0] for v in my_ratio_map] |
| |
|
| | ratio_v = [v[1] for v in my_ratio_map] |
| | h, w, _ = img.shape |
| |
|
| | max_length = min(h, w) |
| |
|
| | low_bound = bisect_left( |
| | ratio_landmarks, max_length, lo=0, hi=len(ratio_landmarks) |
| | ) |
| |
|
| | if 0 == low_bound: |
| | return 1.0 |
| | elif low_bound == len(ratio_landmarks): |
| | return ratio_v[-1] |
| | else: |
| | _l = ratio_v[low_bound - 1] |
| | _r = ratio_v[low_bound] |
| |
|
| | _l_land = ratio_landmarks[low_bound - 1] |
| | _r_land = ratio_landmarks[low_bound] |
| | cur_ratio = _l + (_r - _l) * (max_length - _l_land) / (_r_land - _l_land) |
| |
|
| | return cur_ratio |
| |
|
| | def get_img(self, img_path, sup_res=True): |
| |
|
| | img = cv2.imread(img_path) |
| | img = img[..., ::-1].copy() |
| |
|
| | if self.wo_supers: |
| | return img |
| |
|
| | return img |
| |
|
| | def compute_coords(self, pha, bbox): |
| |
|
| | node_prompts = [] |
| |
|
| | H, W = pha.shape |
| | y_indices, x_indices = np.indices((H, W)) |
| | coors = np.stack((x_indices, y_indices), axis=-1) |
| |
|
| | |
| | |
| |
|
| | pha_coors = np.repeat(pha[..., None], 2, axis=2) |
| | coors_points = (coors * pha_coors).sum(axis=0).sum(axis=0) / (pha.sum() + 1e-6) |
| | node_prompts.append(coors_points.tolist()) |
| |
|
| | _h, _w = np.where(pha > 0.5) |
| |
|
| | sample_ps = torch.from_numpy(np.stack((_w, _h), axis=-1).astype(np.float32)).to( |
| | avaliable_device() |
| | ) |
| |
|
| | |
| | node_prompts_fps, _ = sample_farthest_points(sample_ps[None], K=5) |
| | node_prompts_fps = ( |
| | node_prompts_fps[0].detach().cpu().numpy().astype(np.int32).tolist() |
| | ) |
| |
|
| | node_prompts.extend(node_prompts_fps) |
| | node_prompts_label = [1 for _ in range(len(node_prompts))] |
| |
|
| | return node_prompts, node_prompts_label |
| |
|
| | def _forward(self, img_path, bbox, sup_res=True): |
| |
|
| | img = self.get_img(img_path, sup_res) |
| |
|
| | if bbox is None: |
| | |
| | |
| | |
| | bbox, pha = self.birefnet_predict_bbox(img, 1.01) |
| |
|
| | box = bbox.to_whwh() |
| | bbox = box.get_box() |
| |
|
| | point_coords, point_coords_label = self.compute_coords(pha, bbox) |
| |
|
| | self.image_predictor.set_image(img) |
| |
|
| | masks, scores, logits = self.image_predictor.predict( |
| | point_coords=point_coords, |
| | point_labels=point_coords_label, |
| | box=bbox, |
| | multimask_output=False, |
| | ) |
| |
|
| | alpha = masks[0] |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | img_float = img.astype(np.float32) / 255.0 |
| | process_img = ( |
| | img_float * alpha[..., None] + (1 - alpha[..., None]) * self.background |
| | ) |
| | process_img = (process_img * 255).astype(np.uint8) |
| |
|
| | |
| | |
| | process_img = process_img.astype(np.float) / 255.0 |
| |
|
| | process_pha_img = ( |
| | img_float * pha[..., None] + (1 - pha[..., None]) * self.background |
| | ) |
| |
|
| | return SegmentOut( |
| | masks=alpha, processed_img=process_img, alpha_img=process_pha_img[...] |
| | ) |
| |
|
| | @torch.no_grad() |
| | def __call__(self, **inputs): |
| |
|
| | self.tmp = tempfile.TemporaryDirectory() |
| |
|
| | self.collect_inputs(inputs) |
| |
|
| | out = self._forward(**inputs) |
| |
|
| | self.clean_up() |
| | return out |
| |
|
| |
|
| | def get_parse(): |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description="") |
| | parser.add_argument("-i", "--input", required=True, help="input path") |
| | parser.add_argument("-o", "--output", required=True, help="output path") |
| | parser.add_argument("--mask", action="store_true", help="mask bool") |
| | parser.add_argument( |
| | "--wo_super_reso", action="store_true", help="whether using super_resolution" |
| | ) |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def main(): |
| |
|
| | opt = get_parse() |
| | img_list = os.listdir(opt.input) |
| | img_names = [os.path.join(opt.input, img_name) for img_name in img_list] |
| |
|
| | os.makedirs(opt.output, exist_ok=True) |
| |
|
| | model = SAM2Seg(wo_supres=opt.wo_super_reso) |
| |
|
| | for img in img_names: |
| |
|
| | print(f"processing {img}") |
| | out = model(img_path=img, bbox=None) |
| |
|
| | save_path = os.path.join(opt.output, os.path.basename(img)) |
| |
|
| | alpha = fill_mask(out.masks) |
| | alpha = erode_and_dialted( |
| | (alpha * 255).astype(np.uint8), kernel_size=3, iterations=3 |
| | ) |
| | save_img = alpha |
| | cv2.imwrite(save_path, save_img) |
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | main() |
| |
|