BiliSakura commited on
Commit
8a4456c
·
verified ·
1 Parent(s): 2d5c526

Delete pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +0 -278
pipeline.py DELETED
@@ -1,278 +0,0 @@
1
- """Hub custom pipeline: MVSplitDiTPipeline.
2
- Load with native Hugging Face diffusers and trust_remote_code=True.
3
- """
4
-
5
- from __future__ import annotations
6
-
7
- from dataclasses import dataclass
8
- from typing import List, Optional, Tuple, Union
9
-
10
- import torch
11
- from einops import rearrange
12
-
13
- try:
14
- from diffusers.image_processor import VaeImageProcessor
15
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
16
- from diffusers.utils import BaseOutput
17
- except Exception:
18
- class BaseOutput(dict):
19
- def __post_init__(self):
20
- self.update(self.__dict__)
21
-
22
- class DiffusionPipeline:
23
- def register_modules(self, **kwargs):
24
- for name, module in kwargs.items():
25
- setattr(self, name, module)
26
-
27
- @property
28
- def _execution_device(self):
29
- return torch.device("cpu")
30
-
31
- def maybe_free_model_hooks(self):
32
- pass
33
-
34
- class VaeImageProcessor:
35
- def postprocess(self, image, output_type="pil"):
36
- return image
37
-
38
- # DiT operates on packed FLUX2 latents at 1/16 of the image resolution.
39
- LATENT_DOWNSAMPLE_FACTOR = 16
40
-
41
-
42
- @dataclass
43
- class MVSplitDiTPipelineOutput(BaseOutput):
44
- images: Union[torch.FloatTensor, List]
45
-
46
-
47
- class MVSplitDiTPipeline(DiffusionPipeline):
48
- """
49
- Text-to-image pipeline for MVSplit DiT.
50
-
51
- Sampling follows the official mv-split Euler ODE integrator with time-shift
52
- (see https://github.com/erwold/mv-split sample.py).
53
- """
54
-
55
- model_cpu_offload_seq = "text_encoder->transformer->vae"
56
- _optional_components = ["vae", "text_encoder", "tokenizer"]
57
-
58
- def __init__(
59
- self,
60
- transformer,
61
- scheduler=None,
62
- vae=None,
63
- text_encoder=None,
64
- tokenizer=None,
65
- max_length: int = 256,
66
- time_shift_alpha: float = 4.0,
67
- ):
68
- super().__init__()
69
- self.register_modules(
70
- transformer=transformer,
71
- scheduler=scheduler,
72
- vae=vae,
73
- text_encoder=text_encoder,
74
- tokenizer=tokenizer,
75
- )
76
- self.max_length = max_length
77
- self.time_shift_alpha = time_shift_alpha
78
- self.image_processor = VaeImageProcessor()
79
-
80
- @staticmethod
81
- def _shift_time(t: float, alpha: float) -> float:
82
- return t * alpha / (1.0 + (alpha - 1.0) * t)
83
-
84
- def _prepare_latents(
85
- self,
86
- batch_size: int,
87
- height: int,
88
- width: int,
89
- device: torch.device,
90
- generator: Optional[Union[torch.Generator, List[torch.Generator]]],
91
- ) -> torch.Tensor:
92
- if height % LATENT_DOWNSAMPLE_FACTOR != 0 or width % LATENT_DOWNSAMPLE_FACTOR != 0:
93
- raise ValueError(
94
- f"height and width must be divisible by {LATENT_DOWNSAMPLE_FACTOR}."
95
- )
96
-
97
- latent_height = height // LATENT_DOWNSAMPLE_FACTOR
98
- latent_width = width // LATENT_DOWNSAMPLE_FACTOR
99
- latent_shape = (batch_size, self.transformer.config.in_channels, latent_height, latent_width)
100
- gen_device = device
101
- if generator is not None and getattr(generator, "device", None) is not None:
102
- gen_device = generator.device
103
- noise = torch.randn(latent_shape, generator=generator, device=gen_device, dtype=torch.float32)
104
- return noise.to(device)
105
-
106
- def _encode_text(self, text: Union[str, List[str]], device: torch.device) -> torch.Tensor:
107
- if self.tokenizer is None or self.text_encoder is None:
108
- raise ValueError("Both tokenizer and text_encoder must be provided for text-to-image inference.")
109
-
110
- if isinstance(text, str):
111
- text = [text]
112
-
113
- if not self.tokenizer.pad_token:
114
- self.tokenizer.pad_token = self.tokenizer.eos_token
115
-
116
- tokens = self.tokenizer(
117
- text,
118
- padding="longest",
119
- truncation=True,
120
- max_length=self.max_length,
121
- return_tensors="pt",
122
- )
123
- input_ids = tokens.input_ids.to(device)
124
- attention_mask = tokens.attention_mask.to(device)
125
-
126
- text_model = getattr(self.text_encoder, "model", self.text_encoder)
127
- embed_tokens = getattr(text_model, "embed_tokens", None)
128
- if embed_tokens is None:
129
- outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
130
- if hasattr(outputs, "last_hidden_state") and outputs.last_hidden_state is not None:
131
- return outputs.last_hidden_state
132
- if hasattr(outputs, "hidden_states") and outputs.hidden_states is not None:
133
- return outputs.hidden_states[-1]
134
- if isinstance(outputs, (tuple, list)):
135
- return outputs[0]
136
- raise ValueError("Unable to extract text hidden states from text_encoder output.")
137
-
138
- inputs_embeds = embed_tokens(input_ids)
139
- outputs = text_model(
140
- input_ids=None,
141
- attention_mask=attention_mask,
142
- inputs_embeds=inputs_embeds,
143
- )
144
- return outputs.last_hidden_state
145
-
146
- def _decode_latents(self, latents: torch.Tensor) -> torch.Tensor:
147
- if self.vae is None:
148
- return latents
149
-
150
- vae = self.vae
151
- if not hasattr(vae, "bn"):
152
- decoded = vae.decode(latents)
153
- return decoded.sample if hasattr(decoded, "sample") else decoded
154
-
155
- bn = vae.bn.float().eval()
156
- running_var = bn.running_var.view(1, -1, 1, 1)
157
- running_mean = bn.running_mean.view(1, -1, 1, 1)
158
- latents = (latents.float() * torch.sqrt(running_var + bn.eps) + running_mean).to(latents.dtype)
159
-
160
- patch_size = getattr(vae.config, "patch_size", (2, 2))
161
- if isinstance(patch_size, int):
162
- patch_size = (patch_size, patch_size)
163
- latents = rearrange(
164
- latents,
165
- "... (c pi pj) i j -> ... c (i pi) (j pj)",
166
- pi=patch_size[0],
167
- pj=patch_size[1],
168
- )
169
-
170
- decoded = vae.decode(latents)
171
- return decoded.sample if hasattr(decoded, "sample") else decoded
172
-
173
- def _euler_sample(
174
- self,
175
- latents: torch.Tensor,
176
- prompt_embeds: torch.Tensor,
177
- negative_prompt_embeds: Optional[torch.Tensor],
178
- num_inference_steps: int,
179
- guidance_scale: float,
180
- ) -> torch.Tensor:
181
- model_dtype = next(self.transformer.parameters()).dtype
182
- alpha = self.time_shift_alpha
183
- do_cfg = guidance_scale > 1.0 and negative_prompt_embeds is not None
184
-
185
- latents = latents.to(torch.float32)
186
- for step_index in range(num_inference_steps, 0, -1):
187
- t = step_index / num_inference_steps
188
- t_next = (step_index - 1) / num_inference_steps
189
- t_shifted = self._shift_time(t, alpha)
190
- t_next_shifted = self._shift_time(t_next, alpha)
191
- dt = t_shifted - t_next_shifted
192
-
193
- model_input = latents.to(dtype=model_dtype)
194
- if do_cfg:
195
- velocity_cond = self.transformer(
196
- model_input,
197
- encoder_hidden_states=prompt_embeds.to(dtype=model_dtype),
198
- return_dict=True,
199
- ).sample
200
- velocity_uncond = self.transformer(
201
- model_input,
202
- encoder_hidden_states=negative_prompt_embeds.to(dtype=model_dtype),
203
- return_dict=True,
204
- ).sample
205
- velocity = velocity_uncond + guidance_scale * (velocity_cond - velocity_uncond)
206
- else:
207
- velocity = self.transformer(
208
- model_input,
209
- encoder_hidden_states=prompt_embeds.to(dtype=model_dtype),
210
- return_dict=True,
211
- ).sample
212
-
213
- latents = latents + dt * velocity.to(torch.float32)
214
-
215
- return latents
216
-
217
- @torch.no_grad()
218
- def __call__(
219
- self,
220
- prompt: Union[str, List[str]],
221
- negative_prompt: Optional[Union[str, List[str]]] = None,
222
- height: int = 256,
223
- width: int = 256,
224
- num_inference_steps: int = 35,
225
- guidance_scale: float = 2.0,
226
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
227
- output_type: str = "pil",
228
- return_dict: bool = True,
229
- ) -> Union[MVSplitDiTPipelineOutput, Tuple]:
230
- """Run denoising with the MVSplit Euler sampler and decode the output."""
231
- device = self._execution_device
232
-
233
- if isinstance(prompt, str):
234
- prompt = [prompt]
235
- batch_size = len(prompt)
236
-
237
- prompt_embeds = self._encode_text(prompt, device=device)
238
- negative_prompt_embeds = None
239
- if guidance_scale > 1.0:
240
- if negative_prompt is None:
241
- negative_prompt = [""] * batch_size
242
- elif isinstance(negative_prompt, str):
243
- negative_prompt = [negative_prompt] * batch_size
244
- elif len(negative_prompt) != batch_size:
245
- raise ValueError("negative_prompt must have the same batch size as prompt.")
246
-
247
- # Match mv-split sample.py: encode cond + uncond in one batch so empty
248
- # prompts pick up padding from the conditional sequence length.
249
- all_embeds = self._encode_text(list(prompt) + list(negative_prompt), device=device)
250
- prompt_embeds, negative_prompt_embeds = all_embeds.chunk(2, dim=0)
251
-
252
- latents = self._prepare_latents(
253
- batch_size=batch_size,
254
- height=height,
255
- width=width,
256
- device=device,
257
- generator=generator,
258
- )
259
- latents = self._euler_sample(
260
- latents=latents,
261
- prompt_embeds=prompt_embeds,
262
- negative_prompt_embeds=negative_prompt_embeds,
263
- num_inference_steps=num_inference_steps,
264
- guidance_scale=guidance_scale,
265
- )
266
-
267
- if output_type == "latent":
268
- image = latents
269
- else:
270
- decode_dtype = next(self.vae.parameters()).dtype if self.vae is not None else latents.dtype
271
- image = self._decode_latents(latents.to(decode_dtype))
272
- image = image.mul(0.5).add(0.5).clamp(0, 1)
273
- image = self.image_processor.postprocess(image, output_type=output_type)
274
-
275
- self.maybe_free_model_hooks()
276
- if not return_dict:
277
- return (image,)
278
- return MVSplitDiTPipelineOutput(images=image)