Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| import os | |
| import sys | |
| import json | |
| import pathlib | |
| import argparse | |
| import warnings | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| from util import Map | |
| from rich.pretty import install as pretty_install | |
| from rich.traceback import install as traceback_install | |
| from rich.console import Console | |
| console = Console(log_time=True, log_time_format='%H:%M:%S-%f') | |
| pretty_install(console=console) | |
| traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False) | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'modules', 'lora')) | |
| import library.model_util as model_util | |
| import library.train_util as train_util | |
| warnings.filterwarnings('ignore') | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| options = Map({ | |
| 'batch': 1, | |
| 'input': '', | |
| 'json': '', | |
| 'max': 1024, | |
| 'min': 256, | |
| 'noupscale': False, | |
| 'precision': 'fp32', | |
| 'resolution': '512,512', | |
| 'steps': 64, | |
| 'vae': 'stabilityai/sd-vae-ft-mse' | |
| }) | |
| vae = None | |
| def get_latents(local_vae, images, weight_dtype): | |
| image_transforms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) | |
| img_tensors = [image_transforms(image) for image in images] | |
| img_tensors = torch.stack(img_tensors) | |
| img_tensors = img_tensors.to(device, weight_dtype) | |
| with torch.no_grad(): | |
| latents = local_vae.encode(img_tensors).latent_dist.sample().float().to('cpu').numpy() | |
| return latents, [images[0].shape[0], images[0].shape[1]] | |
| def get_npz_filename_wo_ext(data_dir, image_key): | |
| return os.path.join(data_dir, os.path.splitext(os.path.basename(image_key))[0]) | |
| def create_vae_latents(local_params): | |
| args = Map({**options, **local_params}) | |
| console.log(f'create vae latents args: {args}') | |
| image_paths = train_util.glob_images(args.input) | |
| if os.path.exists(args.json): | |
| with open(args.json, 'rt', encoding='utf-8') as f: | |
| metadata = json.load(f) | |
| else: | |
| return | |
| if args.precision == 'fp16': | |
| weight_dtype = torch.float16 | |
| elif args.precision == 'bf16': | |
| weight_dtype = torch.bfloat16 | |
| else: | |
| weight_dtype = torch.float32 | |
| global vae # pylint: disable=global-statement | |
| if vae is None: | |
| vae = model_util.load_vae(args.vae, weight_dtype) | |
| vae.eval() | |
| vae.to(device, dtype=weight_dtype) | |
| max_reso = tuple([int(t) for t in args.resolution.split(',')]) | |
| assert len(max_reso) == 2, f'illegal resolution: {args.resolution}' | |
| bucket_manager = train_util.BucketManager(args.noupscale, max_reso, args.min, args.max, args.steps) | |
| if not args.noupscale: | |
| bucket_manager.make_buckets() | |
| img_ar_errors = [] | |
| def process_batch(is_last): | |
| for bucket in bucket_manager.buckets: | |
| if (is_last and len(bucket) > 0) or len(bucket) >= args.batch: | |
| latents, original_size = get_latents(vae, [img for _, img in bucket], weight_dtype) | |
| assert latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8, f'latent shape {latents.shape}, {bucket[0][1].shape}' | |
| for (image_key, _), latent in zip(bucket, latents): | |
| npz_file_name = get_npz_filename_wo_ext(args.input, image_key) | |
| # np.savez(npz_file_name, latent) | |
| kwargs = {} | |
| np.savez( | |
| npz_file_name, | |
| latents=latent, | |
| original_size=np.array(original_size), | |
| crop_ltrb=np.array([0, 0]), | |
| **kwargs, | |
| ) | |
| bucket.clear() | |
| data = [[(None, ip)] for ip in image_paths] | |
| bucket_counts = {} | |
| for data_entry in tqdm(data, smoothing=0.0): | |
| if data_entry[0] is None: | |
| continue | |
| img_tensor, image_path = data_entry[0] | |
| if img_tensor is not None: | |
| image = transforms.functional.to_pil_image(img_tensor) | |
| else: | |
| image = Image.open(image_path) | |
| image_key = os.path.basename(image_path) | |
| image_key = os.path.join(os.path.basename(pathlib.Path(image_path).parent), pathlib.Path(image_path).stem) | |
| if image_key not in metadata: | |
| metadata[image_key] = {} | |
| reso, resized_size, ar_error = bucket_manager.select_bucket(image.width, image.height) | |
| img_ar_errors.append(abs(ar_error)) | |
| bucket_counts[reso] = bucket_counts.get(reso, 0) + 1 | |
| metadata[image_key]['train_resolution'] = (reso[0] - reso[0] % 8, reso[1] - reso[1] % 8) | |
| if not args.noupscale: | |
| assert resized_size[0] == reso[0] or resized_size[1] == reso[1], f'internal error, resized size not match: {reso}, {resized_size}, {image.width}, {image.height}' | |
| assert resized_size[0] >= reso[0] and resized_size[1] >= reso[1], f'internal error, resized size too small: {reso}, {resized_size}, {image.width}, {image.height}' | |
| assert resized_size[0] >= reso[0] and resized_size[1] >= reso[1], f'internal error resized size is small: {resized_size}, {reso}' | |
| image = np.array(image) | |
| if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: | |
| image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) | |
| if resized_size[0] > reso[0]: | |
| trim_size = resized_size[0] - reso[0] | |
| image = image[:, trim_size//2:trim_size//2 + reso[0]] | |
| if resized_size[1] > reso[1]: | |
| trim_size = resized_size[1] - reso[1] | |
| image = image[trim_size//2:trim_size//2 + reso[1]] | |
| assert image.shape[0] == reso[1] and image.shape[1] == reso[0], f'internal error, illegal trimmed size: {image.shape}, {reso}' | |
| bucket_manager.add_image(reso, (image_key, image)) | |
| process_batch(False) | |
| process_batch(True) | |
| vae.to('cpu') | |
| bucket_manager.sort() | |
| img_ar_errors = np.array(img_ar_errors) | |
| for i, reso in enumerate(bucket_manager.resos): | |
| count = bucket_counts.get(reso, 0) | |
| if count > 0: | |
| console.log(f'vae latents bucket: {i+1}/{len(bucket_manager.resos)} resolution: {reso} images: {count} mean-ar-error: {np.mean(img_ar_errors)}') | |
| with open(args.json, 'wt', encoding='utf-8') as f: | |
| json.dump(metadata, f, indent=2) | |
| def unload_vae(): | |
| global vae # pylint: disable=global-statement | |
| vae = None | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('input', type=str, help='directory for train images') | |
| parser.add_argument('--json', type=str, required=True, help='metadata file to input') | |
| parser.add_argument('--vae', type=str, required=True, help='model name or path to encode latents') | |
| parser.add_argument('--batch', type=int, default=1, help='batch size in inference') | |
| parser.add_argument('--resolution', type=str, default='512,512', help='max resolution in fine tuning (width,height)') | |
| parser.add_argument('--min', type=int, default=256, help='minimum resolution for buckets') | |
| parser.add_argument('--max', type=int, default=1024, help='maximum resolution for buckets') | |
| parser.add_argument('--steps', type=int, default=64, help='steps of resolution for buckets, divisible by 8') | |
| parser.add_argument('--noupscale', action='store_true', help='make bucket for each image without upscaling') | |
| parser.add_argument('--precision', type=str, default='fp32', choices=['fp32', 'fp16', 'bf16'], help='use precision') | |
| params = parser.parse_args() | |
| create_vae_latents(vars(params)) | |