# #!/usr/bin/env python # # -*- encoding: utf-8 -*- # """ # @Author : Peike Li # @Contact : peike.li@yahoo.com # @File : simple_extractor.py # @Time : 8/30/19 8:59 PM # @Desc : Simple Extractor # @License : This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # """ # import os # import torch # import argparse # import numpy as np # from PIL import Image # from tqdm import tqdm # from torch.utils.data import DataLoader # import torchvision.transforms as transforms # import os # import sys # _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) # .../DEMO/preprocess # if _THIS_DIR not in sys.path: # sys.path.insert(0, _THIS_DIR) # import networks # from utils.transforms import transform_logits # from datasets.simple_extractor_dataset import SimpleFolderDataset # dataset_settings = { # 'lip': { # 'input_size': [473, 473], # 'num_classes': 20, # 'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat', # 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm', # 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe'] # }, # 'atr': { # 'input_size': [512, 512], # 'num_classes': 18, # 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', # 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'] # }, # 'pascal': { # 'input_size': [512, 512], # 'num_classes': 7, # 'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'], # } # } # def get_arguments(): # """Parse all the arguments provided from the CLI. # Returns: # A list of parsed arguments. # """ # parser = argparse.ArgumentParser(description="Self Correction for Human Parsing") # parser.add_argument("--dataset", type=str, default='atr', choices=['lip', 'atr', 'pascal']) # parser.add_argument("--model-restore", type=str, default='', help="restore pretrained model parameters.") # parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.") # parser.add_argument("--category", type=str, default='Upper-clothes', help="category name (optional).") # parser.add_argument("--input-dir", type=str, default='', help="path of input image folder.") # parser.add_argument("--output-dir", type=str, default='', help="path of output image folder.") # parser.add_argument("--logits", action='store_true', default=False, help="whether to save the logits.") # return parser.parse_args() # def get_palette(num_cls): # n = 18 # palette = [0] * (n * 3) # j = num_cls # lab = num_cls # palette[j * 3 + 0] = 0 # palette[j * 3 + 1] = 0 # palette[j * 3 + 2] = 0 # i = 0 # while lab: # palette[j * 3 + 0] = 255 # palette[j * 3 + 1] = 255 # palette[j * 3 + 2] = 255 # i += 1 # lab >>= 3 # return palette # def get_palette2(num_cls): # """ Returns the color map for visualizing the segmentation mask. # Args: # num_cls: Number of classes # Returns: # The color map # """ # n = 18 # palette = [0] * (n * 3) # for j in range(5, 7): # lab = j # palette[j * 3 + 0] = 0 # palette[j * 3 + 1] = 0 # palette[j * 3 + 2] = 0 # i = 0 # while lab: # palette[j * 3 + 0] = 255 # palette[j * 3 + 1] = 255 # palette[j * 3 + 2] = 255 # i += 1 # lab >>= 3 # return palette # def run( # *, # category: str, # input_path: str = "", # input_dir: str = "", # dataset: str = "atr", # model_restore: str = "", # gpu: str = "0", # logits: bool = False, # ): # """ # - input_path (단일 파일) 또는 input_dir(폴더) 중 하나를 받아 parsing 결과를 메모리로 반환. # - 파일 저장 없음. # Returns: # { # "images": List[PIL.Image], # parsing mask (palette 적용됨) # "logits": Optional[List[np.ndarray]], # "names": List[str], # 파일명들 # } # """ # # single GPU만 허용 # gpus = [int(i) for i in gpu.split(',')] # assert len(gpus) == 1 # if gpu != 'None': # os.environ["CUDA_VISIBLE_DEVICES"] = gpu # if not model_restore: # print("[simple_extractor] model_restore not provided → skip extractor.") # return {"images": [], "logits": [] if logits else None, "names": []} # # 입력 검증: 둘 중 하나는 있어야 함 # if bool(input_path) == bool(input_dir): # raise ValueError("Provide exactly one of input_path or input_dir.") # # 파일이면 존재 확인 # if input_path: # if not os.path.isfile(input_path): # raise FileNotFoundError(f"input_path not found or not a file: {input_path}") # # 폴더면 존재 확인 # if input_dir: # if not os.path.isdir(input_dir): # raise NotADirectoryError(f"input_dir not found or not a directory: {input_dir}") # num_classes = dataset_settings[dataset]['num_classes'] # input_size = dataset_settings[dataset]['input_size'] # label = dataset_settings[dataset]['label'] # print(f"Evaluating total class number {num_classes} with {label}") # model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None) # state_dict = torch.load(model_restore)['state_dict'] # from collections import OrderedDict # new_state_dict = OrderedDict() # for k, v in state_dict.items(): # name = k[7:] # remove `module.` # new_state_dict[name] = v # model.load_state_dict(new_state_dict) # model.cuda() # model.eval() # transform = transforms.Compose([ # transforms.ToTensor(), # transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) # ]) # # ---- 파일 리스트 만들기 (단일 파일/폴더 모두 대응) ---- # if input_path: # # root는 파일의 부모 디렉터리, file_list는 파일명 1개 # root = os.path.dirname(input_path) # file_list = [os.path.basename(input_path)] # else: # root = input_dir # file_list = sorted([ # f for f in os.listdir(root) # if f.lower().endswith(('.png', '.jpg', '.jpeg')) # ]) # dataset_obj = SimpleFolderDataset( # root=root, # input_size=input_size, # transform=transform, # file_list=file_list # ) # dataloader = DataLoader(dataset_obj) # palette = get_palette(4) # results_img = [] # results_logits = [] if logits else None # names = [] # with torch.no_grad(): # for batch in tqdm(dataloader): # image, meta = batch # img_name = meta['name'][0] # names.append(img_name) # c = meta['center'].numpy()[0] # s = meta['scale'].numpy()[0] # w = meta['width'].numpy()[0] # h = meta['height'].numpy()[0] # output = model(image.cuda()) # upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) # upsample_output = upsample(output[0][-1][0].unsqueeze(0)) # upsample_output = upsample_output.squeeze() # upsample_output = upsample_output.permute(1, 2, 0) # logits_result = transform_logits( # upsample_output.data.cpu().numpy(), # c, s, w, h, # input_size=input_size # ) # parsing_result = np.argmax(logits_result, axis=2) # out_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8)) # out_img.putpalette(palette) # results_img.append(out_img) # if logits: # results_logits.append(logits_result) # return {"images": results_img, "logits": results_logits, "names": names} # def main(): # # ✅ CLI 호환 유지 # args = get_arguments() # run( # category=args.category, # input_dir=args.input_dir, # output_dir=args.output_dir, # ) # if __name__ == '__main__': # main() #!/usr/bin/env python # -*- encoding: utf-8 -*- """ @Author : Peike Li @Contact : peike.li@yahoo.com @File : simple_extractor.py @Desc : Simple Extractor (category-aware palette selection) """ import os import sys import torch import argparse import numpy as np from PIL import Image from tqdm import tqdm from torch.utils.data import DataLoader import torchvision.transforms as transforms _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) if _THIS_DIR not in sys.path: sys.path.insert(0, _THIS_DIR) import networks from utils.transforms import transform_logits from datasets.simple_extractor_dataset import SimpleFolderDataset dataset_settings = { 'lip': { 'input_size': [473, 473], 'num_classes': 20, 'label': [ 'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe' ] }, 'atr': { 'input_size': [512, 512], 'num_classes': 18, 'label': [ 'Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf' ] }, 'pascal': { 'input_size': [512, 512], 'num_classes': 7, 'label': [ 'Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs' ], } } def get_arguments(): parser = argparse.ArgumentParser(description="Self Correction for Human Parsing") parser.add_argument("--dataset", type=str, default='atr', choices=['lip', 'atr', 'pascal']) parser.add_argument("--model-restore", type=str, default='', help="restore pretrained model parameters.") parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.") parser.add_argument("--category", type=str, default='Upper-cloth', help="category name.") parser.add_argument("--input-dir", type=str, default='', help="path of input image folder.") parser.add_argument("--output-dir", type=str, default='', help="(unused, kept for CLI compatibility)") parser.add_argument("--logits", action='store_true', default=False) return parser.parse_args() def get_palette(num_cls): n = 18 palette = [0] * (n * 3) j = num_cls lab = num_cls palette[j * 3 + 0] = 0 palette[j * 3 + 1] = 0 palette[j * 3 + 2] = 0 while lab: palette[j * 3 + 0] = 255 palette[j * 3 + 1] = 255 palette[j * 3 + 2] = 255 lab >>= 3 return palette def get_palette2(num_cls): n = 18 palette = [0] * (n * 3) for j in range(5, 7): lab = j palette[j * 3 + 0] = 0 palette[j * 3 + 1] = 0 palette[j * 3 + 2] = 0 while lab: palette[j * 3 + 0] = 255 palette[j * 3 + 1] = 255 palette[j * 3 + 2] = 255 lab >>= 3 return palette def _select_palette_by_category(category: str): """ category별 palette 선택 로직 (명시적 규칙) """ if category == "Upper-cloth": return get_palette(4) elif category == "Bottom": return get_palette2(4) elif category == "Dress": return get_palette(7) else: # fallback (명시 안 된 카테고리) return get_palette(7) def run( *, category: str, input_path: str = "", input_dir: str = "", dataset: str = "atr", model_restore: str = "", gpu: str = "0", logits: bool = False, ): """ Returns: { "images": List[PIL.Image], "logits": Optional[List[np.ndarray]], "names": List[str], } """ gpus = [int(i) for i in gpu.split(',')] assert len(gpus) == 1 if gpu != 'None': os.environ["CUDA_VISIBLE_DEVICES"] = gpu if not model_restore: print("[simple_extractor] model_restore not provided → skip extractor.") return {"images": [], "logits": [] if logits else None, "names": []} if bool(input_path) == bool(input_dir): raise ValueError("Provide exactly one of input_path or input_dir.") if input_path and not os.path.isfile(input_path): raise FileNotFoundError(input_path) if input_dir and not os.path.isdir(input_dir): raise NotADirectoryError(input_dir) num_classes = dataset_settings[dataset]['num_classes'] input_size = dataset_settings[dataset]['input_size'] model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None) state_dict = torch.load(model_restore)['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): new_state_dict[k[7:]] = v model.load_state_dict(new_state_dict) model.cuda() model.eval() transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) ]) if input_path: root = os.path.dirname(input_path) file_list = [os.path.basename(input_path)] else: root = input_dir file_list = sorted([ f for f in os.listdir(root) if f.lower().endswith(('.png', '.jpg', '.jpeg')) ]) dataset_obj = SimpleFolderDataset( root=root, input_size=input_size, transform=transform, file_list=file_list ) dataloader = DataLoader(dataset_obj) # ✅ 핵심 수정: category 기반 palette 선택 palette = _select_palette_by_category(category) results_img = [] results_logits = [] if logits else None names = [] with torch.no_grad(): for batch in tqdm(dataloader): image, meta = batch img_name = meta['name'][0] names.append(img_name) c = meta['center'].numpy()[0] s = meta['scale'].numpy()[0] w = meta['width'].numpy()[0] h = meta['height'].numpy()[0] output = model(image.cuda()) upsample = torch.nn.Upsample( size=input_size, mode='bilinear', align_corners=True ) upsample_output = upsample(output[0][-1][0].unsqueeze(0)) upsample_output = upsample_output.squeeze().permute(1, 2, 0) logits_result = transform_logits( upsample_output.data.cpu().numpy(), c, s, w, h, input_size=input_size ) parsing_result = np.argmax(logits_result, axis=2) out_img = Image.fromarray(parsing_result.astype(np.uint8)) out_img.putpalette(palette) results_img.append(out_img) if logits: results_logits.append(logits_result) return { "images": results_img, "logits": results_logits, "names": names } def main(): args = get_arguments() run( category=args.category, input_dir=args.input_dir, dataset=args.dataset, model_restore=args.model_restore, gpu=args.gpu, logits=args.logits, ) if __name__ == '__main__': main()