BiliSakura commited on
Commit
bb3feea
·
verified ·
1 Parent(s): 4e67e00

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  MVSplit-DiT-1000L/demo.png filter=lfs diff=lfs merge=lfs -text
37
  MVSplit-DiT-1000L/tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  MVSplit-DiT-1000L/demo.png filter=lfs diff=lfs merge=lfs -text
37
  MVSplit-DiT-1000L/tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
38
+ demo.png filter=lfs diff=lfs merge=lfs -text
39
+ tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -11,40 +11,113 @@ tags:
11
  - text-to-image
12
  - flow-matching
13
  - mvsplit
 
14
  widget:
15
  - text: a red panda climbing a bamboo stalk
16
  output:
17
- url: MVSplit-DiT-1000L/demo.png
18
  ---
19
 
20
- # BiliSakura/MVSplit-DiT-diffusers
21
 
22
- Diffusers-ready checkpoints for **MVSplit-DiT** (Mean–Variance Split Residual Diffusion Transformers), converted for local/offline use with a project-owned custom `MVSplitDiTPipeline`.
23
 
24
- > **Re-distribution notice:** weights are converted from [`StableKirito/mvsplit-dit-1000l`](https://huggingface.co/StableKirito/mvsplit-dit-1000l). Original work: [Mean Mode Screaming](https://huggingface.co/papers/2605.06169). License: [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0).
25
-
26
- ## Available checkpoints
27
-
28
- | Subfolder | Params | Task | Resolution |
29
- | --- | ---: | --- | ---: |
30
- | [`MVSplit-DiT-1000L/`](MVSplit-DiT-1000L/) | 1000L | text-to-image | 256×256 |
31
-
32
- Each subfolder is a self-contained Diffusers model repo with `pipeline.py`, `model_index.json`, and component weights.
33
 
34
  ## Demo
35
 
36
- ![MVSplit-DiT-1000L demo](MVSplit-DiT-1000L/demo.png)
37
 
38
  Prompt: *a red panda climbing a bamboo stalk* — 256×256, 35 steps, CFG 2.0.
39
 
 
 
 
 
 
 
 
 
 
 
40
  ## Inference
41
 
 
 
42
  ```bash
43
- cd MVSplit-DiT-1000L
44
  python demo_inference.py
45
  ```
46
 
47
- See [`MVSplit-DiT-1000L/README.md`](MVSplit-DiT-1000L/README.md) for full usage and recommended settings.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  ## Citation
50
 
 
11
  - text-to-image
12
  - flow-matching
13
  - mvsplit
14
+ inference: true
15
  widget:
16
  - text: a red panda climbing a bamboo stalk
17
  output:
18
+ url: demo.png
19
  ---
20
 
21
+ # MVSplit-DiT-1000L
22
 
23
+ Self-contained Diffusers checkpoint for **MVSplit-DiT** (1000-layer Diffusion Transformer) with a custom `MVSplitDiTPipeline` (`pipeline.py`).
24
 
25
+ > **Re-distribution notice:** weights are converted from [`StableKirito/mvsplit-dit-1000l`](https://huggingface.co/StableKirito/mvsplit-dit-1000l). Original work: [Mean Mode Screaming: Mean–Variance Split Residuals for 1000-Layer Diffusion Transformers](https://huggingface.co/papers/2605.06169). License: [Apache 2.0](https://www.apache.org/licenses/LICENSE-2.0).
 
 
 
 
 
 
 
 
26
 
27
  ## Demo
28
 
29
+ ![MVSplit-DiT-1000L demo](demo.png)
30
 
31
  Prompt: *a red panda climbing a bamboo stalk* — 256×256, 35 steps, CFG 2.0.
32
 
33
+ ## Components
34
+
35
+ - `pipeline.py` — `MVSplitDiTPipeline`
36
+ - `model_index.json`
37
+ - `transformer/` — `MVSplitDiTTransformer2DModel` (bf16, 1000 layers)
38
+ - `scheduler/` — `FlowMatchEulerDiscreteScheduler`
39
+ - `text_encoder/` — Qwen3-0.6B (`AutoModel`)
40
+ - `tokenizer/` — Qwen3 tokenizer
41
+ - `vae/` — FLUX2 VAE (`AutoencoderKLFlux2`)
42
+
43
  ## Inference
44
 
45
+ Run the bundled demo script:
46
+
47
  ```bash
 
48
  python demo_inference.py
49
  ```
50
 
51
+ This writes `demo.png` with the default prompt and settings below.
52
+
53
+ ```python
54
+ from pathlib import Path
55
+ import importlib.util
56
+ import sys
57
+ import torch
58
+ from diffusers import AutoencoderKLFlux2
59
+ from transformers import AutoModel, AutoTokenizer
60
+
61
+ model_dir = Path(".").resolve()
62
+
63
+ transformer_path = model_dir / "transformer" / "transformer_mvsplit_dit.py"
64
+ spec = importlib.util.spec_from_file_location("transformer_mvsplit_dit", transformer_path)
65
+ module = importlib.util.module_from_spec(spec)
66
+ sys.modules[spec.name] = module
67
+ spec.loader.exec_module(module)
68
+
69
+ pipe_spec = importlib.util.spec_from_file_location("mvsplit_pipeline", model_dir / "pipeline.py")
70
+ pipe_module = importlib.util.module_from_spec(pipe_spec)
71
+ sys.modules[pipe_spec.name] = pipe_module
72
+ pipe_spec.loader.exec_module(pipe_module)
73
+
74
+ transformer = module.MVSplitDiTTransformer2DModel.from_pretrained(
75
+ model_dir / "transformer",
76
+ torch_dtype=torch.bfloat16,
77
+ local_files_only=True,
78
+ )
79
+ tokenizer = AutoTokenizer.from_pretrained(model_dir / "tokenizer", local_files_only=True)
80
+ text_encoder = AutoModel.from_pretrained(
81
+ model_dir / "text_encoder",
82
+ torch_dtype=torch.bfloat16,
83
+ local_files_only=True,
84
+ )
85
+ vae = AutoencoderKLFlux2.from_pretrained(
86
+ model_dir / "vae",
87
+ torch_dtype=torch.bfloat16,
88
+ local_files_only=True,
89
+ )
90
+
91
+ pipe = pipe_module.MVSplitDiTPipeline(
92
+ transformer=transformer,
93
+ vae=vae,
94
+ text_encoder=text_encoder,
95
+ tokenizer=tokenizer,
96
+ time_shift_alpha=4.0,
97
+ )
98
+ pipe.enable_sequential_cpu_offload()
99
+
100
+ generator = torch.Generator(device="cpu").manual_seed(42)
101
+ image = pipe(
102
+ prompt="a red panda climbing a bamboo stalk",
103
+ height=256,
104
+ width=256,
105
+ num_inference_steps=35,
106
+ guidance_scale=2.0,
107
+ generator=generator,
108
+ ).images[0]
109
+ image.save("demo.png")
110
+ ```
111
+
112
+ ### Recommended settings
113
+
114
+ | Parameter | Default | Notes |
115
+ | --- | ---: | --- |
116
+ | `height` / `width` | 256 | Square output resolution |
117
+ | `num_inference_steps` | 35 | Flow-matching Euler steps |
118
+ | `guidance_scale` | 2.0 | Classifier-free guidance |
119
+ | `time_shift_alpha` | 4.0 | Time-shift in the flow schedule (must match training) |
120
+ | `seed` | 42 | Reproducible sampling |
121
 
122
  ## Citation
123
 
__pycache__/pipeline.cpython-312.pyc ADDED
Binary file (12.9 kB). View file
 
demo.png ADDED

Git LFS Details

  • SHA256: 6e5f8bae051bb3441bfe109f6fc509dd1ed12afbd58e74a0f257729d8a44ce9f
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
demo_inference.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Smoke-test MVSplit-DiT inference from the converted Diffusers Hub folder."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import importlib.util
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ import torch
12
+ from diffusers import AutoencoderKLFlux2
13
+ from transformers import AutoModel, AutoTokenizer
14
+
15
+
16
+ def parse_args() -> argparse.Namespace:
17
+ parser = argparse.ArgumentParser(description="Run MVSplit-DiT inference.")
18
+ parser.add_argument(
19
+ "--model",
20
+ type=Path,
21
+ default=Path(__file__).resolve().parent,
22
+ help="Path to MVSplit-DiT-1000L pipeline directory.",
23
+ )
24
+ parser.add_argument(
25
+ "--prompt",
26
+ type=str,
27
+ default="a red panda climbing a bamboo stalk",
28
+ help="Text prompt for generation.",
29
+ )
30
+ parser.add_argument("--height", type=int, default=256)
31
+ parser.add_argument("--width", type=int, default=256)
32
+ parser.add_argument("--num-inference-steps", type=int, default=35)
33
+ parser.add_argument("--guidance-scale", type=float, default=2.0)
34
+ parser.add_argument("--time-shift-alpha", type=float, default=4.0)
35
+ parser.add_argument("--seed", type=int, default=42)
36
+ parser.add_argument(
37
+ "--output",
38
+ type=Path,
39
+ default=Path(__file__).resolve().parent / "demo.png",
40
+ help="Output image path. Ignored when --output-type=latent.",
41
+ )
42
+ parser.add_argument(
43
+ "--output-type",
44
+ choices=("pil", "latent"),
45
+ default="pil",
46
+ help="Return decoded image or raw latents.",
47
+ )
48
+ parser.add_argument(
49
+ "--skip-vae",
50
+ action="store_true",
51
+ help="Skip VAE decode even when output-type=pil (saves memory).",
52
+ )
53
+ parser.add_argument(
54
+ "--device",
55
+ choices=("auto", "cuda", "cpu"),
56
+ default="auto",
57
+ help="Execution device. auto prefers CUDA when available.",
58
+ )
59
+ parser.add_argument(
60
+ "--cpu-offload",
61
+ action="store_true",
62
+ help="Use sequential CPU offload instead of keeping the pipeline on GPU.",
63
+ )
64
+ return parser.parse_args()
65
+
66
+
67
+ def _resolve_device(choice: str) -> torch.device:
68
+ if choice == "auto":
69
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
+ return torch.device(choice)
71
+
72
+
73
+ def _load_pipeline_class(model_dir: Path):
74
+ transformer_path = model_dir / "transformer" / "transformer_mvsplit_dit.py"
75
+ spec = importlib.util.spec_from_file_location("transformer_mvsplit_dit", transformer_path)
76
+ module = importlib.util.module_from_spec(spec)
77
+ sys.modules[spec.name] = module
78
+ spec.loader.exec_module(module)
79
+
80
+ pipe_spec = importlib.util.spec_from_file_location("mvsplit_pipeline", model_dir / "pipeline.py")
81
+ pipe_module = importlib.util.module_from_spec(pipe_spec)
82
+ sys.modules[pipe_spec.name] = pipe_module
83
+ pipe_spec.loader.exec_module(pipe_module)
84
+ return module.MVSplitDiTTransformer2DModel, pipe_module.MVSplitDiTPipeline
85
+
86
+
87
+ def main() -> None:
88
+ args = parse_args()
89
+ model_dir = args.model.resolve()
90
+ device = _resolve_device(args.device)
91
+ transformer_cls, pipeline_cls = _load_pipeline_class(model_dir)
92
+
93
+ print(f"Loading components on {device}...", flush=True)
94
+ transformer = transformer_cls.from_pretrained(
95
+ model_dir / "transformer",
96
+ torch_dtype=torch.bfloat16,
97
+ local_files_only=True,
98
+ )
99
+ tokenizer = AutoTokenizer.from_pretrained(model_dir / "tokenizer", local_files_only=True)
100
+ text_encoder = AutoModel.from_pretrained(
101
+ model_dir / "text_encoder",
102
+ torch_dtype=torch.bfloat16,
103
+ local_files_only=True,
104
+ )
105
+
106
+ vae = None
107
+ if not args.skip_vae and args.output_type == "pil":
108
+ vae = AutoencoderKLFlux2.from_pretrained(
109
+ model_dir / "vae",
110
+ torch_dtype=torch.bfloat16,
111
+ local_files_only=True,
112
+ )
113
+
114
+ pipe = pipeline_cls(
115
+ transformer=transformer,
116
+ scheduler=None,
117
+ vae=vae,
118
+ text_encoder=text_encoder,
119
+ tokenizer=tokenizer,
120
+ time_shift_alpha=args.time_shift_alpha,
121
+ )
122
+ if args.cpu_offload and device.type == "cuda":
123
+ pipe.enable_sequential_cpu_offload(gpu_id=device.index or 0)
124
+ else:
125
+ pipe.to(device)
126
+
127
+ print(
128
+ f"Running inference ({args.num_inference_steps} steps, {args.height}x{args.width})...",
129
+ flush=True,
130
+ )
131
+ generator_device = "cpu" if args.cpu_offload else device.type
132
+ generator = torch.Generator(device=generator_device).manual_seed(args.seed)
133
+ result = pipe(
134
+ prompt=args.prompt,
135
+ height=args.height,
136
+ width=args.width,
137
+ num_inference_steps=args.num_inference_steps,
138
+ guidance_scale=args.guidance_scale,
139
+ generator=generator,
140
+ output_type=args.output_type,
141
+ )
142
+
143
+ if args.output_type == "latent":
144
+ latents = result.images
145
+ print(f"latent shape={tuple(latents.shape)} dtype={latents.dtype}")
146
+ print(
147
+ "latent stats:",
148
+ f"min={float(latents.min()):.4f}",
149
+ f"max={float(latents.max()):.4f}",
150
+ f"mean={float(latents.mean()):.4f}",
151
+ )
152
+ return
153
+
154
+ image = result.images[0]
155
+ args.output.parent.mkdir(parents=True, exist_ok=True)
156
+ image.save(args.output)
157
+ print(f"Saved image to {args.output}")
158
+
159
+
160
+ if __name__ == "__main__":
161
+ main()
model_index.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": [
3
+ "pipeline",
4
+ "MVSplitDiTPipeline"
5
+ ],
6
+ "_diffusers_version": "0.36.0",
7
+ "scheduler": [
8
+ "diffusers",
9
+ "FlowMatchEulerDiscreteScheduler"
10
+ ],
11
+ "transformer": [
12
+ "transformer_mvsplit_dit",
13
+ "MVSplitDiTTransformer2DModel"
14
+ ],
15
+ "vae": [
16
+ "diffusers",
17
+ "AutoencoderKLFlux2"
18
+ ],
19
+ "text_encoder": [
20
+ "transformers",
21
+ "AutoModel"
22
+ ],
23
+ "tokenizer": [
24
+ "transformers",
25
+ "AutoTokenizer"
26
+ ]
27
+ }
pipeline.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hub custom pipeline: MVSplitDiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from dataclasses import dataclass
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from einops import rearrange
12
+
13
+ try:
14
+ from diffusers.image_processor import VaeImageProcessor
15
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
+ from diffusers.utils import BaseOutput
17
+ except Exception:
18
+ class BaseOutput(dict):
19
+ def __post_init__(self):
20
+ self.update(self.__dict__)
21
+
22
+ class DiffusionPipeline:
23
+ def register_modules(self, **kwargs):
24
+ for name, module in kwargs.items():
25
+ setattr(self, name, module)
26
+
27
+ @property
28
+ def _execution_device(self):
29
+ return torch.device("cpu")
30
+
31
+ def maybe_free_model_hooks(self):
32
+ pass
33
+
34
+ class VaeImageProcessor:
35
+ def postprocess(self, image, output_type="pil"):
36
+ return image
37
+
38
+ # DiT operates on packed FLUX2 latents at 1/16 of the image resolution.
39
+ LATENT_DOWNSAMPLE_FACTOR = 16
40
+
41
+
42
+ @dataclass
43
+ class MVSplitDiTPipelineOutput(BaseOutput):
44
+ images: Union[torch.FloatTensor, List]
45
+
46
+
47
+ class MVSplitDiTPipeline(DiffusionPipeline):
48
+ """
49
+ Text-to-image pipeline for MVSplit DiT.
50
+
51
+ Sampling follows the official mv-split Euler ODE integrator with time-shift
52
+ (see https://github.com/erwold/mv-split sample.py).
53
+ """
54
+
55
+ model_cpu_offload_seq = "text_encoder->transformer->vae"
56
+ _optional_components = ["vae", "text_encoder", "tokenizer"]
57
+
58
+ def __init__(
59
+ self,
60
+ transformer,
61
+ scheduler=None,
62
+ vae=None,
63
+ text_encoder=None,
64
+ tokenizer=None,
65
+ max_length: int = 256,
66
+ time_shift_alpha: float = 4.0,
67
+ ):
68
+ super().__init__()
69
+ self.register_modules(
70
+ transformer=transformer,
71
+ scheduler=scheduler,
72
+ vae=vae,
73
+ text_encoder=text_encoder,
74
+ tokenizer=tokenizer,
75
+ )
76
+ self.max_length = max_length
77
+ self.time_shift_alpha = time_shift_alpha
78
+ self.image_processor = VaeImageProcessor()
79
+
80
+ @staticmethod
81
+ def _shift_time(t: float, alpha: float) -> float:
82
+ return t * alpha / (1.0 + (alpha - 1.0) * t)
83
+
84
+ def _prepare_latents(
85
+ self,
86
+ batch_size: int,
87
+ height: int,
88
+ width: int,
89
+ device: torch.device,
90
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]],
91
+ ) -> torch.Tensor:
92
+ if height % LATENT_DOWNSAMPLE_FACTOR != 0 or width % LATENT_DOWNSAMPLE_FACTOR != 0:
93
+ raise ValueError(
94
+ f"height and width must be divisible by {LATENT_DOWNSAMPLE_FACTOR}."
95
+ )
96
+
97
+ latent_height = height // LATENT_DOWNSAMPLE_FACTOR
98
+ latent_width = width // LATENT_DOWNSAMPLE_FACTOR
99
+ latent_shape = (batch_size, self.transformer.config.in_channels, latent_height, latent_width)
100
+ gen_device = device
101
+ if generator is not None and getattr(generator, "device", None) is not None:
102
+ gen_device = generator.device
103
+ noise = torch.randn(latent_shape, generator=generator, device=gen_device, dtype=torch.float32)
104
+ return noise.to(device)
105
+
106
+ def _encode_text(self, text: Union[str, List[str]], device: torch.device) -> torch.Tensor:
107
+ if self.tokenizer is None or self.text_encoder is None:
108
+ raise ValueError("Both tokenizer and text_encoder must be provided for text-to-image inference.")
109
+
110
+ if isinstance(text, str):
111
+ text = [text]
112
+
113
+ if not self.tokenizer.pad_token:
114
+ self.tokenizer.pad_token = self.tokenizer.eos_token
115
+
116
+ tokens = self.tokenizer(
117
+ text,
118
+ padding="longest",
119
+ truncation=True,
120
+ max_length=self.max_length,
121
+ return_tensors="pt",
122
+ )
123
+ input_ids = tokens.input_ids.to(device)
124
+ attention_mask = tokens.attention_mask.to(device)
125
+
126
+ text_model = getattr(self.text_encoder, "model", self.text_encoder)
127
+ embed_tokens = getattr(text_model, "embed_tokens", None)
128
+ if embed_tokens is None:
129
+ outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
130
+ if hasattr(outputs, "last_hidden_state") and outputs.last_hidden_state is not None:
131
+ return outputs.last_hidden_state
132
+ if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
133
+ return outputs.hidden_states[-1]
134
+ if isinstance(outputs, (tuple, list)):
135
+ return outputs[0]
136
+ raise ValueError("Unable to extract text hidden states from text_encoder output.")
137
+
138
+ inputs_embeds = embed_tokens(input_ids)
139
+ outputs = text_model(
140
+ input_ids=None,
141
+ attention_mask=attention_mask,
142
+ inputs_embeds=inputs_embeds,
143
+ )
144
+ return outputs.last_hidden_state
145
+
146
+ def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
147
+ if self.vae is None:
148
+ return latents
149
+
150
+ vae = self.vae
151
+ if not hasattr(vae, "bn"):
152
+ decoded = vae.decode(latents)
153
+ return decoded.sample if hasattr(decoded, "sample") else decoded
154
+
155
+ bn = vae.bn.float().eval()
156
+ running_var = bn.running_var.view(1, -1, 1, 1)
157
+ running_mean = bn.running_mean.view(1, -1, 1, 1)
158
+ latents = (latents.float() * torch.sqrt(running_var + bn.eps) + running_mean).to(latents.dtype)
159
+
160
+ patch_size = getattr(vae.config, "patch_size", (2, 2))
161
+ if isinstance(patch_size, int):
162
+ patch_size = (patch_size, patch_size)
163
+ latents = rearrange(
164
+ latents,
165
+ "... (c pi pj) i j -> ... c (i pi) (j pj)",
166
+ pi=patch_size[0],
167
+ pj=patch_size[1],
168
+ )
169
+
170
+ decoded = vae.decode(latents)
171
+ return decoded.sample if hasattr(decoded, "sample") else decoded
172
+
173
+ def _euler_sample(
174
+ self,
175
+ latents: torch.Tensor,
176
+ prompt_embeds: torch.Tensor,
177
+ negative_prompt_embeds: Optional[torch.Tensor],
178
+ num_inference_steps: int,
179
+ guidance_scale: float,
180
+ ) -> torch.Tensor:
181
+ model_dtype = next(self.transformer.parameters()).dtype
182
+ alpha = self.time_shift_alpha
183
+ do_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
184
+
185
+ latents = latents.to(torch.float32)
186
+ for step_index in range(num_inference_steps, 0, -1):
187
+ t = step_index / num_inference_steps
188
+ t_next = (step_index - 1) / num_inference_steps
189
+ t_shifted = self._shift_time(t, alpha)
190
+ t_next_shifted = self._shift_time(t_next, alpha)
191
+ dt = t_shifted - t_next_shifted
192
+
193
+ model_input = latents.to(dtype=model_dtype)
194
+ if do_cfg:
195
+ velocity_cond = self.transformer(
196
+ model_input,
197
+ encoder_hidden_states=prompt_embeds.to(dtype=model_dtype),
198
+ return_dict=True,
199
+ ).sample
200
+ velocity_uncond = self.transformer(
201
+ model_input,
202
+ encoder_hidden_states=negative_prompt_embeds.to(dtype=model_dtype),
203
+ return_dict=True,
204
+ ).sample
205
+ velocity = velocity_uncond + guidance_scale * (velocity_cond - velocity_uncond)
206
+ else:
207
+ velocity = self.transformer(
208
+ model_input,
209
+ encoder_hidden_states=prompt_embeds.to(dtype=model_dtype),
210
+ return_dict=True,
211
+ ).sample
212
+
213
+ latents = latents + dt * velocity.to(torch.float32)
214
+
215
+ return latents
216
+
217
+ @torch.no_grad()
218
+ def __call__(
219
+ self,
220
+ prompt: Union[str, List[str]],
221
+ negative_prompt: Optional[Union[str, List[str]]] = None,
222
+ height: int = 256,
223
+ width: int = 256,
224
+ num_inference_steps: int = 35,
225
+ guidance_scale: float = 2.0,
226
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
227
+ output_type: str = "pil",
228
+ return_dict: bool = True,
229
+ ) -> Union[MVSplitDiTPipelineOutput, Tuple]:
230
+ """Run denoising with the MVSplit Euler sampler and decode the output."""
231
+ device = self._execution_device
232
+
233
+ if isinstance(prompt, str):
234
+ prompt = [prompt]
235
+ batch_size = len(prompt)
236
+
237
+ prompt_embeds = self._encode_text(prompt, device=device)
238
+ negative_prompt_embeds = None
239
+ if guidance_scale > 1.0:
240
+ if negative_prompt is None:
241
+ negative_prompt = [""] * batch_size
242
+ elif isinstance(negative_prompt, str):
243
+ negative_prompt = [negative_prompt] * batch_size
244
+ elif len(negative_prompt) != batch_size:
245
+ raise ValueError("negative_prompt must have the same batch size as prompt.")
246
+
247
+ # Match mv-split sample.py: encode cond + uncond in one batch so empty
248
+ # prompts pick up padding from the conditional sequence length.
249
+ all_embeds = self._encode_text(list(prompt) + list(negative_prompt), device=device)
250
+ prompt_embeds, negative_prompt_embeds = all_embeds.chunk(2, dim=0)
251
+
252
+ latents = self._prepare_latents(
253
+ batch_size=batch_size,
254
+ height=height,
255
+ width=width,
256
+ device=device,
257
+ generator=generator,
258
+ )
259
+ latents = self._euler_sample(
260
+ latents=latents,
261
+ prompt_embeds=prompt_embeds,
262
+ negative_prompt_embeds=negative_prompt_embeds,
263
+ num_inference_steps=num_inference_steps,
264
+ guidance_scale=guidance_scale,
265
+ )
266
+
267
+ if output_type == "latent":
268
+ image = latents
269
+ else:
270
+ decode_dtype = next(self.vae.parameters()).dtype if self.vae is not None else latents.dtype
271
+ image = self._decode_latents(latents.to(decode_dtype))
272
+ image = image.mul(0.5).add(0.5).clamp(0, 1)
273
+ image = self.image_processor.postprocess(image, output_type=output_type)
274
+
275
+ self.maybe_free_model_hooks()
276
+ if not return_dict:
277
+ return (image,)
278
+ return MVSplitDiTPipelineOutput(images=image)
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
3
+ "_diffusers_version": "0.38.0",
4
+ "base_image_seq_len": 256,
5
+ "base_shift": 0.5,
6
+ "invert_sigmas": false,
7
+ "max_image_seq_len": 4096,
8
+ "max_shift": 1.15,
9
+ "num_train_timesteps": 1000,
10
+ "shift": 4.0,
11
+ "shift_terminal": null,
12
+ "stochastic_sampling": false,
13
+ "time_shift_type": "exponential",
14
+ "use_beta_sigmas": false,
15
+ "use_dynamic_shifting": false,
16
+ "use_exponential_sigmas": false,
17
+ "use_karras_sigmas": false
18
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen3ForCausalLM"
4
+ ],
5
+ "attention_bias": false,
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 151643,
8
+ "eos_token_id": 151645,
9
+ "head_dim": 128,
10
+ "hidden_act": "silu",
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "max_position_embeddings": 40960,
15
+ "max_window_layers": 28,
16
+ "model_type": "qwen3",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 28,
19
+ "num_key_value_heads": 8,
20
+ "rms_norm_eps": 1e-06,
21
+ "rope_scaling": null,
22
+ "rope_theta": 1000000,
23
+ "sliding_window": null,
24
+ "tie_word_embeddings": true,
25
+ "torch_dtype": "bfloat16",
26
+ "transformers_version": "4.51.0",
27
+ "use_cache": true,
28
+ "use_sliding_window": false,
29
+ "vocab_size": 151936
30
+ }
text_encoder/generation_config.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": true,
4
+ "eos_token_id": [
5
+ 151645,
6
+ 151643
7
+ ],
8
+ "pad_token_id": 151643,
9
+ "temperature": 0.6,
10
+ "top_k": 20,
11
+ "top_p": 0.95,
12
+ "transformers_version": "4.51.0"
13
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f47f71177f32bcd101b7573ec9171e6a57f4f4d31148d38e382306f42996874b
3
+ size 1503300328
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aeb13307a71acd8fe81861d94ad54ab689df773318809eed3cbe794b4492dae4
3
+ size 11422654
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "added_tokens_decoder": {
5
+ "151643": {
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false,
11
+ "special": true
12
+ },
13
+ "151644": {
14
+ "content": "<|im_start|>",
15
+ "lstrip": false,
16
+ "normalized": false,
17
+ "rstrip": false,
18
+ "single_word": false,
19
+ "special": true
20
+ },
21
+ "151645": {
22
+ "content": "<|im_end|>",
23
+ "lstrip": false,
24
+ "normalized": false,
25
+ "rstrip": false,
26
+ "single_word": false,
27
+ "special": true
28
+ },
29
+ "151646": {
30
+ "content": "<|object_ref_start|>",
31
+ "lstrip": false,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false,
35
+ "special": true
36
+ },
37
+ "151647": {
38
+ "content": "<|object_ref_end|>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false,
43
+ "special": true
44
+ },
45
+ "151648": {
46
+ "content": "<|box_start|>",
47
+ "lstrip": false,
48
+ "normalized": false,
49
+ "rstrip": false,
50
+ "single_word": false,
51
+ "special": true
52
+ },
53
+ "151649": {
54
+ "content": "<|box_end|>",
55
+ "lstrip": false,
56
+ "normalized": false,
57
+ "rstrip": false,
58
+ "single_word": false,
59
+ "special": true
60
+ },
61
+ "151650": {
62
+ "content": "<|quad_start|>",
63
+ "lstrip": false,
64
+ "normalized": false,
65
+ "rstrip": false,
66
+ "single_word": false,
67
+ "special": true
68
+ },
69
+ "151651": {
70
+ "content": "<|quad_end|>",
71
+ "lstrip": false,
72
+ "normalized": false,
73
+ "rstrip": false,
74
+ "single_word": false,
75
+ "special": true
76
+ },
77
+ "151652": {
78
+ "content": "<|vision_start|>",
79
+ "lstrip": false,
80
+ "normalized": false,
81
+ "rstrip": false,
82
+ "single_word": false,
83
+ "special": true
84
+ },
85
+ "151653": {
86
+ "content": "<|vision_end|>",
87
+ "lstrip": false,
88
+ "normalized": false,
89
+ "rstrip": false,
90
+ "single_word": false,
91
+ "special": true
92
+ },
93
+ "151654": {
94
+ "content": "<|vision_pad|>",
95
+ "lstrip": false,
96
+ "normalized": false,
97
+ "rstrip": false,
98
+ "single_word": false,
99
+ "special": true
100
+ },
101
+ "151655": {
102
+ "content": "<|image_pad|>",
103
+ "lstrip": false,
104
+ "normalized": false,
105
+ "rstrip": false,
106
+ "single_word": false,
107
+ "special": true
108
+ },
109
+ "151656": {
110
+ "content": "<|video_pad|>",
111
+ "lstrip": false,
112
+ "normalized": false,
113
+ "rstrip": false,
114
+ "single_word": false,
115
+ "special": true
116
+ },
117
+ "151657": {
118
+ "content": "<tool_call>",
119
+ "lstrip": false,
120
+ "normalized": false,
121
+ "rstrip": false,
122
+ "single_word": false,
123
+ "special": false
124
+ },
125
+ "151658": {
126
+ "content": "</tool_call>",
127
+ "lstrip": false,
128
+ "normalized": false,
129
+ "rstrip": false,
130
+ "single_word": false,
131
+ "special": false
132
+ },
133
+ "151659": {
134
+ "content": "<|fim_prefix|>",
135
+ "lstrip": false,
136
+ "normalized": false,
137
+ "rstrip": false,
138
+ "single_word": false,
139
+ "special": false
140
+ },
141
+ "151660": {
142
+ "content": "<|fim_middle|>",
143
+ "lstrip": false,
144
+ "normalized": false,
145
+ "rstrip": false,
146
+ "single_word": false,
147
+ "special": false
148
+ },
149
+ "151661": {
150
+ "content": "<|fim_suffix|>",
151
+ "lstrip": false,
152
+ "normalized": false,
153
+ "rstrip": false,
154
+ "single_word": false,
155
+ "special": false
156
+ },
157
+ "151662": {
158
+ "content": "<|fim_pad|>",
159
+ "lstrip": false,
160
+ "normalized": false,
161
+ "rstrip": false,
162
+ "single_word": false,
163
+ "special": false
164
+ },
165
+ "151663": {
166
+ "content": "<|repo_name|>",
167
+ "lstrip": false,
168
+ "normalized": false,
169
+ "rstrip": false,
170
+ "single_word": false,
171
+ "special": false
172
+ },
173
+ "151664": {
174
+ "content": "<|file_sep|>",
175
+ "lstrip": false,
176
+ "normalized": false,
177
+ "rstrip": false,
178
+ "single_word": false,
179
+ "special": false
180
+ },
181
+ "151665": {
182
+ "content": "<tool_response>",
183
+ "lstrip": false,
184
+ "normalized": false,
185
+ "rstrip": false,
186
+ "single_word": false,
187
+ "special": false
188
+ },
189
+ "151666": {
190
+ "content": "</tool_response>",
191
+ "lstrip": false,
192
+ "normalized": false,
193
+ "rstrip": false,
194
+ "single_word": false,
195
+ "special": false
196
+ },
197
+ "151667": {
198
+ "content": "<think>",
199
+ "lstrip": false,
200
+ "normalized": false,
201
+ "rstrip": false,
202
+ "single_word": false,
203
+ "special": false
204
+ },
205
+ "151668": {
206
+ "content": "</think>",
207
+ "lstrip": false,
208
+ "normalized": false,
209
+ "rstrip": false,
210
+ "single_word": false,
211
+ "special": false
212
+ }
213
+ },
214
+ "additional_special_tokens": [
215
+ "<|im_start|>",
216
+ "<|im_end|>",
217
+ "<|object_ref_start|>",
218
+ "<|object_ref_end|>",
219
+ "<|box_start|>",
220
+ "<|box_end|>",
221
+ "<|quad_start|>",
222
+ "<|quad_end|>",
223
+ "<|vision_start|>",
224
+ "<|vision_end|>",
225
+ "<|vision_pad|>",
226
+ "<|image_pad|>",
227
+ "<|video_pad|>"
228
+ ],
229
+ "bos_token": null,
230
+ "chat_template": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('<tool_response>') and message.content.endswith('</tool_response>')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '</think>' in content %}\n {%- set reasoning_content = content.split('</think>')[0].rstrip('\\n').split('<think>')[-1].lstrip('\\n') %}\n {%- set content = content.split('</think>')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n<think>\\n' + reasoning_content.strip('\\n') + '\\n</think>\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '<think>\\n\\n</think>\\n\\n' }}\n {%- endif %}\n{%- endif %}",
231
+ "clean_up_tokenization_spaces": false,
232
+ "eos_token": "<|im_end|>",
233
+ "errors": "replace",
234
+ "model_max_length": 131072,
235
+ "pad_token": "<|endoftext|>",
236
+ "split_special_tokens": false,
237
+ "tokenizer_class": "Qwen2Tokenizer",
238
+ "unk_token": null
239
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
transformer/__pycache__/transformer_mvsplit_dit.cpython-312.pyc ADDED
Binary file (21.4 kB). View file
 
