layout_crazydesign / dataloader /creatidesign_dataset_benchmark.py
maddigit's picture
Upload 27 files
ddbdbca verified
import os
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch
import numpy as np
import random
from datasets import load_dataset
from tqdm import tqdm
def find_nearest_bucket_size(input_width, input_height, mode="x64", ratio=1):
buckets = [
(512, 2048),
(512, 1984),
(512, 1920),
(512, 1856),
(576, 1792),
(576, 1728),
(576, 1664),
(640, 1600),
(640, 1536),
(704, 1472),
(704, 1408),
(704, 1344),
(768, 1344),
(768, 1280),
(832, 1216),
(832, 1152),
(896, 1152),
(896, 1088),
(960, 1088),
(960, 1024),
(1024, 1024),
(1024, 960),
(1088, 960),
(1088, 896),
(1152, 896),
(1152, 832),
(1216, 832),
(1280, 768),
(1344, 768),
(1408, 704),
(1472, 704),
(1536, 640),
(1600, 640),
(1664, 576),
(1728, 576),
(1792, 576),
(1856, 512),
(1920, 512),
(1984, 512),
(2048, 512)
]
aspect_ratios = [w / h for (w, h) in buckets]
assert mode in ["x64", "x8"]
if mode == "x64":
asp = input_width / input_height
diff = [abs(ar - asp) for ar in aspect_ratios]
bucket_id = int(np.argmin(diff))
gen_width, gen_height = buckets[bucket_id]
elif mode == "x8":
max_pixels = 1024 * 1024
ratio = (max_pixels / (input_width * input_height)) ** (0.5)
gen_width, gen_height = round(input_width * ratio), round(input_height * ratio)
gen_width = gen_width - gen_width % 8
gen_height = gen_height - gen_height % 8
else:
raise NotImplementedError
return (int(gen_width * ratio), int(gen_height * ratio))
def adjust_and_normalize_bboxes(bboxes, orig_width, orig_height):
# Adjust and normalize bbox
normalized_bboxes = []
for bbox in bboxes:
x1, y1, x2, y2 = bbox
x1_norm = round(x1 / orig_width,2)
y1_norm = round(y1 / orig_height,2)
x2_norm = round(x2 / orig_width,2)
y2_norm = round(y2 / orig_height,2)
normalized_bboxes.append([x1_norm, y1_norm, x2_norm, y2_norm])
return normalized_bboxes
def img_transforms(image, height=512, width=512):
transform = transforms.Compose(
[
transforms.Resize(
(height, width), interpolation=transforms.InterpolationMode.BILINEAR
),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
image_transformed = transform(image)
return image_transformed
def mask_transforms(mask, height=512, width=512):
transform = transforms.Compose(
[
transforms.Resize(
(height, width),
interpolation=transforms.InterpolationMode.NEAREST
),
transforms.ToTensor(),
]
)
mask_transformed = transform(mask)
return mask_transformed
class DesignDataset(Dataset):
def __init__(
self,
dataset_name,
resolution=512,
condition_resolution=512,
condition_resolution_scale_ratio=0.5,
max_boxes_per_image=10,
neg_condition_image = 'same',
background_color = 'gray',
use_bucket=True,
box_confidence_th = 0.0
):
print(f"Loading dataset from Hugging Face: {dataset_name}")
self.dataset = load_dataset(dataset_name, split="test")
print(f"Loaded {len(self.dataset)} samples")
from IPython.core.debugger import set_trace
set_trace()
self.max_boxes_per_image = max_boxes_per_image
self.resolution = resolution
self.condition_resolution=condition_resolution
self.neg_condition_image = neg_condition_image
self.use_bucket = use_bucket
self.condition_resolution_scale_ratio=condition_resolution_scale_ratio
self.box_confidence_th = box_confidence_th
if background_color == 'white':
self.background_color = (255, 255, 255)
elif background_color == 'black':
self.background_color = (0, 0, 0)
elif background_color == 'gray':
self.background_color = (128, 128, 128)
else:
raise ValueError("Invalid background color. Use 'white' or 'black'.")
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
sample = self.dataset[idx]
image_source = sample['original_image']
subject_image = sample['condition_gray_background']
subject_mask = sample['subject_mask']
json_data = json.loads(sample['metadata'])
#img info
img_info = json_data['img_info']
img_id = img_info['img_id']
orig_width, orig_height = int(img_info["img_width"]),int(img_info["img_height"])
if self.use_bucket:
target_width, target_height = find_nearest_bucket_size(orig_width,orig_height)
condition_width = int(target_width * self.condition_resolution_scale_ratio)
condition_height = int(target_height * self.condition_resolution_scale_ratio)
else:
target_width = target_height = self.resolution
condition_width = condition_height = self.condition_resolution
img_tensor = img_transforms(image_source,height=target_height,width=target_width)
# global caption
global_caption = json_data['global_caption']
# object_annotations
object_annotations = json_data['object_annotations']
# object bbox list
objects_bbox = [item['bbox'] for item in object_annotations]
# object bbox caption
objects_caption = [item['bbox_detail_description'] for item in object_annotations]
# object bbox score
objects_bbox_score = [item['score'][0] for item in object_annotations]
# text
text_list = json_data["text_list"]
txt_bboxs = [item['bbox'] for item in text_list]
txt_captions = ["text:"+item['text'] for item in text_list]
txt_scores = [1.0 for _ in txt_bboxs]
# combine bbox 和 description
objects_bbox.extend(txt_bboxs)
objects_caption.extend(txt_captions)
objects_bbox_score.extend(txt_scores)
objects_bbox =torch.tensor(adjust_and_normalize_bboxes(objects_bbox,orig_width,orig_height))
objects_bbox_score = torch.tensor(objects_bbox_score)
boxes_mask = objects_bbox_score > self.box_confidence_th
objects_bbox_raw = objects_bbox[boxes_mask]
objects_caption = [object_caption for object_caption, box_mask in zip(objects_caption, boxes_mask) if box_mask]
num_boxes = objects_bbox_raw.shape[0]
objects_boxes_padded = torch.zeros((self.max_boxes_per_image, 4))
objects_masks_padded = torch.zeros(self.max_boxes_per_image)
objects_caption = objects_caption[:self.max_boxes_per_image]
objects_boxes_padded[:num_boxes] = objects_bbox_raw[:self.max_boxes_per_image]
objects_masks_padded[:num_boxes] = 1.
# objects_masks_maps
objects_masks_maps_padded = torch.zeros((self.max_boxes_per_image, target_height, target_width))
for idx in range(num_boxes):
x1, y1, x2, y2 = objects_boxes_padded[idx]
x1_pixel = int(x1 * target_width)
y1_pixel = int(y1 * target_height)
x2_pixel = int(x2 * target_width)
y2_pixel = int(y2 * target_height)
x1_pixel = max(0, min(x1_pixel, target_width-1))
y1_pixel = max(0, min(y1_pixel, target_height-1))
x2_pixel = max(0, min(x2_pixel, target_width-1))
y2_pixel = max(0, min(y2_pixel, target_height-1))
objects_masks_maps_padded[idx, y1_pixel:y2_pixel+1, x1_pixel:x2_pixel+1] = 1.0
# subject
original_size_subject_tensor = img_transforms(subject_image,height=target_height,width=target_width)
subject_tensor = img_transforms(subject_image,height=condition_height,width=condition_width)
subject_mask_tensor = mask_transforms(subject_mask, height=condition_height,width=condition_width)
if self.neg_condition_image=='black':
subject_image_black = Image.new('RGB', (orig_width, orig_height), (0, 0, 0))
subject_image_neg_tensor = img_transforms(subject_image_black,height=condition_height,width=condition_width)
elif self.neg_condition_image=='white':
subject_image_white = Image.new('RGB', (orig_width, orig_height), (255, 255, 255))
subject_image_neg_tensor = img_transforms(subject_image_white,height=condition_height,width=condition_width)
elif self.neg_condition_image=='gray':
subject_image_gray = Image.new('RGB', (orig_width, orig_height), (128, 128, 128))
subject_image_neg_tensor = img_transforms(subject_image_gray,height=condition_height,width=condition_width)
elif self.neg_condition_image=='same':
subject_image_neg_tensor = subject_tensor
output = dict(
id=img_id,
caption=global_caption,
objects_boxes=objects_boxes_padded,
objects_caption=objects_caption,
objects_masks=objects_masks_padded,
objects_masks_maps=objects_masks_maps_padded,
img=img_tensor,
condition_img_masks_maps = subject_mask_tensor,
condition_img = subject_tensor,
original_size_condition_img = original_size_subject_tensor,
neg_condtion_img = subject_image_neg_tensor,
img_info = img_info,
target_width=target_width,
target_height=target_height,
)
return output
def collate_fn(examples):
collated_examples = {}
for key in ['id', 'objects_caption', 'caption','img_info','target_width','target_height']:
collated_examples[key] = [example[key] for example in examples]
for key in ['img', 'objects_boxes', 'objects_masks','condition_img','neg_condtion_img','objects_masks_maps','condition_img_masks_maps','original_size_condition_img']:
collated_examples[key] = torch.stack([example[key] for example in examples]).float()
return collated_examples
from typing import Dict
import numpy as np
from PIL import Image, ImageDraw, ImageFont, ImageOps
import random
def draw_mask(mask, draw, random_color=True):
"""Draws a mask with a specified color on an image.
Args:
mask (np.array): Binary mask as a NumPy array.
draw (ImageDraw.Draw): ImageDraw object to draw on the image.
random_color (bool): Whether to use a random color for the mask.
"""
if random_color:
color = (
random.randint(0, 255),
random.randint(0, 255),
random.randint(0, 255),
153,
)
else:
color = (30, 144, 255, 153)
nonzero_coords = np.transpose(np.nonzero(mask))
for coord in nonzero_coords:
draw.point(coord[::-1], fill=color)
def visualize_bbox(image_pil: Image,
result: Dict,
draw_width: float = 6.0,
return_mask=True) -> Image:
"""Plot bounding boxes and labels on an image with text wrapping for long descriptions.
Args:
image_pil (PIL.Image): The input image as a PIL Image object.
result (Dict[str, Union[torch.Tensor, List[torch.Tensor]]]): The target dictionary containing
the bounding boxes and labels. The keys are:
- boxes (List[int]): A list of bounding boxes in shape (N, 4), [x1, y1, x2, y2] format.
- labels (List[str]): A list of labels for each object
- masks (List[PIL.Image], optional): A list of masks in the format of PIL.Image
Returns:
PIL.Image: The input image with plotted bounding boxes, labels, and masks.
"""
# Get the bounding boxes and labels from the target dictionary
boxes = result["boxes"]
categorys = result["labels"]
masks = result.get("masks", [])
color_list = [(255, 162, 76), (177, 214, 144),
(13, 146, 244), (249, 84, 84), (54, 186, 152),
(74, 36, 157), (0, 159, 189),
(80, 118, 135), (188, 90, 148), (119, 205, 255)]
# Use smaller font size to allow more text to be displayed
font_size = 30 # Reduce font size
font = ImageFont.truetype("dataloader/arial.ttf", font_size)
# Get image dimensions
img_width, img_height = image_pil.size
# Find all unique categories and build a cate2color dictionary
cate2color = {}
unique_categorys = sorted(set(categorys))
for idx, cate in enumerate(unique_categorys):
cate2color[cate] = color_list[idx % len(color_list)]
# Create a PIL ImageDraw object to draw on the input image
if isinstance(image_pil, np.ndarray):
image_pil = Image.fromarray(image_pil)
draw = ImageDraw.Draw(image_pil)
# Create a new binary mask image with the same size as the input image
mask = Image.new("L", image_pil.size, 0)
# Create a PIL ImageDraw object to draw on the mask image
mask_draw = ImageDraw.Draw(mask)
# Draw boxes, labels, and masks for each box and label in the target dictionary
for box, category in zip(boxes, categorys):
# Extract the box coordinates
x0, y0, x1, y1 = box
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
box_width = x1 - x0
box_height = y1 - y0
color = cate2color.get(category, color_list[0]) # Default color
# Draw the box outline on the input image
draw.rectangle([x0, y0, x1, y1], outline=color, width=int(draw_width))
# Allow text box to be maximum 2 times the bounding box width, but not exceed image boundaries
max_text_width = min(box_width * 2, img_width - x0)
# Determine the maximum height for text background area
max_text_height = min(box_height * 2, 200) # Also allow more text display, but limit height
# Handle long text based on bounding box width, split text into lines
lines = []
words = category.split()
current_line = words[0]
for word in words[1:]:
# Try to add the next word
test_line = current_line + " " + word
# Use textbbox or textlength to check if width fits the maximum text width
if hasattr(draw, "textbbox"):
# Use textbbox method
bbox = draw.textbbox((0, 0), test_line, font=font)
w = bbox[2] - bbox[0]
elif hasattr(draw, "textlength"):
# Use textlength method
w = draw.textlength(test_line, font=font)
else:
# Fallback - estimate width
w = len(test_line) * (font_size * 0.6) # Estimate average character width
if w <= max_text_width - 20: # Leave some margin
current_line = test_line
else:
lines.append(current_line)
current_line = word
lines.append(current_line) # Add the last line
# Limit number of lines to prevent overflow
max_lines = max_text_height // (font_size + 2) # Line height (font size + spacing)
if len(lines) > max_lines:
lines = lines[:max_lines-1]
lines.append("...") # Add ellipsis
# Calculate actual required width for each line
line_widths = []
for line in lines:
if hasattr(draw, "textbbox"):
bbox = draw.textbbox((0, 0), line, font=font)
line_width = bbox[2] - bbox[0]
elif hasattr(draw, "textlength"):
line_width = draw.textlength(line, font=font)
else:
line_width = len(line) * (font_size * 0.6) # Estimate width
line_widths.append(line_width)
# Determine actual required width for text box
if line_widths:
needed_text_width = max(line_widths) + 10 # Add small margin
else:
needed_text_width = 0
# Use bounding box width as minimum, only expand when needed
text_bg_width = max(box_width, min(needed_text_width, max_text_width))
# Ensure it doesn't exceed image boundaries
text_bg_width = min(text_bg_width, img_width - x0)
# Calculate text background height
text_bg_height = len(lines) * (font_size + 2)
# Ensure text background doesn't exceed image bottom
if y0 + text_bg_height > img_height:
# If it would exceed bottom, adjust text position to above the bounding box bottom
text_y0 = max(0, y1 - text_bg_height)
else:
text_y0 = y0
# Draw text background - note RGBA color handling
if image_pil.mode == "RGBA":
# For RGBA mode, we can directly use alpha color
bg_color = (*color, 180) # Semi-transparent background
else:
# For RGB mode, we cannot use alpha
bg_color = color
draw.rectangle([x0, text_y0, x0 + text_bg_width, text_y0 + text_bg_height], fill=bg_color)
# Draw text
for i, line in enumerate(lines):
y_pos = text_y0 + i * (font_size + 2)
draw.text((x0 + 5, y_pos), line, fill="white", font=font)
# Draw the mask on the input image if masks are provided
if len(masks) > 0 and return_mask:
size = image_pil.size
mask_image = Image.new("RGBA", size, color=(0, 0, 0, 0))
mask_draw = ImageDraw.Draw(mask_image)
for mask in masks:
mask = np.array(mask)[:, :, -1]
draw_mask(mask, mask_draw)
image_pil = Image.alpha_composite(image_pil.convert("RGBA"), mask_image).convert("RGB")
return image_pil
import torchvision.transforms as T
from PIL import Image, ImageDraw, ImageFont, ImageChops
def tensor_to_pil(img_tensor):
"""将tensor转换为PIL图像"""
img_tensor = img_tensor.cpu()
# 反归一化 ([0.5], [0.5])
img_tensor = img_tensor * 0.5 + 0.5
img_tensor = torch.clamp(img_tensor, 0, 1)
return T.ToPILImage()(img_tensor)
def make_image_grid_RGB(images, rows, cols, resize=None):
"""
Prepares a single grid of images. Useful for visualization purposes.
"""
assert len(images) == rows * cols
if resize is not None:
images = [img.resize((resize, resize)) for img in images]
w, h = images[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
for i, img in enumerate(images):
grid.paste(img.convert("RGB"), box=(i % cols * w, i // cols * h))
return grid
if __name__ == "__main__":
resolution = 1024
condition_resolution = 512
neg_condition_image = 'same'
background_color = 'gray'
use_bucket = True
condition_resolution_scale_ratio=0.5
benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark
datasets = DesignDataset(dataset_name=benchmark_repo,
resolution=resolution,
condition_resolution=condition_resolution,
neg_condition_image =neg_condition_image,
background_color=background_color,
use_bucket=use_bucket,
condition_resolution_scale_ratio=condition_resolution_scale_ratio
)
test_dataloader = DataLoader(datasets, batch_size=1, shuffle=False, num_workers=1,collate_fn=collate_fn)
for i, batch in enumerate(tqdm(test_dataloader)):
prompts = batch["caption"]
imgs_id = batch['id']
objects_boxes = batch["objects_boxes"]
objects_caption = batch['objects_caption']
objects_masks = batch['objects_masks']
condition_img = batch['condition_img']
neg_condtion_img = batch['neg_condtion_img']
objects_masks_maps= batch['objects_masks_maps']
subject_masks_maps = batch['condition_img_masks_maps']
target_width=batch['target_width'][0]
target_height=batch['target_height'][0]
img_info = batch["img_info"][0]
filename = img_info["img_id"]+'.jpg'