| | |
| | import pdb |
| | import time |
| |
|
| | import torch |
| | import tqlt.utils as tu |
| | from models.birefnet import BiRefNet |
| | from PIL import Image |
| | from torchvision import transforms |
| |
|
| | |
| | from transformers import AutoModelForImageSegmentation |
| |
|
| | |
| | from utils import check_state_dict |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | imgs = tu.next_files("./in_the_wild", ".png") |
| |
|
| |
|
| | birefnet = BiRefNet(bb_pretrained=False) |
| | state_dict = torch.load("./BiRefNet-general-epoch_244.pth", map_location="cpu") |
| | state_dict = check_state_dict(state_dict) |
| | birefnet.load_state_dict(state_dict) |
| |
|
| |
|
| | |
| | device = "cuda" |
| | torch.set_float32_matmul_precision(["high", "highest"][0]) |
| |
|
| | birefnet.to(device) |
| | birefnet.eval() |
| | print("BiRefNet is ready to use.") |
| |
|
| | |
| | transform_image = transforms.Compose( |
| | [ |
| | transforms.Resize((1024, 1024)), |
| | transforms.ToTensor(), |
| | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| | ] |
| | ) |
| |
|
| |
|
| | import os |
| | from glob import glob |
| |
|
| | from image_proc import refine_foreground |
| |
|
| | src_dir = "./images_todo" |
| | image_paths = glob(os.path.join(src_dir, "*")) |
| | dst_dir = "./predictions" |
| | os.makedirs(dst_dir, exist_ok=True) |
| | for image_path in imgs: |
| |
|
| | print("Processing {} ...".format(image_path)) |
| | image = Image.open(image_path) |
| | input_images = transform_image(image).unsqueeze(0).to("cuda") |
| |
|
| | |
| | start = time.time() |
| |
|
| | with torch.no_grad(): |
| | preds = birefnet(input_images)[-1].sigmoid().cpu() |
| |
|
| | print(time.time() - start) |
| | pred = preds[0].squeeze() |
| |
|
| | |
| | file_ext = os.path.splitext(image_path)[-1] |
| | pred_pil = transforms.ToPILImage()(pred) |
| | pred_pil = pred_pil.resize(image.size) |
| | pred_pil.save(image_path.replace(src_dir, dst_dir).replace(file_ext, "-mask.png")) |
| | image_masked = refine_foreground(image, pred_pil) |
| | image_masked.putalpha(pred_pil) |
| | image_masked.save( |
| | image_path.replace(src_dir, dst_dir).replace(file_ext, "-subject.png") |
| | ) |
| |
|
| | |
| | file_ext = os.path.splitext(image_path)[-1] |
| | pred_pil = transforms.ToPILImage()(pred) |
| | pred_pil = pred_pil.resize(image.size) |
| | pred_pil.save(image_path.replace(src_dir, dst_dir).replace(file_ext, "-mask.png")) |
| |
|