transformer/config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MVSplitDiTTransformer2DModel",
3
+ "_diffusers_version": "0.38.0",
4
+ "context_dim": 1024,
5
+ "depth": 1000,
6
+ "hidden_size": 1024,
7
+ "in_channels": 128,
8
+ "init_alpha": 0.0,
9
+ "init_beta": 0.03,
10
+ "mlp_hidden_dim": 3072,
11
+ "norm_eps": 1e-05,
12
+ "num_heads": 8,
13
+ "num_kv_heads": 8,
14
+ "patch_size": 1,
15
+ "qkv_bias": false,
16
+ "rope_base": 10000,
17
+ "trainable_rms": true,
18
+ "use_rope": true,
19
+ "torch_dtype": "bfloat16"
20
+ }
transformer/diffusion_pytorch_model-00001-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ebd66315a82685b17dcd82724bd8cb91c5d92af4cec794ab2afa94ac48c0038
3
+ size 4998288504
transformer/diffusion_pytorch_model-00002-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b19bf5b84b48ae73e88c039809a63eb60d3f3cb74a541abe0fcba71d387e3839
3
+ size 4993827600
transformer/diffusion_pytorch_model-00003-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d3b16f617d0934d373015d9097661d37abb46c382f747eb006e1070d28bbdbb
3
+ size 4991729616
transformer/diffusion_pytorch_model-00004-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95883e73ca3680ba3ecbe9cdb88ea1d7794fb384ccffa47a9646d2e8e4bbef76
3
+ size 4991729616
transformer/diffusion_pytorch_model-00005-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:abb4700307e188f5cfd71b4fc2c1319d22ecee354c1ad56267cc61796a2d0fbe
3
+ size 4991729616
transformer/diffusion_pytorch_model-00006-of-00006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf5e9d723915a3db3e6f84e181c96c1237378f4f6f1aff2230ca542cdf42a5af
3
+ size 2310435160
transformer/diffusion_pytorch_model.safetensors.index.json ADDED
The diff for this file is too large to render. See raw diff
 
