BiliSakura commited on
Commit
248a056
·
verified ·
1 Parent(s): 3a49b2c

Update all files for BitDance-14B-64x-diffusers

Browse files
Files changed (1) hide show
  1. bitdance_diffusers/convert.py +147 -0
bitdance_diffusers/convert.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ import shutil
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import torch
10
+ from safetensors.torch import load_file as load_safetensors
11
+ from transformers import AutoTokenizer, Qwen3ForCausalLM
12
+
13
+ from .modeling_autoencoder import BitDanceAutoencoder
14
+ from .modeling_diffusion_head import BitDanceDiffusionHead
15
+ from .modeling_projector import BitDanceProjector
16
+ from .pipeline_bitdance import BitDanceDiffusionPipeline
17
+
18
+
19
+ def _resolve_dtype(dtype: str) -> torch.dtype:
20
+ mapping = {
21
+ "float32": torch.float32,
22
+ "float16": torch.float16,
23
+ "bfloat16": torch.bfloat16,
24
+ }
25
+ if dtype not in mapping:
26
+ raise ValueError(f"Unsupported torch dtype '{dtype}'. Choose from {sorted(mapping)}.")
27
+ return mapping[dtype]
28
+
29
+
30
+ def _load_json(path: Path):
31
+ with path.open("r", encoding="utf-8") as handle:
32
+ return json.load(handle)
33
+
34
+
35
+ def _copy_runtime_source(output_path: Path) -> None:
36
+ package_root = Path(__file__).resolve().parent
37
+ target_pkg = output_path / "bitdance_diffusers"
38
+ shutil.copytree(package_root, target_pkg, dirs_exist_ok=True)
39
+
40
+ loader_script = output_path / "load_pipeline.py"
41
+ loader_script.write_text(
42
+ "\n".join(
43
+ [
44
+ "import sys",
45
+ "from pathlib import Path",
46
+ "",
47
+ "from diffusers import DiffusionPipeline",
48
+ "",
49
+ "model_dir = Path(__file__).resolve().parent",
50
+ "sys.path.insert(0, str(model_dir))",
51
+ 'pipe = DiffusionPipeline.from_pretrained(model_dir, trust_remote_code=True).to("cuda")',
52
+ 'images = pipe(prompt="A scenic mountain lake at sunrise.").images',
53
+ 'images[0].save("sample.png")',
54
+ ]
55
+ )
56
+ + "\n",
57
+ encoding="utf-8",
58
+ )
59
+
60
+
61
+ def convert_bitdance_to_diffusers(
62
+ source_model_path: str,
63
+ output_path: str,
64
+ torch_dtype: str = "bfloat16",
65
+ device: str = "cpu",
66
+ copy_runtime_source: bool = True,
67
+ ) -> Path:
68
+ source = Path(source_model_path)
69
+ output = Path(output_path)
70
+ output.mkdir(parents=True, exist_ok=True)
71
+
72
+ dtype = _resolve_dtype(torch_dtype)
73
+
74
+ tokenizer = AutoTokenizer.from_pretrained(source)
75
+ text_encoder = Qwen3ForCausalLM.from_pretrained(
76
+ source,
77
+ torch_dtype=dtype,
78
+ low_cpu_mem_usage=True,
79
+ ).eval()
80
+
81
+ ae_config = _load_json(source / "ae_config.json")
82
+ ddconfig = ae_config.get("ddconfig", ae_config)
83
+ gan_decoder = bool(ae_config.get("gan_decoder", False))
84
+ autoencoder = BitDanceAutoencoder(ddconfig=ddconfig, gan_decoder=gan_decoder).eval()
85
+ autoencoder.load_state_dict(load_safetensors(source / "ae.safetensors"), strict=True, assign=True)
86
+
87
+ vision_head_config = _load_json(source / "vision_head_config.json")
88
+ diffusion_head = BitDanceDiffusionHead(**vision_head_config).eval()
89
+ diffusion_head.load_state_dict(load_safetensors(source / "vision_head.safetensors"), strict=True, assign=True)
90
+
91
+ projector = BitDanceProjector(
92
+ in_dim=int(ddconfig["z_channels"]),
93
+ out_dim=int(text_encoder.config.hidden_size),
94
+ hidden_act="gelu_pytorch_tanh",
95
+ ).eval()
96
+ projector.load_state_dict(load_safetensors(source / "projector.safetensors"), strict=True, assign=True)
97
+
98
+ if device:
99
+ text_encoder.to(device=device)
100
+ autoencoder.to(device=device)
101
+ diffusion_head.to(device=device)
102
+ projector.to(device=device)
103
+
104
+ pipeline = BitDanceDiffusionPipeline(
105
+ tokenizer=tokenizer,
106
+ text_encoder=text_encoder,
107
+ autoencoder=autoencoder,
108
+ diffusion_head=diffusion_head,
109
+ projector=projector,
110
+ )
111
+ pipeline.save_pretrained(output, safe_serialization=True)
112
+
113
+ if copy_runtime_source:
114
+ _copy_runtime_source(output)
115
+
116
+ return output
117
+
118
+
119
+ def parse_args(argv: Optional[list[str]] = None) -> argparse.Namespace:
120
+ parser = argparse.ArgumentParser(description="Convert BitDance checkpoints to Diffusers format.")
121
+ parser.add_argument("--source_model_path", type=str, required=True)
122
+ parser.add_argument("--output_path", type=str, required=True)
123
+ parser.add_argument("--torch_dtype", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"])
124
+ parser.add_argument("--device", type=str, default="cpu")
125
+ parser.add_argument(
126
+ "--copy_runtime_source",
127
+ action=argparse.BooleanOptionalAction,
128
+ default=True,
129
+ help="Copy self-contained runtime source into output directory.",
130
+ )
131
+ return parser.parse_args(argv)
132
+
133
+
134
+ def main(argv: Optional[list[str]] = None) -> None:
135
+ args = parse_args(argv)
136
+ converted = convert_bitdance_to_diffusers(
137
+ source_model_path=args.source_model_path,
138
+ output_path=args.output_path,
139
+ torch_dtype=args.torch_dtype,
140
+ device=args.device,
141
+ copy_runtime_source=args.copy_runtime_source,
142
+ )
143
+ print(f"Saved converted Diffusers pipeline to: {converted}")
144
+
145
+
146
+ if __name__ == "__main__":
147
+ main()