| import torch |
| import os |
| from PIL import Image |
| import numpy as np |
| from diffusers.schedulers import DDIMScheduler, UniPCMultistepScheduler |
| from diffusion_module.utils.Pipline import SDMLDMPipeline |
|
|
| def log_validation(vae, unet, noise_scheduler, accelerator, weight_dtype, data_ld, |
| resolution=512,g_step=2,save_dir="cityspace_test"): |
| scheduler = UniPCMultistepScheduler.from_config(noise_scheduler.config) |
| pipeline = SDMLDMPipeline( |
| vae=accelerator.unwrap_model(vae), |
| unet=accelerator.unwrap_model(unet), |
| scheduler=scheduler, |
| torch_dtype=weight_dtype, |
| resolution = resolution, |
| resolution_type="crack" |
| ) |
|
|
| pipeline = pipeline.to(accelerator.device) |
| pipeline.set_progress_bar_config(disable=False) |
| pipeline.enable_xformers_memory_efficient_attention() |
|
|
| generator = None |
| for i ,batch in enumerate(data_ld): |
| if i > 2: |
| break |
| images = [] |
| with torch.autocast("cuda"): |
| segmap = preprocess_input(batch[1]['label'], num_classes=151) |
| segmap = segmap.to("cuda").to(torch.float16) |
| |
| |
| |
| image = pipeline(segmap=segmap[0][None,:], generator=generator,batch_size = 1, |
| num_inference_steps=50, s=1.5).images |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| images.extend(image) |
| merge_images(images, i,accelerator,g_step) |
| del pipeline |
| torch.cuda.empty_cache() |
| |
|
|
| def merge_images(images, val_step,accelerator,step): |
| for k, image in enumerate(images): |
| """ |
| if k == 0: |
| filename = "{}_condition.png".format(val_step) |
| else: |
| filename = "{}_{}.png".format(val_step, k) |
| """ |
| filename = "{}_{}.png".format(val_step, k) |
| |
| path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "singles", filename) |
| os.makedirs(os.path.split(path)[0], exist_ok=True) |
| |
| image.save(path) |
|
|
| |
| total_width = sum(img.width for img in images) |
| max_height = max(img.height for img in images) |
| combined_image = Image.new('RGB', (total_width, max_height)) |
|
|
| |
| x_offset = 0 |
| for img in images: |
| |
| if img.mode != 'RGB': |
| img = img.convert('RGB') |
| combined_image.paste(img, (x_offset, 0)) |
| x_offset += img.width |
|
|
| |
| merge_filename = "{}_merge.png".format(val_step) |
| merge_path = os.path.join(accelerator.logging_dir, "step_{}".format(step), "merges", merge_filename) |
| os.makedirs(os.path.split(merge_path)[0], exist_ok=True) |
| combined_image.save(merge_path) |
| |
| def preprocess_input(data, num_classes): |
| |
| data = data.to(dtype=torch.int64) |
|
|
| |
| label_map = data |
| bs, _, h, w = label_map.size() |
| input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device) |
| input_semantics = input_label.scatter_(1, label_map, 1.0) |
|
|
| return input_semantics |
|
|
|
|