Use ImageVaeProcessor
Browse files- rct_diffusion_pipeline.py +23 -13
- test_pipeline.py +2 -2
- train_model.py +7 -14
rct_diffusion_pipeline.py
CHANGED
|
@@ -10,9 +10,10 @@ from datasets import load_dataset
|
|
| 10 |
import numpy as np
|
| 11 |
import pandas as pd
|
| 12 |
from tqdm.auto import tqdm
|
|
|
|
| 13 |
|
| 14 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
| 15 |
-
def __init__(self, unet, scheduler, vae, text_tokenizer, text_encoder, latent_size=32, sample_size=256):
|
| 16 |
super().__init__()
|
| 17 |
|
| 18 |
# dictionnary that keeps the different classes of object description, color1, color2 and color3
|
|
@@ -29,6 +30,9 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
| 29 |
self.text_encoder = text_encoder
|
| 30 |
self.text_tokenizer = text_tokenizer
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
# channels for 1 image
|
| 33 |
self.num_channels = int(self.unet.config.in_channels)
|
| 34 |
self.load_dictionaries_from_dataset()
|
|
@@ -172,8 +176,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
| 172 |
|
| 173 |
def generate_noise_batches(self, batch_size):
|
| 174 |
noise_batches = torch.Tensor(size=(batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
| 175 |
-
seed =
|
| 176 |
-
np.random.seed(seed)
|
| 177 |
torch.manual_seed(seed)
|
| 178 |
torch.cuda.manual_seed(seed)
|
| 179 |
for batch_index in range(batch_size):
|
|
@@ -260,6 +263,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
| 260 |
# now call the model for the n interations
|
| 261 |
progress_bar = tqdm(total=num_inference_steps)
|
| 262 |
epoch = 0
|
|
|
|
| 263 |
for t in self.scheduler.timesteps:
|
| 264 |
progress_bar.set_description(f'Inference step {epoch}')
|
| 265 |
|
|
@@ -269,8 +273,14 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
| 269 |
noise_residual = self.unet(noise_batch, t, encoder_hidden_states=embeddings).sample
|
| 270 |
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batch).prev_sample
|
| 271 |
noise_batches[batch_index] = previous_noisy_sample
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
progress_bar.update(1)
|
| 273 |
epoch = epoch + 1
|
|
|
|
| 274 |
|
| 275 |
# reshape the data so we get back 4 RGB images
|
| 276 |
noise_batches = torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size))
|
|
@@ -280,22 +290,22 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
| 280 |
|
| 281 |
with torch.no_grad():
|
| 282 |
image = noise_batches
|
| 283 |
-
result = self.vae.decode(image).sample
|
| 284 |
-
|
| 285 |
-
images =
|
| 286 |
|
| 287 |
# convert those tensors to PIL images
|
| 288 |
tensor_to_pil = T.ToPILImage()
|
| 289 |
output_images = []
|
| 290 |
for batch_index in range(batch_size):
|
| 291 |
image = images[batch_index]
|
| 292 |
-
image = (image / 2 + 0.5).clamp(0, 1)
|
| 293 |
-
#image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
| 294 |
-
#image = (image * 255).round().astype("uint8")
|
| 295 |
-
#image = Image.fromarray(image)
|
| 296 |
-
image = tensor_to_pil(image)
|
| 297 |
-
image.save(f'test{batch_index}.png')
|
| 298 |
output_images.append(image)
|
| 299 |
|
| 300 |
# for now just return the images
|
| 301 |
-
return output_images
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
import numpy as np
|
| 11 |
import pandas as pd
|
| 12 |
from tqdm.auto import tqdm
|
| 13 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 14 |
|
| 15 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
| 16 |
+
def __init__(self, unet, scheduler, vae, text_tokenizer, text_encoder, vae_image_processor : VaeImageProcessor, latent_size=32, sample_size=256):
|
| 17 |
super().__init__()
|
| 18 |
|
| 19 |
# dictionnary that keeps the different classes of object description, color1, color2 and color3
|
|
|
|
| 30 |
self.text_encoder = text_encoder
|
| 31 |
self.text_tokenizer = text_tokenizer
|
| 32 |
|
| 33 |
+
# use vae image processor
|
| 34 |
+
self.vae_image_processor = vae_image_processor
|
| 35 |
+
|
| 36 |
# channels for 1 image
|
| 37 |
self.num_channels = int(self.unet.config.in_channels)
|
| 38 |
self.load_dictionaries_from_dataset()
|
|
|
|
| 176 |
|
| 177 |
def generate_noise_batches(self, batch_size):
|
| 178 |
noise_batches = torch.Tensor(size=(batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
| 179 |
+
seed = torch.seed()
|
|
|
|
| 180 |
torch.manual_seed(seed)
|
| 181 |
torch.cuda.manual_seed(seed)
|
| 182 |
for batch_index in range(batch_size):
|
|
|
|
| 263 |
# now call the model for the n interations
|
| 264 |
progress_bar = tqdm(total=num_inference_steps)
|
| 265 |
epoch = 0
|
| 266 |
+
test_image = None
|
| 267 |
for t in self.scheduler.timesteps:
|
| 268 |
progress_bar.set_description(f'Inference step {epoch}')
|
| 269 |
|
|
|
|
| 273 |
noise_residual = self.unet(noise_batch, t, encoder_hidden_states=embeddings).sample
|
| 274 |
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batch).prev_sample
|
| 275 |
noise_batches[batch_index] = previous_noisy_sample
|
| 276 |
+
|
| 277 |
+
# test
|
| 278 |
+
test_image = self.decode_latent(noise_batches[batch_index], self.vae.config.scaling_factor)
|
| 279 |
+
|
| 280 |
+
|
| 281 |
progress_bar.update(1)
|
| 282 |
epoch = epoch + 1
|
| 283 |
+
test_image.show()
|
| 284 |
|
| 285 |
# reshape the data so we get back 4 RGB images
|
| 286 |
noise_batches = torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size))
|
|
|
|
| 290 |
|
| 291 |
with torch.no_grad():
|
| 292 |
image = noise_batches
|
| 293 |
+
result = self.vae.decode(image / self.vae.config.scaling_factor).sample
|
| 294 |
+
image = self.vae_image_processor.denormalize(result)
|
| 295 |
+
images = image
|
| 296 |
|
| 297 |
# convert those tensors to PIL images
|
| 298 |
tensor_to_pil = T.ToPILImage()
|
| 299 |
output_images = []
|
| 300 |
for batch_index in range(batch_size):
|
| 301 |
image = images[batch_index]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
output_images.append(image)
|
| 303 |
|
| 304 |
# for now just return the images
|
| 305 |
+
return [tensor_to_pil(image) for image in output_images]
|
| 306 |
+
|
| 307 |
+
def decode_latent(self, image, vae_scaling_factor) -> torch.Tensor:
|
| 308 |
+
tensor_to_pil = T.ToPILImage()
|
| 309 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 310 |
+
image = tensor_to_pil(image)
|
| 311 |
+
return image
|
test_pipeline.py
CHANGED
|
@@ -39,8 +39,8 @@ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors
|
|
| 39 |
vae = vae.to('cuda', dtype=torch.float16)
|
| 40 |
|
| 41 |
#pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
| 42 |
-
pipeline = RCTDiffusionPipeline.from_pretrained('
|
| 43 |
-
output = pipeline(['pagoda
|
| 44 |
output[0].save('out.png')
|
| 45 |
pipeline.save_pretrained('test')
|
| 46 |
print('test')
|
|
|
|
| 39 |
vae = vae.to('cuda', dtype=torch.float16)
|
| 40 |
|
| 41 |
#pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
| 42 |
+
pipeline = RCTDiffusionPipeline.from_pretrained('rct_foliage_249')
|
| 43 |
+
output = pipeline(['(cabbage) pagoda tree'], ['(dark) green'], ['brown'])
|
| 44 |
output[0].save('out.png')
|
| 45 |
pipeline.save_pretrained('test')
|
| 46 |
print('test')
|
train_model.py
CHANGED
|
@@ -14,6 +14,7 @@ from accelerate import Accelerator
|
|
| 14 |
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
|
| 15 |
from transformers import CLIPTextModel, CLIPTokenizer
|
| 16 |
import torch.nn as nn
|
|
|
|
| 17 |
|
| 18 |
SAMPLE_SIZE = 256
|
| 19 |
LATENT_SIZE = 32
|
|
@@ -31,24 +32,13 @@ def save_and_test(pipeline, epoch):
|
|
| 31 |
pipeline.save_pretrained(model_file)
|
| 32 |
|
| 33 |
def transform_images(image):
|
| 34 |
-
res = torch.Tensor((SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE))
|
| 35 |
pil_to_tensor = T.PILToTensor()
|
| 36 |
-
tensor_to_pil = T.ToPILImage()
|
| 37 |
-
|
| 38 |
-
res_index = 0
|
| 39 |
scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height)
|
| 40 |
image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
|
| 41 |
|
| 42 |
new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
|
| 43 |
new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
|
| 44 |
-
|
| 45 |
-
#data = np.array(new_image, dtype=np.float32)
|
| 46 |
-
#data = (data / 128.0 - 1.0)
|
| 47 |
-
#res = torch.from_numpy(data)
|
| 48 |
-
res = pil_to_tensor(new_image)
|
| 49 |
-
res.to(dtype=torch.float32)
|
| 50 |
-
res = res / torch.Tensor([128.0]) - torch.Tensor([1.0])
|
| 51 |
-
return res
|
| 52 |
|
| 53 |
def convert_images(dataset):
|
| 54 |
images = [transform_images(image) for image in dataset["image"]]
|
|
@@ -101,6 +91,8 @@ def create_embeddings(dataset, model):
|
|
| 101 |
|
| 102 |
|
| 103 |
def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=100, save_model_interval=10, start_learning_rate=1e-4, lr_warmup_steps=500):
|
|
|
|
|
|
|
| 104 |
dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
|
| 105 |
dataset.set_transform(convert_images)
|
| 106 |
num_images = dataset.num_rows
|
|
@@ -133,7 +125,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
| 133 |
num_warmup_steps=lr_warmup_steps,
|
| 134 |
num_training_steps=num_images * epochs
|
| 135 |
)
|
| 136 |
-
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
| 137 |
unet = unet.to('cuda')
|
| 138 |
|
| 139 |
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
@@ -154,6 +146,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
| 154 |
to(device='cuda')
|
| 155 |
|
| 156 |
# use the vae to get the latent images
|
|
|
|
| 157 |
latent_images = vae.encode(clean_images).latent_dist.sample()
|
| 158 |
latent_images = latent_images * vae.config.scaling_factor
|
| 159 |
|
|
@@ -192,4 +185,4 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
| 192 |
|
| 193 |
|
| 194 |
if __name__ == '__main__':
|
| 195 |
-
train_model(batch_size=
|
|
|
|
| 14 |
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
|
| 15 |
from transformers import CLIPTextModel, CLIPTokenizer
|
| 16 |
import torch.nn as nn
|
| 17 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 18 |
|
| 19 |
SAMPLE_SIZE = 256
|
| 20 |
LATENT_SIZE = 32
|
|
|
|
| 32 |
pipeline.save_pretrained(model_file)
|
| 33 |
|
| 34 |
def transform_images(image):
|
|
|
|
| 35 |
pil_to_tensor = T.PILToTensor()
|
|
|
|
|
|
|
|
|
|
| 36 |
scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height)
|
| 37 |
image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
|
| 38 |
|
| 39 |
new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
|
| 40 |
new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
|
| 41 |
+
return pil_to_tensor(new_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
def convert_images(dataset):
|
| 44 |
images = [transform_images(image) for image in dataset["image"]]
|
|
|
|
| 91 |
|
| 92 |
|
| 93 |
def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=100, save_model_interval=10, start_learning_rate=1e-4, lr_warmup_steps=500):
|
| 94 |
+
vae_image_processor = VaeImageProcessor()
|
| 95 |
+
|
| 96 |
dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
|
| 97 |
dataset.set_transform(convert_images)
|
| 98 |
num_images = dataset.num_rows
|
|
|
|
| 125 |
num_warmup_steps=lr_warmup_steps,
|
| 126 |
num_training_steps=num_images * epochs
|
| 127 |
)
|
| 128 |
+
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder, vae_image_processor)
|
| 129 |
unet = unet.to('cuda')
|
| 130 |
|
| 131 |
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
| 146 |
to(device='cuda')
|
| 147 |
|
| 148 |
# use the vae to get the latent images
|
| 149 |
+
clean_images = vae_image_processor.preprocess(clean_images)
|
| 150 |
latent_images = vae.encode(clean_images).latent_dist.sample()
|
| 151 |
latent_images = latent_images * vae.config.scaling_factor
|
| 152 |
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
if __name__ == '__main__':
|
| 188 |
+
train_model(batch_size=1, total_images=4, save_model_interval=25, epochs=500, start_learning_rate=1e-5)
|