BiliSakura commited on
Commit
a0ae8da
·
verified ·
1 Parent(s): 6025bd2

Update iMF-L-2/pipeline.py

Browse files
Files changed (1) hide show
  1. iMF-L-2/pipeline.py +49 -84
iMF-L-2/pipeline.py CHANGED
@@ -1,79 +1,47 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
  """Hub custom pipeline: IMFPipeline.
16
-
17
  Load with native Hugging Face diffusers and trust_remote_code=True.
18
  """
19
 
20
  from __future__ import annotations
21
 
 
 
22
  import json
23
  from pathlib import Path
24
- from typing import Dict, List, Optional, Tuple, Union
25
 
26
  import torch
27
- from diffusers.image_processor import VaeImageProcessor
28
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
29
- from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
30
  from diffusers.utils.torch_utils import randn_tensor
31
 
32
-
33
- def _set_imf_timesteps(
34
- scheduler: FlowMatchEulerDiscreteScheduler,
35
- num_inference_steps: int,
36
- device: torch.device,
37
- ) -> torch.Tensor:
38
- flow_sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1, device=device, dtype=torch.float32)
39
- scheduler.set_timesteps(sigmas=flow_sigmas.tolist(), device=device)
40
- return flow_sigmas
41
-
42
-
43
  class IMFPipeline(DiffusionPipeline):
44
- r"""
45
- Pipeline for ImageNet class-conditional generation with Improved Mean Flows (iMF).
46
 
47
- Parameters:
48
- transformer ([`IMFTransformer2DModel`]):
49
- Class-conditioned iMF transformer that predicts mean-flow velocity.
50
- scheduler ([`FlowMatchEulerDiscreteScheduler`]):
51
- Flow-matching Euler scheduler.
52
- vae ([`AutoencoderKL`]):
53
- Variational autoencoder used to decode transformer latents to pixels.
54
- id2label (`dict[int, str]`, *optional*):
55
- ImageNet class id to English label mapping.
56
- """
 
 
 
57
 
58
- model_cpu_offload_seq = "transformer->vae"
59
 
60
  def __init__(
61
  self,
62
  transformer,
63
  scheduler,
64
- vae,
65
  id2label: Optional[Dict[Union[int, str], str]] = None,
66
  ):
67
  super().__init__()
68
  if scheduler is None:
69
- scheduler = FlowMatchEulerDiscreteScheduler(
70
- num_train_timesteps=1000,
71
- shift=1.0,
72
- stochastic_sampling=False,
73
- )
74
- self.register_modules(transformer=transformer, scheduler=scheduler, vae=vae)
75
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
76
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
77
  self._id2label = self._normalize_id2label(id2label)
78
  self.labels = self._build_label2id(self._id2label)
79
  self._labels_loaded_from_model_index = bool(self._id2label)
@@ -125,12 +93,13 @@ class IMFPipeline(DiffusionPipeline):
125
  self._ensure_labels_loaded()
126
  if not self.labels:
127
  raise ValueError("No labels loaded. Ensure `id2label` exists in model_index.json.")
128
- labels = [label] if isinstance(label, str) else label
129
- missing = [item for item in labels if item not in self.labels]
 
130
  if missing:
131
  preview = ", ".join(list(self.labels.keys())[:8])
132
  raise ValueError(f"Unknown label(s): {missing}. Example valid labels: {preview}, ...")
133
- return [self.labels[item] for item in labels]
134
 
135
  def _normalize_class_labels(self, class_labels: Union[int, str, List[Union[int, str]]]) -> List[int]:
136
  if isinstance(class_labels, int):
@@ -153,16 +122,12 @@ class IMFPipeline(DiffusionPipeline):
153
  guidance_interval_end: float,
154
  do_classifier_free_guidance: bool,
155
  ) -> torch.Tensor:
156
- dtype = latents.dtype
157
- timestep = timestep.to(device=latents.device, dtype=dtype)
158
- time_gap = time_gap.to(device=latents.device, dtype=dtype)
159
-
160
  if do_classifier_free_guidance:
161
  latents_in = torch.cat([latents, latents], dim=0)
162
  labels = torch.cat([class_labels, class_null], dim=0)
163
- omega = torch.tensor([guidance_scale, 1.0], device=latents.device, dtype=dtype)
164
- t_min = torch.tensor([guidance_interval_start, 0.0], device=latents.device, dtype=dtype)
165
- t_max = torch.tensor([guidance_interval_end, 1.0], device=latents.device, dtype=dtype)
166
  batch = latents.shape[0]
167
  timestep_in = timestep.reshape(1).repeat(2 * batch)
168
  time_gap_in = time_gap.reshape(1).repeat(2 * batch)
@@ -175,9 +140,9 @@ class IMFPipeline(DiffusionPipeline):
175
  batch = latents.shape[0]
