| |
| import argparse |
| from contextlib import nullcontext |
|
|
| import torch |
| from safetensors.torch import load_file |
| from transformers import ( |
| AutoTokenizer, |
| CLIPConfig, |
| CLIPImageProcessor, |
| CLIPTextModelWithProjection, |
| CLIPVisionModelWithProjection, |
| ) |
|
|
| from diffusers import ( |
| DDPMWuerstchenScheduler, |
| StableCascadeCombinedPipeline, |
| StableCascadeDecoderPipeline, |
| StableCascadePriorPipeline, |
| ) |
| from diffusers.loaders.single_file_utils import convert_stable_cascade_unet_single_file_to_diffusers |
| from diffusers.models import StableCascadeUNet |
| from diffusers.models.modeling_utils import load_model_dict_into_meta |
| from diffusers.pipelines.wuerstchen import PaellaVQModel |
| from diffusers.utils import is_accelerate_available |
|
|
|
|
| if is_accelerate_available(): |
| from accelerate import init_empty_weights |
|
|
| parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline") |
| parser.add_argument("--model_path", type=str, help="Location of Stable Cascade weights") |
| parser.add_argument( |
| "--stage_c_name", type=str, default="stage_c_lite.safetensors", help="Name of stage c checkpoint file" |
| ) |
| parser.add_argument( |
| "--stage_b_name", type=str, default="stage_b_lite.safetensors", help="Name of stage b checkpoint file" |
| ) |
| parser.add_argument("--skip_stage_c", action="store_true", help="Skip converting stage c") |
| parser.add_argument("--skip_stage_b", action="store_true", help="Skip converting stage b") |
| parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion") |
| parser.add_argument( |
| "--prior_output_path", |
| default="stable-cascade-prior-lite", |
| type=str, |
| help="Hub organization to save the pipelines to", |
| ) |
| parser.add_argument( |
| "--decoder_output_path", |
| type=str, |
| default="stable-cascade-decoder-lite", |
| help="Hub organization to save the pipelines to", |
| ) |
| parser.add_argument( |
| "--combined_output_path", |
| type=str, |
| default="stable-cascade-combined-lite", |
| help="Hub organization to save the pipelines to", |
| ) |
| parser.add_argument("--save_combined", action="store_true") |
| parser.add_argument("--push_to_hub", action="store_true", help="Push to hub") |
| parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights") |
|
|
| args = parser.parse_args() |
|
|
| if args.skip_stage_b and args.skip_stage_c: |
| raise ValueError("At least one stage should be converted") |
| if (args.skip_stage_b or args.skip_stage_c) and args.save_combined: |
| raise ValueError("Cannot skip stages when creating a combined pipeline") |
|
|
| model_path = args.model_path |
|
|
| device = "cpu" |
| if args.variant == "bf16": |
| dtype = torch.bfloat16 |
| else: |
| dtype = torch.float32 |
|
|
| |
| prior_checkpoint_path = f"{model_path}/{args.stage_c_name}" |
| decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}" |
|
|
| |
| config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") |
| config.text_config.projection_dim = config.projection_dim |
| text_encoder = CLIPTextModelWithProjection.from_pretrained( |
| "laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config |
| ) |
| tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") |
|
|
| |
| feature_extractor = CLIPImageProcessor() |
| image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") |
| |
| scheduler = DDPMWuerstchenScheduler() |
|
|
| ctx = init_empty_weights if is_accelerate_available() else nullcontext |
|
|
| if not args.skip_stage_c: |
| |
| if args.use_safetensors: |
| prior_orig_state_dict = load_file(prior_checkpoint_path, device=device) |
| else: |
| prior_orig_state_dict = torch.load(prior_checkpoint_path, map_location=device) |
|
|
| prior_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(prior_orig_state_dict) |
| with ctx(): |
| prior_model = StableCascadeUNet( |
| in_channels=16, |
| out_channels=16, |
| timestep_ratio_embedding_dim=64, |
| patch_size=1, |
| conditioning_dim=1536, |
| block_out_channels=[1536, 1536], |
| num_attention_heads=[24, 24], |
| down_num_layers_per_block=[4, 12], |
| up_num_layers_per_block=[12, 4], |
| down_blocks_repeat_mappers=[1, 1], |
| up_blocks_repeat_mappers=[1, 1], |
| block_types_per_layer=[ |
| ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], |
| ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], |
| ], |
| clip_text_in_channels=1280, |
| clip_text_pooled_in_channels=1280, |
| clip_image_in_channels=768, |
| clip_seq=4, |
| kernel_size=3, |
| dropout=[0.1, 0.1], |
| self_attn=True, |
| timestep_conditioning_type=["sca", "crp"], |
| switch_level=[False], |
| ) |
|
|
| if is_accelerate_available(): |
| load_model_dict_into_meta(prior_model, prior_state_dict) |
| else: |
| prior_model.load_state_dict(prior_state_dict) |
|
|
| |
| prior_pipeline = StableCascadePriorPipeline( |
| prior=prior_model, |
| tokenizer=tokenizer, |
| text_encoder=text_encoder, |
| image_encoder=image_encoder, |
| scheduler=scheduler, |
| feature_extractor=feature_extractor, |
| ) |
| prior_pipeline.to(dtype).save_pretrained( |
| args.prior_output_path, push_to_hub=args.push_to_hub, variant=args.variant |
| ) |
|
|
| if not args.skip_stage_b: |
| |
| if args.use_safetensors: |
| decoder_orig_state_dict = load_file(decoder_checkpoint_path, device=device) |
| else: |
| decoder_orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device) |
|
|
| decoder_state_dict = convert_stable_cascade_unet_single_file_to_diffusers(decoder_orig_state_dict) |
|
|
| with ctx(): |
| decoder = StableCascadeUNet( |
| in_channels=4, |
| out_channels=4, |
| timestep_ratio_embedding_dim=64, |
| patch_size=2, |
| conditioning_dim=1280, |
| block_out_channels=[320, 576, 1152, 1152], |
| down_num_layers_per_block=[2, 4, 14, 4], |
| up_num_layers_per_block=[4, 14, 4, 2], |
| down_blocks_repeat_mappers=[1, 1, 1, 1], |
| up_blocks_repeat_mappers=[2, 2, 2, 2], |
| num_attention_heads=[0, 9, 18, 18], |
| block_types_per_layer=[ |
| ["SDCascadeResBlock", "SDCascadeTimestepBlock"], |
| ["SDCascadeResBlock", "SDCascadeTimestepBlock"], |
| ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], |
| ["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"], |
| ], |
| clip_text_pooled_in_channels=1280, |
| clip_seq=4, |
| effnet_in_channels=16, |
| pixel_mapper_in_channels=3, |
| kernel_size=3, |
| dropout=[0, 0, 0.1, 0.1], |
| self_attn=True, |
| timestep_conditioning_type=["sca"], |
| ) |
|
|
| if is_accelerate_available(): |
| load_model_dict_into_meta(decoder, decoder_state_dict) |
| else: |
| decoder.load_state_dict(decoder_state_dict) |
|
|
| |
| vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan") |
|
|
| |
| decoder_pipeline = StableCascadeDecoderPipeline( |
| decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler |
| ) |
| decoder_pipeline.to(dtype).save_pretrained( |
| args.decoder_output_path, push_to_hub=args.push_to_hub, variant=args.variant |
| ) |
|
|
| if args.save_combined: |
| |
| stable_cascade_pipeline = StableCascadeCombinedPipeline( |
| |
| text_encoder=text_encoder, |
| tokenizer=tokenizer, |
| decoder=decoder, |
| scheduler=scheduler, |
| vqgan=vqmodel, |
| |
| prior_text_encoder=text_encoder, |
| prior_tokenizer=tokenizer, |
| prior_prior=prior_model, |
| prior_scheduler=scheduler, |
| prior_image_encoder=image_encoder, |
| prior_feature_extractor=feature_extractor, |
| ) |
| stable_cascade_pipeline.to(dtype).save_pretrained( |
| args.combined_output_path, push_to_hub=args.push_to_hub, variant=args.variant |
| ) |
|
|