viton-hd / dataset_loader.py
known57's picture
Add files using upload-large-folder tool
24870a9 verified
"""
Dataset loader for pre-computed VITON-HD test data.
This uses the original preprocessing instead of generating new ones.
"""
import os
import json
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
class DatasetLoader:
def __init__(self, dataset_dir='./datasets/test'):
self.dataset_dir = dataset_dir
self.transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# Load test pairs
pairs_file = os.path.join(os.path.dirname(dataset_dir), 'test_pairs.txt')
self.pairs = []
with open(pairs_file, 'r') as f:
for line in f:
person, cloth = line.strip().split()
self.pairs.append((person, cloth))
def get_pair_names(self):
"""Return list of (person_name, cloth_name) tuples."""
return self.pairs
def load_pair(self, person_name, cloth_name):
"""
Load a person-cloth pair with all pre-computed data.
Returns dict with:
- img_agnostic: [1, 3, H, W]
- parse_agnostic: [1, 13, H, W]
- pose: [1, 3, H, W]
- c: [1, 3, H, W]
- cm: [1, 1, H, W]
"""
# Load person image
img_path = os.path.join(self.dataset_dir, 'image', person_name)
img = Image.open(img_path).convert('RGB')
img = transforms.Resize(768, interpolation=2)(img)
# Load cloth image
cloth_path = os.path.join(self.dataset_dir, 'cloth', cloth_name)
c = Image.open(cloth_path).convert('RGB')
c = transforms.Resize(768, interpolation=2)(c)
c = self.transform(c).unsqueeze(0)
# Load cloth mask
cm_path = os.path.join(self.dataset_dir, 'cloth-mask', cloth_name)
cm = Image.open(cm_path)
cm = transforms.Resize(768, interpolation=0)(cm)
cm_array = np.array(cm)
cm_array = (cm_array >= 128).astype(np.float32)
cm = torch.from_numpy(cm_array).unsqueeze(0).unsqueeze(0)
# Load pose RGB
pose_name = person_name.replace('.jpg', '_rendered.png')
pose_path = os.path.join(self.dataset_dir, 'openpose-img', pose_name)
pose_rgb = Image.open(pose_path)
pose_rgb = transforms.Resize(768, interpolation=2)(pose_rgb)
pose = self.transform(pose_rgb).unsqueeze(0)
# Load pose keypoints
pose_json_name = person_name.replace('.jpg', '_keypoints.json')
pose_json_path = os.path.join(self.dataset_dir, 'openpose-json', pose_json_name)
with open(pose_json_path, 'r') as f:
pose_label = json.load(f)
pose_data = pose_label['people'][0]['pose_keypoints_2d']
pose_data = np.array(pose_data).reshape((-1, 3))[:, :2]
# Load parsing
parse_name = person_name.replace('.jpg', '.png')
parse_path = os.path.join(self.dataset_dir, 'image-parse', parse_name)
parse = Image.open(parse_path)
parse = transforms.Resize(768, interpolation=0)(parse)
# Get parse_agnostic and img_agnostic
parse_agnostic = self._get_parse_agnostic(parse, pose_data)
img_agnostic = self._get_img_agnostic(img, parse, pose_data)
return {
'img_agnostic': img_agnostic,
'parse_agnostic': parse_agnostic,
'pose': pose,
'c': c,
'cm': cm
}
def _get_parse_agnostic(self, parse, pose_data):
"""Generate parse_agnostic from parse map."""
parse_array = np.array(parse)
parse_upper = ((parse_array == 5).astype(np.float32) +
(parse_array == 6).astype(np.float32) +
(parse_array == 7).astype(np.float32))
parse_neck = (parse_array == 10).astype(np.float32)
r = 10
agnostic = parse.copy()
# Mask arms (simplified from datasets.py)
for parse_id, pose_ids in [(14, [2, 5, 6, 7]), (15, [5, 2, 3, 4])]:
from PIL import ImageDraw
mask_arm = Image.new('L', (768, 1024), 'black')
mask_arm_draw = ImageDraw.Draw(mask_arm)
i_prev = pose_ids[0]
for i in pose_ids[1:]:
if (pose_data[i_prev, 0] == 0.0 and pose_data[i_prev, 1] == 0.0) or \
(pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
continue
mask_arm_draw.line([tuple(pose_data[j]) for j in [i_prev, i]], 'white', width=r*10)
pointx, pointy = pose_data[i]
radius = r*4 if i == pose_ids[-1] else r*15
mask_arm_draw.ellipse((pointx-radius, pointy-radius, pointx+radius, pointy+radius),
'white', 'white')
i_prev = i
parse_arm = (np.array(mask_arm) / 255) * (parse_array == parse_id).astype(np.float32)
agnostic.paste(0, None, Image.fromarray(np.uint8(parse_arm * 255), 'L'))
# Mask torso & neck
agnostic.paste(0, None, Image.fromarray(np.uint8(parse_upper * 255), 'L'))
agnostic.paste(0, None, Image.fromarray(np.uint8(parse_neck * 255), 'L'))
# Convert to one-hot
parse_agnostic_array = np.array(agnostic)
parse_agnostic_map = torch.zeros(20, 1024, 768, dtype=torch.float)
parse_agnostic_map.scatter_(0, torch.from_numpy(parse_agnostic_array).long().unsqueeze(0), 1.0)
# Map to 13 channels
labels = {
0: ['background', [0, 10]],
1: ['hair', [1, 2]],
2: ['face', [4, 13]],
3: ['upper', [5, 6, 7]],
4: ['bottom', [9, 12]],
5: ['left_arm', [14]],
6: ['right_arm', [15]],
7: ['left_leg', [16]],
8: ['right_leg', [17]],
9: ['left_shoe', [18]],
10: ['right_shoe', [19]],
11: ['socks', [8]],
12: ['noise', [3, 11]]
}
new_parse_agnostic_map = torch.zeros(13, 1024, 768, dtype=torch.float)
for i in range(len(labels)):
for label in labels[i][1]:
new_parse_agnostic_map[i] += parse_agnostic_map[label]
return new_parse_agnostic_map.unsqueeze(0)
def _get_img_agnostic(self, img, parse, pose_data):
"""Generate img_agnostic from image."""
parse_array = np.array(parse)
parse_head = ((parse_array == 4).astype(np.float32) +
(parse_array == 13).astype(np.float32))
parse_lower = ((parse_array == 9).astype(np.float32) +
(parse_array == 12).astype(np.float32) +
(parse_array == 16).astype(np.float32) +
(parse_array == 17).astype(np.float32) +
(parse_array == 18).astype(np.float32) +
(parse_array == 19).astype(np.float32))
r = 20
agnostic = img.copy()
from PIL import ImageDraw
agnostic_draw = ImageDraw.Draw(agnostic)
length_a = np.linalg.norm(pose_data[5] - pose_data[2])
length_b = np.linalg.norm(pose_data[12] - pose_data[9])
point = (pose_data[9] + pose_data[12]) / 2
pose_data[9] = point + (pose_data[9] - point) / length_b * length_a
pose_data[12] = point + (pose_data[12] - point) / length_b * length_a
# Mask arms
agnostic_draw.line([tuple(pose_data[i]) for i in [2, 5]], 'gray', width=r*10)
for i in [2, 5]:
pointx, pointy = pose_data[i]
agnostic_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'gray', 'gray')
for i in [3, 4, 6, 7]:
if (pose_data[i - 1, 0] == 0.0 and pose_data[i - 1, 1] == 0.0) or \
(pose_data[i, 0] == 0.0 and pose_data[i, 1] == 0.0):
continue
agnostic_draw.line([tuple(pose_data[j]) for j in [i - 1, i]], 'gray', width=r*10)
pointx, pointy = pose_data[i]
agnostic_draw.ellipse((pointx-r*5, pointy-r*5, pointx+r*5, pointy+r*5), 'gray', 'gray')
# Mask torso
for i in [9, 12]:
pointx, pointy = pose_data[i]
agnostic_draw.ellipse((pointx-r*3, pointy-r*6, pointx+r*3, pointy+r*6), 'gray', 'gray')
agnostic_draw.line([tuple(pose_data[i]) for i in [2, 9]], 'gray', width=r*6)
agnostic_draw.line([tuple(pose_data[i]) for i in [5, 12]], 'gray', width=r*6)
agnostic_draw.line([tuple(pose_data[i]) for i in [9, 12]], 'gray', width=r*12)
agnostic_draw.polygon([tuple(pose_data[i]) for i in [2, 5, 12, 9]], 'gray', 'gray')
# Mask neck
pointx, pointy = pose_data[1]
agnostic_draw.rectangle((pointx-r*7, pointy-r*7, pointx+r*7, pointy+r*7), 'gray', 'gray')
agnostic.paste(img, None, Image.fromarray(np.uint8(parse_head * 255), 'L'))
agnostic.paste(img, None, Image.fromarray(np.uint8(parse_lower * 255), 'L'))
return self.transform(agnostic).unsqueeze(0)