|
|
import argparse |
|
|
import os |
|
|
|
|
|
import torch |
|
|
from huggingface_hub import snapshot_download |
|
|
from safetensors.torch import load_file |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, OmniGenPipeline, OmniGenTransformer2DModel |
|
|
|
|
|
|
|
|
def main(args): |
|
|
|
|
|
|
|
|
if not os.path.exists(args.origin_ckpt_path): |
|
|
print("Model not found, downloading...") |
|
|
cache_folder = os.getenv("HF_HUB_CACHE") |
|
|
args.origin_ckpt_path = snapshot_download( |
|
|
repo_id=args.origin_ckpt_path, |
|
|
cache_dir=cache_folder, |
|
|
ignore_patterns=["flax_model.msgpack", "rust_model.ot", "tf_model.h5", "model.pt"], |
|
|
) |
|
|
print(f"Downloaded model to {args.origin_ckpt_path}") |
|
|
|
|
|
ckpt = os.path.join(args.origin_ckpt_path, "model.safetensors") |
|
|
ckpt = load_file(ckpt, device="cpu") |
|
|
|
|
|
mapping_dict = { |
|
|
"pos_embed": "patch_embedding.pos_embed", |
|
|
"x_embedder.proj.weight": "patch_embedding.output_image_proj.weight", |
|
|
"x_embedder.proj.bias": "patch_embedding.output_image_proj.bias", |
|
|
"input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight", |
|
|
"input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias", |
|
|
"final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight", |
|
|
"final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", |
|
|
"final_layer.linear.weight": "proj_out.weight", |
|
|
"final_layer.linear.bias": "proj_out.bias", |
|
|
"time_token.mlp.0.weight": "time_token.linear_1.weight", |
|
|
"time_token.mlp.0.bias": "time_token.linear_1.bias", |
|
|
"time_token.mlp.2.weight": "time_token.linear_2.weight", |
|
|
"time_token.mlp.2.bias": "time_token.linear_2.bias", |
|
|
"t_embedder.mlp.0.weight": "t_embedder.linear_1.weight", |
|
|
"t_embedder.mlp.0.bias": "t_embedder.linear_1.bias", |
|
|
"t_embedder.mlp.2.weight": "t_embedder.linear_2.weight", |
|
|
"t_embedder.mlp.2.bias": "t_embedder.linear_2.bias", |
|
|
"llm.embed_tokens.weight": "embed_tokens.weight", |
|
|
} |
|
|
|
|
|
converted_state_dict = {} |
|
|
for k, v in ckpt.items(): |
|
|
if k in mapping_dict: |
|
|
converted_state_dict[mapping_dict[k]] = v |
|
|
elif "qkv" in k: |
|
|
to_q, to_k, to_v = v.chunk(3) |
|
|
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_q.weight"] = to_q |
|
|
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_k.weight"] = to_k |
|
|
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_v.weight"] = to_v |
|
|
elif "o_proj" in k: |
|
|
converted_state_dict[f"layers.{k.split('.')[2]}.self_attn.to_out.0.weight"] = v |
|
|
else: |
|
|
converted_state_dict[k[4:]] = v |
|
|
|
|
|
transformer = OmniGenTransformer2DModel( |
|
|
rope_scaling={ |
|
|
"long_factor": [ |
|
|
1.0299999713897705, |
|
|
1.0499999523162842, |
|
|
1.0499999523162842, |
|
|
1.0799999237060547, |
|
|
1.2299998998641968, |
|
|
1.2299998998641968, |
|
|
1.2999999523162842, |
|
|
1.4499999284744263, |
|
|
1.5999999046325684, |
|
|
1.6499998569488525, |
|
|
1.8999998569488525, |
|
|
2.859999895095825, |
|
|
3.68999981880188, |
|
|
5.419999599456787, |
|
|
5.489999771118164, |
|
|
5.489999771118164, |
|
|
9.09000015258789, |
|
|
11.579999923706055, |
|
|
15.65999984741211, |
|
|
15.769999504089355, |
|
|
15.789999961853027, |
|
|
18.360000610351562, |
|
|
21.989999771118164, |
|
|
23.079999923706055, |
|
|
30.009998321533203, |
|
|
32.35000228881836, |
|
|
32.590003967285156, |
|
|
35.56000518798828, |
|
|
39.95000457763672, |
|
|
53.840003967285156, |
|
|
56.20000457763672, |
|
|
57.95000457763672, |
|
|
59.29000473022461, |
|
|
59.77000427246094, |
|
|
59.920005798339844, |
|
|
61.190006256103516, |
|
|
61.96000671386719, |
|
|
62.50000762939453, |
|
|
63.3700065612793, |
|
|
63.48000717163086, |
|
|
63.48000717163086, |
|
|
63.66000747680664, |
|
|
63.850006103515625, |
|
|
64.08000946044922, |
|
|
64.760009765625, |
|
|
64.80001068115234, |
|
|
64.81001281738281, |
|
|
64.81001281738281, |
|
|
], |
|
|
"short_factor": [ |
|
|
1.05, |
|
|
1.05, |
|
|
1.05, |
|
|
1.1, |
|
|
1.1, |
|
|
1.1, |
|
|
1.2500000000000002, |
|
|
1.2500000000000002, |
|
|
1.4000000000000004, |
|
|
1.4500000000000004, |
|
|
1.5500000000000005, |
|
|
1.8500000000000008, |
|
|
1.9000000000000008, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.000000000000001, |
|
|
2.1000000000000005, |
|
|
2.1000000000000005, |
|
|
2.2, |
|
|
2.3499999999999996, |
|
|
2.3499999999999996, |
|
|
2.3499999999999996, |
|
|
2.3499999999999996, |
|
|
2.3999999999999995, |
|
|
2.3999999999999995, |
|
|
2.6499999999999986, |
|
|
2.6999999999999984, |
|
|
2.8999999999999977, |
|
|
2.9499999999999975, |
|
|
3.049999999999997, |
|
|
3.049999999999997, |
|
|
3.049999999999997, |
|
|
], |
|
|
"type": "su", |
|
|
}, |
|
|
patch_size=2, |
|
|
in_channels=4, |
|
|
pos_embed_max_size=192, |
|
|
) |
|
|
transformer.load_state_dict(converted_state_dict, strict=True) |
|
|
transformer.to(torch.bfloat16) |
|
|
|
|
|
num_model_params = sum(p.numel() for p in transformer.parameters()) |
|
|
print(f"Total number of transformer parameters: {num_model_params}") |
|
|
|
|
|
scheduler = FlowMatchEulerDiscreteScheduler(invert_sigmas=True, num_train_timesteps=1) |
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(os.path.join(args.origin_ckpt_path, "vae"), torch_dtype=torch.float32) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.origin_ckpt_path) |
|
|
|
|
|
pipeline = OmniGenPipeline(tokenizer=tokenizer, transformer=transformer, vae=vae, scheduler=scheduler) |
|
|
pipeline.save_pretrained(args.dump_path) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
parser.add_argument( |
|
|
"--origin_ckpt_path", |
|
|
default="Shitao/OmniGen-v1", |
|
|
type=str, |
|
|
required=False, |
|
|
help="Path to the checkpoint to convert.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--dump_path", default="OmniGen-v1-diffusers", type=str, required=False, help="Path to the output pipeline." |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
main(args) |
|
|
|