| from diffusers.schedulers import UniPCMultistepScheduler |
| from diffusers import AutoencoderKL |
| from diffusion_module.unet import UNetModel |
| import torch |
| from diffusion_module.utils.LSDMPipeline_expandDataset import SDMLDMPipeline |
| from accelerate import Accelerator |
| from evolution import random_walk |
| import cv2 |
| import numpy as np |
|
|
| def mask2onehot(data, num_classes): |
| |
| data = data.to(dtype=torch.int64) |
|
|
| |
| label_map = data |
| bs, _, h, w = label_map.size() |
| input_label = torch.FloatTensor(bs, num_classes, h, w).zero_().to(data.device) |
| input_semantics = input_label.scatter_(1, label_map, 1.0) |
|
|
| return input_semantics |
|
|
| def generate(img, pretrain_weight,seed=None): |
|
|
| noise_scheduler = UniPCMultistepScheduler() |
| vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae") |
| latent_size = (64, 64) |
| unet = UNetModel( |
| image_size = latent_size, |
| in_channels=vae.config.latent_channels, |
| model_channels=256, |
| out_channels=vae.config.latent_channels, |
| num_res_blocks=2, |
| attention_resolutions=(2, 4, 8), |
| dropout=0, |
| channel_mult=(1, 2, 3, 4), |
| num_heads=8, |
| num_head_channels=-1, |
| num_heads_upsample=-1, |
| use_scale_shift_norm=True, |
| resblock_updown=True, |
| use_new_attention_order=False, |
| num_classes=151, |
| mask_emb="resize", |
| use_checkpoint=True, |
| SPADE_type="spade", |
| ) |
| |
|
|
| unet = unet.from_pretrained(pretrain_weight) |
| device = 'cpu' |
| if device != 'cpu': |
| mixed_precision = "fp16" |
| else: |
| mixed_precision = "no" |
| |
| |
| accelerator = Accelerator( |
| mixed_precision=mixed_precision, |
| cpu= True if device is 'cpu' else False |
| ) |
|
|
| weight_dtype = torch.float32 |
| if accelerator.mixed_precision == "fp16": |
| weight_dtype = torch.float16 |
| |
| unet,vae = accelerator.prepare(unet, vae) |
| vae.to(device=accelerator.device, dtype=weight_dtype) |
| pipeline = SDMLDMPipeline( |
| vae=accelerator.unwrap_model(vae), |
| unet=accelerator.unwrap_model(unet), |
| scheduler=noise_scheduler, |
| torch_dtype=weight_dtype, |
| resolution_type="crack" |
| ) |
| """ |
| if accelerator.device != 'cpu': |
| pipeline.enable_xformers_memory_efficient_attention() |
| """ |
| pipeline = pipeline.to(accelerator.device) |
| pipeline.set_progress_bar_config(disable=False) |
|
|
| if seed is None: |
| generator = None |
| else: |
| generator = torch.Generator(device=accelerator.device).manual_seed(seed) |
| |
| resized_s = cv2.resize(img, (64, 64), interpolation=cv2.INTER_AREA) |
| |
| _, binary_s = cv2.threshold(resized_s, 1, 255, cv2.THRESH_BINARY) |
| |
| tensor_s = torch.from_numpy(binary_s / 255) |
| |
| tensor_s = tensor_s.unsqueeze(0).unsqueeze(0) |
| onehot_skeletons=[] |
| onehot_s = mask2onehot(tensor_s, 151) |
| onehot_skeletons.append(onehot_s) |
|
|
| onehot_skeletons = torch.stack(onehot_skeletons, dim=1).squeeze(0) |
| onehot_skeletons = onehot_skeletons.to(dtype=weight_dtype,device=accelerator.device) |
|
|
| images = pipeline(onehot_skeletons, generator=generator,batch_size = 1, |
| num_inference_steps=20, s=1.5, |
| num_evolution_per_mask=1).images |
| |
| return images |