Use SGD and text encoder/tokenizer
Browse files- rct_diffusion_pipeline.py +70 -7
- test_pipeline.py +21 -3
- train_model.py +82 -46
rct_diffusion_pipeline.py
CHANGED
|
@@ -12,7 +12,7 @@ import pandas as pd
|
|
| 12 |
from tqdm.auto import tqdm
|
| 13 |
|
| 14 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
| 15 |
-
def __init__(self, unet, scheduler, vae, latent_size=32, sample_size=256):
|
| 16 |
super().__init__()
|
| 17 |
|
| 18 |
# dictionnary that keeps the different classes of object description, color1, color2 and color3
|
|
@@ -26,11 +26,13 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
| 26 |
self.vae = vae
|
| 27 |
self.latent_size = latent_size
|
| 28 |
self.sample_size = sample_size
|
|
|
|
|
|
|
| 29 |
|
| 30 |
# channels for 1 image
|
| 31 |
self.num_channels = int(self.unet.config.in_channels / 4)
|
| 32 |
self.load_dictionaries_from_dataset()
|
| 33 |
-
self.register_modules(unet=unet, scheduler=scheduler, vae=vae)
|
| 34 |
|
| 35 |
def load_dictionaries_from_dataset(self):
|
| 36 |
dataset = load_dataset('frutiemax/rct_dataset')
|
|
@@ -177,13 +179,72 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
| 177 |
|
| 178 |
return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
| 179 |
|
| 180 |
-
def
|
| 181 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
|
| 183 |
|
| 184 |
-
|
| 185 |
-
if
|
| 186 |
return None
|
|
|
|
|
|
|
| 187 |
|
| 188 |
# set the inference steps
|
| 189 |
self.scheduler.set_timesteps(num_inference_steps)
|
|
@@ -196,8 +257,9 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
| 196 |
progress_bar.set_description(f'Inference step {epoch}')
|
| 197 |
|
| 198 |
for batch_index in range(batch_size):
|
|
|
|
| 199 |
with torch.no_grad():
|
| 200 |
-
noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=
|
| 201 |
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample
|
| 202 |
noise_batches[batch_index] = previous_noisy_sample
|
| 203 |
progress_bar.update(1)
|
|
@@ -223,6 +285,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
| 223 |
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
| 224 |
image = (image * 255).round().astype("uint8")
|
| 225 |
image = Image.fromarray(image)
|
|
|
|
| 226 |
output_images.append(image)
|
| 227 |
|
| 228 |
# for now just return the images
|
|
|
|
| 12 |
from tqdm.auto import tqdm
|
| 13 |
|
| 14 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
| 15 |
+
def __init__(self, unet, scheduler, vae, text_tokenizer, text_encoder, latent_size=32, sample_size=256):
|
| 16 |
super().__init__()
|
| 17 |
|
| 18 |
# dictionnary that keeps the different classes of object description, color1, color2 and color3
|
|
|
|
| 26 |
self.vae = vae
|
| 27 |
self.latent_size = latent_size
|
| 28 |
self.sample_size = sample_size
|
| 29 |
+
self.text_encoder = text_encoder
|
| 30 |
+
self.text_tokenizer = text_tokenizer
|
| 31 |
|
| 32 |
# channels for 1 image
|
| 33 |
self.num_channels = int(self.unet.config.in_channels / 4)
|
| 34 |
self.load_dictionaries_from_dataset()
|
| 35 |
+
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, text_tokenizer=text_tokenizer, text_encoder=text_encoder)
|
| 36 |
|
| 37 |
def load_dictionaries_from_dataset(self):
|
| 38 |
dataset = load_dataset('frutiemax/rct_dataset')
|
|
|
|
| 179 |
|
| 180 |
return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
| 181 |
|
| 182 |
+
def test_generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
|
| 183 |
+
batch_size = len(object_description)
|
| 184 |
+
|
| 185 |
+
embeddings = torch.Tensor(size=(batch_size, 77, 768))
|
| 186 |
+
for batch_index in range(batch_size):
|
| 187 |
+
prompt = f'{object_description[batch_index]},{color1[batch_index]},{color2[batch_index]}, {color3[batch_index]}'
|
| 188 |
+
tokens = self.text_tokenizer(prompt, \
|
| 189 |
+
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
embeddings[batch_index] = self.text_encoder(tokens.input_ids.to('cuda'))[0]
|
| 192 |
+
|
| 193 |
+
return embeddings.to(dtype=torch.float16)
|
| 194 |
+
|
| 195 |
+
def generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
|
| 196 |
+
batch_size = len(object_description)
|
| 197 |
+
|
| 198 |
+
embeddings = torch.Tensor(size=(batch_size, 77, 768 * 4))
|
| 199 |
+
for batch_index in range(batch_size):
|
| 200 |
+
object_description_tokens = self.text_tokenizer(object_description[batch_index], \
|
| 201 |
+
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 202 |
+
color1_tokens = self.text_tokenizer(color1[batch_index], \
|
| 203 |
+
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 204 |
+
color2_tokens = self.text_tokenizer(color2[batch_index], \
|
| 205 |
+
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 206 |
+
color3_tokens = self.text_tokenizer(color3[batch_index], \
|
| 207 |
+
padding="max_length", max_length=self.text_tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 208 |
+
with torch.no_grad():
|
| 209 |
+
object_description_embeddings = self.text_encoder(object_description_tokens.input_ids.to('cuda'))[0]
|
| 210 |
+
color1_embeddings = self.text_encoder(color1_tokens.input_ids.to('cuda'))[0]
|
| 211 |
+
color2_embeddings = self.text_encoder(color2_tokens.input_ids.to('cuda'))[0]
|
| 212 |
+
color3_embeddings = self.text_encoder(color3_tokens.input_ids.to('cuda'))[0]
|
| 213 |
+
|
| 214 |
+
emb = torch.cat([object_description_embeddings, color1_embeddings, color2_embeddings, color3_embeddings], dim=2)
|
| 215 |
+
embeddings[batch_index] = emb
|
| 216 |
+
|
| 217 |
+
return embeddings.to(dtype=torch.float16)
|
| 218 |
+
|
| 219 |
+
def validate_inputs(self, object_description : list[str], color1 : list[str], \
|
| 220 |
+
color2 : list[str], color3 : list[str], batch_size) -> tuple[bool, list[str], list[str], list[str], list[str]]:
|
| 221 |
+
# check if the labels sizes are correct
|
| 222 |
+
if len(object_description) != batch_size:
|
| 223 |
+
return False
|
| 224 |
+
|
| 225 |
+
if len(color1) != batch_size:
|
| 226 |
+
return False
|
| 227 |
+
|
| 228 |
+
if color2 == None:
|
| 229 |
+
color2 = ['none'] * batch_size
|
| 230 |
+
elif len(color2) != batch_size:
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
if color3 == None:
|
| 234 |
+
color3 = ['none'] * batch_size
|
| 235 |
+
elif len(color3) != batch_size:
|
| 236 |
+
return False
|
| 237 |
+
return True, object_description, color1, color2, color3
|
| 238 |
+
|
| 239 |
+
def __call__(self, object_description : list[str], color1 : list[str], \
|
| 240 |
+
color2 : list[str] = None, color3 : list[str] = None, \
|
| 241 |
batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
|
| 242 |
|
| 243 |
+
res, object_description, color1, color2, color3 = self.validate_inputs(object_description, color1, color2, color3, batch_size)
|
| 244 |
+
if res == False:
|
| 245 |
return None
|
| 246 |
+
embeddings = self.test_generate_embeddings(object_description, color1, color2, color3)
|
| 247 |
+
embeddings = embeddings.to('cuda')
|
| 248 |
|
| 249 |
# set the inference steps
|
| 250 |
self.scheduler.set_timesteps(num_inference_steps)
|
|
|
|
| 257 |
progress_bar.set_description(f'Inference step {epoch}')
|
| 258 |
|
| 259 |
for batch_index in range(batch_size):
|
| 260 |
+
noise_batches[batch_index] = self.scheduler.scale_model_input(noise_batches[batch_index], timestep=t)
|
| 261 |
with torch.no_grad():
|
| 262 |
+
noise_residual = self.unet(noise_batches[batch_index], t, encoder_hidden_states=embeddings).sample
|
| 263 |
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batches[batch_index]).prev_sample
|
| 264 |
noise_batches[batch_index] = previous_noisy_sample
|
| 265 |
progress_bar.update(1)
|
|
|
|
| 285 |
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
| 286 |
image = (image * 255).round().astype("uint8")
|
| 287 |
image = Image.fromarray(image)
|
| 288 |
+
image.save(f'test{image_index}.png')
|
| 289 |
output_images.append(image)
|
| 290 |
|
| 291 |
# for now just return the images
|
test_pipeline.py
CHANGED
|
@@ -1,20 +1,38 @@
|
|
| 1 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
| 2 |
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
|
| 3 |
import torch
|
|
|
|
| 4 |
|
| 5 |
torch_device = "cuda"
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
unet = UNet2DConditionModel(sample_size=32, in_channels=16, out_channels=16, \
|
| 8 |
down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
|
| 9 |
-
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=
|
| 10 |
block_out_channels=(64, 128, 256), norm_num_groups=32)
|
| 11 |
unet = unet.to('cuda', dtype=torch.float16)
|
| 12 |
scheduler = DDPMScheduler(num_train_timesteps=20)
|
| 13 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
| 14 |
vae = vae.to('cuda', dtype=torch.float16)
|
| 15 |
|
| 16 |
-
pipeline = RCTDiffusionPipeline(unet, scheduler, vae)
|
| 17 |
-
output = pipeline([
|
| 18 |
pipeline.save_pretrained('test')
|
| 19 |
|
| 20 |
# from PIL import Image
|
|
|
|
| 1 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
| 2 |
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
|
| 3 |
import torch
|
| 4 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 5 |
|
| 6 |
torch_device = "cuda"
|
| 7 |
|
| 8 |
+
# test of text tokenizers
|
| 9 |
+
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
|
| 10 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
| 11 |
+
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
|
| 12 |
+
).to('cuda')
|
| 13 |
+
|
| 14 |
+
test1 = tokenizer(['aleppo pine tree, common oak tree'], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 15 |
+
#test3 = tokenizer([1.0, 0.0, .05], is_split_into_words=True, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 16 |
+
|
| 17 |
+
with torch.no_grad():
|
| 18 |
+
test1 = text_encoder(test1.input_ids.to('cuda'))[0]
|
| 19 |
+
|
| 20 |
+
test2 = tokenizer('dark green', padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
|
| 21 |
+
|
| 22 |
+
with torch.no_grad():
|
| 23 |
+
test2 = text_encoder(test2.input_ids.to('cuda'))[0]
|
| 24 |
+
|
| 25 |
unet = UNet2DConditionModel(sample_size=32, in_channels=16, out_channels=16, \
|
| 26 |
down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
|
| 27 |
+
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=768*4,
|
| 28 |
block_out_channels=(64, 128, 256), norm_num_groups=32)
|
| 29 |
unet = unet.to('cuda', dtype=torch.float16)
|
| 30 |
scheduler = DDPMScheduler(num_train_timesteps=20)
|
| 31 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
| 32 |
vae = vae.to('cuda', dtype=torch.float16)
|
| 33 |
|
| 34 |
+
pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
| 35 |
+
output = pipeline(['aleppo pine tree'], ['dark green'])
|
| 36 |
pipeline.save_pretrained('test')
|
| 37 |
|
| 38 |
# from PIL import Image
|
train_model.py
CHANGED
|
@@ -11,6 +11,7 @@ from diffusers.optimization import get_cosine_schedule_with_warmup
|
|
| 11 |
from tqdm.auto import tqdm
|
| 12 |
from accelerate import Accelerator
|
| 13 |
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
|
|
|
|
| 14 |
|
| 15 |
SAMPLE_SIZE = 256
|
| 16 |
LATENT_SIZE = 32
|
|
@@ -18,12 +19,12 @@ SAMPLE_NUM_CHANNELS = 3
|
|
| 18 |
LATENT_NUM_CHANNELS = 4
|
| 19 |
|
| 20 |
def save_and_test(pipeline, epoch):
|
| 21 |
-
outputs = pipeline([
|
| 22 |
for image_index in range(len(outputs)):
|
| 23 |
file_name = f'out{image_index}_{epoch}.png'
|
| 24 |
outputs[image_index].save(file_name)
|
| 25 |
|
| 26 |
-
model_file = f'rct_foliage_{epoch}
|
| 27 |
pipeline.save_pretrained(model_file)
|
| 28 |
|
| 29 |
def convert_images(dataset):
|
|
@@ -42,18 +43,18 @@ def convert_images(dataset):
|
|
| 42 |
for entry in views[view_index]:
|
| 43 |
image = entry['image']
|
| 44 |
|
| 45 |
-
scale_factor =
|
| 46 |
-
image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST)
|
| 47 |
|
| 48 |
-
new_image = PIL.Image.new('
|
| 49 |
-
new_image.paste(image, box=(int((
|
| 50 |
images.append(new_image)
|
| 51 |
image_views.append(images)
|
| 52 |
|
| 53 |
del views
|
| 54 |
|
| 55 |
# convert those views in tensors
|
| 56 |
-
targets = torch.Tensor(size=(num_images, 4,
|
| 57 |
pillow_to_tensor = T.ToTensor()
|
| 58 |
|
| 59 |
for image_index in range(num_images):
|
|
@@ -62,7 +63,7 @@ def convert_images(dataset):
|
|
| 62 |
del image_views
|
| 63 |
del entries
|
| 64 |
|
| 65 |
-
return torch.reshape(targets, (num_images, 4 *
|
| 66 |
|
| 67 |
def convert_labels(dataset, model, num_images):
|
| 68 |
# get the labels
|
|
@@ -96,80 +97,115 @@ def convert_labels(dataset, model, num_images):
|
|
| 96 |
del dataset
|
| 97 |
return class_labels.to(dtype=torch.float16, device='cuda')
|
| 98 |
|
| 99 |
-
def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=
|
| 100 |
dataset = load_dataset('frutiemax/rct_dataset')
|
| 101 |
dataset = dataset['train']
|
| 102 |
|
| 103 |
targets = convert_images(dataset)
|
| 104 |
-
num_images = int(dataset.num_rows / 4)
|
| 105 |
|
| 106 |
-
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS
|
| 107 |
-
down_block_types=(
|
| 108 |
-
up_block_types=(
|
| 109 |
-
block_out_channels=(
|
| 110 |
unet = unet.to(dtype=torch.float16)
|
| 111 |
-
scheduler = DDPMScheduler(num_train_timesteps=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
| 113 |
vae = vae.to(dtype=torch.float16)
|
| 114 |
|
| 115 |
-
optimizer = torch.optim.
|
| 116 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
| 117 |
optimizer=optimizer,
|
| 118 |
num_warmup_steps=lr_warmup_steps,
|
| 119 |
num_training_steps=num_images * epochs
|
| 120 |
)
|
| 121 |
-
model = RCTDiffusionPipeline(unet, scheduler, vae)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
labels = convert_labels(dataset, model, num_images)
|
| 123 |
del model
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
| 126 |
progress_bar = tqdm(total=epochs)
|
| 127 |
accelerator = Accelerator(mixed_precision='fp16')
|
|
|
|
| 128 |
unet, scheduler, lr_scheduler, vae = accelerator.prepare(unet, scheduler, lr_scheduler, vae)
|
| 129 |
|
|
|
|
|
|
|
|
|
|
| 130 |
for epoch in range(epochs):
|
| 131 |
# create a noisy version of each sprite
|
| 132 |
for batch_index in range(0, num_images, batch_size):
|
| 133 |
-
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
|
| 134 |
batch_end = np.minimum(num_images, batch_index + batch_size)
|
| 135 |
clean_images = targets[batch_index:batch_end]
|
| 136 |
-
clean_images = torch.reshape(clean_images, ((batch_end - batch_index),
|
|
|
|
| 137 |
|
| 138 |
noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
|
| 139 |
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
|
|
|
|
| 140 |
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
| 141 |
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
#
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
if (epoch + 1) % save_model_interval == 0:
|
| 169 |
-
model = RCTDiffusionPipeline(accelerator.unwrap_model(unet), scheduler, vae)
|
| 170 |
save_and_test(model, epoch)
|
| 171 |
progress_bar.update(1)
|
| 172 |
|
| 173 |
|
| 174 |
if __name__ == '__main__':
|
| 175 |
-
train_model(1, save_model_interval=1)
|
|
|
|
| 11 |
from tqdm.auto import tqdm
|
| 12 |
from accelerate import Accelerator
|
| 13 |
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
|
| 14 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
| 15 |
|
| 16 |
SAMPLE_SIZE = 256
|
| 17 |
LATENT_SIZE = 32
|
|
|
|
| 19 |
LATENT_NUM_CHANNELS = 4
|
| 20 |
|
| 21 |
def save_and_test(pipeline, epoch):
|
| 22 |
+
outputs = pipeline(['aleppo pine tree'], ['dark green'])
|
| 23 |
for image_index in range(len(outputs)):
|
| 24 |
file_name = f'out{image_index}_{epoch}.png'
|
| 25 |
outputs[image_index].save(file_name)
|
| 26 |
|
| 27 |
+
model_file = f'rct_foliage_{epoch}'
|
| 28 |
pipeline.save_pretrained(model_file)
|
| 29 |
|
| 30 |
def convert_images(dataset):
|
|
|
|
| 43 |
for entry in views[view_index]:
|
| 44 |
image = entry['image']
|
| 45 |
|
| 46 |
+
scale_factor = np.minimum(LATENT_SIZE / image.width, LATENT_SIZE / image.height)
|
| 47 |
+
image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
|
| 48 |
|
| 49 |
+
new_image = PIL.Image.new('RGBA', (LATENT_SIZE, LATENT_SIZE))
|
| 50 |
+
new_image.paste(image, box=(int((LATENT_SIZE - image.width)/2), int((LATENT_SIZE - image.height)/2)))
|
| 51 |
images.append(new_image)
|
| 52 |
image_views.append(images)
|
| 53 |
|
| 54 |
del views
|
| 55 |
|
| 56 |
# convert those views in tensors
|
| 57 |
+
targets = torch.Tensor(size=(num_images, 4, LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE)).to(dtype=torch.float16)
|
| 58 |
pillow_to_tensor = T.ToTensor()
|
| 59 |
|
| 60 |
for image_index in range(num_images):
|
|
|
|
| 63 |
del image_views
|
| 64 |
del entries
|
| 65 |
|
| 66 |
+
return torch.reshape(targets, (num_images, 4 * LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE))
|
| 67 |
|
| 68 |
def convert_labels(dataset, model, num_images):
|
| 69 |
# get the labels
|
|
|
|
| 97 |
del dataset
|
| 98 |
return class_labels.to(dtype=torch.float16, device='cuda')
|
| 99 |
|
| 100 |
+
def train_model(batch_size=4, total_images=None, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=1):
|
| 101 |
dataset = load_dataset('frutiemax/rct_dataset')
|
| 102 |
dataset = dataset['train']
|
| 103 |
|
| 104 |
targets = convert_images(dataset)
|
| 105 |
+
num_images = int(dataset.num_rows / 4) if total_images == None else int(total_images / 4)
|
| 106 |
|
| 107 |
+
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS*4, out_channels=LATENT_NUM_CHANNELS*4, \
|
| 108 |
+
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
|
| 109 |
+
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
|
| 110 |
+
block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32)
|
| 111 |
unet = unet.to(dtype=torch.float16)
|
| 112 |
+
scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps)
|
| 113 |
+
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
|
| 114 |
+
text_encoder = CLIPTextModel.from_pretrained(
|
| 115 |
+
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
|
| 116 |
+
).to('cuda')
|
| 117 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
| 118 |
vae = vae.to(dtype=torch.float16)
|
| 119 |
|
| 120 |
+
optimizer = torch.optim.SGD(unet.parameters(), lr=start_learning_rate)
|
| 121 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
| 122 |
optimizer=optimizer,
|
| 123 |
num_warmup_steps=lr_warmup_steps,
|
| 124 |
num_training_steps=num_images * epochs
|
| 125 |
)
|
| 126 |
+
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
| 127 |
+
|
| 128 |
+
# get all the object descriptions, color1, color2, color3
|
| 129 |
+
object_descriptions = dataset['object_description']
|
| 130 |
+
colors1 = dataset['color1']
|
| 131 |
+
colors2 = dataset['color2']
|
| 132 |
+
colors3 = dataset['color3']
|
| 133 |
+
|
| 134 |
+
# we only need 1 of the 4 views
|
| 135 |
+
object_descriptions = [object_descriptions[desc_index] for desc_index in range(0, len(object_descriptions), 4)]
|
| 136 |
+
colors1 = [colors1[desc_index] for desc_index in range(0, len(colors1), 4)]
|
| 137 |
+
colors2 = [colors2[desc_index] for desc_index in range(0, len(colors2), 4)]
|
| 138 |
+
colors3 = [colors3[desc_index] for desc_index in range(0, len(colors3), 4)]
|
| 139 |
+
#embeddings = model.generate_embeddings(object_descriptions, colors1, colors2, colors3)
|
| 140 |
+
embeddings = model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
|
| 141 |
+
|
| 142 |
labels = convert_labels(dataset, model, num_images)
|
| 143 |
del model
|
| 144 |
|
| 145 |
+
if total_images != None:
|
| 146 |
+
targets = targets[:int(total_images/4)]
|
| 147 |
+
label_indices = [index for index in range(0, total_images, 4)]
|
| 148 |
+
labels = labels[label_indices]
|
| 149 |
+
|
| 150 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
| 151 |
progress_bar = tqdm(total=epochs)
|
| 152 |
accelerator = Accelerator(mixed_precision='fp16')
|
| 153 |
+
accelerator.clip_grad_norm_(unet.parameters(), 1.0)
|
| 154 |
unet, scheduler, lr_scheduler, vae = accelerator.prepare(unet, scheduler, lr_scheduler, vae)
|
| 155 |
|
| 156 |
+
loss_fn = torch.nn.MSELoss()
|
| 157 |
+
|
| 158 |
+
tensor_to_pillow = T.ToPILImage()
|
| 159 |
for epoch in range(epochs):
|
| 160 |
# create a noisy version of each sprite
|
| 161 |
for batch_index in range(0, num_images, batch_size):
|
|
|
|
| 162 |
batch_end = np.minimum(num_images, batch_index + batch_size)
|
| 163 |
clean_images = targets[batch_index:batch_end]
|
| 164 |
+
clean_images = torch.reshape(clean_images, ((batch_end - batch_index), LATENT_NUM_CHANNELS * 4, LATENT_SIZE, LATENT_SIZE)).\
|
| 165 |
+
to(device='cuda', dtype=torch.float16)
|
| 166 |
|
| 167 |
noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
|
| 168 |
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
|
| 169 |
+
|
| 170 |
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
| 171 |
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
|
| 172 |
+
|
| 173 |
+
# with accelerator.accumulate(unet):
|
| 174 |
+
# assert not torch.any(torch.isnan(timesteps))
|
| 175 |
+
|
| 176 |
+
# batch_embeddings = embeddings[batch_index:batch_end]
|
| 177 |
+
# batch_embeddings = batch_embeddings.to('cuda')
|
| 178 |
+
|
| 179 |
+
# optimizer.zero_grad()
|
| 180 |
+
# unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
|
| 181 |
+
# unet_results = unet_results.to(dtype=torch.float16)
|
| 182 |
+
|
| 183 |
+
# loss = loss_fn(unet_results, noise)
|
| 184 |
+
# accelerator.backward(loss)
|
| 185 |
+
|
| 186 |
+
# optimizer.step()
|
| 187 |
+
# lr_scheduler.step()
|
| 188 |
+
# optimizer.zero_grad()
|
| 189 |
+
|
| 190 |
+
batch_embeddings = embeddings[batch_index:batch_end]
|
| 191 |
+
batch_embeddings = batch_embeddings.to('cuda')
|
| 192 |
+
|
| 193 |
+
optimizer.zero_grad()
|
| 194 |
+
unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
|
| 195 |
+
unet_results = unet_results.to(dtype=torch.float16)
|
| 196 |
+
loss = loss_fn(unet_results, noise)
|
| 197 |
+
loss.backward()
|
| 198 |
+
optimizer.step()
|
| 199 |
+
lr_scheduler.step()
|
| 200 |
+
optimizer.zero_grad()
|
| 201 |
+
|
| 202 |
+
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}, last_loss={loss.item()}')
|
| 203 |
|
| 204 |
if (epoch + 1) % save_model_interval == 0:
|
| 205 |
+
model = RCTDiffusionPipeline(accelerator.unwrap_model(unet), scheduler, vae, tokenizer, text_encoder)
|
| 206 |
save_and_test(model, epoch)
|
| 207 |
progress_bar.update(1)
|
| 208 |
|
| 209 |
|
| 210 |
if __name__ == '__main__':
|
| 211 |
+
train_model(1, total_images=4, save_model_interval=1)
|