Spaces:
Runtime error
Runtime error
| # ------------------------------------------ | |
| # TextDiffuser: Diffusion Models as Text Painters | |
| # Paper Link: https://arxiv.org/abs/2305.10855 | |
| # Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser | |
| # Copyright (c) Microsoft Corporation. | |
| # This file provides the inference script. | |
| # ------------------------------------------ | |
| import os | |
| import cv2 | |
| import random | |
| import logging | |
| import argparse | |
| import numpy as np | |
| from pathlib import Path | |
| from tqdm.auto import tqdm | |
| from typing import Optional | |
| from packaging import version | |
| from termcolor import colored | |
| from PIL import Image, ImageDraw, ImageFont, ImageOps, ImageEnhance # import for visualization | |
| from huggingface_hub import HfFolder, Repository, create_repo, whoami | |
| import datasets | |
| from datasets import load_dataset | |
| from datasets import disable_caching | |
| import torch | |
| import torch.utils.checkpoint | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| import accelerate | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import ProjectConfiguration, set_seed | |
| import diffusers | |
| from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel | |
| from diffusers.optimization import get_scheduler | |
| from diffusers.training_utils import EMAModel | |
| from diffusers.utils import check_min_version, deprecate | |
| from diffusers.utils.import_utils import is_xformers_available | |
| import transformers | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| from util import segmentation_mask_visualization, make_caption_pil, combine_image, transform_mask, filter_segmentation_mask, inpainting_merge_image | |
| from model.layout_generator import get_layout_from_prompt | |
| from model.text_segmenter.unet import UNet | |
| import torchsnooper | |
| disable_caching() | |
| check_min_version("0.15.0.dev0") | |
| logger = get_logger(__name__, log_level="INFO") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| type=str, | |
| default='runwayml/stable-diffusion-v1-5', # no need to modify this | |
| help="Path to pretrained model or model identifier from huggingface.co/models. Please do not modify this.", | |
| ) | |
| parser.add_argument( | |
| "--revision", | |
| type=str, | |
| default=None, | |
| required=False, | |
| help="Revision of pretrained model identifier from huggingface.co/models.", | |
| ) | |
| parser.add_argument( | |
| "--mode", | |
| type=str, | |
| default=None, | |
| required=True, | |
| choices=["text-to-image", "text-to-image-with-template", "text-inpainting"], | |
| help="Three modes can be used.", | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| default="", | |
| required=True, | |
| help="The text prompts provided by users.", | |
| ) | |
| parser.add_argument( | |
| "--template_image", | |
| type=str, | |
| default="", | |
| help="The template image should be given when using γtext-to-image-with-templateγ mode.", | |
| ) | |
| parser.add_argument( | |
| "--original_image", | |
| type=str, | |
| default="", | |
| help="The original image should be given when using γtext-inpaintingγ mode.", | |
| ) | |
| parser.add_argument( | |
| "--text_mask", | |
| type=str, | |
| default="", | |
| help="The text mask should be given when using γtext-inpaintingγ mode.", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="output", | |
| help="The output directory where the model predictions and checkpoints will be written.", | |
| ) | |
| parser.add_argument( | |
| "--cache_dir", | |
| type=str, | |
| default=None, | |
| help="The directory where the downloaded models and datasets will be stored.", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=None, | |
| help="A seed for reproducible training." | |
| ) | |
| parser.add_argument( | |
| "--resolution", | |
| type=int, | |
| default=512, | |
| help=( | |
| "The resolution for input images, all the images in the train/validation dataset will be resized to this" | |
| " resolution" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--classifier_free_scale", | |
| type=float, | |
| default=7.5, # following stable diffusion (https://github.com/CompVis/stable-diffusion) | |
| help="Classifier free scale following https://arxiv.org/abs/2207.12598.", | |
| ) | |
| parser.add_argument( | |
| "--drop_caption", | |
| action="store_true", | |
| help="Whether to drop captions during training following https://arxiv.org/abs/2207.12598.." | |
| ) | |
| parser.add_argument( | |
| "--dataloader_num_workers", | |
| type=int, | |
| default=0, | |
| help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." | |
| ) | |
| parser.add_argument( | |
| "--push_to_hub", | |
| action="store_true", | |
| help="Whether or not to push the model to the Hub." | |
| ) | |
| parser.add_argument( | |
| "--hub_token", | |
| type=str, | |
| default=None, | |
| help="The token to use to push to the Model Hub." | |
| ) | |
| parser.add_argument( | |
| "--hub_model_id", | |
| type=str, | |
| default=None, | |
| help="The name of the repository to keep in sync with the local `output_dir`.", | |
| ) | |
| parser.add_argument( | |
| "--logging_dir", | |
| type=str, | |
| default="logs", | |
| help=( | |
| "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" | |
| " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--mixed_precision", | |
| type=str, | |
| default='fp16', | |
| choices=["no", "fp16", "bf16"], | |
| help=( | |
| "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" | |
| " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" | |
| " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--report_to", | |
| type=str, | |
| default="tensorboard", | |
| help=( | |
| 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' | |
| ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--local_rank", | |
| type=int, | |
| default=-1, | |
| help="For distributed training: local_rank" | |
| ) | |
| parser.add_argument( | |
| "--checkpointing_steps", | |
| type=int, | |
| default=500, | |
| help=( | |
| "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming" | |
| " training using `--resume_from_checkpoint`." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--checkpoints_total_limit", | |
| type=int, | |
| default=5, | |
| help=( | |
| "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`." | |
| " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state" | |
| " for more docs" | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--resume_from_checkpoint", | |
| type=str, | |
| default=None, # should be specified during inference | |
| help=( | |
| "Whether training should be resumed from a previous checkpoint. Use a path saved by" | |
| ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--enable_xformers_memory_efficient_attention", | |
| action="store_true", | |
| help="Whether or not to use xformers." | |
| ) | |
| parser.add_argument( | |
| "--font_path", | |
| type=str, | |
| default='assets/font/Arial.ttf', | |
| help="The path of font for visualization." | |
| ) | |
| parser.add_argument( | |
| "--sample_steps", | |
| type=int, | |
| default=50, # following stable diffusion (https://github.com/CompVis/stable-diffusion) | |
| help="Diffusion steps for sampling." | |
| ) | |
| parser.add_argument( | |
| "--vis_num", | |
| type=int, | |
| default=9, # please decreases the number if out-of-memory error occurs | |
| help="Number of images to be sample. Please decrease it when encountering out of memory error." | |
| ) | |
| parser.add_argument( | |
| "--binarization", | |
| action="store_true", | |
| help="Whether to binarize the template image." | |
| ) | |
| parser.add_argument( | |
| "--use_pillow_segmentation_mask", | |
| type=bool, | |
| default=True, | |
| help="In the γtext-to-imageγ mode, please specify whether to use the segmentation masks provided by PILLOW" | |
| ) | |
| parser.add_argument( | |
| "--character_segmenter_path", | |
| type=str, | |
| default='textdiffuser-ckpt/text_segmenter.pth', | |
| help="checkpoint of character-level segmenter" | |
| ) | |
| args = parser.parse_args() | |
| print(f'{colored("[β]", "green")} Arguments are loaded.') | |
| print(args) | |
| env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | |
| if env_local_rank != -1 and env_local_rank != args.local_rank: | |
| args.local_rank = env_local_rank | |
| return args | |
| def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): | |
| if token is None: | |
| token = HfFolder.get_token() | |
| if organization is None: | |
| username = whoami(token)["name"] | |
| return f"{username}/{model_id}" | |
| else: | |
| return f"{organization}/{model_id}" | |
| # @torchsnooper.snoop() | |
| def main(): | |
| args = parse_args() | |
| # If passed along, set the training seed now. | |
| seed = args.seed if args.seed is not None else random.randint(0, 1000000) | |
| set_seed(seed) | |
| print(f'{colored("[β]", "green")} Seed is set to {seed}.') | |
| logging_dir = os.path.join(args.output_dir, args.logging_dir) | |
| sub_output_dir = f"{args.prompt}_[{args.mode.upper()}]_[SEED-{seed}]" | |
| print(f'{colored("[β]", "green")} Logging dir is set to {logging_dir}.') | |
| accelerator_project_config = ProjectConfiguration(total_limit=args.checkpoints_total_limit) | |
| accelerator = Accelerator( | |
| gradient_accumulation_steps=1, | |
| mixed_precision=args.mixed_precision, | |
| log_with=args.report_to, | |
| logging_dir=logging_dir, | |
| project_config=accelerator_project_config, | |
| ) | |
| # Make one log on every process with the configuration for debugging. | |
| logging.basicConfig( | |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO, | |
| ) | |
| logger.info(accelerator.state, main_process_only=False) | |
| if accelerator.is_local_main_process: | |
| datasets.utils.logging.set_verbosity_warning() | |
| transformers.utils.logging.set_verbosity_warning() | |
| diffusers.utils.logging.set_verbosity_info() | |
| else: | |
| datasets.utils.logging.set_verbosity_error() | |
| transformers.utils.logging.set_verbosity_error() | |
| diffusers.utils.logging.set_verbosity_error() | |
| # Handle the repository creation | |
| if accelerator.is_main_process: | |
| if args.push_to_hub: | |
| if args.hub_model_id is None: | |
| repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) | |
| else: | |
| repo_name = args.hub_model_id | |
| create_repo(repo_name, exist_ok=True, token=args.hub_token) | |
| repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token) | |
| with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: | |
| if "step_*" not in gitignore: | |
| gitignore.write("step_*\n") | |
| if "epoch_*" not in gitignore: | |
| gitignore.write("epoch_*\n") | |
| elif args.output_dir is not None: | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| print(args.output_dir) | |
| # Load scheduler, tokenizer and models. | |
| tokenizer = CLIPTokenizer.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision | |
| ) | |
| text_encoder = CLIPTextModel.from_pretrained( | |
| args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision | |
| ) | |
| vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision).cuda() | |
| unet = UNet2DConditionModel.from_pretrained( | |
| args.resume_from_checkpoint, subfolder="unet", revision=None | |
| ).cuda() | |
| # Freeze vae and text_encoder | |
| vae.requires_grad_(False) | |
| text_encoder.requires_grad_(False) | |
| if args.enable_xformers_memory_efficient_attention: | |
| if is_xformers_available(): | |
| import xformers | |
| xformers_version = version.parse(xformers.__version__) | |
| if xformers_version == version.parse("0.0.16"): | |
| logger.warn( | |
| "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." | |
| ) | |
| unet.enable_xformers_memory_efficient_attention() | |
| else: | |
| raise ValueError("xformers is not available. Make sure it is installed correctly") | |
| # `accelerate` 0.16.0 will have better support for customized saving | |
| if version.parse(accelerate.__version__) >= version.parse("0.16.0"): | |
| # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format | |
| def save_model_hook(models, weights, output_dir): | |
| for i, model in enumerate(models): | |
| model.save_pretrained(os.path.join(output_dir, "unet")) | |
| # make sure to pop weight so that corresponding model is not saved again | |
| weights.pop() | |
| def load_model_hook(models, input_dir): | |
| for i in range(len(models)): | |
| # pop models so that they are not loaded again | |
| model = models.pop() | |
| # load diffusers style into model | |
| load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") | |
| model.register_to_config(**load_model.config) | |
| model.load_state_dict(load_model.state_dict()) | |
| del load_model | |
| accelerator.register_save_state_pre_hook(save_model_hook) | |
| accelerator.register_load_state_pre_hook(load_model_hook) | |
| # setup schedulers | |
| scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") | |
| scheduler.set_timesteps(args.sample_steps) | |
| sample_num = args.vis_num | |
| noise = torch.randn((sample_num, 4, 64, 64)).to("cuda") # (b, 4, 64, 64) | |
| input = noise # (b, 4, 64, 64) | |
| captions = [args.prompt] * sample_num | |
| captions_nocond = [""] * sample_num | |
| print(f'{colored("[β]", "green")} Prompt is loaded: {args.prompt}.') | |
| # encode text prompts | |
| inputs = tokenizer( | |
| captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | |
| ).input_ids # (b, 77) | |
| encoder_hidden_states = text_encoder(inputs)[0].cuda() # (b, 77, 768) | |
| print(f'{colored("[β]", "green")} encoder_hidden_states: {encoder_hidden_states.shape}.') | |
| inputs_nocond = tokenizer( | |
| captions_nocond, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" | |
| ).input_ids # (b, 77) | |
| encoder_hidden_states_nocond = text_encoder(inputs_nocond)[0].cuda() # (b, 77, 768) | |
| print(f'{colored("[β]", "green")} encoder_hidden_states_nocond: {encoder_hidden_states_nocond.shape}.') | |
| # load character-level segmenter | |
| segmenter = UNet(3, 96, True).cuda() | |
| segmenter = torch.nn.DataParallel(segmenter) | |
| segmenter.load_state_dict(torch.load(args.character_segmenter_path)) | |
| segmenter.eval() | |
| print(f'{colored("[β]", "green")} Text segmenter is successfully loaded.') | |
| #### text-to-image #### | |
| if args.mode == 'text-to-image': | |
| render_image, segmentation_mask_from_pillow = get_layout_from_prompt(args) | |
| if args.use_pillow_segmentation_mask: | |
| segmentation_mask = torch.Tensor(np.array(segmentation_mask_from_pillow)).cuda() # (512, 512) | |
| else: | |
| to_tensor = transforms.ToTensor() | |
| image_tensor = to_tensor(render_image).unsqueeze(0).cuda().sub_(0.5).div_(0.5) | |
| with torch.no_grad(): | |
| segmentation_mask = segmenter(image_tensor) | |
| segmentation_mask = segmentation_mask.max(1)[1].squeeze(0) | |
| segmentation_mask = filter_segmentation_mask(segmentation_mask) | |
| segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(256, 256), mode='nearest') | |
| segmentation_mask = segmentation_mask.squeeze(1).repeat(sample_num, 1, 1).long().to('cuda') # (1, 1, 256, 256) | |
| print(f'{colored("[β]", "green")} character-level segmentation_mask: {segmentation_mask.shape}.') | |
| feature_mask = torch.ones(sample_num, 1, 64, 64).to('cuda') # (b, 1, 64, 64) | |
| masked_image = torch.zeros(sample_num, 3, 512, 512).to('cuda') # (b, 3, 512, 512) | |
| masked_feature = vae.encode(masked_image).latent_dist.sample() # (b, 4, 64, 64) | |
| masked_feature = masked_feature * vae.config.scaling_factor | |
| print(f'{colored("[β]", "green")} feature_mask: {feature_mask.shape}.') | |
| print(f'{colored("[β]", "green")} masked_feature: {masked_feature.shape}.') | |
| #### text-to-image-with-template #### | |
| if args.mode == 'text-to-image-with-template': | |
| template_image = Image.open(args.template_image).resize((256,256)).convert('RGB') | |
| # whether binarization is needed | |
| print(f'{colored("[Warning]", "red")} args.binarization is set to {args.binarization}. You may need it when using handwritten images as templates.') | |
| if args.binarization: | |
| gray = ImageOps.grayscale(template_image) | |
| binary = gray.point(lambda x: 255 if x > 96 else 0, '1') | |
| template_image = binary.convert('RGB') | |
| to_tensor = transforms.ToTensor() | |
| image_tensor = to_tensor(template_image).unsqueeze(0).cuda().sub_(0.5).div_(0.5) # (b, 3, 256, 256) | |
| with torch.no_grad(): | |
| segmentation_mask = segmenter(image_tensor) # (b, 96, 256, 256) | |
| segmentation_mask = segmentation_mask.max(1)[1].squeeze(0) # (256, 256) | |
| segmentation_mask = filter_segmentation_mask(segmentation_mask) # (256, 256) | |
| segmentation_mask_pil = Image.fromarray(segmentation_mask.type(torch.uint8).cpu().numpy()).convert('RGB') | |
| segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(256, 256), mode='nearest') # (b, 1, 256, 256) | |
| segmentation_mask = segmentation_mask.squeeze(1).repeat(sample_num, 1, 1).long().to('cuda') # (b, 1, 256, 256) | |
| print(f'{colored("[β]", "green")} Character-level segmentation_mask: {segmentation_mask.shape}.') | |
| feature_mask = torch.ones(sample_num, 1, 64, 64).to('cuda') # (b, 1, 64, 64) | |
| masked_image = torch.zeros(sample_num, 3, 512, 512).to('cuda') # (b, 3, 512, 512) | |
| masked_feature = vae.encode(masked_image).latent_dist.sample() # (b, 4, 64, 64) | |
| masked_feature = masked_feature * vae.config.scaling_factor # (b, 4, 64, 64) | |
| print(f'{colored("[β]", "green")} feature_mask: {feature_mask.shape}.') | |
| print(f'{colored("[β]", "green")} masked_feature: {masked_feature.shape}.') | |
| render_image = template_image # for visualization | |
| #### text-inpainting #### | |
| if args.mode == 'text-inpainting': | |
| text_mask = cv2.imread(args.text_mask) | |
| threshold = 128 | |
| _, text_mask = cv2.threshold(text_mask, threshold, 255, cv2.THRESH_BINARY) | |
| text_mask = Image.fromarray(text_mask).convert('RGB').resize((256,256)) | |
| text_mask_tensor = transforms.ToTensor()(text_mask).unsqueeze(0).cuda().sub_(0.5).div_(0.5) | |
| with torch.no_grad(): | |
| segmentation_mask = segmenter(text_mask_tensor) | |
| segmentation_mask = segmentation_mask.max(1)[1].squeeze(0) | |
| segmentation_mask = filter_segmentation_mask(segmentation_mask) | |
| segmentation_mask = torch.nn.functional.interpolate(segmentation_mask.unsqueeze(0).unsqueeze(0).float(), size=(256, 256), mode='nearest') | |
| image_mask = transform_mask(args.text_mask) | |
| image_mask = torch.from_numpy(image_mask).cuda().unsqueeze(0).unsqueeze(0) | |
| image = Image.open(args.original_image).convert('RGB').resize((512,512)) | |
| image_tensor = transforms.ToTensor()(image).unsqueeze(0).cuda().sub_(0.5).div_(0.5) | |
| masked_image = image_tensor * (1-image_mask) | |
| masked_feature = vae.encode(masked_image).latent_dist.sample().repeat(sample_num, 1, 1, 1) | |
| masked_feature = masked_feature * vae.config.scaling_factor | |
| image_mask = torch.nn.functional.interpolate(image_mask, size=(256, 256), mode='nearest').repeat(sample_num, 1, 1, 1) | |
| segmentation_mask = segmentation_mask * image_mask | |
| feature_mask = torch.nn.functional.interpolate(image_mask, size=(64, 64), mode='nearest') | |
| print(f'{colored("[β]", "green")} feature_mask: {feature_mask.shape}.') | |
| print(f'{colored("[β]", "green")} segmentation_mask: {segmentation_mask.shape}.') | |
| print(f'{colored("[β]", "green")} masked_feature: {masked_feature.shape}.') | |
| render_image = Image.open(args.original_image) | |
| # diffusion process | |
| intermediate_images = [] | |
| for t in tqdm(scheduler.timesteps): | |
| with torch.no_grad(): | |
| noise_pred_cond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states, segmentation_mask=segmentation_mask, feature_mask=feature_mask, masked_feature=masked_feature).sample # b, 4, 64, 64 | |
| noise_pred_uncond = unet(sample=input, timestep=t, encoder_hidden_states=encoder_hidden_states_nocond, segmentation_mask=segmentation_mask, feature_mask=feature_mask, masked_feature=masked_feature).sample # b, 4, 64, 64 | |
| noisy_residual = noise_pred_uncond + args.classifier_free_scale * (noise_pred_cond - noise_pred_uncond) # b, 4, 64, 64 | |
| prev_noisy_sample = scheduler.step(noisy_residual, t, input).prev_sample | |
| input = prev_noisy_sample | |
| intermediate_images.append(prev_noisy_sample) | |
| # decode and visualization | |
| input = 1 / vae.config.scaling_factor * input | |
| sample_images = vae.decode(input.float(), return_dict=False)[0] # (b, 3, 512, 512) | |
| image_pil = render_image.resize((512,512)) | |
| segmentation_mask = segmentation_mask[0].squeeze().cpu().numpy() | |
| character_mask_pil = Image.fromarray(((segmentation_mask!=0)*255).astype('uint8')).resize((512,512)) | |
| character_mask_highlight_pil = segmentation_mask_visualization(args.font_path,segmentation_mask) | |
| caption_pil = make_caption_pil(args.font_path, captions) | |
| # save pred_img | |
| pred_image_list = [] | |
| for image in sample_images.float(): | |
| image = (image / 2 + 0.5).clamp(0, 1).unsqueeze(0) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy()[0] | |
| image = Image.fromarray((image * 255).round().astype("uint8")).convert('RGB') | |
| pred_image_list.append(image) | |
| os.makedirs(f'{args.output_dir}/{sub_output_dir}', exist_ok=True) | |
| # save additional info | |
| if args.mode == 'text-to-image': | |
| image_pil.save(os.path.join(args.output_dir, sub_output_dir, 'render_text_image.png')) | |
| enhancer = ImageEnhance.Brightness(segmentation_mask_from_pillow) | |
| im_brightness = enhancer.enhance(5) | |
| im_brightness.save(os.path.join(args.output_dir, sub_output_dir, 'segmentation_mask_from_pillow.png')) | |
| if args.mode == 'text-to-image-with-template': | |
| template_image.save(os.path.join(args.output_dir, sub_output_dir, 'template.png')) | |
| enhancer = ImageEnhance.Brightness(segmentation_mask_pil) | |
| im_brightness = enhancer.enhance(5) | |
| im_brightness.save(os.path.join(args.output_dir, sub_output_dir, 'segmentation_mask_from_template.png')) | |
| if args.mode == 'text-inpainting': | |
| character_mask_highlight_pil = character_mask_pil | |
| # background | |
| background = Image.open(args.original_image).resize((512, 512)) | |
| alpha = Image.new('L', background.size, int(255 * 0.2)) | |
| background.putalpha(alpha) | |
| # foreground | |
| foreground = Image.open(args.text_mask).convert('L').resize((512, 512)) | |
| threshold = 200 | |
| alpha = foreground.point(lambda x: 0 if x > threshold else 255, '1') | |
| foreground.putalpha(alpha) | |
| character_mask_pil = Image.alpha_composite(foreground.convert('RGBA'), background.convert('RGBA')).convert('RGB') | |
| # merge | |
| pred_image_list_new = [] | |
| for pred_image in pred_image_list: | |
| pred_image = inpainting_merge_image(Image.open(args.original_image), Image.open(args.text_mask).convert('L'), pred_image) | |
| pred_image_list_new.append(pred_image) | |
| pred_image_list = pred_image_list_new | |
| combine_image(args, sub_output_dir, pred_image_list, image_pil, character_mask_pil, character_mask_highlight_pil, caption_pil) | |
| # create a soft link | |
| if os.path.exists(os.path.join(args.output_dir, 'latest')): | |
| os.unlink(os.path.join(args.output_dir, 'latest')) | |
| os.symlink(os.path.abspath(os.path.join(args.output_dir, sub_output_dir)), os.path.abspath(os.path.join(args.output_dir, 'latest/'))) | |
| color_sub_output_dir = colored(sub_output_dir, 'green') | |
| print(f'{colored("[β]", "green")} Save successfully. Please check the output at {color_sub_output_dir} OR the latest folder') | |
| if __name__ == "__main__": | |
| main() |