| | import argparse |
| | import os |
| |
|
| | import torch |
| | from safetensors.torch import load_file |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| | from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, LuminaNextDiT2DModel, LuminaText2ImgPipeline |
| |
|
| |
|
| | def main(args): |
| | |
| | all_sd = load_file(args.origin_ckpt_path, device="cpu") |
| | converted_state_dict = {} |
| | |
| | converted_state_dict["pad_token"] = all_sd["pad_token"] |
| |
|
| | |
| | converted_state_dict["patch_embedder.weight"] = all_sd["x_embedder.weight"] |
| | converted_state_dict["patch_embedder.bias"] = all_sd["x_embedder.bias"] |
| |
|
| | |
| | converted_state_dict["time_caption_embed.timestep_embedder.linear_1.weight"] = all_sd["t_embedder.mlp.0.weight"] |
| | converted_state_dict["time_caption_embed.timestep_embedder.linear_1.bias"] = all_sd["t_embedder.mlp.0.bias"] |
| | converted_state_dict["time_caption_embed.timestep_embedder.linear_2.weight"] = all_sd["t_embedder.mlp.2.weight"] |
| | converted_state_dict["time_caption_embed.timestep_embedder.linear_2.bias"] = all_sd["t_embedder.mlp.2.bias"] |
| | converted_state_dict["time_caption_embed.caption_embedder.0.weight"] = all_sd["cap_embedder.0.weight"] |
| | converted_state_dict["time_caption_embed.caption_embedder.0.bias"] = all_sd["cap_embedder.0.bias"] |
| | converted_state_dict["time_caption_embed.caption_embedder.1.weight"] = all_sd["cap_embedder.1.weight"] |
| | converted_state_dict["time_caption_embed.caption_embedder.1.bias"] = all_sd["cap_embedder.1.bias"] |
| |
|
| | for i in range(24): |
| | |
| | converted_state_dict[f"layers.{i}.gate"] = all_sd[f"layers.{i}.attention.gate"] |
| | converted_state_dict[f"layers.{i}.adaLN_modulation.1.weight"] = all_sd[f"layers.{i}.adaLN_modulation.1.weight"] |
| | converted_state_dict[f"layers.{i}.adaLN_modulation.1.bias"] = all_sd[f"layers.{i}.adaLN_modulation.1.bias"] |
| |
|
| | |
| | converted_state_dict[f"layers.{i}.attn1.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"] |
| | converted_state_dict[f"layers.{i}.attn1.to_k.weight"] = all_sd[f"layers.{i}.attention.wk.weight"] |
| | converted_state_dict[f"layers.{i}.attn1.to_v.weight"] = all_sd[f"layers.{i}.attention.wv.weight"] |
| |
|
| | |
| | converted_state_dict[f"layers.{i}.attn2.to_q.weight"] = all_sd[f"layers.{i}.attention.wq.weight"] |
| | converted_state_dict[f"layers.{i}.attn2.to_k.weight"] = all_sd[f"layers.{i}.attention.wk_y.weight"] |
| | converted_state_dict[f"layers.{i}.attn2.to_v.weight"] = all_sd[f"layers.{i}.attention.wv_y.weight"] |
| |
|
| | |
| | converted_state_dict[f"layers.{i}.attn2.to_out.0.weight"] = all_sd[f"layers.{i}.attention.wo.weight"] |
| |
|
| | |
| | |
| | converted_state_dict[f"layers.{i}.attn1.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"] |
| | converted_state_dict[f"layers.{i}.attn1.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"] |
| |
|
| | converted_state_dict[f"layers.{i}.attn1.norm_k.weight"] = all_sd[f"layers.{i}.attention.k_norm.weight"] |
| | converted_state_dict[f"layers.{i}.attn1.norm_k.bias"] = all_sd[f"layers.{i}.attention.k_norm.bias"] |
| |
|
| | converted_state_dict[f"layers.{i}.attn2.norm_q.weight"] = all_sd[f"layers.{i}.attention.q_norm.weight"] |
| | converted_state_dict[f"layers.{i}.attn2.norm_q.bias"] = all_sd[f"layers.{i}.attention.q_norm.bias"] |
| |
|
| | converted_state_dict[f"layers.{i}.attn2.norm_k.weight"] = all_sd[f"layers.{i}.attention.ky_norm.weight"] |
| | converted_state_dict[f"layers.{i}.attn2.norm_k.bias"] = all_sd[f"layers.{i}.attention.ky_norm.bias"] |
| |
|
| | |
| | converted_state_dict[f"layers.{i}.attn_norm1.weight"] = all_sd[f"layers.{i}.attention_norm1.weight"] |
| | converted_state_dict[f"layers.{i}.attn_norm2.weight"] = all_sd[f"layers.{i}.attention_norm2.weight"] |
| | converted_state_dict[f"layers.{i}.norm1_context.weight"] = all_sd[f"layers.{i}.attention_y_norm.weight"] |
| |
|
| | |
| | converted_state_dict[f"layers.{i}.feed_forward.linear_1.weight"] = all_sd[f"layers.{i}.feed_forward.w1.weight"] |
| | converted_state_dict[f"layers.{i}.feed_forward.linear_2.weight"] = all_sd[f"layers.{i}.feed_forward.w2.weight"] |
| | converted_state_dict[f"layers.{i}.feed_forward.linear_3.weight"] = all_sd[f"layers.{i}.feed_forward.w3.weight"] |
| |
|
| | |
| | converted_state_dict[f"layers.{i}.ffn_norm1.weight"] = all_sd[f"layers.{i}.ffn_norm1.weight"] |
| | converted_state_dict[f"layers.{i}.ffn_norm2.weight"] = all_sd[f"layers.{i}.ffn_norm2.weight"] |
| |
|
| | |
| | converted_state_dict["final_layer.linear.weight"] = all_sd["final_layer.linear.weight"] |
| | converted_state_dict["final_layer.linear.bias"] = all_sd["final_layer.linear.bias"] |
| |
|
| | converted_state_dict["final_layer.adaLN_modulation.1.weight"] = all_sd["final_layer.adaLN_modulation.1.weight"] |
| | converted_state_dict["final_layer.adaLN_modulation.1.bias"] = all_sd["final_layer.adaLN_modulation.1.bias"] |
| |
|
| | |
| | transformer = LuminaNextDiT2DModel( |
| | sample_size=128, |
| | patch_size=2, |
| | in_channels=4, |
| | hidden_size=2304, |
| | num_layers=24, |
| | num_attention_heads=32, |
| | num_kv_heads=8, |
| | multiple_of=256, |
| | ffn_dim_multiplier=None, |
| | norm_eps=1e-5, |
| | learn_sigma=True, |
| | qk_norm=True, |
| | cross_attention_dim=2048, |
| | scaling_factor=1.0, |
| | ) |
| | transformer.load_state_dict(converted_state_dict, strict=True) |
| |
|
| | num_model_params = sum(p.numel() for p in transformer.parameters()) |
| | print(f"Total number of transformer parameters: {num_model_params}") |
| |
|
| | if args.only_transformer: |
| | transformer.save_pretrained(os.path.join(args.dump_path, "transformer")) |
| | else: |
| | scheduler = FlowMatchEulerDiscreteScheduler() |
| |
|
| | vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae", torch_dtype=torch.float32) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b") |
| | text_encoder = AutoModel.from_pretrained("google/gemma-2b") |
| |
|
| | pipeline = LuminaText2ImgPipeline( |
| | tokenizer=tokenizer, text_encoder=text_encoder, transformer=transformer, vae=vae, scheduler=scheduler |
| | ) |
| | pipeline.save_pretrained(args.dump_path) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument( |
| | "--origin_ckpt_path", default=None, type=str, required=False, help="Path to the checkpoint to convert." |
| | ) |
| | parser.add_argument( |
| | "--image_size", |
| | default=1024, |
| | type=int, |
| | choices=[256, 512, 1024], |
| | required=False, |
| | help="Image size of pretrained model, either 512 or 1024.", |
| | ) |
| | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.") |
| | parser.add_argument("--only_transformer", default=True, type=bool, required=True) |
| |
|
| | args = parser.parse_args() |
| | main(args) |
| |
|