File size: 5,868 Bytes
04c78c7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import numpy as np
import torch
import torch.nn as nn
import os
from diffusers import StableDiffusionInstructPix2PixPipeline
from PIL import Image
import random
SEED=42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
generator=torch.Generator("cuda" if torch.cuda.is_available() else "cpu").manual_seed(SEED)
class FabricDiffusionPipeline():
def __init__(self, device, texture_checkpoint, print_checkpoint):
self.device = device
self.texture_checkpoint = texture_checkpoint
self.print_base_model = print_checkpoint
if texture_checkpoint:
self.texture_model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
texture_checkpoint,
torch_dtype=torch.float16,
safety_checker=None
)
# with open(os.path.join(texture_checkpoint, "unet", "diffusion_pytorch_model.safetensors"), "rb") as f:
# data = f.read()
# loaded = load(data)
# self.texture_pipeline.unet.load_state_dict(loaded)
self.texture_model = self.texture_model.to(device)
else:
self.texture_model = None
# set circular convolution for the texture model
if self.texture_model:
for a, b in self.texture_model.unet.named_modules():
if isinstance(b, nn.Conv2d):
setattr(b, 'padding_mode', 'circular')
for a, b in self.texture_model.vae.named_modules():
if isinstance(b, nn.Conv2d):
setattr(b, 'padding_mode', 'circular')
if print_checkpoint:
self.print_model = StableDiffusionInstructPix2PixPipeline.from_pretrained(
print_checkpoint,
torch_dtype=torch.float16,
safety_checker=None
)
self.print_model = self.print_model.to(device)
else:
self.print_model = None
def load_real_data_with_mask(self, dataset_path, image_name):
image = np.array(Image.open(os.path.join(dataset_path, 'images', image_name)).convert('RGB'))
seg_mask = np.array(Image.open(os.path.join(dataset_path, 'seg_mask', image_name)).convert('L'))[..., None]
texture_mask = np.array(Image.open(os.path.join(dataset_path, 'texture_mask', image_name)).convert('L'))[
..., None]
# crop the image based on texture_mask
x1, y1, x2, y2 = np.where(texture_mask > 0)[1].min(), np.where(texture_mask > 0)[0].min(), \
np.where(texture_mask > 0)[1].max(), np.where(texture_mask > 0)[0].max()
texture_patch = image[y1:y2, x1:x2]
# resize the texture_patch to 256x256
texture_patch = Image.fromarray(texture_patch.astype(np.uint8)).resize((256, 256))
return image, seg_mask, texture_patch
def load_patch_data(self, patch_path):
texture_patch = Image.open(patch_path).convert('RGB').resize((256, 256))
return texture_patch
def flatten_texture(self, texture_patch, n_samples=3, use_inversion=True):
num_inference_steps = 20
self.texture_model.scheduler.set_timesteps(num_inference_steps)
timesteps = self.texture_model.scheduler.timesteps
# convert image to latent using vae
image = self.texture_model.image_processor.preprocess(texture_patch)
if use_inversion:
image_latents = self.texture_model.prepare_image_latents(image, batch_size=1,
num_images_per_prompt=1,
device=self.device,
dtype=torch.float16,
do_classifier_free_guidance=False)
image_latents = (image_latents - torch.mean(image_latents)) / torch.std(image_latents)
# forward noising process
noise = torch.randn_like(image_latents)
noisy_image_latents = self.texture_model.scheduler.add_noise(image_latents, noise, timesteps[0:1])
noisy_image_latents /= self.texture_model.scheduler.init_noise_sigma
noisy_image_latents = torch.tile(noisy_image_latents, (n_samples, 1, 1, 1))
else:
noisy_image_latents = None
image = torch.tile(image, (n_samples, 1, 1, 1))
gen_imgs = self.texture_model(
"",
image=image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7.,
latents=noisy_image_latents,
num_images_per_prompt=n_samples,
generator=generator
).images
return gen_imgs
def flatten_print(self, print_patch, n_samples=3):
image = self.print_model.image_processor.preprocess(print_patch)
gen_imgs = []
for i in range(n_samples):
gen_img = self.print_model(
"",
image=image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7.,
generator=generator
).images[0]
gen_img = np.asarray(gen_img) / 255.
alpha_map = np.clip(gen_img / 0.1 * 1.2 - 0.2, 0., 1).mean(axis=-1, keepdims=True)
gen_img = np.clip((gen_img - 0.1) / 0.9, 0., 1.)
gen_img = np.concatenate([gen_img, alpha_map], axis=-1)
gen_img = (gen_img * 255).astype(np.uint8)
gen_img = Image.fromarray(gen_img)
gen_imgs.append(gen_img)
return gen_imgs
|