176
  timestep_in = timestep.reshape(1).repeat(batch)
177
  time_gap_in = time_gap.reshape(1).repeat(batch)
178
- omega = torch.full((batch,), guidance_scale, device=latents.device, dtype=dtype)
179
- t_min = torch.full((batch,), guidance_interval_start, device=latents.device, dtype=dtype)
180
- t_max = torch.full((batch,), guidance_interval_end, device=latents.device, dtype=dtype)
181
 
182
  outputs = self.transformer(
183
  sample=latents_in,
@@ -197,17 +162,6 @@ class IMFPipeline(DiffusionPipeline):
197
  u_cond, u_uncond = velocity_u.chunk(2, dim=0)
198
  return u_uncond + guidance_scale * (u_cond - u_uncond)
199
 
200
- def decode_latents(self, latents: torch.Tensor, output_type: str = "pil"):
201
- if output_type == "latent":
202
- return latents
203
-
204
- scaling_factor = self.vae.config.scaling_factor
205
- latents = latents.to(device=self.vae.device, dtype=self.vae.dtype)
206
- image = self.vae.decode(latents / scaling_factor).sample
207
- if output_type == "pt":
208
- return image
209
- return self.image_processor.postprocess(image, output_type=output_type)
210
-
211
  @torch.inference_mode()
212
  def __call__(
213
  self,
@@ -218,7 +172,7 @@ class IMFPipeline(DiffusionPipeline):
218
  guidance_interval_end: float = 0.9,
219
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
220
  latents: Optional[torch.Tensor] = None,
221
- output_type: str = "pil",
222
  return_dict: bool = True,
223
  ) -> Union[ImagePipelineOutput, Tuple]:
224
  if output_type not in {"pil", "np", "pt", "latent"}:
@@ -230,7 +184,7 @@ class IMFPipeline(DiffusionPipeline):
230
 
231
  image_size = int(self.transformer.config.sample_size)
232
  channels = int(self.transformer.config.in_channels)
233
- null_class_val = int(self.transformer.config.num_classes)
234
 
235
  if latents is None:
236
  latents = randn_tensor(
@@ -244,11 +198,14 @@ class IMFPipeline(DiffusionPipeline):
244
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
245
  class_null = torch.full_like(class_labels_t, null_class_val)
246
 
247
- flow_sigmas = _set_imf_timesteps(self.scheduler, num_inference_steps, latents.device)
 
 
 
248
 
249
  for i in self.progress_bar(range(num_inference_steps)):
250
- t = flow_sigmas[i]
251
- t_next = flow_sigmas[i + 1]
252
  time_gap = t - t_next
253
  velocity_u = self._predict_velocity_u(
254
  latents,
@@ -261,9 +218,18 @@ class IMFPipeline(DiffusionPipeline):
261
  guidance_interval_end,
262
  do_classifier_free_guidance,
263
  )
264
- latents = self.scheduler.step(velocity_u, self.scheduler.timesteps[i], latents).prev_sample
265
 
266
- images = self.decode_latents(latents, output_type=output_type)
 
 
 
 
 
 
 
 
 
267
 
268
  self.maybe_free_model_hooks()
269
 
@@ -271,5 +237,4 @@ class IMFPipeline(DiffusionPipeline):
271
  return (images,)
272
  return ImagePipelineOutput(images=images)
273
 
274
-
275
- IMFPipelineOutput = ImagePipelineOutput
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Hub custom pipeline: IMFPipeline.
 
2
  Load with native Hugging Face diffusers and trust_remote_code=True.
3
  """
4
 
5
  from __future__ import annotations
6
 
7
+ import inspect
8
+
9
  import json
10
  from pathlib import Path
11
+ from typing import Dict, List, Optional, Tuple, Union, Any
12
 
13
  import torch
 
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
15
  from diffusers.utils.torch_utils import randn_tensor
16
 
 
 
 
 
 
 
 
 
 
 
 
17
  class IMFPipeline(DiffusionPipeline):
 
 
18
 
19
+ @staticmethod
20
+ def prepare_extra_step_kwargs(
21
+ scheduler,
22
+ generator=None,
23
+ eta: float | None = None,
24
+ ):
25
+ kwargs = {}
26
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
27
+ if "generator" in step_params:
28
+ kwargs["generator"] = generator
29
+ if eta is not None and "eta" in step_params:
30
+ kwargs["eta"] = eta
31
+ return kwargs
32
 
33
+ model_cpu_offload_seq = "transformer"
34
 
35
  def __init__(
36
  self,
37
  transformer,
38
  scheduler,
 
39
  id2label: Optional[Dict[Union[int, str], str]] = None,
40
  ):
41
  super().__init__()
42
  if scheduler is None:
43
+ raise ValueError("IMFPipeline requires a scheduler loaded from the checkpoint.")
44
+ self.register_modules(transformer=transformer, scheduler=scheduler)
 
 
 
 
 
 
45
  self._id2label = self._normalize_id2label(id2label)
46
  self.labels = self._build_label2id(self._id2label)
47
  self._labels_loaded_from_model_index = bool(self._id2label)
 
93
  self._ensure_labels_loaded()
94
  if not self.labels:
95
  raise ValueError("No labels loaded. Ensure `id2label` exists in model_index.json.")
96
+ if isinstance(label, str):
97
+ label = [label]
98
+ missing = [item for item in label if item not in self.labels]
99
  if missing:
100
  preview = ", ".join(list(self.labels.keys())[:8])
101
  raise ValueError(f"Unknown label(s): {missing}. Example valid labels: {preview}, ...")
102
+ return [self.labels[item] for item in label]
103
 
104
  def _normalize_class_labels(self, class_labels: Union[int, str, List[Union[int, str]]]) -> List[int]:
105
  if isinstance(class_labels, int):
 
122
  guidance_interval_end: float,
123
  do_classifier_free_guidance: bool,
124
  ) -> torch.Tensor:
 
 
 
 
125
  if do_classifier_free_guidance:
126
  latents_in = torch.cat([latents, latents], dim=0)
127
  labels = torch.cat([class_labels, class_null], dim=0)
128
+ omega = torch.tensor([guidance_scale, 1.0], device=latents.device, dtype=latents.dtype)
129
+ t_min = torch.tensor([guidance_interval_start, 0.0], device=latents.device, dtype=latents.dtype)
130
+ t_max = torch.tensor([guidance_interval_end, 1.0], device=latents.device, dtype=latents.dtype)
131
  batch = latents.shape[0]
132
  timestep_in = timestep.reshape(1).repeat(2 * batch)
133
  time_gap_in = time_gap.reshape(1).repeat(2 * batch)
 
140
  batch = latents.shape[0]
141
  timestep_in = timestep.reshape(1).repeat(batch)
142
  time_gap_in = time_gap.reshape(1).repeat(batch)
143
+ omega = torch.full((batch,), guidance_scale, device=latents.device, dtype=latents.dtype)
144
+ t_min = torch.full((batch,), guidance_interval_start, device=latents.device, dtype=latents.dtype)
145
+ t_max = torch.full((batch,), guidance_interval_end, device=latents.device, dtype=latents.dtype)
146
 
147
  outputs = self.transformer(
148
  sample=latents_in,
 
162
  u_cond, u_uncond = velocity_u.chunk(2, dim=0)
163
  return u_uncond + guidance_scale * (u_cond - u_uncond)
164
 
 
 
 
 
 
 
 
 
 
 
 
165
  @torch.inference_mode()
166
  def __call__(
167
  self,
 
172
  guidance_interval_end: float = 0.9,
173
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
174
  latents: Optional[torch.Tensor] = None,
175
+ output_type: Optional[str] = "pil",
176
  return_dict: bool = True,
177
  ) -> Union[ImagePipelineOutput, Tuple]:
178
  if output_type not in {"pil", "np", "pt", "latent"}:
 
184
 
185
  image_size = int(self.transformer.config.sample_size)
186
  channels = int(self.transformer.config.in_channels)
187
+ null_class_val = int(getattr(self.transformer.config, "num_classes", 1000))
188
 
189
  if latents is None:
190
  latents = randn_tensor(
 
198
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
199
  class_null = torch.full_like(class_labels_t, null_class_val)
200
 
201
+ self.scheduler.set_timesteps(num_inference_steps, device=latents.device)
202
+ timesteps = self.scheduler.timesteps
203
+
204
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
205
 
206
  for i in self.progress_bar(range(num_inference_steps)):
207
+ t = timesteps[i]
208
+ t_next = timesteps[i + 1]
209
  time_gap = t - t_next
210
  velocity_u = self._predict_velocity_u(
211
  latents,
 
218
  guidance_interval_end,
219
  do_classifier_free_guidance,
220
  )
221
+ latents = self.scheduler.step(velocity_u, t, latents, **extra_step_kwargs).prev_sample
222
 
223
+ if output_type == "latent":
224
+ images = latents
225
+ else:
226
+ images_pt = latents.float().clamp(-4, 4)
227
+ if output_type == "pt":
228
+ images = images_pt
229
+ elif output_type == "np":
230
+ images = images_pt.cpu().permute(0, 2, 3, 1).numpy()
231
+ else:
232
+ images = self.numpy_to_pil(images_pt.cpu().permute(0, 2, 3, 1).numpy())
233
 
234
  self.maybe_free_model_hooks()
235
 
 
237
  return (images,)
238
  return ImagePipelineOutput(images=images)
239
 
240
+ IMFPipelineOutput = ImagePipelineOutput