import numpy as np import torch import torch.nn as nn import os from diffusers import StableDiffusionInstructPix2PixPipeline from PIL import Image import random SEED=42 torch.manual_seed(SEED) torch.cuda.manual_seed_all(SEED) np.random.seed(SEED) random.seed(SEED) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False generator=torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(SEED) class FabricDiffusionPipeline(): def __init__(self, device, texture_checkpoint, print_checkpoint): self.device = device self.texture_checkpoint = texture_checkpoint self.print_base_model = print_checkpoint if texture_checkpoint: self.texture_model = StableDiffusionInstructPix2PixPipeline.from_pretrained( texture_checkpoint, torch_dtype=torch.float16, safety_checker=None ) # with open(os.path.join(texture_checkpoint, "unet", "diffusion_pytorch_model.safetensors"), "rb") as f: # data = f.read() # loaded = load(data) # self.texture_pipeline.unet.load_state_dict(loaded) self.texture_model = self.texture_model.to(device) else: self.texture_model = None # set circular convolution for the texture model if self.texture_model: for a, b in self.texture_model.unet.named_modules(): if isinstance(b, nn.Conv2d): setattr(b, 'padding_mode', 'circular') for a, b in self.texture_model.vae.named_modules(): if isinstance(b, nn.Conv2d): setattr(b, 'padding_mode', 'circular') if print_checkpoint: self.print_model = StableDiffusionInstructPix2PixPipeline.from_pretrained( print_checkpoint, torch_dtype=torch.float16, safety_checker=None ) self.print_model = self.print_model.to(device) else: self.print_model = None def load_real_data_with_mask(self, dataset_path, image_name): image = np.array(Image.open(os.path.join(dataset_path, 'images', image_name)).convert('RGB')) seg_mask = np.array(Image.open(os.path.join(dataset_path, 'seg_mask', image_name)).convert('L'))[..., None] texture_mask = np.array(Image.open(os.path.join(dataset_path, 'texture_mask', image_name)).convert('L'))[ ..., None] # crop the image based on texture_mask x1, y1, x2, y2 = np.where(texture_mask > 0)[1].min(), np.where(texture_mask > 0)[0].min(), \ np.where(texture_mask > 0)[1].max(), np.where(texture_mask > 0)[0].max() texture_patch = image[y1:y2, x1:x2] # resize the texture_patch to 256x256 texture_patch = Image.fromarray(texture_patch.astype(np.uint8)).resize((256, 256)) return image, seg_mask, texture_patch def load_patch_data(self, patch_path): texture_patch = Image.open(patch_path).convert('RGB').resize((256, 256)) return texture_patch def flatten_texture(self, texture_patch, n_samples=3, use_inversion=True): num_inference_steps = 20 self.texture_model.scheduler.set_timesteps(num_inference_steps) timesteps = self.texture_model.scheduler.timesteps # convert image to latent using vae image = self.texture_model.image_processor.preprocess(texture_patch) if use_inversion: image_latents = self.texture_model.prepare_image_latents(image, batch_size=1, num_images_per_prompt=1, device=self.device, dtype=torch.float16, do_classifier_free_guidance=False) image_latents = (image_latents - torch.mean(image_latents)) / torch.std(image_latents) # forward noising process noise = torch.randn_like(image_latents) noisy_image_latents = self.texture_model.scheduler.add_noise(image_latents, noise, timesteps[0:1]) noisy_image_latents /= self.texture_model.scheduler.init_noise_sigma noisy_image_latents = torch.tile(noisy_image_latents, (n_samples, 1, 1, 1)) else: noisy_image_latents = None image = torch.tile(image, (n_samples, 1, 1, 1)) gen_imgs = self.texture_model( "", image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7., latents=noisy_image_latents, num_images_per_prompt=n_samples, generator=generator ).images return gen_imgs def flatten_print(self, print_patch, n_samples=3): image = self.print_model.image_processor.preprocess(print_patch) gen_imgs = [] for i in range(n_samples): gen_img = self.print_model( "", image=image, num_inference_steps=20, image_guidance_scale=1.5, guidance_scale=7., generator=generator ).images[0] gen_img = np.asarray(gen_img) / 255. alpha_map = np.clip(gen_img / 0.1 * 1.2 - 0.2, 0., 1).mean(axis=-1, keepdims=True) gen_img = np.clip((gen_img - 0.1) / 0.9, 0., 1.) gen_img = np.concatenate([gen_img, alpha_map], axis=-1) gen_img = (gen_img * 255).astype(np.uint8) gen_img = Image.fromarray(gen_img) gen_imgs.append(gen_img) return gen_imgs