import torch from PIL.Image import Image from diffusers import StableDiffusionXLPipeline from pipelines.models import TextToImageRequest from diffusers import DDIMScheduler from torch import Generator from loss import SchedulerWrapper, get_instance import time from onediffx import compile_pipe, save_pipe, load_pipe from torch.cuda.amp import autocast, GradScaler import copy instance = None loss_fn = None optimizer = None def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs): if step_index == int(pipe.num_timesteps * 0.78): callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1] callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1] callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1] pipe._guidance_scale = 0.1 return callback_kwargs def load_pipeline(pipeline=None) -> StableDiffusionXLPipeline: global instance, loss_fn, optimizer if not pipeline: pipeline = StableDiffusionXLPipeline.from_pretrained( "stablediffusionapi/newdream-sdxl-20", torch_dtype=torch.float16, ).to("cuda") pipeline.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipeline.scheduler.config)) pipeline = compile_pipe(pipeline) load_pipe(pipeline, dir="/home/sandbox/.cache/huggingface/hub/models--RobertML--cached-pipe-02/snapshots/58d70deae87034cce351b780b48841f9746d4ad7") for _ in range(1): deepcache_output = pipeline(prompt="telestereography, unstrengthen, preadministrator, copatroness, hyperpersonal, paramountness, paranoid, guaniferous", output_type="pil", num_inference_steps=20) pipeline.scheduler.prepare_loss() # for _ in range(2): # pipeline(prompt="telestereography, unstrengthen, preadministrator, copatroness, hyperpersonal, paramountness, paranoid, guaniferous", output_type="pil", num_inference_steps=20) instance = get_instance("cuda") def get_pred_original_sample(sched, model_output, timestep, sample): alpha_prod_t = sched.scheduler.alphas_cumprod[timestep] return (sample - (1 - alpha_prod_t) ** 0.5 * model_output) / alpha_prod_t ** 0.5 preview_images_new, preview_images_original = [], [] def add_taesd_previewing(pipe, compress): sched = pipe.scheduler if not hasattr(sched, "_step"): sched._step = sched.step @torch.no_grad() def step_and_preview(*args, **kwargs): global preview_images, preview_handle latents = get_pred_original_sample(sched, *args) latents = latents.float() output = compress(latents) output = output.mul_(2).sub_(1) decoded_new = pipe.image_processor.postprocess(output)[0] # latents = latents[:, :3, :, :] # latents = latents.half() # pipe.vae = pipe.vae.to(dtype=torch.half) # output = pipe.vae.decode(latents) # output = output.mul_(2).sub_(1) decoded_original = pipe.image_processor.postprocess(output)[0] preview_images_new.append(decoded_new) # preview_images_original.append(decoded_original) return sched._step(*args, **kwargs) sched.step = step_and_preview # instance = instance.to(dtype=torch.half) loss_fn = torch.nn.MSELoss() # Specify the layers to train (e.g., last 3 Block layers and final Conv2d) trainable_layers = list(instance.children())[-4:] # Adjust indices as needed for layer in trainable_layers: for param in layer.parameters(): param.requires_grad = True optimizer = torch.optim.Adam([param for layer in trainable_layers for param in layer.parameters()], lr=0.001) print(len([param for layer in trainable_layers for param in layer.parameters()])) instance.train() pipeline.vae.eval() device = 'cuda' mul = torch.nn.Parameter(torch.tensor(1.2, requires_grad=True, device=device)) sub = torch.nn.Parameter(torch.tensor(0.75, requires_grad=True, device=device)) scaling_factor = torch.nn.Parameter(torch.tensor(pipeline.vae.config.scaling_factor, requires_grad=True, device=device)) scaler = GradScaler() checkpoint_path= "../chkpt" import collections dataset = collections.defaultdict(list) counter = 0 # @torch.no_grad() def add_final_speed_comparison(pipe, compress, loss_fn, optimizer, scaler, checkpoint_path, dataset, counter): if not hasattr(pipe.vae, "_decode_without_taesd_preview"): pipe.vae._decode_without_taesd_preview = pipe.vae.decode def decode_latents_and_compare_speeds(latents, *args, **kwargs): global counter dataset[0].append(latents) res_sd = pipe.vae._decode_without_taesd_preview(latents, *args, **kwargs)[0] # counter +=1 res_ = None return (res_sd,) pipe.vae.decode = decode_latents_and_compare_speeds add_final_speed_comparison(pipeline, instance, loss_fn, optimizer, scaler, checkpoint_path, dataset, counter) # add_taesd_previewing(pipeline, instance) for i in range(2): decode = pipeline(prompt="telestereography, unstrengthen, preadministrator, copatroness, hyperpersonal, paramountness, paranoid, guaniferous", output_type="pt", num_inference_steps=20,guidance_scale = 5.0).images[0] print("dataset collected") epochs = 2 print(len(dataset)) for i in range(epochs): print("one") for keys, values in dataset.items(): latents_orig = copy.deepcopy(latents) with torch.no_grad(): res_sd = pipe.vae._decode_without_taesd_preview(latents, *args, **kwargs)[0].detach() # Compute reconstruction with the decoder latents = latents_orig.float() res_taesd = instance(latents.mul( float(pipeline.vae.config.scaling_factor) )).mul(2.).sub(1.) loss = loss_fn(res_taesd, res_sd) print(loss) optimizer.zero_grad() loss.cuda() scaler.scale(loss).backward() optimizer.step() # # Save parameters torch.save(compress.state_dict(), f"{checkpoint_path}/compress_{optimizer.state_dict()['step']}.pth") total_loss += loss.item() print(f"epoch {epoch} loss {total_loss}") # print("debugging...") # # Simple training loop outside the pipeline # for _ in range(2): # print("creating latent") # latents = torch.randn(1, 4, 128, 128, device='cuda', requires_grad=False) # print("decoding ") # latents_orig = copy.deepcopy(latents) # latents = latents.to(dtype=torch.half) # # Decode latents with VAE (fixed target) # with torch.no_grad(): # res_sd = pipeline.vae.decode(latents)[0].detach() # # res_sd = pipeline.vae.decode(latents)[0] # print("compressing ") # latents = latents_orig.float() # # res_taesd = instance(latents.mul(scaling_factor)).mul(mul).sub(sub) # # with autocast(dtype=torch.half): # res_taesd = instance(latents.mul(scaling_factor)).mul(mul).sub(sub) # loss = loss_fn(res_taesd, res_sd) # scaler.scale(loss).backward() # print("checking loss") # # loss = loss_fn(res_taesd, res_sd.float()) # # optimizer.zero_grad() # print("backward") # # loss.backward() # optimizer.step() # preview_images_new[0].save("../preview_images_new_12.gif", save_all=True, append_images=preview_images_new[1:], duration=100, loop=0) # preview_images_original[0].save("../preview_images_original_1.gif", save_all=True, append_images=preview_images_original[1:], duration=100, loop=0) # image_preview.save("../preview_image_decoded_1.jpg") torch.save(instance.state_dict(), "trained_decoder.pth") import sys sys.exit(1) return pipeline def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image: global instance, loss_fn, optimizer if request.seed is None: generator = None else: generator = Generator(pipeline.device).manual_seed(request.seed) return pipeline( prompt=request.prompt, negative_prompt=request.negative_prompt, width=request.width, height=request.height, generator=generator, num_inference_steps=13, cache_interval=1, cache_layer_id=1, cache_block_id=0, eta=1.0, guidance_scale = 5.0, guidance_rescale = 0.0, callback_on_step_end=callback_dynamic_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'], ).images[0]