Spaces:
Build error
Build error
| import os | |
| import json | |
| from PIL import Image | |
| from torch.utils.data import Dataset, DataLoader | |
| from torchvision import transforms | |
| import torch | |
| import numpy as np | |
| import random | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| def find_nearest_bucket_size(input_width, input_height, mode="x64", ratio=1): | |
| buckets = [ | |
| (512, 2048), | |
| (512, 1984), | |
| (512, 1920), | |
| (512, 1856), | |
| (576, 1792), | |
| (576, 1728), | |
| (576, 1664), | |
| (640, 1600), | |
| (640, 1536), | |
| (704, 1472), | |
| (704, 1408), | |
| (704, 1344), | |
| (768, 1344), | |
| (768, 1280), | |
| (832, 1216), | |
| (832, 1152), | |
| (896, 1152), | |
| (896, 1088), | |
| (960, 1088), | |
| (960, 1024), | |
| (1024, 1024), | |
| (1024, 960), | |
| (1088, 960), | |
| (1088, 896), | |
| (1152, 896), | |
| (1152, 832), | |
| (1216, 832), | |
| (1280, 768), | |
| (1344, 768), | |
| (1408, 704), | |
| (1472, 704), | |
| (1536, 640), | |
| (1600, 640), | |
| (1664, 576), | |
| (1728, 576), | |
| (1792, 576), | |
| (1856, 512), | |
| (1920, 512), | |
| (1984, 512), | |
| (2048, 512) | |
| ] | |
| aspect_ratios = [w / h for (w, h) in buckets] | |
| assert mode in ["x64", "x8"] | |
| if mode == "x64": | |
| asp = input_width / input_height | |
| diff = [abs(ar - asp) for ar in aspect_ratios] | |
| bucket_id = int(np.argmin(diff)) | |
| gen_width, gen_height = buckets[bucket_id] | |
| elif mode == "x8": | |
| max_pixels = 1024 * 1024 | |
| ratio = (max_pixels / (input_width * input_height)) ** (0.5) | |
| gen_width, gen_height = round(input_width * ratio), round(input_height * ratio) | |
| gen_width = gen_width - gen_width % 8 | |
| gen_height = gen_height - gen_height % 8 | |
| else: | |
| raise NotImplementedError | |
| return (int(gen_width * ratio), int(gen_height * ratio)) | |
| def adjust_and_normalize_bboxes(bboxes, orig_width, orig_height): | |
| # Adjust and normalize bbox | |
| normalized_bboxes = [] | |
| for bbox in bboxes: | |
| x1, y1, x2, y2 = bbox | |
| x1_norm = round(x1 / orig_width,2) | |
| y1_norm = round(y1 / orig_height,2) | |
| x2_norm = round(x2 / orig_width,2) | |
| y2_norm = round(y2 / orig_height,2) | |
| normalized_bboxes.append([x1_norm, y1_norm, x2_norm, y2_norm]) | |
| return normalized_bboxes | |
| def img_transforms(image, height=512, width=512): | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| (height, width), interpolation=transforms.InterpolationMode.BILINEAR | |
| ), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5], [0.5]), | |
| ] | |
| ) | |
| image_transformed = transform(image) | |
| return image_transformed | |
| def mask_transforms(mask, height=512, width=512): | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize( | |
| (height, width), | |
| interpolation=transforms.InterpolationMode.NEAREST | |
| ), | |
| transforms.ToTensor(), | |
| ] | |
| ) | |
| mask_transformed = transform(mask) | |
| return mask_transformed | |
| class DesignDataset(Dataset): | |
| def __init__( | |
| self, | |
| dataset_name, | |
| resolution=512, | |
| condition_resolution=512, | |
| condition_resolution_scale_ratio=0.5, | |
| max_boxes_per_image=10, | |
| neg_condition_image = 'same', | |
| background_color = 'gray', | |
| use_bucket=True, | |
| box_confidence_th = 0.0 | |
| ): | |
| print(f"Loading dataset from Hugging Face: {dataset_name}") | |
| self.dataset = load_dataset(dataset_name, split="test") | |
| print(f"Loaded {len(self.dataset)} samples") | |
| from IPython.core.debugger import set_trace | |
| set_trace() | |
| self.max_boxes_per_image = max_boxes_per_image | |
| self.resolution = resolution | |
| self.condition_resolution=condition_resolution | |
| self.neg_condition_image = neg_condition_image | |
| self.use_bucket = use_bucket | |
| self.condition_resolution_scale_ratio=condition_resolution_scale_ratio | |
| self.box_confidence_th = box_confidence_th | |
| if background_color == 'white': | |
| self.background_color = (255, 255, 255) | |
| elif background_color == 'black': | |
| self.background_color = (0, 0, 0) | |
| elif background_color == 'gray': | |
| self.background_color = (128, 128, 128) | |
| else: | |
| raise ValueError("Invalid background color. Use 'white' or 'black'.") | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| sample = self.dataset[idx] | |
| image_source = sample['original_image'] | |
| subject_image = sample['condition_gray_background'] | |
| subject_mask = sample['subject_mask'] | |
| json_data = json.loads(sample['metadata']) | |
| #img info | |
| img_info = json_data['img_info'] | |
| img_id = img_info['img_id'] | |
| orig_width, orig_height = int(img_info["img_width"]),int(img_info["img_height"]) | |
| if self.use_bucket: | |
| target_width, target_height = find_nearest_bucket_size(orig_width,orig_height) | |
| condition_width = int(target_width * self.condition_resolution_scale_ratio) | |
| condition_height = int(target_height * self.condition_resolution_scale_ratio) | |
| else: | |
| target_width = target_height = self.resolution | |
| condition_width = condition_height = self.condition_resolution | |
| img_tensor = img_transforms(image_source,height=target_height,width=target_width) | |
| # global caption | |
| global_caption = json_data['global_caption'] | |
| # object_annotations | |
| object_annotations = json_data['object_annotations'] | |
| # object bbox list | |
| objects_bbox = [item['bbox'] for item in object_annotations] | |
| # object bbox caption | |
| objects_caption = [item['bbox_detail_description'] for item in object_annotations] | |
| # object bbox score | |
| objects_bbox_score = [item['score'][0] for item in object_annotations] | |
| # text | |
| text_list = json_data["text_list"] | |
| txt_bboxs = [item['bbox'] for item in text_list] | |
| txt_captions = ["text:"+item['text'] for item in text_list] | |
| txt_scores = [1.0 for _ in txt_bboxs] | |
| # combine bbox 和 description | |
| objects_bbox.extend(txt_bboxs) | |
| objects_caption.extend(txt_captions) | |
| objects_bbox_score.extend(txt_scores) | |
| objects_bbox =torch.tensor(adjust_and_normalize_bboxes(objects_bbox,orig_width,orig_height)) | |
| objects_bbox_score = torch.tensor(objects_bbox_score) | |
| boxes_mask = objects_bbox_score > self.box_confidence_th | |
| objects_bbox_raw = objects_bbox[boxes_mask] | |
| objects_caption = [object_caption for object_caption, box_mask in zip(objects_caption, boxes_mask) if box_mask] | |
| num_boxes = objects_bbox_raw.shape[0] | |
| objects_boxes_padded = torch.zeros((self.max_boxes_per_image, 4)) | |
| objects_masks_padded = torch.zeros(self.max_boxes_per_image) | |
| objects_caption = objects_caption[:self.max_boxes_per_image] | |
| objects_boxes_padded[:num_boxes] = objects_bbox_raw[:self.max_boxes_per_image] | |
| objects_masks_padded[:num_boxes] = 1. | |
| # objects_masks_maps | |
| objects_masks_maps_padded = torch.zeros((self.max_boxes_per_image, target_height, target_width)) | |
| for idx in range(num_boxes): | |
| x1, y1, x2, y2 = objects_boxes_padded[idx] | |
| x1_pixel = int(x1 * target_width) | |
| y1_pixel = int(y1 * target_height) | |
| x2_pixel = int(x2 * target_width) | |
| y2_pixel = int(y2 * target_height) | |
| x1_pixel = max(0, min(x1_pixel, target_width-1)) | |
| y1_pixel = max(0, min(y1_pixel, target_height-1)) | |
| x2_pixel = max(0, min(x2_pixel, target_width-1)) | |
| y2_pixel = max(0, min(y2_pixel, target_height-1)) | |
| objects_masks_maps_padded[idx, y1_pixel:y2_pixel+1, x1_pixel:x2_pixel+1] = 1.0 | |
| # subject | |
| original_size_subject_tensor = img_transforms(subject_image,height=target_height,width=target_width) | |
| subject_tensor = img_transforms(subject_image,height=condition_height,width=condition_width) | |
| subject_mask_tensor = mask_transforms(subject_mask, height=condition_height,width=condition_width) | |
| if self.neg_condition_image=='black': | |
| subject_image_black = Image.new('RGB', (orig_width, orig_height), (0, 0, 0)) | |
| subject_image_neg_tensor = img_transforms(subject_image_black,height=condition_height,width=condition_width) | |
| elif self.neg_condition_image=='white': | |
| subject_image_white = Image.new('RGB', (orig_width, orig_height), (255, 255, 255)) | |
| subject_image_neg_tensor = img_transforms(subject_image_white,height=condition_height,width=condition_width) | |
| elif self.neg_condition_image=='gray': | |
| subject_image_gray = Image.new('RGB', (orig_width, orig_height), (128, 128, 128)) | |
| subject_image_neg_tensor = img_transforms(subject_image_gray,height=condition_height,width=condition_width) | |
| elif self.neg_condition_image=='same': | |
| subject_image_neg_tensor = subject_tensor | |
| output = dict( | |
| id=img_id, | |
| caption=global_caption, | |
| objects_boxes=objects_boxes_padded, | |
| objects_caption=objects_caption, | |
| objects_masks=objects_masks_padded, | |
| objects_masks_maps=objects_masks_maps_padded, | |
| img=img_tensor, | |
| condition_img_masks_maps = subject_mask_tensor, | |
| condition_img = subject_tensor, | |
| original_size_condition_img = original_size_subject_tensor, | |
| neg_condtion_img = subject_image_neg_tensor, | |
| img_info = img_info, | |
| target_width=target_width, | |
| target_height=target_height, | |
| ) | |
| return output | |
| def collate_fn(examples): | |
| collated_examples = {} | |
| for key in ['id', 'objects_caption', 'caption','img_info','target_width','target_height']: | |
| collated_examples[key] = [example[key] for example in examples] | |
| for key in ['img', 'objects_boxes', 'objects_masks','condition_img','neg_condtion_img','objects_masks_maps','condition_img_masks_maps','original_size_condition_img']: | |
| collated_examples[key] = torch.stack([example[key] for example in examples]).float() | |
| return collated_examples | |
| from typing import Dict | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont, ImageOps | |
| import random | |
| def draw_mask(mask, draw, random_color=True): | |
| """Draws a mask with a specified color on an image. | |
| Args: | |
| mask (np.array): Binary mask as a NumPy array. | |
| draw (ImageDraw.Draw): ImageDraw object to draw on the image. | |
| random_color (bool): Whether to use a random color for the mask. | |
| """ | |
| if random_color: | |
| color = ( | |
| random.randint(0, 255), | |
| random.randint(0, 255), | |
| random.randint(0, 255), | |
| 153, | |
| ) | |
| else: | |
| color = (30, 144, 255, 153) | |
| nonzero_coords = np.transpose(np.nonzero(mask)) | |
| for coord in nonzero_coords: | |
| draw.point(coord[::-1], fill=color) | |
| def visualize_bbox(image_pil: Image, | |
| result: Dict, | |
| draw_width: float = 6.0, | |
| return_mask=True) -> Image: | |
| """Plot bounding boxes and labels on an image with text wrapping for long descriptions. | |
| Args: | |
| image_pil (PIL.Image): The input image as a PIL Image object. | |
| result (Dict[str, Union[torch.Tensor, List[torch.Tensor]]]): The target dictionary containing | |
| the bounding boxes and labels. The keys are: | |
| - boxes (List[int]): A list of bounding boxes in shape (N, 4), [x1, y1, x2, y2] format. | |
| - labels (List[str]): A list of labels for each object | |
| - masks (List[PIL.Image], optional): A list of masks in the format of PIL.Image | |
| Returns: | |
| PIL.Image: The input image with plotted bounding boxes, labels, and masks. | |
| """ | |
| # Get the bounding boxes and labels from the target dictionary | |
| boxes = result["boxes"] | |
| categorys = result["labels"] | |
| masks = result.get("masks", []) | |
| color_list = [(255, 162, 76), (177, 214, 144), | |
| (13, 146, 244), (249, 84, 84), (54, 186, 152), | |
| (74, 36, 157), (0, 159, 189), | |
| (80, 118, 135), (188, 90, 148), (119, 205, 255)] | |
| # Use smaller font size to allow more text to be displayed | |
| font_size = 30 # Reduce font size | |
| font = ImageFont.truetype("dataloader/arial.ttf", font_size) | |
| # Get image dimensions | |
| img_width, img_height = image_pil.size | |
| # Find all unique categories and build a cate2color dictionary | |
| cate2color = {} | |
| unique_categorys = sorted(set(categorys)) | |
| for idx, cate in enumerate(unique_categorys): | |
| cate2color[cate] = color_list[idx % len(color_list)] | |
| # Create a PIL ImageDraw object to draw on the input image | |
| if isinstance(image_pil, np.ndarray): | |
| image_pil = Image.fromarray(image_pil) | |
| draw = ImageDraw.Draw(image_pil) | |
| # Create a new binary mask image with the same size as the input image | |
| mask = Image.new("L", image_pil.size, 0) | |
| # Create a PIL ImageDraw object to draw on the mask image | |
| mask_draw = ImageDraw.Draw(mask) | |
| # Draw boxes, labels, and masks for each box and label in the target dictionary | |
| for box, category in zip(boxes, categorys): | |
| # Extract the box coordinates | |
| x0, y0, x1, y1 = box | |
| x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) | |
| box_width = x1 - x0 | |
| box_height = y1 - y0 | |
| color = cate2color.get(category, color_list[0]) # Default color | |
| # Draw the box outline on the input image | |
| draw.rectangle([x0, y0, x1, y1], outline=color, width=int(draw_width)) | |
| # Allow text box to be maximum 2 times the bounding box width, but not exceed image boundaries | |
| max_text_width = min(box_width * 2, img_width - x0) | |
| # Determine the maximum height for text background area | |
| max_text_height = min(box_height * 2, 200) # Also allow more text display, but limit height | |
| # Handle long text based on bounding box width, split text into lines | |
| lines = [] | |
| words = category.split() | |
| current_line = words[0] | |
| for word in words[1:]: | |
| # Try to add the next word | |
| test_line = current_line + " " + word | |
| # Use textbbox or textlength to check if width fits the maximum text width | |
| if hasattr(draw, "textbbox"): | |
| # Use textbbox method | |
| bbox = draw.textbbox((0, 0), test_line, font=font) | |
| w = bbox[2] - bbox[0] | |
| elif hasattr(draw, "textlength"): | |
| # Use textlength method | |
| w = draw.textlength(test_line, font=font) | |
| else: | |
| # Fallback - estimate width | |
| w = len(test_line) * (font_size * 0.6) # Estimate average character width | |
| if w <= max_text_width - 20: # Leave some margin | |
| current_line = test_line | |
| else: | |
| lines.append(current_line) | |
| current_line = word | |
| lines.append(current_line) # Add the last line | |
| # Limit number of lines to prevent overflow | |
| max_lines = max_text_height // (font_size + 2) # Line height (font size + spacing) | |
| if len(lines) > max_lines: | |
| lines = lines[:max_lines-1] | |
| lines.append("...") # Add ellipsis | |
| # Calculate actual required width for each line | |
| line_widths = [] | |
| for line in lines: | |
| if hasattr(draw, "textbbox"): | |
| bbox = draw.textbbox((0, 0), line, font=font) | |
| line_width = bbox[2] - bbox[0] | |
| elif hasattr(draw, "textlength"): | |
| line_width = draw.textlength(line, font=font) | |
| else: | |
| line_width = len(line) * (font_size * 0.6) # Estimate width | |
| line_widths.append(line_width) | |
| # Determine actual required width for text box | |
| if line_widths: | |
| needed_text_width = max(line_widths) + 10 # Add small margin | |
| else: | |
| needed_text_width = 0 | |
| # Use bounding box width as minimum, only expand when needed | |
| text_bg_width = max(box_width, min(needed_text_width, max_text_width)) | |
| # Ensure it doesn't exceed image boundaries | |
| text_bg_width = min(text_bg_width, img_width - x0) | |
| # Calculate text background height | |
| text_bg_height = len(lines) * (font_size + 2) | |
| # Ensure text background doesn't exceed image bottom | |
| if y0 + text_bg_height > img_height: | |
| # If it would exceed bottom, adjust text position to above the bounding box bottom | |
| text_y0 = max(0, y1 - text_bg_height) | |
| else: | |
| text_y0 = y0 | |
| # Draw text background - note RGBA color handling | |
| if image_pil.mode == "RGBA": | |
| # For RGBA mode, we can directly use alpha color | |
| bg_color = (*color, 180) # Semi-transparent background | |
| else: | |
| # For RGB mode, we cannot use alpha | |
| bg_color = color | |
| draw.rectangle([x0, text_y0, x0 + text_bg_width, text_y0 + text_bg_height], fill=bg_color) | |
| # Draw text | |
| for i, line in enumerate(lines): | |
| y_pos = text_y0 + i * (font_size + 2) | |
| draw.text((x0 + 5, y_pos), line, fill="white", font=font) | |
| # Draw the mask on the input image if masks are provided | |
| if len(masks) > 0 and return_mask: | |
| size = image_pil.size | |
| mask_image = Image.new("RGBA", size, color=(0, 0, 0, 0)) | |
| mask_draw = ImageDraw.Draw(mask_image) | |
| for mask in masks: | |
| mask = np.array(mask)[:, :, -1] | |
| draw_mask(mask, mask_draw) | |
| image_pil = Image.alpha_composite(image_pil.convert("RGBA"), mask_image).convert("RGB") | |
| return image_pil | |
| import torchvision.transforms as T | |
| from PIL import Image, ImageDraw, ImageFont, ImageChops | |
| def tensor_to_pil(img_tensor): | |
| """将tensor转换为PIL图像""" | |
| img_tensor = img_tensor.cpu() | |
| # 反归一化 ([0.5], [0.5]) | |
| img_tensor = img_tensor * 0.5 + 0.5 | |
| img_tensor = torch.clamp(img_tensor, 0, 1) | |
| return T.ToPILImage()(img_tensor) | |
| def make_image_grid_RGB(images, rows, cols, resize=None): | |
| """ | |
| Prepares a single grid of images. Useful for visualization purposes. | |
| """ | |
| assert len(images) == rows * cols | |
| if resize is not None: | |
| images = [img.resize((resize, resize)) for img in images] | |
| w, h = images[0].size | |
| grid = Image.new("RGB", size=(cols * w, rows * h)) | |
| for i, img in enumerate(images): | |
| grid.paste(img.convert("RGB"), box=(i % cols * w, i // cols * h)) | |
| return grid | |
| if __name__ == "__main__": | |
| resolution = 1024 | |
| condition_resolution = 512 | |
| neg_condition_image = 'same' | |
| background_color = 'gray' | |
| use_bucket = True | |
| condition_resolution_scale_ratio=0.5 | |
| benchmark_repo = 'HuiZhang0812/CreatiDesign_benchmark' # huggingface repo of benchmark | |
| datasets = DesignDataset(dataset_name=benchmark_repo, | |
| resolution=resolution, | |
| condition_resolution=condition_resolution, | |
| neg_condition_image =neg_condition_image, | |
| background_color=background_color, | |
| use_bucket=use_bucket, | |
| condition_resolution_scale_ratio=condition_resolution_scale_ratio | |
| ) | |
| test_dataloader = DataLoader(datasets, batch_size=1, shuffle=False, num_workers=1,collate_fn=collate_fn) | |
| for i, batch in enumerate(tqdm(test_dataloader)): | |
| prompts = batch["caption"] | |
| imgs_id = batch['id'] | |
| objects_boxes = batch["objects_boxes"] | |
| objects_caption = batch['objects_caption'] | |
| objects_masks = batch['objects_masks'] | |
| condition_img = batch['condition_img'] | |
| neg_condtion_img = batch['neg_condtion_img'] | |
| objects_masks_maps= batch['objects_masks_maps'] | |
| subject_masks_maps = batch['condition_img_masks_maps'] | |
| target_width=batch['target_width'][0] | |
| target_height=batch['target_height'][0] | |
| img_info = batch["img_info"][0] | |
| filename = img_info["img_id"]+'.jpg' | |