#!/usr/bin/env python # -*- encoding: utf-8 -*- """ @Author : Peike Li @Contact : peike.li@yahoo.com @File : simple_extractor.py @Time : 8/30/19 8:59 PM @Desc : Simple Extractor @License : This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. """ import os import torch import argparse import numpy as np from PIL import Image from tqdm import tqdm from torch.utils.data import DataLoader import torchvision.transforms as transforms import os import sys _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) # .../DEMO/preprocess if _THIS_DIR not in sys.path: sys.path.insert(0, _THIS_DIR) import networks from utils.transforms import transform_logits from datasets.simple_extractor_dataset import SimpleFolderDataset dataset_settings = { 'lip': { 'input_size': [473, 473], 'num_classes': 20, 'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat', 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe'] }, 'atr': { 'input_size': [512, 512], 'num_classes': 18, 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt', 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf'] }, 'pascal': { 'input_size': [512, 512], 'num_classes': 7, 'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'], } } def get_arguments(): """Parse all the arguments provided from the CLI. Returns: A list of parsed arguments. """ parser = argparse.ArgumentParser(description="Self Correction for Human Parsing") parser.add_argument("--dataset", type=str, default='atr', choices=['lip', 'atr', 'pascal']) parser.add_argument("--model-restore", type=str, default='', help="restore pretrained model parameters.") parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.") parser.add_argument("--category", type=str, default='Upper-clothes', help="category name (optional).") parser.add_argument("--input-dir", type=str, default='', help="path of input image folder.") parser.add_argument("--output-dir", type=str, default='', help="path of output image folder.") parser.add_argument("--logits", action='store_true', default=False, help="whether to save the logits.") return parser.parse_args() def get_palette(num_cls): n = 18 palette = [0] * (n * 3) j = num_cls lab = num_cls palette[j * 3 + 0] = 0 palette[j * 3 + 1] = 0 palette[j * 3 + 2] = 0 i = 0 while lab: palette[j * 3 + 0] = 255 palette[j * 3 + 1] = 255 palette[j * 3 + 2] = 255 i += 1 lab >>= 3 return palette # def run( # *, # category: str, # input_dir: str, # output_dir: str, # dataset: str = "atr", # model_restore: str = "", # gpu: str = "0", # logits: bool = False, # ): # """ # ✅ 외부(다른 파이썬 코드)에서 import 해서 호출하기 위한 엔트리 함수. # - 기존 main()의 내용을 거의 그대로 옮김 # - CLI 인자 대신 파라미터로 받음 # """ # # (원 코드 유지) single GPU만 허용 # gpus = [int(i) for i in gpu.split(',')] # assert len(gpus) == 1 # if gpu != 'None': # os.environ["CUDA_VISIBLE_DEVICES"] = gpu # num_classes = dataset_settings[dataset]['num_classes'] # input_size = dataset_settings[dataset]['input_size'] # label = dataset_settings[dataset]['label'] # print("Evaluating total class number {} with {}".format(num_classes, label)) # model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None) # if not model_restore: # print("[simple_extractor] model_restore not provided → skip extractor.") # return False # state_dict = torch.load(model_restore)['state_dict'] # # print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ args.model_restore: ", state_dict) # from collections import OrderedDict # new_state_dict = OrderedDict() # for k, v in state_dict.items(): # name = k[7:] # remove `module.` # new_state_dict[name] = v # model.load_state_dict(new_state_dict) # model.cuda() # model.eval() # transform = transforms.Compose([ # transforms.ToTensor(), # transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) # ]) # # ----------------------------- # # 입력 폴더 이미지 로드 # # ----------------------------- # if not input_dir: # raise ValueError("--input-dir (input_dir) is required.") # if not output_dir: # raise ValueError("--output-dir (output_dir) is required.") # all_files = sorted([f for f in os.listdir(input_dir) # if f.lower().endswith(('.png', '.jpg', '.jpeg'))]) # selected_files = all_files[:] # print(f"Total images found: {len(all_files)} → Using first {len(selected_files)} images") # dataset_obj = SimpleFolderDataset( # root=input_dir, # input_size=input_size, # transform=transform, # file_list=selected_files # ) # dataloader = DataLoader(dataset_obj) # os.makedirs(output_dir, exist_ok=True) # # NOTE: 기존 코드가 palette = get_palette(4)로 고정인데, # # 지금도 그대로 유지 (필요하면 category 기반으로 바꾸는 것도 가능) # palette = get_palette(4) # with torch.no_grad(): # for idx, batch in enumerate(tqdm(dataloader)): # print("--: ", idx) # image, meta = batch # img_name = meta['name'][0] # c = meta['center'].numpy()[0] # s = meta['scale'].numpy()[0] # w = meta['width'].numpy()[0] # h = meta['height'].numpy()[0] # output = model(image.cuda()) # upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) # upsample_output = upsample(output[0][-1][0].unsqueeze(0)) # upsample_output = upsample_output.squeeze() # upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC # logits_result = transform_logits( # upsample_output.data.cpu().numpy(), # c, s, w, h, # input_size=input_size # ) # parsing_result = np.argmax(logits_result, axis=2) # parsing_result_path = os.path.join(output_dir, img_name[:-4] + '.png') # output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8)) # output_img.putpalette(palette) # output_img.save(parsing_result_path) # if logits: # logits_result_path = os.path.join(output_dir, img_name[:-4] + '.npy') # np.save(logits_result_path, logits_result) # return def run( *, category: str, input_path: str = "", input_dir: str = "", dataset: str = "atr", model_restore: str = "", gpu: str = "0", logits: bool = False, ): """ - input_path (단일 파일) 또는 input_dir(폴더) 중 하나를 받아 parsing 결과를 메모리로 반환. - 파일 저장 없음. Returns: { "images": List[PIL.Image], # parsing mask (palette 적용됨) "logits": Optional[List[np.ndarray]], "names": List[str], # 파일명들 } """ # single GPU만 허용 gpus = [int(i) for i in gpu.split(',')] assert len(gpus) == 1 if gpu != 'None': os.environ["CUDA_VISIBLE_DEVICES"] = gpu if not model_restore: print("[simple_extractor] model_restore not provided → skip extractor.") return {"images": [], "logits": [] if logits else None, "names": []} # 입력 검증: 둘 중 하나는 있어야 함 if bool(input_path) == bool(input_dir): raise ValueError("Provide exactly one of input_path or input_dir.") # 파일이면 존재 확인 if input_path: if not os.path.isfile(input_path): raise FileNotFoundError(f"input_path not found or not a file: {input_path}") # 폴더면 존재 확인 if input_dir: if not os.path.isdir(input_dir): raise NotADirectoryError(f"input_dir not found or not a directory: {input_dir}") num_classes = dataset_settings[dataset]['num_classes'] input_size = dataset_settings[dataset]['input_size'] label = dataset_settings[dataset]['label'] print(f"Evaluating total class number {num_classes} with {label}") model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None) state_dict = torch.load(model_restore)['state_dict'] from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model.load_state_dict(new_state_dict) model.cuda() model.eval() transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229]) ]) # ---- 파일 리스트 만들기 (단일 파일/폴더 모두 대응) ---- if input_path: # root는 파일의 부모 디렉터리, file_list는 파일명 1개 root = os.path.dirname(input_path) file_list = [os.path.basename(input_path)] else: root = input_dir file_list = sorted([ f for f in os.listdir(root) if f.lower().endswith(('.png', '.jpg', '.jpeg')) ]) dataset_obj = SimpleFolderDataset( root=root, input_size=input_size, transform=transform, file_list=file_list ) dataloader = DataLoader(dataset_obj) palette = get_palette(4) results_img = [] results_logits = [] if logits else None names = [] with torch.no_grad(): for batch in tqdm(dataloader): image, meta = batch img_name = meta['name'][0] names.append(img_name) c = meta['center'].numpy()[0] s = meta['scale'].numpy()[0] w = meta['width'].numpy()[0] h = meta['height'].numpy()[0] output = model(image.cuda()) upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True) upsample_output = upsample(output[0][-1][0].unsqueeze(0)) upsample_output = upsample_output.squeeze() upsample_output = upsample_output.permute(1, 2, 0) logits_result = transform_logits( upsample_output.data.cpu().numpy(), c, s, w, h, input_size=input_size ) parsing_result = np.argmax(logits_result, axis=2) out_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8)) out_img.putpalette(palette) results_img.append(out_img) if logits: results_logits.append(logits_result) return {"images": results_img, "logits": results_logits, "names": names} def main(): # ✅ CLI 호환 유지 args = get_arguments() run( category=args.category, input_dir=args.input_dir, output_dir=args.output_dir, ) if __name__ == '__main__': main()