ssoxye commited on
Commit
64933a5
·
1 Parent(s): 7a4d0c5

update preprocess

Browse files
Files changed (1) hide show
  1. preprocess/simple_extractor.py +340 -169
preprocess/simple_extractor.py CHANGED
@@ -1,3 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  #!/usr/bin/env python
2
  # -*- encoding: utf-8 -*-
3
 
@@ -5,13 +265,11 @@
5
  @Author : Peike Li
6
  @Contact : peike.li@yahoo.com
7
  @File : simple_extractor.py
8
- @Time : 8/30/19 8:59 PM
9
- @Desc : Simple Extractor
10
- @License : This source code is licensed under the license found in the
11
- LICENSE file in the root directory of this source tree.
12
  """
13
 
14
  import os
 
15
  import torch
16
  import argparse
17
  import numpy as np
@@ -21,56 +279,60 @@ from tqdm import tqdm
21
  from torch.utils.data import DataLoader
22
  import torchvision.transforms as transforms
23
 
24
- import os
25
- import sys
26
-
27
- _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) # .../DEMO/preprocess
28
  if _THIS_DIR not in sys.path:
29
  sys.path.insert(0, _THIS_DIR)
30
 
31
-
32
  import networks
33
  from utils.transforms import transform_logits
34
  from datasets.simple_extractor_dataset import SimpleFolderDataset
35
 
36
 
37
-
38
  dataset_settings = {
39
  'lip': {
40
  'input_size': [473, 473],
41
  'num_classes': 20,
42
- 'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
43
- 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
44
- 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
 
 
 
 
45
  },
46
  'atr': {
47
  'input_size': [512, 512],
48
  'num_classes': 18,
49
- 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
50
- 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf']
 
 
 
 
 
51
  },
52
  'pascal': {
53
  'input_size': [512, 512],
54
  'num_classes': 7,
55
- 'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'],
 
 
 
 
56
  }
57
  }
58
 
59
 
60
  def get_arguments():
61
- """Parse all the arguments provided from the CLI.
62
- Returns:
63
- A list of parsed arguments.
64
- """
65
  parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
66
 
67
  parser.add_argument("--dataset", type=str, default='atr', choices=['lip', 'atr', 'pascal'])
68
  parser.add_argument("--model-restore", type=str, default='', help="restore pretrained model parameters.")
69
  parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
70
- parser.add_argument("--category", type=str, default='Upper-clothes', help="category name (optional).")
71
  parser.add_argument("--input-dir", type=str, default='', help="path of input image folder.")
72
- parser.add_argument("--output-dir", type=str, default='', help="path of output image folder.")
73
- parser.add_argument("--logits", action='store_true', default=False, help="whether to save the logits.")
74
 
75
  return parser.parse_args()
76
 
@@ -83,126 +345,43 @@ def get_palette(num_cls):
83
  palette[j * 3 + 0] = 0
84
  palette[j * 3 + 1] = 0
85
  palette[j * 3 + 2] = 0
86
- i = 0
87
  while lab:
88
  palette[j * 3 + 0] = 255
89
  palette[j * 3 + 1] = 255
90
  palette[j * 3 + 2] = 255
91
- i += 1
92
  lab >>= 3
93
  return palette
94
 
95
 
96
- # def run(
97
- # *,
98
- # category: str,
99
- # input_dir: str,
100
- # output_dir: str,
101
- # dataset: str = "atr",
102
- # model_restore: str = "",
103
- # gpu: str = "0",
104
- # logits: bool = False,
105
- # ):
106
- # """
107
- # ✅ 외부(다른 파이썬 코드)에서 import 해서 호출하기 위한 엔트리 함수.
108
- # - 기존 main()의 내용을 거의 그대로 옮김
109
- # - CLI 인자 대신 파라미터로 받음
110
- # """
111
- # # (원 코드 유지) single GPU만 허용
112
- # gpus = [int(i) for i in gpu.split(',')]
113
- # assert len(gpus) == 1
114
- # if gpu != 'None':
115
- # os.environ["CUDA_VISIBLE_DEVICES"] = gpu
116
-
117
- # num_classes = dataset_settings[dataset]['num_classes']
118
- # input_size = dataset_settings[dataset]['input_size']
119
- # label = dataset_settings[dataset]['label']
120
- # print("Evaluating total class number {} with {}".format(num_classes, label))
121
-
122
- # model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)
123
-
124
- # if not model_restore:
125
- # print("[simple_extractor] model_restore not provided → skip extractor.")
126
- # return False
127
-
128
-
129
- # state_dict = torch.load(model_restore)['state_dict']
130
-
131
- # # print("@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@ args.model_restore: ", state_dict)
132
- # from collections import OrderedDict
133
- # new_state_dict = OrderedDict()
134
- # for k, v in state_dict.items():
135
- # name = k[7:] # remove `module.`
136
- # new_state_dict[name] = v
137
- # model.load_state_dict(new_state_dict)
138
- # model.cuda()
139
- # model.eval()
140
-
141
- # transform = transforms.Compose([
142
- # transforms.ToTensor(),
143
- # transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
144
- # ])
145
-
146
- # # -----------------------------
147
- # # 입력 폴더 이미지 로드
148
- # # -----------------------------
149
- # if not input_dir:
150
- # raise ValueError("--input-dir (input_dir) is required.")
151
- # if not output_dir:
152
- # raise ValueError("--output-dir (output_dir) is required.")
153
-
154
- # all_files = sorted([f for f in os.listdir(input_dir)
155
- # if f.lower().endswith(('.png', '.jpg', '.jpeg'))])
156
- # selected_files = all_files[:]
157
- # print(f"Total images found: {len(all_files)} → Using first {len(selected_files)} images")
158
-
159
- # dataset_obj = SimpleFolderDataset(
160
- # root=input_dir,
161
- # input_size=input_size,
162
- # transform=transform,
163
- # file_list=selected_files
164
- # )
165
- # dataloader = DataLoader(dataset_obj)
166
-
167
- # os.makedirs(output_dir, exist_ok=True)
168
-
169
- # # NOTE: 기존 코드가 palette = get_palette(4)로 고정인데,
170
- # # 지금도 그대로 유지 (필요하면 category 기반으로 바꾸는 것도 가능)
171
- # palette = get_palette(4)
172
-
173
- # with torch.no_grad():
174
- # for idx, batch in enumerate(tqdm(dataloader)):
175
- # print("--: ", idx)
176
- # image, meta = batch
177
- # img_name = meta['name'][0]
178
- # c = meta['center'].numpy()[0]
179
- # s = meta['scale'].numpy()[0]
180
- # w = meta['width'].numpy()[0]
181
- # h = meta['height'].numpy()[0]
182
-
183
- # output = model(image.cuda())
184
- # upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
185
- # upsample_output = upsample(output[0][-1][0].unsqueeze(0))
186
- # upsample_output = upsample_output.squeeze()
187
- # upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
188
-
189
- # logits_result = transform_logits(
190
- # upsample_output.data.cpu().numpy(),
191
- # c, s, w, h,
192
- # input_size=input_size
193
- # )
194
- # parsing_result = np.argmax(logits_result, axis=2)
195
-
196
- # parsing_result_path = os.path.join(output_dir, img_name[:-4] + '.png')
197
- # output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
198
- # output_img.putpalette(palette)
199
- # output_img.save(parsing_result_path)
200
 
