|
|
import os |
|
|
from pathlib import Path |
|
|
|
|
|
import cv2 |
|
|
import torch |
|
|
from model import BiSeNet |
|
|
from PIL import Image |
|
|
from torch.utils.data import Dataset |
|
|
from torchvision import transforms |
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MaskDataset(Dataset): |
|
|
def __init__(self, img_root, mask_root): |
|
|
img_dir = Path(img_root) |
|
|
self.to_tensor_normalize = transforms.Compose( |
|
|
[ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
|
] |
|
|
) |
|
|
self.img_files = list(img_dir.glob(f"**/*.jpg")) |
|
|
self.img_files.sort() |
|
|
self.mask_files = [os.path.join(mask_root, os.path.relpath(img_path, img_root)) for img_path in self.img_files] |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.mask_files) |
|
|
|
|
|
def __getitem__(self, index): |
|
|
img = Image.open(self.img_files[index]).convert("RGB") |
|
|
return {"img": self.to_tensor_normalize(img), "mask_path": self.mask_files[index]} |
|
|
|
|
|
|
|
|
class MaskDataLoader: |
|
|
def __init__(self): |
|
|
"""Initialize this class""" |
|
|
self.dataset = MaskDataset(img_root="/data/dataset/face_1k/alignHQ", mask_root="/data/dataset/face_1k/mask") |
|
|
|
|
|
self.dataloader = torch.utils.data.DataLoader( |
|
|
self.dataset, batch_size=8, shuffle=True, num_workers=8, drop_last=False |
|
|
) |
|
|
|
|
|
def __len__(self): |
|
|
"""Return the number of data in the dataset""" |
|
|
return len(self.dataset) / 8 |
|
|
|
|
|
def __iter__(self): |
|
|
"""Return a batch of data""" |
|
|
for data in self.dataloader: |
|
|
yield data |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
dataloader = MaskDataLoader() |
|
|
bisenet_path = "/data/useful_ckpt/face_parsing/parsing_model_79999_iter.pth" |
|
|
bisenet = BiSeNet(n_classes=19) |
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
bisenet.to(device) |
|
|
state_dict = torch.load(bisenet_path, map_location=device) |
|
|
bisenet.load_state_dict(state_dict) |
|
|
bisenet.eval() |
|
|
|
|
|
for data in tqdm(dataloader): |
|
|
mask, ignore_ids = bisenet.get_mask(data["img"].to(device), 256) |
|
|
mask = (mask * 255).to(torch.uint8).cpu().numpy().transpose(0, 2, 3, 1).repeat(3, 3) |
|
|
|
|
|
for i in range(mask.shape[0]): |
|
|
if ignore_ids[i]: |
|
|
continue |
|
|
path = data["mask_path"][i] |
|
|
dirname = os.path.dirname(path) |
|
|
os.makedirs(dirname, exist_ok=True) |
|
|
cv2.imwrite(path, mask[i]) |
|
|
|