| import torch |
| from PIL.Image import Image |
| from diffusers import StableDiffusionXLPipeline |
|
|
| from pipelines.models import TextToImageRequest |
| from diffusers import DDIMScheduler |
| from torch import Generator |
| from loss import SchedulerWrapper, get_instance |
| import time |
| from onediffx import compile_pipe, save_pipe, load_pipe |
| from torch.cuda.amp import autocast, GradScaler |
| import copy |
| instance = None |
| loss_fn = None |
| optimizer = None |
|
|
|
|
| def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs): |
| if step_index == int(pipe.num_timesteps * 0.78): |
| callback_kwargs['prompt_embeds'] = callback_kwargs['prompt_embeds'].chunk(2)[-1] |
| callback_kwargs['add_text_embeds'] = callback_kwargs['add_text_embeds'].chunk(2)[-1] |
| callback_kwargs['add_time_ids'] = callback_kwargs['add_time_ids'].chunk(2)[-1] |
| pipe._guidance_scale = 0.1 |
|
|
| return callback_kwargs |
|
|
| def load_pipeline(pipeline=None) -> StableDiffusionXLPipeline: |
| global instance, loss_fn, optimizer |
| if not pipeline: |
| pipeline = StableDiffusionXLPipeline.from_pretrained( |
| "stablediffusionapi/newdream-sdxl-20", |
| torch_dtype=torch.float16, |
| ).to("cuda") |
| |
| pipeline.scheduler = SchedulerWrapper(DDIMScheduler.from_config(pipeline.scheduler.config)) |
|
|
| pipeline = compile_pipe(pipeline) |
| load_pipe(pipeline, dir="/home/sandbox/.cache/huggingface/hub/models--RobertML--cached-pipe-02/snapshots/58d70deae87034cce351b780b48841f9746d4ad7") |
|
|
| for _ in range(1): |
| deepcache_output = pipeline(prompt="telestereography, unstrengthen, preadministrator, copatroness, hyperpersonal, paramountness, paranoid, guaniferous", output_type="pil", num_inference_steps=20) |
| pipeline.scheduler.prepare_loss() |
| |
| |
| instance = get_instance("cuda") |
| def get_pred_original_sample(sched, model_output, timestep, sample): |
| alpha_prod_t = sched.scheduler.alphas_cumprod[timestep] |
| return (sample - (1 - alpha_prod_t) ** 0.5 * model_output) / alpha_prod_t ** 0.5 |
|
|
| preview_images_new, preview_images_original = [], [] |
| def add_taesd_previewing(pipe, compress): |
| sched = pipe.scheduler |
| if not hasattr(sched, "_step"): |
| sched._step = sched.step |
| @torch.no_grad() |
| def step_and_preview(*args, **kwargs): |
| global preview_images, preview_handle |
| latents = get_pred_original_sample(sched, *args) |
| latents = latents.float() |
| output = compress(latents) |
| output = output.mul_(2).sub_(1) |
| decoded_new = pipe.image_processor.postprocess(output)[0] |
| |
| |
| |
| |
| |
| decoded_original = pipe.image_processor.postprocess(output)[0] |
| preview_images_new.append(decoded_new) |
| |
| return sched._step(*args, **kwargs) |
| sched.step = step_and_preview |
|
|
| |
| loss_fn = torch.nn.MSELoss() |
|
|
| |
| |
| trainable_layers = list(instance.children())[-4:] |
| for layer in trainable_layers: |
| for param in layer.parameters(): |
| param.requires_grad = True |
|
|
| optimizer = torch.optim.Adam([param for layer in trainable_layers for param in layer.parameters()], lr=0.001) |
| print(len([param for layer in trainable_layers for param in layer.parameters()])) |
| instance.train() |
| pipeline.vae.eval() |
|
|
| device = 'cuda' |
| mul = torch.nn.Parameter(torch.tensor(1.2, requires_grad=True, device=device)) |
| sub = torch.nn.Parameter(torch.tensor(0.75, requires_grad=True, device=device)) |
| scaling_factor = torch.nn.Parameter(torch.tensor(pipeline.vae.config.scaling_factor, requires_grad=True, device=device)) |
| scaler = GradScaler() |
| checkpoint_path= "../chkpt" |
| import collections |
| dataset = collections.defaultdict(list) |
| counter = 0 |
| |
| def add_final_speed_comparison(pipe, compress, loss_fn, optimizer, scaler, checkpoint_path, dataset, counter): |
| if not hasattr(pipe.vae, "_decode_without_taesd_preview"): |
| pipe.vae._decode_without_taesd_preview = pipe.vae.decode |
| def decode_latents_and_compare_speeds(latents, *args, **kwargs): |
| global counter |
| dataset[0].append(latents) |
| res_sd = pipe.vae._decode_without_taesd_preview(latents, *args, **kwargs)[0] |
| |
| res_ = None |
| return (res_sd,) |
| pipe.vae.decode = decode_latents_and_compare_speeds |
|
|
| |
|
|
| add_final_speed_comparison(pipeline, instance, loss_fn, optimizer, scaler, checkpoint_path, dataset, counter) |
| |
| for i in range(2): |
| decode = pipeline(prompt="telestereography, unstrengthen, preadministrator, copatroness, hyperpersonal, paramountness, paranoid, guaniferous", output_type="pt", num_inference_steps=20,guidance_scale = 5.0).images[0] |
| print("dataset collected") |
| epochs = 2 |
| print(len(dataset)) |
| for i in range(epochs): |
| print("one") |
| for keys, values in dataset.items(): |
| latents_orig = copy.deepcopy(latents) |
| |
| with torch.no_grad(): |
| res_sd = pipe.vae._decode_without_taesd_preview(latents, *args, **kwargs)[0].detach() |
| |
| latents = latents_orig.float() |
| res_taesd = instance(latents.mul( float(pipeline.vae.config.scaling_factor) )).mul(2.).sub(1.) |
| loss = loss_fn(res_taesd, res_sd) |
| print(loss) |
| optimizer.zero_grad() |
| loss.cuda() |
| scaler.scale(loss).backward() |
| optimizer.step() |
| |
| |
| torch.save(compress.state_dict(), f"{checkpoint_path}/compress_{optimizer.state_dict()['step']}.pth") |
| total_loss += loss.item() |
|
|
| print(f"epoch {epoch} loss {total_loss}") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| torch.save(instance.state_dict(), "trained_decoder.pth") |
| import sys |
| sys.exit(1) |
| return pipeline |
|
|
| def infer(request: TextToImageRequest, pipeline: StableDiffusionXLPipeline) -> Image: |
| global instance, loss_fn, optimizer |
| if request.seed is None: |
| generator = None |
| else: |
| generator = Generator(pipeline.device).manual_seed(request.seed) |
|
|
| return pipeline( |
| prompt=request.prompt, |
| negative_prompt=request.negative_prompt, |
| width=request.width, |
| height=request.height, |
| generator=generator, |
| num_inference_steps=13, |
| cache_interval=1, |
| cache_layer_id=1, |
| cache_block_id=0, |
| eta=1.0, |
| guidance_scale = 5.0, |
| guidance_rescale = 0.0, |
| callback_on_step_end=callback_dynamic_cfg, |
| callback_on_step_end_tensor_inputs=['prompt_embeds', 'add_text_embeds', 'add_time_ids'], |
| ).images[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|