VISTA / preprocess /simple_extractor.py
ssoxye's picture
Clean Space repo (code only) + gradio app
689a987
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Peike Li
@Contact : peike.li@yahoo.com
@File : simple_extractor.py
@Time : 8/30/19 8:59 PM
@Desc : Simple Extractor
@License : This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import os
import torch
import argparse
import numpy as np
from PIL import Image
from tqdm import tqdm
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import os
import sys
_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) # .../DEMO/preprocess
if _THIS_DIR not in sys.path:
sys.path.insert(0, _THIS_DIR)
import networks
from utils.transforms import transform_logits
from datasets.simple_extractor_dataset import SimpleFolderDataset
dataset_settings = {
'lip': {
'input_size': [473, 473],
'num_classes': 20,
'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
},
'atr': {
'input_size': [512, 512],
'num_classes': 18,
'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf']
},
'pascal': {
'input_size': [512, 512],
'num_classes': 7,
'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'],
}
}
def get_arguments():
"""Parse all the arguments provided from the CLI.
Returns:
A list of parsed arguments.
"""
parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
parser.add_argument("--dataset", type=str, default='atr', choices=['lip', 'atr', 'pascal'])
parser.add_argument("--model-restore", type=str, default='', help="restore pretrained model parameters.")
parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
parser.add_argument("--category", type=str, default='Upper-clothes', help="category name (optional).")
parser.add_argument("--input-dir", type=str, default='', help="path of input image folder.")
parser.add_argument("--output-dir", type=str, default='', help="path of output image folder.")
parser.add_argument("--logits", action='store_true', default=False, help="whether to save the logits.")
return parser.parse_args()
def get_palette(num_cls):
n = 18
palette = [0] * (n * 3)
j = num_cls
lab = num_cls
palette[j * 3 + 0] = 0
palette[j * 3 + 1] = 0
palette[j * 3 + 2] = 0
i = 0
while lab:
palette[j * 3 + 0] = 255
palette[j * 3 + 1] = 255
palette[j * 3 + 2] = 255
i += 1
lab >>= 3
return palette
# def run(
# *,
# category: str,
# input_dir: str,
# output_dir: str,
# dataset: str = "atr",
# model_restore: str = "",
# gpu: str = "0",
# logits: bool = False,
# ):
# """
# ✅ 외부(다른 파이썬 코드)에서 import 해서 호출하기 위한 엔트리 함수.
# - 기존 main()의 내용을 거의 그대로 옮김
# - CLI 인자 대신 파라미터로 받음
# """
# # (원 코드 유지) single GPU만 허용
# gpus = [int(i) for i in gpu.split(',')]
# assert len(gpus) == 1
# if gpu != 'None':
# os.environ["CUDA_VISIBLE_DEVICES"] = gpu
# num_classes = dataset_settings[dataset]['num_classes']
# input_size = dataset_settings[dataset]['input_size']
# label = dataset_settings[dataset]['label']
# print("Evaluating total class number {} with {}".format(num_classes, label))
# model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)
# if not model_restore:
# print("[simple_extractor] model_restore not provided → skip extractor.")
# return False
# state_dict = torch.load(model_restore)['state_dict']
# # print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ args.model_restore: ", state_dict)
# from collections import OrderedDict
# new_state_dict = OrderedDict()
# for k, v in state_dict.items():
# name = k[7:] # remove `module.`
# new_state_dict[name] = v
# model.load_state_dict(new_state_dict)
# model.cuda()
# model.eval()
# transform = transforms.Compose([
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
# ])
# # -----------------------------
# # 입력 폴더 이미지 로드
# # -----------------------------
# if not input_dir:
# raise ValueError("--input-dir (input_dir) is required.")
# if not output_dir:
# raise ValueError("--output-dir (output_dir) is required.")
# all_files = sorted([f for f in os.listdir(input_dir)
# if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
# selected_files = all_files[:]
# print(f"Total images found: {len(all_files)} → Using first {len(selected_files)} images")
# dataset_obj = SimpleFolderDataset(
# root=input_dir,
# input_size=input_size,
# transform=transform,
# file_list=selected_files
# )
# dataloader = DataLoader(dataset_obj)
# os.makedirs(output_dir, exist_ok=True)
# # NOTE: 기존 코드가 palette = get_palette(4)로 고정인데,
# # 지금도 그대로 유지 (필요하면 category 기반으로 바꾸는 것도 가능)
# palette = get_palette(4)
# with torch.no_grad():
# for idx, batch in enumerate(tqdm(dataloader)):
# print("--: ", idx)
# image, meta = batch
# img_name = meta['name'][0]
# c = meta['center'].numpy()[0]
# s = meta['scale'].numpy()[0]
# w = meta['width'].numpy()[0]
# h = meta['height'].numpy()[0]
# output = model(image.cuda())
# upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
# upsample_output = upsample(output[0][-1][0].unsqueeze(0))
# upsample_output = upsample_output.squeeze()
# upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
# logits_result = transform_logits(
# upsample_output.data.cpu().numpy(),
# c, s, w, h,
# input_size=input_size
# )
# parsing_result = np.argmax(logits_result, axis=2)
# parsing_result_path = os.path.join(output_dir, img_name[:-4] + '.png')
# output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
# output_img.putpalette(palette)
# output_img.save(parsing_result_path)
# if logits:
# logits_result_path = os.path.join(output_dir, img_name[:-4] + '.npy')
# np.save(logits_result_path, logits_result)
# return
def run(
*,
category: str,
input_path: str = "",
input_dir: str = "",
dataset: str = "atr",
model_restore: str = "",
gpu: str = "0",
logits: bool = False,
):
"""
- input_path (단일 파일) 또는 input_dir(폴더) 중 하나를 받아 parsing 결과를 메모리로 반환.
- 파일 저장 없음.
Returns:
{
"images": List[PIL.Image], # parsing mask (palette 적용됨)
"logits": Optional[List[np.ndarray]],
"names": List[str], # 파일명들
}
"""
# single GPU만 허용
gpus = [int(i) for i in gpu.split(',')]
assert len(gpus) == 1
if gpu != 'None':
os.environ["CUDA_VISIBLE_DEVICES"] = gpu
if not model_restore:
print("[simple_extractor] model_restore not provided → skip extractor.")
return {"images": [], "logits": [] if logits else None, "names": []}
# 입력 검증: 둘 중 하나는 있어야 함
if bool(input_path) == bool(input_dir):
raise ValueError("Provide exactly one of input_path or input_dir.")
# 파일이면 존재 확인
if input_path:
if not os.path.isfile(input_path):
raise FileNotFoundError(f"input_path not found or not a file: {input_path}")
# 폴더면 존재 확인
if input_dir:
if not os.path.isdir(input_dir):
raise NotADirectoryError(f"input_dir not found or not a directory: {input_dir}")
num_classes = dataset_settings[dataset]['num_classes']
input_size = dataset_settings[dataset]['input_size']
label = dataset_settings[dataset]['label']
print(f"Evaluating total class number {num_classes} with {label}")
model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)
state_dict = torch.load(model_restore)['state_dict']
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.cuda()
model.eval()
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
])
# ---- 파일 리스트 만들기 (단일 파일/폴더 모두 대응) ----
if input_path:
# root는 파일의 부모 디렉터리, file_list는 파일명 1개
root = os.path.dirname(input_path)
file_list = [os.path.basename(input_path)]
else:
root = input_dir
file_list = sorted([
f for f in os.listdir(root)
if f.lower().endswith(('.png', '.jpg', '.jpeg'))
])
dataset_obj = SimpleFolderDataset(
root=root,
input_size=input_size,
transform=transform,
file_list=file_list
)
dataloader = DataLoader(dataset_obj)
palette = get_palette(4)
results_img = []
results_logits = [] if logits else None
names = []
with torch.no_grad():
for batch in tqdm(dataloader):
image, meta = batch
img_name = meta['name'][0]
names.append(img_name)
c = meta['center'].numpy()[0]
s = meta['scale'].numpy()[0]
w = meta['width'].numpy()[0]
h = meta['height'].numpy()[0]
output = model(image.cuda())
upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
upsample_output = upsample(output[0][-1][0].unsqueeze(0))
upsample_output = upsample_output.squeeze()
upsample_output = upsample_output.permute(1, 2, 0)
logits_result = transform_logits(
upsample_output.data.cpu().numpy(),
c, s, w, h,
input_size=input_size
)
parsing_result = np.argmax(logits_result, axis=2)
out_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
out_img.putpalette(palette)
results_img.append(out_img)
if logits:
results_logits.append(logits_result)
return {"images": results_img, "logits": results_logits, "names": names}
def main():
# ✅ CLI 호환 유지
args = get_arguments()
run(
category=args.category,
input_dir=args.input_dir,
output_dir=args.output_dir,
)
if __name__ == '__main__':
main()