201
- # if logits:
202
- # logits_result_path = os.path.join(output_dir, img_name[:-4] + '.npy')
203
- # np.save(logits_result_path, logits_result)
204
 
205
- # return
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
 
208
  def run(
@@ -216,17 +395,14 @@ def run(
216
  logits: bool = False,
217
  ):
218
  """
219
- - input_path (단일 파일) 또는 input_dir(폴더) 중 하나를 받아 parsing 결과를 메모리로 반환.
220
- - 파일 저장 없음.
221
-
222
  Returns:
223
  {
224
- "images": List[PIL.Image], # parsing mask (palette 적용됨)
225
  "logits": Optional[List[np.ndarray]],
226
- "names": List[str], # 파일명들
227
  }
228
  """
229
- # single GPU만 허용
230
  gpus = [int(i) for i in gpu.split(',')]
231
  assert len(gpus) == 1
232
  if gpu != 'None':
@@ -236,46 +412,36 @@ def run(
236
  print("[simple_extractor] model_restore not provided → skip extractor.")
237
  return {"images": [], "logits": [] if logits else None, "names": []}
238
 
239
- # 입력 검증: 둘 중 하나는 있어야 함
240
  if bool(input_path) == bool(input_dir):
241
  raise ValueError("Provide exactly one of input_path or input_dir.")
242
 
243
- # 파일이면 존재 확인
244
- if input_path:
245
- if not os.path.isfile(input_path):
246
- raise FileNotFoundError(f"input_path not found or not a file: {input_path}")
247
-
248
- # 폴더면 존재 확인
249
- if input_dir:
250
- if not os.path.isdir(input_dir):
251
- raise NotADirectoryError(f"input_dir not found or not a directory: {input_dir}")
252
 
253
  num_classes = dataset_settings[dataset]['num_classes']
254
  input_size = dataset_settings[dataset]['input_size']
255
- label = dataset_settings[dataset]['label']
256
- print(f"Evaluating total class number {num_classes} with {label}")
257
 
258
  model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)
259
-
260
  state_dict = torch.load(model_restore)['state_dict']
 
261
  from collections import OrderedDict
262
  new_state_dict = OrderedDict()
263
  for k, v in state_dict.items():
264
- name = k[7:] # remove `module.`
265
- new_state_dict[name] = v
266
-
267
  model.load_state_dict(new_state_dict)
 
268
  model.cuda()
269
  model.eval()
270
 
271
  transform = transforms.Compose([
272
  transforms.ToTensor(),
273
- transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
 
274
  ])
275
 
276
- # ---- 파일 리스트 만들기 (단일 파일/폴더 모두 대응) ----
277
  if input_path:
278
- # root는 파일의 부모 디렉터리, file_list는 파일명 1개
279
  root = os.path.dirname(input_path)
280
  file_list = [os.path.basename(input_path)]
281
  else:
@@ -293,7 +459,8 @@ def run(
293
  )
294
  dataloader = DataLoader(dataset_obj)
295
 
296
- palette = get_palette(4)
 
297
 
298
  results_img = []
299
  results_logits = [] if logits else None
@@ -311,10 +478,11 @@ def run(
311
  h = meta['height'].numpy()[0]
312
 
313
  output = model(image.cuda())
314
- upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
 
 
315
  upsample_output = upsample(output[0][-1][0].unsqueeze(0))
316
- upsample_output = upsample_output.squeeze()
317
- upsample_output = upsample_output.permute(1, 2, 0)
318
 
319
  logits_result = transform_logits(
320
  upsample_output.data.cpu().numpy(),
@@ -323,28 +491,31 @@ def run(
323
  )
324
  parsing_result = np.argmax(logits_result, axis=2)
325
 
326
- out_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
327
  out_img.putpalette(palette)
328
  results_img.append(out_img)
329
 
330
  if logits:
331
  results_logits.append(logits_result)
332
 
333
- return {"images": results_img, "logits": results_logits, "names": names}
334
-
335
-
 
 
336
 
337
 
338
  def main():
339
- # ✅ CLI 호환 유지
340
  args = get_arguments()
341
  run(
342
  category=args.category,
343
  input_dir=args.input_dir,
344
- output_dir=args.output_dir,
 
 
 
345
  )
346
 
347
 
348
  if __name__ == '__main__':
349
  main()
350
-
 
1
+ # #!/usr/bin/env python
2
+ # # -*- encoding: utf-8 -*-
3
+
4
+ # """
5
+ # @Author : Peike Li
6
+ # @Contact : peike.li@yahoo.com
7
+ # @File : simple_extractor.py
8
+ # @Time : 8/30/19 8:59 PM
9
+ # @Desc : Simple Extractor
10
+ # @License : This source code is licensed under the license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+ # """
13
+
14
+ # import os
15
+ # import torch
16
+ # import argparse
17
+ # import numpy as np
18
+ # from PIL import Image
19
+ # from tqdm import tqdm
20
+
21
+ # from torch.utils.data import DataLoader
22
+ # import torchvision.transforms as transforms
23
+
24
+ # import os
25
+ # import sys
26
+
27
+ # _THIS_DIR = os.path.dirname(os.path.abspath(__file__)) # .../DEMO/preprocess
28
+ # if _THIS_DIR not in sys.path:
29
+ # sys.path.insert(0, _THIS_DIR)
30
+
31
+
32
+ # import networks
33
+ # from utils.transforms import transform_logits
34
+ # from datasets.simple_extractor_dataset import SimpleFolderDataset
35
+
36
+
37
+
38
+ # dataset_settings = {
39
+ # 'lip': {
40
+ # 'input_size': [473, 473],
41
+ # 'num_classes': 20,
42
+ # 'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
43
+ # 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
44
+ # 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
45
+ # },
46
+ # 'atr': {
47
+ # 'input_size': [512, 512],
48
+ # 'num_classes': 18,
49
+ # 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
50
+ # 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf']
51
+ # },
52
+ # 'pascal': {
53
+ # 'input_size': [512, 512],
54
+ # 'num_classes': 7,
55
+ # 'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'],
56
+ # }
57
+ # }
58
+
59
+
60
+ # def get_arguments():
61
+ # """Parse all the arguments provided from the CLI.
62
+ # Returns:
63
+ # A list of parsed arguments.
64
+ # """
65
+ # parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
66
+
67
+ # parser.add_argument("--dataset", type=str, default='atr', choices=['lip', 'atr', 'pascal'])
68
+ # parser.add_argument("--model-restore", type=str, default='', help="restore pretrained model parameters.")
69
+ # parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
70
+ # parser.add_argument("--category", type=str, default='Upper-clothes', help="category name (optional).")
71
+ # parser.add_argument("--input-dir", type=str, default='', help="path of input image folder.")
72
+ # parser.add_argument("--output-dir", type=str, default='', help="path of output image folder.")
73
+ # parser.add_argument("--logits", action='store_true', default=False, help="whether to save the logits.")
74
+
75
+ # return parser.parse_args()
76
+
77
+
78
+ # def get_palette(num_cls):
79
+ # n = 18
80
+ # palette = [0] * (n * 3)
81
+ # j = num_cls
82
+ # lab = num_cls
83
+ # palette[j * 3 + 0] = 0
84
+ # palette[j * 3 + 1] = 0
85
+ # palette[j * 3 + 2] = 0
86
+ # i = 0
87
+ # while lab:
88
+ # palette[j * 3 + 0] = 255
89
+ # palette[j * 3 + 1] = 255
90
+ # palette[j * 3 + 2] = 255
91
+ # i += 1
92
+ # lab >>= 3
93
+ # return palette
94
+
95
+ # def get_palette2(num_cls):
96
+ # """ Returns the color map for visualizing the segmentation mask.
97
+ # Args:
98
+ # num_cls: Number of classes
99
+ # Returns:
100
+ # The color map
101
+ # """
102
+ # n = 18
103
+ # palette = [0] * (n * 3)
104
+ # for j in range(5, 7):
105
+ # lab = j
106
+ # palette[j * 3 + 0] = 0
107
+ # palette[j * 3 + 1] = 0
108
+ # palette[j * 3 + 2] = 0
109
+ # i = 0
110
+ # while lab:
111
+ # palette[j * 3 + 0] = 255
112
+ # palette[j * 3 + 1] = 255
113
+ # palette[j * 3 + 2] = 255
114
+ # i += 1
115
+ # lab >>= 3
116
+ # return palette
117
+
118
+ # def run(
119
+ # *,
120
+ # category: str,
121
+ # input_path: str = "",
122
+ # input_dir: str = "",
123
+ # dataset: str = "atr",
124
+ # model_restore: str = "",
125
+ # gpu: str = "0",
126
+ # logits: bool = False,
127
+ # ):
128
+ # """
129
+ # - input_path (단일 파일) 또는 input_dir(폴더) 중 하나를 받아 parsing 결과를 메모리로 반환.
130
+ # - 파일 저장 없음.
131
+
132
+ # Returns:
133
+ # {
134
+ # "images": List[PIL.Image], # parsing mask (palette 적용됨)
135
+ # "logits": Optional[List[np.ndarray]],
136
+ # "names": List[str], # 파일명들
137
+ # }
138
+ # """
139
+ # # single GPU만 허용
140
+ # gpus = [int(i) for i in gpu.split(',')]
141
+ # assert len(gpus) == 1
142
+ # if gpu != 'None':
143
+ # os.environ["CUDA_VISIBLE_DEVICES"] = gpu
144
+
145
+ # if not model_restore:
146
+ # print("[simple_extractor] model_restore not provided → skip extractor.")
147
+ # return {"images": [], "logits": [] if logits else None, "names": []}
148
+
149
+ # # 입력 검증: 둘 중 하나는 있어야 함
150
+ # if bool(input_path) == bool(input_dir):
151
+ # raise ValueError("Provide exactly one of input_path or input_dir.")
152
+
153
+ # # 파일이면 존재 확인
154
+ # if input_path:
155
+ # if not os.path.isfile(input_path):
156
+ # raise FileNotFoundError(f"input_path not found or not a file: {input_path}")
157
+
158
+ # # 폴더면 존재 확인
159
+ # if input_dir:
160
+ # if not os.path.isdir(input_dir):
161
+ # raise NotADirectoryError(f"input_dir not found or not a directory: {input_dir}")
162
+
163
+ # num_classes = dataset_settings[dataset]['num_classes']
164
+ # input_size = dataset_settings[dataset]['input_size']
165
+ # label = dataset_settings[dataset]['label']
166
+ # print(f"Evaluating total class number {num_classes} with {label}")
167
+
168
+ # model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)
169
+
170
+ # state_dict = torch.load(model_restore)['state_dict']
171
+ # from collections import OrderedDict
172
+ # new_state_dict = OrderedDict()
173
+ # for k, v in state_dict.items():
174
+ # name = k[7:] # remove `module.`
175
+ # new_state_dict[name] = v
176
+
177
+ # model.load_state_dict(new_state_dict)
178
+ # model.cuda()
179
+ # model.eval()
180
+
181
+ # transform = transforms.Compose([
182
+ # transforms.ToTensor(),
183
+ # transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
184
+ # ])
185
+
186
+ # # ---- 파일 리스트 만들기 (단일 파일/폴더 모두 대응) ----
187
+ # if input_path:
188
+ # # root는 파일의 부모 디렉터리, file_list는 파일명 1개
189
+ # root = os.path.dirname(input_path)
190
+ # file_list = [os.path.basename(input_path)]
191
+ # else:
192
+ # root = input_dir
193
+ # file_list = sorted([
194
+ # f for f in os.listdir(root)
195
+ # if f.lower().endswith(('.png', '.jpg', '.jpeg'))
196
+ # ])
197
+
198
+ # dataset_obj = SimpleFolderDataset(
199
+ # root=root,
200
+ # input_size=input_size,
201
+ # transform=transform,
202
+ # file_list=file_list
203
+ # )
204
+ # dataloader = DataLoader(dataset_obj)
205
+
206
+ # palette = get_palette(4)
207
+
208
+ # results_img = []
209
+ # results_logits = [] if logits else None
210
+ # names = []
211
+
212
+ # with torch.no_grad():
213
+ # for batch in tqdm(dataloader):
214
+ # image, meta = batch
215
+ # img_name = meta['name'][0]
216
+ # names.append(img_name)
217
+
218
+ # c = meta['center'].numpy()[0]
219
+ # s = meta['scale'].numpy()[0]
220
+ # w = meta['width'].numpy()[0]
221
+ # h = meta['height'].numpy()[0]
222
+
223
+ # output = model(image.cuda())
224
+ # upsample = torch.nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
225
+ # upsample_output = upsample(output[0][-1][0].unsqueeze(0))
226
+ # upsample_output = upsample_output.squeeze()
227
+ # upsample_output = upsample_output.permute(1, 2, 0)
228
+
229
+ # logits_result = transform_logits(
230
+ # upsample_output.data.cpu().numpy(),
231
+ # c, s, w, h,
232
+ # input_size=input_size
233
+ # )
234
+ # parsing_result = np.argmax(logits_result, axis=2)
235
+
236
+ # out_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
237
+ # out_img.putpalette(palette)
238
+ # results_img.append(out_img)
239
+
240
+ # if logits:
241
+ # results_logits.append(logits_result)
242
+
243
+ # return {"images": results_img, "logits": results_logits, "names": names}
244
+
245
+
246
+
247
+
248
+ # def main():
249
+ # # ✅ CLI 호환 유지
250
+ # args = get_arguments()
251
+ # run(
252
+ # category=args.category,
253
+ # input_dir=args.input_dir,
254
+ # output_dir=args.output_dir,
255
+ # )
256
+
257
+
258
+ # if __name__ == '__main__':
259
+ # main()
260
+
261
  #!/usr/bin/env python
262
  # -*- encoding: utf-8 -*-
263
 
 
265
  @Author : Peike Li
266
  @Contact : peike.li@yahoo.com
267
  @File : simple_extractor.py
268
+ @Desc : Simple Extractor (category-aware palette selection)
 
 
 
269
  """
270
 
271
  import os
272
+ import sys
273
  import torch
274
  import argparse
275
  import numpy as np
 
279
  from torch.utils.data import DataLoader
280
  import torchvision.transforms as transforms
281
 
282
+ _THIS_DIR = os.path.dirname(os.path.abspath(__file__))
 
 
 
283
  if _THIS_DIR not in sys.path:
284
  sys.path.insert(0, _THIS_DIR)
285
 
 
286
  import networks
287
  from utils.transforms import transform_logits
288
  from datasets.simple_extractor_dataset import SimpleFolderDataset
289
 
290
 
 
291
  dataset_settings = {
292
  'lip': {
293
  'input_size': [473, 473],
294
  'num_classes': 20,
295
+ 'label': [
296
+ 'Background', 'Hat', 'Hair', 'Glove', 'Sunglasses',
297
+ 'Upper-clothes', 'Dress', 'Coat', 'Socks', 'Pants',
298
+ 'Jumpsuits', 'Scarf', 'Skirt', 'Face',
299
+ 'Left-arm', 'Right-arm', 'Left-leg', 'Right-leg',
300
+ 'Left-shoe', 'Right-shoe'
301
+ ]
302
  },
303
  'atr': {
304
  'input_size': [512, 512],
305
  'num_classes': 18,
306
+ 'label': [
307
+ 'Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes',
308
+ 'Skirt', 'Pants', 'Dress', 'Belt',
309
+ 'Left-shoe', 'Right-shoe', 'Face',
310
+ 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm',
311
+ 'Bag', 'Scarf'
312
+ ]
313
  },
314
  'pascal': {
315
  'input_size': [512, 512],
316
  'num_classes': 7,
317
+ 'label': [
318
+ 'Background', 'Head', 'Torso',
319
+ 'Upper Arms', 'Lower Arms',
320
+ 'Upper Legs', 'Lower Legs'
321
+ ],
322
  }
323
  }
324
 
325
 
326
  def get_arguments():
 
 
 
 
327
  parser = argparse.ArgumentParser(description="Self Correction for Human Parsing")
328
 
329
  parser.add_argument("--dataset", type=str, default='atr', choices=['lip', 'atr', 'pascal'])
330
  parser.add_argument("--model-restore", type=str, default='', help="restore pretrained model parameters.")
331
  parser.add_argument("--gpu", type=str, default='0', help="choose gpu device.")
332
+ parser.add_argument("--category", type=str, default='Upper-cloth', help="category name.")
333
  parser.add_argument("--input-dir", type=str, default='', help="path of input image folder.")
334
+ parser.add_argument("--output-dir", type=str, default='', help="(unused, kept for CLI compatibility)")
335
+ parser.add_argument("--logits", action='store_true', default=False)
336
 
337
  return parser.parse_args()
338
 
 
345
  palette[j * 3 + 0] = 0
346
  palette[j * 3 + 1] = 0
347
  palette[j * 3 + 2] = 0
 
348
  while lab:
349
  palette[j * 3 + 0] = 255
350
  palette[j * 3 + 1] = 255
351
  palette[j * 3 + 2] = 255
 
352
  lab >>= 3
353
  return palette
354
 
355
 
356
+ def get_palette2(num_cls):
357
+ n = 18
358
+ palette = [0] * (n * 3)
359
+ for j in range(5, 7):
360
+ lab = j
361
+ palette[j * 3 + 0] = 0
362
+ palette[j * 3 + 1] = 0
363
+ palette[j * 3 + 2] = 0
364
+ while lab:
365
+ palette[j * 3 + 0] = 255
366
+ palette[j * 3 + 1] = 255
367
+ palette[j * 3 + 2] = 255
368
+ lab >>= 3
369
+ return palette
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
370
 
 
 
 
371
 
372
+ def _select_palette_by_category(category: str):
373
+ """
374
+ category별 palette 선택 로직 (명시적 규칙)
375
+ """
376
+ if category == "Upper-body":
377
+ return get_palette(4)
378
+ elif category == "Lower-body":
379
+ return get_palette2(4)
380
+ elif category == "Dress":
381
+ return get_palette(7)
382
+ else:
383
+ # fallback (명시 안 된 카테고리)
384
+ return get_palette(7)
385
 
386
 
387
  def run(
 
395
  logits: bool = False,
396
  ):
397
  """
 
 
 
398
  Returns:
399
  {
400
+ "images": List[PIL.Image],
401
  "logits": Optional[List[np.ndarray]],
402
+ "names": List[str],
403
  }
404
  """
405
+
406
  gpus = [int(i) for i in gpu.split(',')]
407
  assert len(gpus) == 1
408
  if gpu != 'None':
 
412
  print("[simple_extractor] model_restore not provided → skip extractor.")
413
  return {"images": [], "logits": [] if logits else None, "names": []}
414
 
 
415
  if bool(input_path) == bool(input_dir):
416
  raise ValueError("Provide exactly one of input_path or input_dir.")
417
 
418
+ if input_path and not os.path.isfile(input_path):
419
+ raise FileNotFoundError(input_path)
420
+ if input_dir and not os.path.isdir(input_dir):
421
+ raise NotADirectoryError(input_dir)
 
 
 
 
 
422
 
423
  num_classes = dataset_settings[dataset]['num_classes']
424
  input_size = dataset_settings[dataset]['input_size']
 
 
425
 
426
  model = networks.init_model('resnet101', num_classes=num_classes, pretrained=None)
 
427
  state_dict = torch.load(model_restore)['state_dict']
428
+
429
  from collections import OrderedDict
430
  new_state_dict = OrderedDict()
431
  for k, v in state_dict.items():
432
+ new_state_dict[k[7:]] = v
 
 
433
  model.load_state_dict(new_state_dict)
434
+
435
  model.cuda()
436
  model.eval()
437
 
438
  transform = transforms.Compose([
439
  transforms.ToTensor(),
440
+ transforms.Normalize(mean=[0.406, 0.456, 0.485],
441
+ std=[0.225, 0.224, 0.229])
442
  ])
443
 
 
444
  if input_path:
 
445
  root = os.path.dirname(input_path)
446
  file_list = [os.path.basename(input_path)]
447
  else:
 
459
  )
460
  dataloader = DataLoader(dataset_obj)
461
 
462
+ # ✅ 핵심 수정: category 기반 palette 선택
463
+ palette = _select_palette_by_category(category)
464
 
465
  results_img = []
466
  results_logits = [] if logits else None
 
478
  h = meta['height'].numpy()[0]
479
 
480
  output = model(image.cuda())
481
+ upsample = torch.nn.Upsample(
482
+ size=input_size, mode='bilinear', align_corners=True
483
+ )
484
  upsample_output = upsample(output[0][-1][0].unsqueeze(0))
485
+ upsample_output = upsample_output.squeeze().permute(1, 2, 0)
 
486
 
487
  logits_result = transform_logits(
488
  upsample_output.data.cpu().numpy(),
 
491
  )
492
  parsing_result = np.argmax(logits_result, axis=2)
493
 
494
+ out_img = Image.fromarray(parsing_result.astype(np.uint8))
495
  out_img.putpalette(palette)
496
  results_img.append(out_img)
497
 
498
  if logits:
499
  results_logits.append(logits_result)
500
 
501
+ return {
502
+ "images": results_img,
503
+ "logits": results_logits,
504
+ "names": names
505
+ }
506
 
507
 
508
  def main():
 
509
  args = get_arguments()
510
  run(
511
  category=args.category,
512
  input_dir=args.input_dir,
513
+ dataset=args.dataset,
514
+ model_restore=args.model_restore,
515
+ gpu=args.gpu,
516
+ logits=args.logits,
517
  )
518
 
519
 
520
  if __name__ == '__main__':
521
  main()