transformer/transformer_mvsplit_dit.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import math
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from diffusers.models.activations import SwiGLU
9
+ from diffusers.models.embeddings import PatchEmbed, apply_rotary_emb
10
+ from diffusers.models.normalization import RMSNorm
11
+
12
+ try:
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from diffusers.models.modeling_utils import ModelMixin
15
+ from diffusers.utils import BaseOutput
16
+ except Exception:
17
+ class BaseOutput(dict):
18
+ def __post_init__(self):
19
+ self.update(self.__dict__)
20
+
21
+ class _Config(dict):
22
+ def __getattr__(self, key):
23
+ try:
24
+ return self[key]
25
+ except KeyError as error:
26
+ raise AttributeError(key) from error
27
+
28
+ class ConfigMixin:
29
+ config_name = "config.json"
30
+
31
+ class ModelMixin(nn.Module):
32
+ pass
33
+
34
+ def register_to_config(init):
35
+ def wrapper(self, *args, **kwargs):
36
+ import inspect
37
+
38
+ signature = inspect.signature(init)
39
+ bound = signature.bind(self, *args, **kwargs)
40
+ bound.apply_defaults()
41
+ self.config = _Config({key: value for key, value in bound.arguments.items() if key != "self"})
42
+ init(self, *args, **kwargs)
43
+
44
+ return wrapper
45
+
46
+
47
+ @dataclass
48
+ class MVSplitDiTTransformer2DModelOutput(BaseOutput):
49
+ sample: torch.FloatTensor
50
+
51
+
52
+ class TwoDimRotary(nn.Module):
53
+ def __init__(self, dim: int, base: int = 10000):
54
+ super().__init__()
55
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, dtype=torch.float32) / max(dim, 1)))
56
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
57
+
58
+ def forward(
59
+ self,
60
+ height: int,
61
+ width: int,
62
+ device: torch.device,
63
+ dtype: torch.dtype,
64
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
65
+ pos_h = torch.arange(height, device=device, dtype=self.inv_freq.dtype)
66
+ pos_w = torch.arange(width, device=device, dtype=self.inv_freq.dtype)
67
+ freqs_h = torch.outer(pos_h, self.inv_freq).unsqueeze(1).repeat(1, width, 1)
68
+ freqs_w = torch.outer(pos_w, self.inv_freq).unsqueeze(0).repeat(height, 1, 1)
69
+ freqs = torch.cat([freqs_h, freqs_w], dim=-1).reshape(height * width, -1)
70
+ cos = freqs.cos().to(dtype=dtype)
71
+ sin = freqs.sin().to(dtype=dtype)
72
+ return cos, sin
73
+
74
+
75
+ class QKNorm(nn.Module):
76
+ def __init__(self, dim: int, eps: float = 1e-6, trainable: bool = False):
77
+ super().__init__()
78
+ self.query_norm = RMSNorm(dim, eps=eps, elementwise_affine=trainable)
79
+ self.key_norm = RMSNorm(dim, eps=eps, elementwise_affine=trainable)
80
+
81
+ def forward(self, query: torch.Tensor, key: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
82
+ return self.query_norm(query), self.key_norm(key)
83
+
84
+
85
+ class FusedMVSplitNorm1(nn.Module):
86
+ def __init__(self, dim: int, eps: float = 1e-5, init_alpha: float = 0.0, init_beta: float = 0.03):
87
+ super().__init__()
88
+ self.eps = eps
89
+ self.alpha = nn.Parameter(torch.full((dim,), init_alpha))
90
+ self.beta = nn.Parameter(torch.full((dim,), init_beta))
91
+ self.weight = nn.Parameter(torch.ones(dim))
92
+
93
+ def _rms_norm(self, hidden_states: torch.Tensor) -> torch.Tensor:
94
+ original_dtype = hidden_states.dtype
95
+ hidden_states = hidden_states.float()
96
+ hidden_states = hidden_states * torch.rsqrt(hidden_states.pow(2).mean(dim=-1, keepdim=True) + self.eps)
97
+ hidden_states = hidden_states * self.weight.float()
98
+ return hidden_states.to(dtype=original_dtype)
99
+
100
+ def forward(
101
+ self,
102
+ residual: torch.Tensor,
103
+ update: torch.Tensor,
104
+ l_image_tokens: Optional[int] = None,
105
+ ) -> torch.Tensor:
106
+ if l_image_tokens is not None and 0 < l_image_tokens < residual.shape[1]:
107
+ residual_img, residual_txt = residual[:, :l_image_tokens], residual[:, l_image_tokens:]
108
+ update_img, update_txt = update[:, :l_image_tokens], update[:, l_image_tokens:]
109
+
110
+ residual_img_mean = residual_img.mean(dim=1, keepdim=True)
111
+ residual_txt_mean = residual_txt.mean(dim=1, keepdim=True)
112
+ update_img_mean = update_img.mean(dim=1, keepdim=True)
113
+ update_txt_mean = update_txt.mean(dim=1, keepdim=True)
114
+
115
+ update_img_var = update_img - update_img_mean
116
+ update_txt_var = update_txt - update_txt_mean
117
+
118
+ alpha = self.alpha.view(1, 1, -1)
119
+ beta = self.beta.view(1, 1, -1)
120
+ var_update = torch.cat([update_img_var * beta, update_txt_var * beta], dim=1)
121
+ mean_update = torch.cat(
122
+ [
123
+ (alpha * (update_img_mean - residual_img_mean)).expand_as(residual_img),
124
+ (alpha * (update_txt_mean - residual_txt_mean)).expand_as(residual_txt),
125
+ ],
126
+ dim=1,
127
+ )
128
+ else:
129
+ residual_mean = residual.mean(dim=1, keepdim=True)
130
+ update_mean = update.mean(dim=1, keepdim=True)
131
+ var_update = self.beta * (update - update_mean)
132
+ mean_update = self.alpha * (update_mean - residual_mean).expand_as(residual)
133
+
134
+ return self._rms_norm(residual + var_update + mean_update)
135
+
136
+
137
+ class Attention(nn.Module):
138
+ def __init__(
139
+ self,
140
+ dim: int,
141
+ num_heads: int,
142
+ num_kv_heads: int,
143
+ qkv_bias: bool,
144
+ trainable_rms: bool,
145
+ ):
146
+ super().__init__()
147
+ if dim % num_heads != 0:
148
+ raise ValueError("dim must be divisible by num_heads.")
149
+
150
+ self.num_heads = num_heads
151
+ self.num_kv_heads = num_kv_heads
152
+ self.head_dim = dim // num_heads
153
+ if self.num_heads % self.num_kv_heads != 0:
154
+ raise ValueError("num_heads must be divisible by num_kv_heads.")
155
+ self.num_groups = self.num_heads // self.num_kv_heads
156
+ kv_dim = self.num_kv_heads * self.head_dim
157
+
158
+ self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
159
+ self.k_proj = nn.Linear(dim, kv_dim, bias=qkv_bias)
160
+ self.v_proj = nn.Linear(dim, kv_dim, bias=qkv_bias)
161
+ self.proj = nn.Linear(dim, dim, bias=False)
162
+ self.qk_norm = QKNorm(self.head_dim, trainable=trainable_rms)
163
+ self.scale = 1.0 / math.sqrt(self.head_dim)
164
+
165
+ def forward(self, hidden_states: torch.Tensor, rope: Optional[Tuple[torch.Tensor, torch.Tensor]]) -> torch.Tensor:
166
+ batch_size, _, _ = hidden_states.shape
167
+ query = self.q_proj(hidden_states).reshape(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
168
+ key = self.k_proj(hidden_states).reshape(batch_size, -1, self.num_kv_heads, self.head_dim).transpose(1, 2)
169
+ value = self.v_proj(hidden_states).reshape(batch_size, -1, self.num_kv_heads, self.head_dim).transpose(1, 2)
170
+
171
+ if rope is not None:
172
+ query = apply_rotary_emb(query, rope)
173
+ key = apply_rotary_emb(key, rope)
174
+ query, key = self.qk_norm(query, key)
175
+
176
+ if self.num_groups > 1:
177
+ key = torch.repeat_interleave(key, self.num_groups, dim=1)
178
+ value = torch.repeat_interleave(value, self.num_groups, dim=1)
179
+
180
+ hidden_states = F.scaled_dot_product_attention(query, key, value, scale=self.scale)
181
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim)
182
+ return self.proj(hidden_states)
183
+
184
+
185
+ class DiTBlock(nn.Module):
186
+ def __init__(
187
+ self,
188
+ hidden_size: int,
189
+ num_heads: int,
190
+ num_kv_heads: int,
191
+ mlp_hidden_dim: int,
192
+ qkv_bias: bool,
193
+ trainable_rms: bool,
194
+ norm_eps: float,
195
+ init_alpha: float,
196
+ init_beta: float,
197
+ ):
198
+ super().__init__()
199
+ self.attn = Attention(hidden_size, num_heads, num_kv_heads, qkv_bias=qkv_bias, trainable_rms=trainable_rms)
200
+ self.ffn = nn.Sequential(
201
+ SwiGLU(hidden_size, mlp_hidden_dim, bias=qkv_bias),
202
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=qkv_bias),
203
+ )
204
+ self.norm1 = FusedMVSplitNorm1(hidden_size, eps=norm_eps, init_alpha=init_alpha, init_beta=init_beta)
205
+ self.norm2 = FusedMVSplitNorm1(hidden_size, eps=norm_eps, init_alpha=init_alpha, init_beta=init_beta)
206
+
207
+ def forward(
208
+ self,
209
+ hidden_states: torch.Tensor,
210
+ rope: Optional[Tuple[torch.Tensor, torch.Tensor]],
211
+ l_image_tokens: Optional[int],
212
+ ) -> torch.Tensor:
213
+ residual = hidden_states
214
+ hidden_states = self.attn(hidden_states, rope=rope)
215
+ hidden_states = self.norm1(residual, hidden_states, l_image_tokens=l_image_tokens)
216
+
217
+ residual = hidden_states
218
+ hidden_states = self.ffn(hidden_states)
219
+ hidden_states = self.norm2(residual, hidden_states, l_image_tokens=l_image_tokens)
220
+ return hidden_states
221
+
222
+
223
+ class MVSplitDiTTransformer2DModel(ModelMixin, ConfigMixin):
224
+ config_name = "config.json"
225
+
226
+ @register_to_config
227
+ def __init__(
228
+ self,
229
+ in_channels: int = 128,
230
+ patch_size: int = 1,
231
+ hidden_size: int = 1024,
232
+ depth: int = 1000,
233
+ num_heads: int = 8,
234
+ num_kv_heads: int = 8,
235
+ mlp_hidden_dim: int = 3072,
236
+ context_dim: int = 1024,
237
+ qkv_bias: bool = False,
238
+ trainable_rms: bool = False,
239
+ use_rope: bool = True,
240
+ rope_base: int = 10000,
241
+ norm_eps: float = 1e-5,
242
+ init_alpha: float = 0.0,
243
+ init_beta: float = 0.03,
244
+ ):
245
+ super().__init__()
246
+ self.in_channels = in_channels
247
+ self.out_channels = in_channels
248
+ self.patch_size = patch_size
249
+ self.hidden_size = hidden_size
250
+ self.use_rope = use_rope
251
+ self.rope_dim = hidden_size // (2 * num_heads)
252
+
253
+ self.patch_embed = PatchEmbed(
254
+ height=1,
255
+ width=1,
256
+ patch_size=patch_size,
257
+ in_channels=in_channels,
258
+ embed_dim=hidden_size,
259
+ layer_norm=False,
260
+ flatten=True,
261
+ bias=True,
262
+ pos_embed_type=None,
263
+ )
264
+ self.norm_img_input = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=trainable_rms)
265
+ self.norm_text_input = RMSNorm(hidden_size, eps=norm_eps, elementwise_affine=trainable_rms)
266
+ self.context_proj = nn.Identity() if context_dim == hidden_size else nn.Linear(context_dim, hidden_size, bias=False)
267
+ self.rope = TwoDimRotary(self.rope_dim, base=rope_base) if use_rope else None
268
+
269
+ self.blocks = nn.ModuleList(
270
+ [
271
+ DiTBlock(
272
+ hidden_size=hidden_size,
273
+ num_heads=num_heads,
274
+ num_kv_heads=num_kv_heads,
275
+ mlp_hidden_dim=mlp_hidden_dim,
276
+ qkv_bias=qkv_bias,
277
+ trainable_rms=trainable_rms,
278
+ norm_eps=norm_eps,
279
+ init_alpha=init_alpha,
280
+ init_beta=init_beta,
281
+ )
282
+ for _ in range(depth)
283
+ ]
284
+ )
285
+ self.final_proj = nn.Linear(hidden_size, patch_size * patch_size * self.out_channels, bias=True)
286
+
287
+ def _unpatchify(
288
+ self,
289
+ hidden_states: torch.Tensor,
290
+ batch_size: int,
291
+ height_tokens: int,
292
+ width_tokens: int,
293
+ ) -> torch.Tensor:
294
+ patch = self.patch_size
295
+ hidden_states = hidden_states.reshape(
296
+ batch_size, height_tokens, width_tokens, patch, patch, self.out_channels
297
+ )
298
+ hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4).reshape(
299
+ batch_size, self.out_channels, height_tokens * patch, width_tokens * patch
300
+ )
301
+ return hidden_states
302
+
303
+ def forward(
304
+ self,
305
+ hidden_states: torch.Tensor,
306
+ encoder_hidden_states: torch.Tensor,
307
+ timestep: Optional[Union[torch.Tensor, float]] = None,
308
+ return_dict: bool = True,
309
+ ) -> Union[MVSplitDiTTransformer2DModelOutput, Tuple[torch.Tensor]]:
310
+ del timestep
311
+ if hidden_states.ndim != 4:
312
+ raise ValueError("hidden_states must have shape [B, C, H, W].")
313
+ if encoder_hidden_states.ndim != 3:
314
+ raise ValueError("encoder_hidden_states must have shape [B, L_text, context_dim].")
315
+
316
+ batch_size, channels, height, width = hidden_states.shape
317
+ if channels != self.in_channels:
318
+ raise ValueError(f"Expected {self.in_channels} latent channels, got {channels}.")
319
+ if height % self.patch_size != 0 or width % self.patch_size != 0:
320
+ raise ValueError("Latent height and width must be divisible by patch_size.")
321
+
322
+ height_tokens = height // self.patch_size
323
+ width_tokens = width // self.patch_size
324
+ image_tokens = self.norm_img_input(self.patch_embed(hidden_states))
325
+ l_image_tokens = image_tokens.shape[1]
326
+
327
+ text_tokens = self.norm_text_input(self.context_proj(encoder_hidden_states))
328
+ sequence = torch.cat([image_tokens, text_tokens], dim=1)
329
+
330
+ rope = None
331
+ if self.use_rope and self.rope is not None:
332
+ cos_image, sin_image = self.rope(height_tokens, width_tokens, sequence.device, sequence.dtype)
333
+ text_length = text_tokens.shape[1]
334
+ rope_width = cos_image.shape[-1]
335
+ if text_length > 0:
336
+ cos_text = torch.ones((text_length, rope_width), device=sequence.device, dtype=sequence.dtype)
337
+ sin_text = torch.zeros((text_length, rope_width), device=sequence.device, dtype=sequence.dtype)
338
+ rope = (torch.cat([cos_image, cos_text], dim=0), torch.cat([sin_image, sin_text], dim=0))
339
+ else:
340
+ rope = (cos_image, sin_image)
341
+
342
+ for block in self.blocks:
343
+ sequence = block(sequence, rope=rope, l_image_tokens=l_image_tokens)
344
+
345
+ sequence = self.final_proj(sequence[:, :l_image_tokens, :])
346
+ sequence = self._unpatchify(sequence, batch_size=batch_size, height_tokens=height_tokens, width_tokens=width_tokens)
347
+
348
+ if not return_dict:
349
+ return (sequence,)
350
+ return MVSplitDiTTransformer2DModelOutput(sample=sequence)
vae/config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKLFlux2",
3
+ "_diffusers_version": "0.37.0.dev0",
4
+ "_name_or_path": "black-forest-labs/FLUX.2-dev",
5
+ "act_fn": "silu",
6
+ "batch_norm_eps": 0.0001,
7
+ "batch_norm_momentum": 0.1,
8
+ "block_out_channels": [
9
+ 128,
10
+ 256,
11
+ 512,
12
+ 512
13
+ ],
14
+ "down_block_types": [
15
+ "DownEncoderBlock2D",
16
+ "DownEncoderBlock2D",
17
+ "DownEncoderBlock2D",
18
+ "DownEncoderBlock2D"
19
+ ],
20
+ "force_upcast": true,
21
+ "in_channels": 3,
22
+ "latent_channels": 32,
23
+ "layers_per_block": 2,
24
+ "mid_block_add_attention": true,
25
+ "norm_num_groups": 32,
26
+ "out_channels": 3,
27
+ "patch_size": [
28
+ 2,
29
+ 2
30
+ ],
31
+ "sample_size": 1024,
32
+ "up_block_types": [
33
+ "UpDecoderBlock2D",
34
+ "UpDecoderBlock2D",
35
+ "UpDecoderBlock2D",
36
+ "UpDecoderBlock2D"
37
+ ],
38
+ "use_post_quant_conv": true,
39
+ "use_quant_conv": true
40
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca70d2202afe6415bdbcb8793ba8cd99fd159cfe6192381504d6c4d3036e0f04
3
+ size 168120878