Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| import os | |
| import numpy as np | |
| from torch.utils.data import Dataset | |
| import torchvision.transforms as T | |
| import random | |
| import torch | |
| import json | |
| class MMCelebAHQ(Dataset): | |
| def __init__( | |
| self, | |
| root="data/mmcelebahq", | |
| condition_size: int = 512, | |
| target_size: int = 512, | |
| condition_type: str = "depth", | |
| drop_text_prob: float = 0.1, | |
| drop_image_prob: float = 0.1, | |
| return_pil_image: bool = False, | |
| position_scale=1.0, | |
| ): | |
| self.root = root | |
| self.face_paths, self.mask_paths, self.prompts = self.get_face_mask_prompt() | |
| self.condition_size = condition_size | |
| self.target_size = target_size | |
| self.condition_type = condition_type | |
| self.drop_text_prob = drop_text_prob | |
| self.drop_image_prob = drop_image_prob | |
| self.return_pil_image = return_pil_image | |
| self.position_scale = position_scale | |
| self.to_tensor = T.ToTensor() | |
| def get_face_mask_prompt(self): | |
| face_paths = [ | |
| os.path.join(self.root, "face", f"{i}.jpg") for i in range(0, 27000) | |
| ] | |
| mask_paths = [ | |
| os.path.join(self.root, "mask", f"{i}.png") for i in range(0, 27000) | |
| ] | |
| with open(os.path.join(self.root, "text.json"), mode="r") as f: | |
| prompts = json.load(f) | |
| return face_paths, mask_paths, prompts | |
| def __len__(self): | |
| return len(self.face_paths) | |
| def __getitem__(self, idx): | |
| image = Image.open(self.face_paths[idx]).convert("RGB") | |
| prompts = self.prompts[f"{idx}.jpg"] | |
| description = random.choices(prompts, k=1)[0].strip() | |
| enable_scale = random.random() < 1 | |
| if not enable_scale: | |
| condition_size = int(self.condition_size * self.position_scale) | |
| position_scale = 1.0 | |
| else: | |
| condition_size = self.condition_size | |
| position_scale = self.position_scale | |
| # Get the condition image | |
| position_delta = np.array([0, 0]) | |
| mask = np.array(Image.open(self.mask_paths[idx])) | |
| mask_list = [self.to_tensor(Image.open(self.mask_paths[idx]).convert("RGB"))] | |
| for i in range(19): | |
| local_mask = np.zeros_like(mask) | |
| local_mask[mask == i] = 255 | |
| drop_image = random.random() < self.drop_image_prob | |
| if drop_image: | |
| local_mask = np.zeros_like(mask) | |
| local_mask_rgb = Image.fromarray(local_mask).convert("RGB") | |
| local_mask_tensor = self.to_tensor(local_mask_rgb) | |
| mask_list.append(local_mask_tensor) | |
| condition_img = torch.stack(mask_list,dim=0) | |
| # Randomly drop text or image | |
| drop_text = random.random() < self.drop_text_prob | |
| # drop_image = random.random() < self.drop_image_prob | |
| if drop_text: | |
| description = "" | |
| return { | |
| "image": self.to_tensor(image), | |
| "condition": condition_img, | |
| # "condition": self.to_tensor(condition_img), | |
| "condition_type": self.condition_type, | |
| "description": description, | |
| "position_delta": position_delta, | |
| **({"pil_image": [image, condition_img]} if self.return_pil_image else {}), | |
| **({"position_scale": position_scale} if position_scale != 1.0 else {}), | |
| } | |