BiliSakura commited on
Commit
55bc029
·
verified ·
1 Parent(s): 5673750

Fix generator determinism: forward generator through scheduler steps and seeded noise

Browse files
JiT-B-16/pipeline.py CHANGED
@@ -1,36 +1,24 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import importlib
16
  import json
17
- import sys
18
  from pathlib import Path
19
- from typing import Dict, List, Optional, Tuple, Union
20
 
21
  import torch
22
-
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
- from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
-
28
  RECOMMENDED_NOISE_BY_SIZE = {
29
  256: 1.0,
30
  512: 2.0,
31
  }
32
 
33
-
34
  class JiTPipeline(DiffusionPipeline):
35
  r"""
36
  Pipeline for image generation using JiT (Just image Transformer).
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
45
  """
46
 
47
- model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- @classmethod
50
- def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
51
- """Load a self-contained variant folder locally or from the Hub.
52
-
53
- Examples:
54
- JiTPipeline.from_pretrained(".")
55
- JiTPipeline.from_pretrained("./JiT-H-32")
56
- DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
57
- """
58
- repo_root = Path(__file__).resolve().parent
59
-
60
- if pretrained_model_name_or_path in (None, "", "."):
61
- variant = repo_root
62
- elif (
63
- isinstance(pretrained_model_name_or_path, str)
64
- and "/" in pretrained_model_name_or_path
65
- and not Path(pretrained_model_name_or_path).exists()
66
- ):
67
- from huggingface_hub import snapshot_download
68
-
69
- hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
- if subfolder:
71
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
- cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
- variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
- else:
75
- variant = Path(pretrained_model_name_or_path)
76
- if not variant.is_absolute():
77
- candidate = (Path.cwd() / variant).resolve()
78
- variant = candidate if candidate.exists() else (repo_root / variant).resolve()
79
- if subfolder:
80
- variant = variant / subfolder
81
-
82
- id2label_override = kwargs.pop("id2label", None)
83
- model_kwargs = dict(kwargs)
84
- inserted: List[str] = []
85
-
86
- def _load_component(folder: str, module_name: str, class_name: str):
87
- comp_dir = variant / folder
88
- module_path = comp_dir / f"{module_name}.py"
89
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
90
- if not module_path.exists() or not has_weights:
91
- return None
92
-
93
- comp_path = str(comp_dir)
94
- if comp_path not in sys.path:
95
- sys.path.insert(0, comp_path)
96
- inserted.append(comp_path)
97
-
98
- module = importlib.import_module(module_name)
99
- component_cls = getattr(module, class_name)
100
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
101
-
102
- try:
103
- transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
- try:
105
- scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
- except Exception:
107
- scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
-
109
- if transformer is None:
110
- raise ValueError(f"No loadable transformer found under {variant}")
111
-
112
- variant_path = str(variant)
113
- model_index_path = variant / "model_index.json"
114
- id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
-
116
- pipe = cls(
117
- transformer=transformer,
118
- scheduler=scheduler,
119
- id2label=id2label,
120
- )
121
- if variant_path and hasattr(pipe, "register_to_config"):
122
- pipe.register_to_config(_name_or_path=variant_path)
123
- return pipe
124
- finally:
125
- for comp_path in inserted:
126
- if comp_path in sys.path:
127
- sys.path.remove(comp_path)
128
 
129
  def __init__(
130
  self,
131
  transformer,
132
- scheduler: FlowMatchHeunDiscreteScheduler,
133
  id2label: Optional[Dict[Union[int, str], str]] = None,
134
  ):
135
  super().__init__()
136
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
-
139
  self._id2label = self._normalize_id2label(id2label)
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
146
  return {int(key): value for key, value in id2label.items()}
147
 
148
  @staticmethod
149
- def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
 
 
 
 
150
  if not model_index_path.exists():
151
  return {}
152
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
167
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
- """ImageNet class id to English label string (comma-separated synonyms)."""
171
  return self._id2label
172
 
173
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
174
- r"""
175
- Map ImageNet label strings to class ids.
176
-
177
- Args:
178
- label (`str` or `list[str]`):
179
- One or more English label strings. Each string must match a synonym in `id2label`.
180
- """
181
  label2id = self.labels
182
  if not label2id:
183
- raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
188
  missing = [item for item in label if item not in label2id]
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
- raise ValueError(
192
- f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
- )
194
  return [label2id[item] for item in label]
195
 
196
  def _normalize_class_labels(
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
228
- r"""
229
- Generate class-conditional images.
230
-
231
- Args:
232
- class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
- ImageNet class indices or human-readable English label strings.
234
- guidance_scale (`float`, *optional*):
235
- Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
- guidance_interval_min (`float`, defaults to `0.1`):
237
- Lower bound of the CFG interval in flow time `t in [0, 1]`.
238
- guidance_interval_max (`float`, defaults to `1.0`):
239
- Upper bound of the CFG interval in flow time.
240
- noise_scale (`float`, *optional*):
241
- Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
- t_eps (`float`, defaults to `5e-2`):
243
- Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
244
- generator (`torch.Generator`, *optional*):
245
- RNG for reproducibility.
246
- num_inference_steps (`int`, defaults to `50`):
247
- Number of solver steps (at least 2).
248
- output_type (`str`, *optional*, defaults to `"pil"`):
249
- `"pil"`, `"np"`, or `"pt"`.
250
- return_dict (`bool`, *optional*, defaults to `True`):
251
- Return [`ImagePipelineOutput`] if True.
252
- """
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
 
 
255
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
268
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
  )
270
  channels = int(self.transformer.config.in_channels)
271
- null_class_val = int(self.transformer.config.num_classes)
 
 
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
- latents = (
279
- randn_tensor(
280
  shape=(batch_size, channels, height, width),
281
- generator=generator,
282
- device=self._execution_device,
283
- dtype=self.transformer.dtype,
284
- )
285
- * noise_scale
286
- )
287
 
288
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
295
  class_labels_input = class_labels_t
296
 
297
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
 
298
  for t in self.progress_bar(self.scheduler.timesteps):
299
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
329
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
  model_output = -(x_pred - latents) / sigma
332
- latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
344
  if not return_dict:
345
  return (images,)
346
  return ImagePipelineOutput(images=images)
 
 
 
1
+ """Hub custom pipeline: JiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import inspect
8
+
 
 
 
 
 
 
 
9
  import json
 
10
  from pathlib import Path
11
+ from typing import Dict, List, Optional, Tuple, Union, Any
12
 
13
  import torch
 
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
15
  from diffusers.utils.torch_utils import randn_tensor
16
 
 
17
  RECOMMENDED_NOISE_BY_SIZE = {
18
  256: 1.0,
19
  512: 2.0,
20
  }
21
 
 
22
  class JiTPipeline(DiffusionPipeline):
23
  r"""
24
  Pipeline for image generation using JiT (Just image Transformer).
 
32
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
33
  """
34
 
35
+ @staticmethod
36
+ def prepare_extra_step_kwargs(
37
+ scheduler,
38
+ generator=None,
39
+ eta: float | None = None,
40
+ ):
41
+ kwargs = {}
42
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
43
+ if "generator" in step_params:
44
+ kwargs["generator"] = generator
45
+ if eta is not None and "eta" in step_params:
46
+ kwargs["eta"] = eta
47
+ return kwargs
48
 
49
+ model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def __init__(
52
  self,
53
  transformer,
54
+ scheduler,
55
  id2label: Optional[Dict[Union[int, str], str]] = None,
56
  ):
57
  super().__init__()
58
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
59
  self.register_modules(transformer=transformer, scheduler=scheduler)
 
60
  self._id2label = self._normalize_id2label(id2label)
61
  self.labels = self._build_label2id(self._id2label)
62
+ self._labels_loaded_from_model_index = bool(self._id2label)
63
+
64
+ def _ensure_labels_loaded(self) -> None:
65
+ if self._labels_loaded_from_model_index:
66
+ return
67
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
68
+ if loaded:
69
+ self._id2label = loaded
70
+ self.labels = self._build_label2id(self._id2label)
71
+ self._labels_loaded_from_model_index = True
72
 
73
  @staticmethod
74
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
 
77
  return {int(key): value for key, value in id2label.items()}
78
 
79
  @staticmethod
80
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
81
+ if not variant_path:
82
+ return {}
83
+ variant_dir = Path(variant_path).resolve()
84
+ model_index_path = variant_dir / "model_index.json"
85
  if not model_index_path.exists():
86
  return {}
87
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
 
102
 
103
  @property
104
  def id2label(self) -> Dict[int, str]:
105
+ self._ensure_labels_loaded()
106
  return self._id2label
107
 
108
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
109
+ self._ensure_labels_loaded()
 
 
 
 
 
 
110
  label2id = self.labels
111
  if not label2id:
112
+ raise ValueError(
113
+ "No English labels loaded. Ensure `id2label` exists in model_index.json."
114
+ )
115
 
116
  if isinstance(label, str):
117
  label = [label]
 
119
  missing = [item for item in label if item not in label2id]
120
  if missing:
121
  preview = ", ".join(list(label2id.keys())[:8])
122
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
 
 
123
  return [label2id[item] for item in label]
124
 
125
  def _normalize_class_labels(
 
154
  output_type: Optional[str] = "pil",
155
  return_dict: bool = True,
156
  ) -> Union[ImagePipelineOutput, Tuple]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if num_inference_steps < 2:
158
  raise ValueError("num_inference_steps must be >= 2.")
159
+ if output_type not in {"pil", "np", "pt"}:
160
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
161
 
162
  class_label_ids = self._normalize_class_labels(class_labels)
163
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
 
174
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
175
  )
176
  channels = int(self.transformer.config.in_channels)
177
+ null_class_val = int(
178
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
179
+ )
180
 
181
  if guidance_scale is None:
182
  guidance_scale = 1.0
183
  if noise_scale is None:
184
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
185
 
186
+ latents = randn_tensor(
 
187
  shape=(batch_size, channels, height, width),
188
+ generator=generator,
189
+ device=self._execution_device,
190
+ dtype=self.transformer.dtype,
191
+ ) * noise_scale
 
 
192
 
193
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
194
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
 
200
  class_labels_input = class_labels_t
201
 
202
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
203
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
204
  for t in self.progress_bar(self.scheduler.timesteps):
205
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
206
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
 
235
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
236
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
237
  model_output = -(x_pred - latents) / sigma
238
+ latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
239
 
240
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
241
  if output_type == "pt":
 
250
  if not return_dict:
251
  return (images,)
252
  return ImagePipelineOutput(images=images)
253
+
254
+ JiTPipelineOutput = ImagePipelineOutput
JiT-B-32/pipeline.py CHANGED
@@ -1,36 +1,24 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import importlib
16
  import json
17
- import sys
18
  from pathlib import Path
19
- from typing import Dict, List, Optional, Tuple, Union
20
 
21
  import torch
22
-
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
- from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
-
28
  RECOMMENDED_NOISE_BY_SIZE = {
29
  256: 1.0,
30
  512: 2.0,
31
  }
32
 
33
-
34
  class JiTPipeline(DiffusionPipeline):
35
  r"""
36
  Pipeline for image generation using JiT (Just image Transformer).
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
45
  """
46
 
47
- model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- @classmethod
50
- def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
51
- """Load a self-contained variant folder locally or from the Hub.
52
-
53
- Examples:
54
- JiTPipeline.from_pretrained(".")
55
- JiTPipeline.from_pretrained("./JiT-H-32")
56
- DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
57
- """
58
- repo_root = Path(__file__).resolve().parent
59
-
60
- if pretrained_model_name_or_path in (None, "", "."):
61
- variant = repo_root
62
- elif (
63
- isinstance(pretrained_model_name_or_path, str)
64
- and "/" in pretrained_model_name_or_path
65
- and not Path(pretrained_model_name_or_path).exists()
66
- ):
67
- from huggingface_hub import snapshot_download
68
-
69
- hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
- if subfolder:
71
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
- cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
- variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
- else:
75
- variant = Path(pretrained_model_name_or_path)
76
- if not variant.is_absolute():
77
- candidate = (Path.cwd() / variant).resolve()
78
- variant = candidate if candidate.exists() else (repo_root / variant).resolve()
79
- if subfolder:
80
- variant = variant / subfolder
81
-
82
- id2label_override = kwargs.pop("id2label", None)
83
- model_kwargs = dict(kwargs)
84
- inserted: List[str] = []
85
-
86
- def _load_component(folder: str, module_name: str, class_name: str):
87
- comp_dir = variant / folder
88
- module_path = comp_dir / f"{module_name}.py"
89
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
90
- if not module_path.exists() or not has_weights:
91
- return None
92
-
93
- comp_path = str(comp_dir)
94
- if comp_path not in sys.path:
95
- sys.path.insert(0, comp_path)
96
- inserted.append(comp_path)
97
-
98
- module = importlib.import_module(module_name)
99
- component_cls = getattr(module, class_name)
100
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
101
-
102
- try:
103
- transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
- try:
105
- scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
- except Exception:
107
- scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
-
109
- if transformer is None:
110
- raise ValueError(f"No loadable transformer found under {variant}")
111
-
112
- variant_path = str(variant)
113
- model_index_path = variant / "model_index.json"
114
- id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
-
116
- pipe = cls(
117
- transformer=transformer,
118
- scheduler=scheduler,
119
- id2label=id2label,
120
- )
121
- if variant_path and hasattr(pipe, "register_to_config"):
122
- pipe.register_to_config(_name_or_path=variant_path)
123
- return pipe
124
- finally:
125
- for comp_path in inserted:
126
- if comp_path in sys.path:
127
- sys.path.remove(comp_path)
128
 
129
  def __init__(
130
  self,
131
  transformer,
132
- scheduler: FlowMatchHeunDiscreteScheduler,
133
  id2label: Optional[Dict[Union[int, str], str]] = None,
134
  ):
135
  super().__init__()
136
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
-
139
  self._id2label = self._normalize_id2label(id2label)
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
146
  return {int(key): value for key, value in id2label.items()}
147
 
148
  @staticmethod
149
- def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
 
 
 
 
150
  if not model_index_path.exists():
151
  return {}
152
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
167
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
- """ImageNet class id to English label string (comma-separated synonyms)."""
171
  return self._id2label
172
 
173
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
174
- r"""
175
- Map ImageNet label strings to class ids.
176
-
177
- Args:
178
- label (`str` or `list[str]`):
179
- One or more English label strings. Each string must match a synonym in `id2label`.
180
- """
181
  label2id = self.labels
182
  if not label2id:
183
- raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
188
  missing = [item for item in label if item not in label2id]
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
- raise ValueError(
192
- f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
- )
194
  return [label2id[item] for item in label]
195
 
196
  def _normalize_class_labels(
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
228
- r"""
229
- Generate class-conditional images.
230
-
231
- Args:
232
- class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
- ImageNet class indices or human-readable English label strings.
234
- guidance_scale (`float`, *optional*):
235
- Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
- guidance_interval_min (`float`, defaults to `0.1`):
237
- Lower bound of the CFG interval in flow time `t in [0, 1]`.
238
- guidance_interval_max (`float`, defaults to `1.0`):
239
- Upper bound of the CFG interval in flow time.
240
- noise_scale (`float`, *optional*):
241
- Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
- t_eps (`float`, defaults to `5e-2`):
243
- Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
244
- generator (`torch.Generator`, *optional*):
245
- RNG for reproducibility.
246
- num_inference_steps (`int`, defaults to `50`):
247
- Number of solver steps (at least 2).
248
- output_type (`str`, *optional*, defaults to `"pil"`):
249
- `"pil"`, `"np"`, or `"pt"`.
250
- return_dict (`bool`, *optional*, defaults to `True`):
251
- Return [`ImagePipelineOutput`] if True.
252
- """
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
 
 
255
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
268
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
  )
270
  channels = int(self.transformer.config.in_channels)
271
- null_class_val = int(self.transformer.config.num_classes)
 
 
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
- latents = (
279
- randn_tensor(
280
  shape=(batch_size, channels, height, width),
281
- generator=generator,
282
- device=self._execution_device,
283
- dtype=self.transformer.dtype,
284
- )
285
- * noise_scale
286
- )
287
 
288
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
295
  class_labels_input = class_labels_t
296
 
297
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
 
298
  for t in self.progress_bar(self.scheduler.timesteps):
299
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
329
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
  model_output = -(x_pred - latents) / sigma
332
- latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
344
  if not return_dict:
345
  return (images,)
346
  return ImagePipelineOutput(images=images)
 
 
 
1
+ """Hub custom pipeline: JiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import inspect
8
+
 
 
 
 
 
 
 
9
  import json
 
10
  from pathlib import Path
11
+ from typing import Dict, List, Optional, Tuple, Union, Any
12
 
13
  import torch
 
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
15
  from diffusers.utils.torch_utils import randn_tensor
16
 
 
17
  RECOMMENDED_NOISE_BY_SIZE = {
18
  256: 1.0,
19
  512: 2.0,
20
  }
21
 
 
22
  class JiTPipeline(DiffusionPipeline):
23
  r"""
24
  Pipeline for image generation using JiT (Just image Transformer).
 
32
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
33
  """
34
 
35
+ @staticmethod
36
+ def prepare_extra_step_kwargs(
37
+ scheduler,
38
+ generator=None,
39
+ eta: float | None = None,
40
+ ):
41
+ kwargs = {}
42
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
43
+ if "generator" in step_params:
44
+ kwargs["generator"] = generator
45
+ if eta is not None and "eta" in step_params:
46
+ kwargs["eta"] = eta
47
+ return kwargs
48
 
49
+ model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def __init__(
52
  self,
53
  transformer,
54
+ scheduler,
55
  id2label: Optional[Dict[Union[int, str], str]] = None,
56
  ):
57
  super().__init__()
58
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
59
  self.register_modules(transformer=transformer, scheduler=scheduler)
 
60
  self._id2label = self._normalize_id2label(id2label)
61
  self.labels = self._build_label2id(self._id2label)
62
+ self._labels_loaded_from_model_index = bool(self._id2label)
63
+
64
+ def _ensure_labels_loaded(self) -> None:
65
+ if self._labels_loaded_from_model_index:
66
+ return
67
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
68
+ if loaded:
69
+ self._id2label = loaded
70
+ self.labels = self._build_label2id(self._id2label)
71
+ self._labels_loaded_from_model_index = True
72
 
73
  @staticmethod
74
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
 
77
  return {int(key): value for key, value in id2label.items()}
78
 
79
  @staticmethod
80
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
81
+ if not variant_path:
82
+ return {}
83
+ variant_dir = Path(variant_path).resolve()
84
+ model_index_path = variant_dir / "model_index.json"
85
  if not model_index_path.exists():
86
  return {}
87
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
 
102
 
103
  @property
104
  def id2label(self) -> Dict[int, str]:
105
+ self._ensure_labels_loaded()
106
  return self._id2label
107
 
108
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
109
+ self._ensure_labels_loaded()
 
 
 
 
 
 
110
  label2id = self.labels
111
  if not label2id:
112
+ raise ValueError(
113
+ "No English labels loaded. Ensure `id2label` exists in model_index.json."
114
+ )
115
 
116
  if isinstance(label, str):
117
  label = [label]
 
119
  missing = [item for item in label if item not in label2id]
120
  if missing:
121
  preview = ", ".join(list(label2id.keys())[:8])
122
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
 
 
123
  return [label2id[item] for item in label]
124
 
125
  def _normalize_class_labels(
 
154
  output_type: Optional[str] = "pil",
155
  return_dict: bool = True,
156
  ) -> Union[ImagePipelineOutput, Tuple]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if num_inference_steps < 2:
158
  raise ValueError("num_inference_steps must be >= 2.")
159
+ if output_type not in {"pil", "np", "pt"}:
160
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
161
 
162
  class_label_ids = self._normalize_class_labels(class_labels)
163
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
 
174
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
175
  )
176
  channels = int(self.transformer.config.in_channels)
177
+ null_class_val = int(
178
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
179
+ )
180
 
181
  if guidance_scale is None:
182
  guidance_scale = 1.0
183
  if noise_scale is None:
184
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
185
 
186
+ latents = randn_tensor(
 
187
  shape=(batch_size, channels, height, width),
188
+ generator=generator,
189
+ device=self._execution_device,
190
+ dtype=self.transformer.dtype,
191
+ ) * noise_scale
 
 
192
 
193
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
194
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
 
200
  class_labels_input = class_labels_t
201
 
202
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
203
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
204
  for t in self.progress_bar(self.scheduler.timesteps):
205
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
206
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
 
235
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
236
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
237
  model_output = -(x_pred - latents) / sigma
238
+ latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
239
 
240
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
241
  if output_type == "pt":
 
250
  if not return_dict:
251
  return (images,)
252
  return ImagePipelineOutput(images=images)
253
+
254
+ JiTPipelineOutput = ImagePipelineOutput
JiT-H-16/pipeline.py CHANGED
@@ -1,36 +1,24 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import importlib
16
  import json
17
- import sys
18
  from pathlib import Path
19
- from typing import Dict, List, Optional, Tuple, Union
20
 
21
  import torch
22
-
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
- from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
-
28
  RECOMMENDED_NOISE_BY_SIZE = {
29
  256: 1.0,
30
  512: 2.0,
31
  }
32
 
33
-
34
  class JiTPipeline(DiffusionPipeline):
35
  r"""
36
  Pipeline for image generation using JiT (Just image Transformer).
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
45
  """
46
 
47
- model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- @classmethod
50
- def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
51
- """Load a self-contained variant folder locally or from the Hub.
52
-
53
- Examples:
54
- JiTPipeline.from_pretrained(".")
55
- JiTPipeline.from_pretrained("./JiT-H-32")
56
- DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
57
- """
58
- repo_root = Path(__file__).resolve().parent
59
-
60
- if pretrained_model_name_or_path in (None, "", "."):
61
- variant = repo_root
62
- elif (
63
- isinstance(pretrained_model_name_or_path, str)
64
- and "/" in pretrained_model_name_or_path
65
- and not Path(pretrained_model_name_or_path).exists()
66
- ):
67
- from huggingface_hub import snapshot_download
68
-
69
- hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
- if subfolder:
71
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
- cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
- variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
- else:
75
- variant = Path(pretrained_model_name_or_path)
76
- if not variant.is_absolute():
77
- candidate = (Path.cwd() / variant).resolve()
78
- variant = candidate if candidate.exists() else (repo_root / variant).resolve()
79
- if subfolder:
80
- variant = variant / subfolder
81
-
82
- id2label_override = kwargs.pop("id2label", None)
83
- model_kwargs = dict(kwargs)
84
- inserted: List[str] = []
85
-
86
- def _load_component(folder: str, module_name: str, class_name: str):
87
- comp_dir = variant / folder
88
- module_path = comp_dir / f"{module_name}.py"
89
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
90
- if not module_path.exists() or not has_weights:
91
- return None
92
-
93
- comp_path = str(comp_dir)
94
- if comp_path not in sys.path:
95
- sys.path.insert(0, comp_path)
96
- inserted.append(comp_path)
97
-
98
- module = importlib.import_module(module_name)
99
- component_cls = getattr(module, class_name)
100
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
101
-
102
- try:
103
- transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
- try:
105
- scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
- except Exception:
107
- scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
-
109
- if transformer is None:
110
- raise ValueError(f"No loadable transformer found under {variant}")
111
-
112
- variant_path = str(variant)
113
- model_index_path = variant / "model_index.json"
114
- id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
-
116
- pipe = cls(
117
- transformer=transformer,
118
- scheduler=scheduler,
119
- id2label=id2label,
120
- )
121
- if variant_path and hasattr(pipe, "register_to_config"):
122
- pipe.register_to_config(_name_or_path=variant_path)
123
- return pipe
124
- finally:
125
- for comp_path in inserted:
126
- if comp_path in sys.path:
127
- sys.path.remove(comp_path)
128
 
129
  def __init__(
130
  self,
131
  transformer,
132
- scheduler: FlowMatchHeunDiscreteScheduler,
133
  id2label: Optional[Dict[Union[int, str], str]] = None,
134
  ):
135
  super().__init__()
136
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
-
139
  self._id2label = self._normalize_id2label(id2label)
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
146
  return {int(key): value for key, value in id2label.items()}
147
 
148
  @staticmethod
149
- def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
 
 
 
 
150
  if not model_index_path.exists():
151
  return {}
152
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
167
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
- """ImageNet class id to English label string (comma-separated synonyms)."""
171
  return self._id2label
172
 
173
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
174
- r"""
175
- Map ImageNet label strings to class ids.
176
-
177
- Args:
178
- label (`str` or `list[str]`):
179
- One or more English label strings. Each string must match a synonym in `id2label`.
180
- """
181
  label2id = self.labels
182
  if not label2id:
183
- raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
188
  missing = [item for item in label if item not in label2id]
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
- raise ValueError(
192
- f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
- )
194
  return [label2id[item] for item in label]
195
 
196
  def _normalize_class_labels(
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
228
- r"""
229
- Generate class-conditional images.
230
-
231
- Args:
232
- class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
- ImageNet class indices or human-readable English label strings.
234
- guidance_scale (`float`, *optional*):
235
- Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
- guidance_interval_min (`float`, defaults to `0.1`):
237
- Lower bound of the CFG interval in flow time `t in [0, 1]`.
238
- guidance_interval_max (`float`, defaults to `1.0`):
239
- Upper bound of the CFG interval in flow time.
240
- noise_scale (`float`, *optional*):
241
- Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
- t_eps (`float`, defaults to `5e-2`):
243
- Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
244
- generator (`torch.Generator`, *optional*):
245
- RNG for reproducibility.
246
- num_inference_steps (`int`, defaults to `50`):
247
- Number of solver steps (at least 2).
248
- output_type (`str`, *optional*, defaults to `"pil"`):
249
- `"pil"`, `"np"`, or `"pt"`.
250
- return_dict (`bool`, *optional*, defaults to `True`):
251
- Return [`ImagePipelineOutput`] if True.
252
- """
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
 
 
255
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
268
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
  )
270
  channels = int(self.transformer.config.in_channels)
271
- null_class_val = int(self.transformer.config.num_classes)
 
 
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
- latents = (
279
- randn_tensor(
280
  shape=(batch_size, channels, height, width),
281
- generator=generator,
282
- device=self._execution_device,
283
- dtype=self.transformer.dtype,
284
- )
285
- * noise_scale
286
- )
287
 
288
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
295
  class_labels_input = class_labels_t
296
 
297
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
 
298
  for t in self.progress_bar(self.scheduler.timesteps):
299
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
329
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
  model_output = -(x_pred - latents) / sigma
332
- latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
344
  if not return_dict:
345
  return (images,)
346
  return ImagePipelineOutput(images=images)
 
 
 
1
+ """Hub custom pipeline: JiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import inspect
8
+
 
 
 
 
 
 
 
9
  import json
 
10
  from pathlib import Path
11
+ from typing import Dict, List, Optional, Tuple, Union, Any
12
 
13
  import torch
 
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
15
  from diffusers.utils.torch_utils import randn_tensor
16
 
 
17
  RECOMMENDED_NOISE_BY_SIZE = {
18
  256: 1.0,
19
  512: 2.0,
20
  }
21
 
 
22
  class JiTPipeline(DiffusionPipeline):
23
  r"""
24
  Pipeline for image generation using JiT (Just image Transformer).
 
32
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
33
  """
34
 
35
+ @staticmethod
36
+ def prepare_extra_step_kwargs(
37
+ scheduler,
38
+ generator=None,
39
+ eta: float | None = None,
40
+ ):
41
+ kwargs = {}
42
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
43
+ if "generator" in step_params:
44
+ kwargs["generator"] = generator
45
+ if eta is not None and "eta" in step_params:
46
+ kwargs["eta"] = eta
47
+ return kwargs
48
 
49
+ model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def __init__(
52
  self,
53
  transformer,
54
+ scheduler,
55
  id2label: Optional[Dict[Union[int, str], str]] = None,
56
  ):
57
  super().__init__()
58
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
59
  self.register_modules(transformer=transformer, scheduler=scheduler)
 
60
  self._id2label = self._normalize_id2label(id2label)
61
  self.labels = self._build_label2id(self._id2label)
62
+ self._labels_loaded_from_model_index = bool(self._id2label)
63
+
64
+ def _ensure_labels_loaded(self) -> None:
65
+ if self._labels_loaded_from_model_index:
66
+ return
67
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
68
+ if loaded:
69
+ self._id2label = loaded
70
+ self.labels = self._build_label2id(self._id2label)
71
+ self._labels_loaded_from_model_index = True
72
 
73
  @staticmethod
74
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
 
77
  return {int(key): value for key, value in id2label.items()}
78
 
79
  @staticmethod
80
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
81
+ if not variant_path:
82
+ return {}
83
+ variant_dir = Path(variant_path).resolve()
84
+ model_index_path = variant_dir / "model_index.json"
85
  if not model_index_path.exists():
86
  return {}
87
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
 
102
 
103
  @property
104
  def id2label(self) -> Dict[int, str]:
105
+ self._ensure_labels_loaded()
106
  return self._id2label
107
 
108
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
109
+ self._ensure_labels_loaded()
 
 
 
 
 
 
110
  label2id = self.labels
111
  if not label2id:
112
+ raise ValueError(
113
+ "No English labels loaded. Ensure `id2label` exists in model_index.json."
114
+ )
115
 
116
  if isinstance(label, str):
117
  label = [label]
 
119
  missing = [item for item in label if item not in label2id]
120
  if missing:
121
  preview = ", ".join(list(label2id.keys())[:8])
122
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
 
 
123
  return [label2id[item] for item in label]
124
 
125
  def _normalize_class_labels(
 
154
  output_type: Optional[str] = "pil",
155
  return_dict: bool = True,
156
  ) -> Union[ImagePipelineOutput, Tuple]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if num_inference_steps < 2:
158
  raise ValueError("num_inference_steps must be >= 2.")
159
+ if output_type not in {"pil", "np", "pt"}:
160
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
161
 
162
  class_label_ids = self._normalize_class_labels(class_labels)
163
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
 
174
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
175
  )
176
  channels = int(self.transformer.config.in_channels)
177
+ null_class_val = int(
178
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
179
+ )
180
 
181
  if guidance_scale is None:
182
  guidance_scale = 1.0
183
  if noise_scale is None:
184
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
185
 
186
+ latents = randn_tensor(
 
187
  shape=(batch_size, channels, height, width),
188
+ generator=generator,
189
+ device=self._execution_device,
190
+ dtype=self.transformer.dtype,
191
+ ) * noise_scale
 
 
192
 
193
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
194
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
 
200
  class_labels_input = class_labels_t
201
 
202
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
203
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
204
  for t in self.progress_bar(self.scheduler.timesteps):
205
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
206
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
 
235
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
236
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
237
  model_output = -(x_pred - latents) / sigma
238
+ latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
239
 
240
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
241
  if output_type == "pt":
 
250
  if not return_dict:
251
  return (images,)
252
  return ImagePipelineOutput(images=images)
253
+
254
+ JiTPipelineOutput = ImagePipelineOutput
JiT-H-32/pipeline.py CHANGED
@@ -1,36 +1,24 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import importlib
16
  import json
17
- import sys
18
  from pathlib import Path
19
- from typing import Dict, List, Optional, Tuple, Union
20
 
21
  import torch
22
-
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
- from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
-
28
  RECOMMENDED_NOISE_BY_SIZE = {
29
  256: 1.0,
30
  512: 2.0,
31
  }
32
 
33
-
34
  class JiTPipeline(DiffusionPipeline):
35
  r"""
36
  Pipeline for image generation using JiT (Just image Transformer).
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
45
  """
46
 
47
- model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- @classmethod
50
- def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
51
- """Load a self-contained variant folder locally or from the Hub.
52
-
53
- Examples:
54
- JiTPipeline.from_pretrained(".")
55
- JiTPipeline.from_pretrained("./JiT-H-32")
56
- DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
57
- """
58
- repo_root = Path(__file__).resolve().parent
59
-
60
- if pretrained_model_name_or_path in (None, "", "."):
61
- variant = repo_root
62
- elif (
63
- isinstance(pretrained_model_name_or_path, str)
64
- and "/" in pretrained_model_name_or_path
65
- and not Path(pretrained_model_name_or_path).exists()
66
- ):
67
- from huggingface_hub import snapshot_download
68
-
69
- hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
- if subfolder:
71
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
- cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
- variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
- else:
75
- variant = Path(pretrained_model_name_or_path)
76
- if not variant.is_absolute():
77
- candidate = (Path.cwd() / variant).resolve()
78
- variant = candidate if candidate.exists() else (repo_root / variant).resolve()
79
- if subfolder:
80
- variant = variant / subfolder
81
-
82
- id2label_override = kwargs.pop("id2label", None)
83
- model_kwargs = dict(kwargs)
84
- inserted: List[str] = []
85
-
86
- def _load_component(folder: str, module_name: str, class_name: str):
87
- comp_dir = variant / folder
88
- module_path = comp_dir / f"{module_name}.py"
89
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
90
- if not module_path.exists() or not has_weights:
91
- return None
92
-
93
- comp_path = str(comp_dir)
94
- if comp_path not in sys.path:
95
- sys.path.insert(0, comp_path)
96
- inserted.append(comp_path)
97
-
98
- module = importlib.import_module(module_name)
99
- component_cls = getattr(module, class_name)
100
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
101
-
102
- try:
103
- transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
- try:
105
- scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
- except Exception:
107
- scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
-
109
- if transformer is None:
110
- raise ValueError(f"No loadable transformer found under {variant}")
111
-
112
- variant_path = str(variant)
113
- model_index_path = variant / "model_index.json"
114
- id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
-
116
- pipe = cls(
117
- transformer=transformer,
118
- scheduler=scheduler,
119
- id2label=id2label,
120
- )
121
- if variant_path and hasattr(pipe, "register_to_config"):
122
- pipe.register_to_config(_name_or_path=variant_path)
123
- return pipe
124
- finally:
125
- for comp_path in inserted:
126
- if comp_path in sys.path:
127
- sys.path.remove(comp_path)
128
 
129
  def __init__(
130
  self,
131
  transformer,
132
- scheduler: FlowMatchHeunDiscreteScheduler,
133
  id2label: Optional[Dict[Union[int, str], str]] = None,
134
  ):
135
  super().__init__()
136
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
-
139
  self._id2label = self._normalize_id2label(id2label)
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
146
  return {int(key): value for key, value in id2label.items()}
147
 
148
  @staticmethod
149
- def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
 
 
 
 
150
  if not model_index_path.exists():
151
  return {}
152
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
167
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
- """ImageNet class id to English label string (comma-separated synonyms)."""
171
  return self._id2label
172
 
173
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
174
- r"""
175
- Map ImageNet label strings to class ids.
176
-
177
- Args:
178
- label (`str` or `list[str]`):
179
- One or more English label strings. Each string must match a synonym in `id2label`.
180
- """
181
  label2id = self.labels
182
  if not label2id:
183
- raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
188
  missing = [item for item in label if item not in label2id]
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
- raise ValueError(
192
- f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
- )
194
  return [label2id[item] for item in label]
195
 
196
  def _normalize_class_labels(
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
228
- r"""
229
- Generate class-conditional images.
230
-
231
- Args:
232
- class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
- ImageNet class indices or human-readable English label strings.
234
- guidance_scale (`float`, *optional*):
235
- Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
- guidance_interval_min (`float`, defaults to `0.1`):
237
- Lower bound of the CFG interval in flow time `t in [0, 1]`.
238
- guidance_interval_max (`float`, defaults to `1.0`):
239
- Upper bound of the CFG interval in flow time.
240
- noise_scale (`float`, *optional*):
241
- Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
- t_eps (`float`, defaults to `5e-2`):
243
- Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
244
- generator (`torch.Generator`, *optional*):
245
- RNG for reproducibility.
246
- num_inference_steps (`int`, defaults to `50`):
247
- Number of solver steps (at least 2).
248
- output_type (`str`, *optional*, defaults to `"pil"`):
249
- `"pil"`, `"np"`, or `"pt"`.
250
- return_dict (`bool`, *optional*, defaults to `True`):
251
- Return [`ImagePipelineOutput`] if True.
252
- """
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
 
 
255
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
268
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
  )
270
  channels = int(self.transformer.config.in_channels)
271
- null_class_val = int(self.transformer.config.num_classes)
 
 
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
- latents = (
279
- randn_tensor(
280
  shape=(batch_size, channels, height, width),
281
- generator=generator,
282
- device=self._execution_device,
283
- dtype=self.transformer.dtype,
284
- )
285
- * noise_scale
286
- )
287
 
288
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
295
  class_labels_input = class_labels_t
296
 
297
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
 
298
  for t in self.progress_bar(self.scheduler.timesteps):
299
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
329
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
  model_output = -(x_pred - latents) / sigma
332
- latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
344
  if not return_dict:
345
  return (images,)
346
  return ImagePipelineOutput(images=images)
 
 
 
1
+ """Hub custom pipeline: JiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import inspect
8
+
 
 
 
 
 
 
 
9
  import json
 
10
  from pathlib import Path
11
+ from typing import Dict, List, Optional, Tuple, Union, Any
12
 
13
  import torch
 
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
15
  from diffusers.utils.torch_utils import randn_tensor
16
 
 
17
  RECOMMENDED_NOISE_BY_SIZE = {
18
  256: 1.0,
19
  512: 2.0,
20
  }
21
 
 
22
  class JiTPipeline(DiffusionPipeline):
23
  r"""
24
  Pipeline for image generation using JiT (Just image Transformer).
 
32
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
33
  """
34
 
35
+ @staticmethod
36
+ def prepare_extra_step_kwargs(
37
+ scheduler,
38
+ generator=None,
39
+ eta: float | None = None,
40
+ ):
41
+ kwargs = {}
42
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
43
+ if "generator" in step_params:
44
+ kwargs["generator"] = generator
45
+ if eta is not None and "eta" in step_params:
46
+ kwargs["eta"] = eta
47
+ return kwargs
48
 
49
+ model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def __init__(
52
  self,
53
  transformer,
54
+ scheduler,
55
  id2label: Optional[Dict[Union[int, str], str]] = None,
56
  ):
57
  super().__init__()
58
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
59
  self.register_modules(transformer=transformer, scheduler=scheduler)
 
60
  self._id2label = self._normalize_id2label(id2label)
61
  self.labels = self._build_label2id(self._id2label)
62
+ self._labels_loaded_from_model_index = bool(self._id2label)
63
+
64
+ def _ensure_labels_loaded(self) -> None:
65
+ if self._labels_loaded_from_model_index:
66
+ return
67
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
68
+ if loaded:
69
+ self._id2label = loaded
70
+ self.labels = self._build_label2id(self._id2label)
71
+ self._labels_loaded_from_model_index = True
72
 
73
  @staticmethod
74
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
 
77
  return {int(key): value for key, value in id2label.items()}
78
 
79
  @staticmethod
80
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
81
+ if not variant_path:
82
+ return {}
83
+ variant_dir = Path(variant_path).resolve()
84
+ model_index_path = variant_dir / "model_index.json"
85
  if not model_index_path.exists():
86
  return {}
87
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
 
102
 
103
  @property
104
  def id2label(self) -> Dict[int, str]:
105
+ self._ensure_labels_loaded()
106
  return self._id2label
107
 
108
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
109
+ self._ensure_labels_loaded()
 
 
 
 
 
 
110
  label2id = self.labels
111
  if not label2id:
112
+ raise ValueError(
113
+ "No English labels loaded. Ensure `id2label` exists in model_index.json."
114
+ )
115
 
116
  if isinstance(label, str):
117
  label = [label]
 
119
  missing = [item for item in label if item not in label2id]
120
  if missing:
121
  preview = ", ".join(list(label2id.keys())[:8])
122
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
 
 
123
  return [label2id[item] for item in label]
124
 
125
  def _normalize_class_labels(
 
154
  output_type: Optional[str] = "pil",
155
  return_dict: bool = True,
156
  ) -> Union[ImagePipelineOutput, Tuple]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if num_inference_steps < 2:
158
  raise ValueError("num_inference_steps must be >= 2.")
159
+ if output_type not in {"pil", "np", "pt"}:
160
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
161
 
162
  class_label_ids = self._normalize_class_labels(class_labels)
163
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
 
174
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
175
  )
176
  channels = int(self.transformer.config.in_channels)
177
+ null_class_val = int(
178
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
179
+ )
180
 
181
  if guidance_scale is None:
182
  guidance_scale = 1.0
183
  if noise_scale is None:
184
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
185
 
186
+ latents = randn_tensor(
 
187
  shape=(batch_size, channels, height, width),
188
+ generator=generator,
189
+ device=self._execution_device,
190
+ dtype=self.transformer.dtype,
191
+ ) * noise_scale
 
 
192
 
193
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
194
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
 
200
  class_labels_input = class_labels_t
201
 
202
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
203
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
204
  for t in self.progress_bar(self.scheduler.timesteps):
205
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
206
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
 
235
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
236
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
237
  model_output = -(x_pred - latents) / sigma
238
+ latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
239
 
240
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
241
  if output_type == "pt":
 
250
  if not return_dict:
251
  return (images,)
252
  return ImagePipelineOutput(images=images)
253
+
254
+ JiTPipelineOutput = ImagePipelineOutput
JiT-L-16/pipeline.py CHANGED
@@ -1,36 +1,24 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import importlib
16
  import json
17
- import sys
18
  from pathlib import Path
19
- from typing import Dict, List, Optional, Tuple, Union
20
 
21
  import torch
22
-
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
- from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
-
28
  RECOMMENDED_NOISE_BY_SIZE = {
29
  256: 1.0,
30
  512: 2.0,
31
  }
32
 
33
-
34
  class JiTPipeline(DiffusionPipeline):
35
  r"""
36
  Pipeline for image generation using JiT (Just image Transformer).
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
45
  """
46
 
47
- model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- @classmethod
50
- def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
51
- """Load a self-contained variant folder locally or from the Hub.
52
-
53
- Examples:
54
- JiTPipeline.from_pretrained(".")
55
- JiTPipeline.from_pretrained("./JiT-H-32")
56
- DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
57
- """
58
- repo_root = Path(__file__).resolve().parent
59
-
60
- if pretrained_model_name_or_path in (None, "", "."):
61
- variant = repo_root
62
- elif (
63
- isinstance(pretrained_model_name_or_path, str)
64
- and "/" in pretrained_model_name_or_path
65
- and not Path(pretrained_model_name_or_path).exists()
66
- ):
67
- from huggingface_hub import snapshot_download
68
-
69
- hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
- if subfolder:
71
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
- cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
- variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
- else:
75
- variant = Path(pretrained_model_name_or_path)
76
- if not variant.is_absolute():
77
- candidate = (Path.cwd() / variant).resolve()
78
- variant = candidate if candidate.exists() else (repo_root / variant).resolve()
79
- if subfolder:
80
- variant = variant / subfolder
81
-
82
- id2label_override = kwargs.pop("id2label", None)
83
- model_kwargs = dict(kwargs)
84
- inserted: List[str] = []
85
-
86
- def _load_component(folder: str, module_name: str, class_name: str):
87
- comp_dir = variant / folder
88
- module_path = comp_dir / f"{module_name}.py"
89
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
90
- if not module_path.exists() or not has_weights:
91
- return None
92
-
93
- comp_path = str(comp_dir)
94
- if comp_path not in sys.path:
95
- sys.path.insert(0, comp_path)
96
- inserted.append(comp_path)
97
-
98
- module = importlib.import_module(module_name)
99
- component_cls = getattr(module, class_name)
100
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
101
-
102
- try:
103
- transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
- try:
105
- scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
- except Exception:
107
- scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
-
109
- if transformer is None:
110
- raise ValueError(f"No loadable transformer found under {variant}")
111
-
112
- variant_path = str(variant)
113
- model_index_path = variant / "model_index.json"
114
- id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
-
116
- pipe = cls(
117
- transformer=transformer,
118
- scheduler=scheduler,
119
- id2label=id2label,
120
- )
121
- if variant_path and hasattr(pipe, "register_to_config"):
122
- pipe.register_to_config(_name_or_path=variant_path)
123
- return pipe
124
- finally:
125
- for comp_path in inserted:
126
- if comp_path in sys.path:
127
- sys.path.remove(comp_path)
128
 
129
  def __init__(
130
  self,
131
  transformer,
132
- scheduler: FlowMatchHeunDiscreteScheduler,
133
  id2label: Optional[Dict[Union[int, str], str]] = None,
134
  ):
135
  super().__init__()
136
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
-
139
  self._id2label = self._normalize_id2label(id2label)
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
146
  return {int(key): value for key, value in id2label.items()}
147
 
148
  @staticmethod
149
- def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
 
 
 
 
150
  if not model_index_path.exists():
151
  return {}
152
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
167
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
- """ImageNet class id to English label string (comma-separated synonyms)."""
171
  return self._id2label
172
 
173
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
174
- r"""
175
- Map ImageNet label strings to class ids.
176
-
177
- Args:
178
- label (`str` or `list[str]`):
179
- One or more English label strings. Each string must match a synonym in `id2label`.
180
- """
181
  label2id = self.labels
182
  if not label2id:
183
- raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
188
  missing = [item for item in label if item not in label2id]
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
- raise ValueError(
192
- f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
- )
194
  return [label2id[item] for item in label]
195
 
196
  def _normalize_class_labels(
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
228
- r"""
229
- Generate class-conditional images.
230
-
231
- Args:
232
- class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
- ImageNet class indices or human-readable English label strings.
234
- guidance_scale (`float`, *optional*):
235
- Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
- guidance_interval_min (`float`, defaults to `0.1`):
237
- Lower bound of the CFG interval in flow time `t in [0, 1]`.
238
- guidance_interval_max (`float`, defaults to `1.0`):
239
- Upper bound of the CFG interval in flow time.
240
- noise_scale (`float`, *optional*):
241
- Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
- t_eps (`float`, defaults to `5e-2`):
243
- Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
244
- generator (`torch.Generator`, *optional*):
245
- RNG for reproducibility.
246
- num_inference_steps (`int`, defaults to `50`):
247
- Number of solver steps (at least 2).
248
- output_type (`str`, *optional*, defaults to `"pil"`):
249
- `"pil"`, `"np"`, or `"pt"`.
250
- return_dict (`bool`, *optional*, defaults to `True`):
251
- Return [`ImagePipelineOutput`] if True.
252
- """
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
 
 
255
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
268
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
  )
270
  channels = int(self.transformer.config.in_channels)
271
- null_class_val = int(self.transformer.config.num_classes)
 
 
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
- latents = (
279
- randn_tensor(
280
  shape=(batch_size, channels, height, width),
281
- generator=generator,
282
- device=self._execution_device,
283
- dtype=self.transformer.dtype,
284
- )
285
- * noise_scale
286
- )
287
 
288
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
295
  class_labels_input = class_labels_t
296
 
297
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
 
298
  for t in self.progress_bar(self.scheduler.timesteps):
299
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
329
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
  model_output = -(x_pred - latents) / sigma
332
- latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
344
  if not return_dict:
345
  return (images,)
346
  return ImagePipelineOutput(images=images)
 
 
 
1
+ """Hub custom pipeline: JiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import inspect
8
+
 
 
 
 
 
 
 
9
  import json
 
10
  from pathlib import Path
11
+ from typing import Dict, List, Optional, Tuple, Union, Any
12
 
13
  import torch
 
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
15
  from diffusers.utils.torch_utils import randn_tensor
16
 
 
17
  RECOMMENDED_NOISE_BY_SIZE = {
18
  256: 1.0,
19
  512: 2.0,
20
  }
21
 
 
22
  class JiTPipeline(DiffusionPipeline):
23
  r"""
24
  Pipeline for image generation using JiT (Just image Transformer).
 
32
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
33
  """
34
 
35
+ @staticmethod
36
+ def prepare_extra_step_kwargs(
37
+ scheduler,
38
+ generator=None,
39
+ eta: float | None = None,
40
+ ):
41
+ kwargs = {}
42
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
43
+ if "generator" in step_params:
44
+ kwargs["generator"] = generator
45
+ if eta is not None and "eta" in step_params:
46
+ kwargs["eta"] = eta
47
+ return kwargs
48
 
49
+ model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def __init__(
52
  self,
53
  transformer,
54
+ scheduler,
55
  id2label: Optional[Dict[Union[int, str], str]] = None,
56
  ):
57
  super().__init__()
58
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
59
  self.register_modules(transformer=transformer, scheduler=scheduler)
 
60
  self._id2label = self._normalize_id2label(id2label)
61
  self.labels = self._build_label2id(self._id2label)
62
+ self._labels_loaded_from_model_index = bool(self._id2label)
63
+
64
+ def _ensure_labels_loaded(self) -> None:
65
+ if self._labels_loaded_from_model_index:
66
+ return
67
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
68
+ if loaded:
69
+ self._id2label = loaded
70
+ self.labels = self._build_label2id(self._id2label)
71
+ self._labels_loaded_from_model_index = True
72
 
73
  @staticmethod
74
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
 
77
  return {int(key): value for key, value in id2label.items()}
78
 
79
  @staticmethod
80
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
81
+ if not variant_path:
82
+ return {}
83
+ variant_dir = Path(variant_path).resolve()
84
+ model_index_path = variant_dir / "model_index.json"
85
  if not model_index_path.exists():
86
  return {}
87
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
 
102
 
103
  @property
104
  def id2label(self) -> Dict[int, str]:
105
+ self._ensure_labels_loaded()
106
  return self._id2label
107
 
108
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
109
+ self._ensure_labels_loaded()
 
 
 
 
 
 
110
  label2id = self.labels
111
  if not label2id:
112
+ raise ValueError(
113
+ "No English labels loaded. Ensure `id2label` exists in model_index.json."
114
+ )
115
 
116
  if isinstance(label, str):
117
  label = [label]
 
119
  missing = [item for item in label if item not in label2id]
120
  if missing:
121
  preview = ", ".join(list(label2id.keys())[:8])
122
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
 
 
123
  return [label2id[item] for item in label]
124
 
125
  def _normalize_class_labels(
 
154
  output_type: Optional[str] = "pil",
155
  return_dict: bool = True,
156
  ) -> Union[ImagePipelineOutput, Tuple]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if num_inference_steps < 2:
158
  raise ValueError("num_inference_steps must be >= 2.")
159
+ if output_type not in {"pil", "np", "pt"}:
160
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
161
 
162
  class_label_ids = self._normalize_class_labels(class_labels)
163
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
 
174
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
175
  )
176
  channels = int(self.transformer.config.in_channels)
177
+ null_class_val = int(
178
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
179
+ )
180
 
181
  if guidance_scale is None:
182
  guidance_scale = 1.0
183
  if noise_scale is None:
184
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
185
 
186
+ latents = randn_tensor(
 
187
  shape=(batch_size, channels, height, width),
188
+ generator=generator,
189
+ device=self._execution_device,
190
+ dtype=self.transformer.dtype,
191
+ ) * noise_scale
 
 
192
 
193
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
194
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
 
200
  class_labels_input = class_labels_t
201
 
202
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
203
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
204
  for t in self.progress_bar(self.scheduler.timesteps):
205
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
206
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
 
235
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
236
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
237
  model_output = -(x_pred - latents) / sigma
238
+ latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
239
 
240
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
241
  if output_type == "pt":
 
250
  if not return_dict:
251
  return (images,)
252
  return ImagePipelineOutput(images=images)
253
+
254
+ JiTPipelineOutput = ImagePipelineOutput
JiT-L-32/pipeline.py CHANGED
@@ -1,36 +1,24 @@
1
- # Copyright 2026 The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import importlib
16
  import json
17
- import sys
18
  from pathlib import Path
19
- from typing import Dict, List, Optional, Tuple, Union
20
 
21
  import torch
22
-
23
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
24
- from diffusers.schedulers import FlowMatchHeunDiscreteScheduler, KarrasDiffusionSchedulers
25
  from diffusers.utils.torch_utils import randn_tensor
26
 
27
-
28
  RECOMMENDED_NOISE_BY_SIZE = {
29
  256: 1.0,
30
  512: 2.0,
31
  }
32
 
33
-
34
  class JiTPipeline(DiffusionPipeline):
35
  r"""
36
  Pipeline for image generation using JiT (Just image Transformer).
@@ -44,100 +32,43 @@ class JiTPipeline(DiffusionPipeline):
44
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
45
  """
46
 
47
- model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- @classmethod
50
- def from_pretrained(cls, pretrained_model_name_or_path=None, subfolder=None, **kwargs):
51
- """Load a self-contained variant folder locally or from the Hub.
52
-
53
- Examples:
54
- JiTPipeline.from_pretrained(".")
55
- JiTPipeline.from_pretrained("./JiT-H-32")
56
- DiffusionPipeline.from_pretrained("BiliSakura/JiT-diffusers", subfolder="JiT-H-32", trust_remote_code=True)
57
- """
58
- repo_root = Path(__file__).resolve().parent
59
-
60
- if pretrained_model_name_or_path in (None, "", "."):
61
- variant = repo_root
62
- elif (
63
- isinstance(pretrained_model_name_or_path, str)
64
- and "/" in pretrained_model_name_or_path
65
- and not Path(pretrained_model_name_or_path).exists()
66
- ):
67
- from huggingface_hub import snapshot_download
68
-
69
- hub_kwargs = dict(kwargs.pop("hub_kwargs", {}))
70
- if subfolder:
71
- hub_kwargs.setdefault("allow_patterns", [f"{subfolder}/**"])
72
- cache_dir = snapshot_download(pretrained_model_name_or_path, **hub_kwargs)
73
- variant = Path(cache_dir) / subfolder if subfolder else Path(cache_dir)
74
- else:
75
- variant = Path(pretrained_model_name_or_path)
76
- if not variant.is_absolute():
77
- candidate = (Path.cwd() / variant).resolve()
78
- variant = candidate if candidate.exists() else (repo_root / variant).resolve()
79
- if subfolder:
80
- variant = variant / subfolder
81
-
82
- id2label_override = kwargs.pop("id2label", None)
83
- model_kwargs = dict(kwargs)
84
- inserted: List[str] = []
85
-
86
- def _load_component(folder: str, module_name: str, class_name: str):
87
- comp_dir = variant / folder
88
- module_path = comp_dir / f"{module_name}.py"
89
- has_weights = (comp_dir / "config.json").exists() or (comp_dir / "scheduler_config.json").exists()
90
- if not module_path.exists() or not has_weights:
91
- return None
92
-
93
- comp_path = str(comp_dir)
94
- if comp_path not in sys.path:
95
- sys.path.insert(0, comp_path)
96
- inserted.append(comp_path)
97
-
98
- module = importlib.import_module(module_name)
99
- component_cls = getattr(module, class_name)
100
- return component_cls.from_pretrained(str(comp_dir), **model_kwargs)
101
-
102
- try:
103
- transformer = _load_component("transformer", "jit_transformer_2d", "JiTTransformer2DModel")
104
- try:
105
- scheduler = FlowMatchHeunDiscreteScheduler.from_pretrained(str(variant), subfolder="scheduler")
106
- except Exception:
107
- scheduler = FlowMatchHeunDiscreteScheduler(shift=4.0)
108
-
109
- if transformer is None:
110
- raise ValueError(f"No loadable transformer found under {variant}")
111
-
112
- variant_path = str(variant)
113
- model_index_path = variant / "model_index.json"
114
- id2label = id2label_override or cls._read_id2label_from_model_index(model_index_path)
115
-
116
- pipe = cls(
117
- transformer=transformer,
118
- scheduler=scheduler,
119
- id2label=id2label,
120
- )
121
- if variant_path and hasattr(pipe, "register_to_config"):
122
- pipe.register_to_config(_name_or_path=variant_path)
123
- return pipe
124
- finally:
125
- for comp_path in inserted:
126
- if comp_path in sys.path:
127
- sys.path.remove(comp_path)
128
 
129
  def __init__(
130
  self,
131
  transformer,
132
- scheduler: FlowMatchHeunDiscreteScheduler,
133
  id2label: Optional[Dict[Union[int, str], str]] = None,
134
  ):
135
  super().__init__()
136
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
137
  self.register_modules(transformer=transformer, scheduler=scheduler)
138
-
139
  self._id2label = self._normalize_id2label(id2label)
140
  self.labels = self._build_label2id(self._id2label)
 
 
 
 
 
 
 
 
 
 
141
 
142
  @staticmethod
143
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
@@ -146,7 +77,11 @@ class JiTPipeline(DiffusionPipeline):
146
  return {int(key): value for key, value in id2label.items()}
147
 
148
  @staticmethod
149
- def _read_id2label_from_model_index(model_index_path: Path) -> Dict[int, str]:
 
 
 
 
150
  if not model_index_path.exists():
151
  return {}
152
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
@@ -167,20 +102,16 @@ class JiTPipeline(DiffusionPipeline):
167
 
168
  @property
169
  def id2label(self) -> Dict[int, str]:
170
- """ImageNet class id to English label string (comma-separated synonyms)."""
171
  return self._id2label
172
 
173
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
174
- r"""
175
- Map ImageNet label strings to class ids.
176
-
177
- Args:
178
- label (`str` or `list[str]`):
179
- One or more English label strings. Each string must match a synonym in `id2label`.
180
- """
181
  label2id = self.labels
182
  if not label2id:
183
- raise ValueError("No English labels loaded. Ensure `id2label` exists in model_index.json.")
 
 
184
 
185
  if isinstance(label, str):
186
  label = [label]
@@ -188,9 +119,7 @@ class JiTPipeline(DiffusionPipeline):
188
  missing = [item for item in label if item not in label2id]
189
  if missing:
190
  preview = ", ".join(list(label2id.keys())[:8])
191
- raise ValueError(
192
- f"Unknown English label(s): {missing}. Example valid labels: {preview}, ..."
193
- )
194
  return [label2id[item] for item in label]
195
 
196
  def _normalize_class_labels(
@@ -225,33 +154,10 @@ class JiTPipeline(DiffusionPipeline):
225
  output_type: Optional[str] = "pil",
226
  return_dict: bool = True,
227
  ) -> Union[ImagePipelineOutput, Tuple]:
228
- r"""
229
- Generate class-conditional images.
230
-
231
- Args:
232
- class_labels (`int`, `str`, `list[int]`, or `list[str]`):
233
- ImageNet class indices or human-readable English label strings.
234
- guidance_scale (`float`, *optional*):
235
- Classifier-free guidance scale. CFG is active when `guidance_scale > 1.0`.
236
- guidance_interval_min (`float`, defaults to `0.1`):
237
- Lower bound of the CFG interval in flow time `t in [0, 1]`.
238
- guidance_interval_max (`float`, defaults to `1.0`):
239
- Upper bound of the CFG interval in flow time.
240
- noise_scale (`float`, *optional*):
241
- Initial Gaussian noise scale (`1.0` for 256px, `2.0` for 512px by default).
242
- t_eps (`float`, defaults to `5e-2`):
243
- Epsilon clamp for the `1 - t` denominator, matching JiT source defaults.
244
- generator (`torch.Generator`, *optional*):
245
- RNG for reproducibility.
246
- num_inference_steps (`int`, defaults to `50`):
247
- Number of solver steps (at least 2).
248
- output_type (`str`, *optional*, defaults to `"pil"`):
249
- `"pil"`, `"np"`, or `"pt"`.
250
- return_dict (`bool`, *optional*, defaults to `True`):
251
- Return [`ImagePipelineOutput`] if True.
252
- """
253
  if num_inference_steps < 2:
254
  raise ValueError("num_inference_steps must be >= 2.")
 
 
255
 
256
  class_label_ids = self._normalize_class_labels(class_labels)
257
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
@@ -268,22 +174,21 @@ class JiTPipeline(DiffusionPipeline):
268
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
269
  )
270
  channels = int(self.transformer.config.in_channels)
271
- null_class_val = int(self.transformer.config.num_classes)
 
 
272
 
273
  if guidance_scale is None:
274
  guidance_scale = 1.0
275
  if noise_scale is None:
276
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
277
 
278
- latents = (
279
- randn_tensor(
280
  shape=(batch_size, channels, height, width),
281
- generator=generator,
282
- device=self._execution_device,
283
- dtype=self.transformer.dtype,
284
- )
285
- * noise_scale
286
- )
287
 
288
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
289
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
@@ -295,6 +200,7 @@ class JiTPipeline(DiffusionPipeline):
295
  class_labels_input = class_labels_t
296
 
297
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
 
298
  for t in self.progress_bar(self.scheduler.timesteps):
299
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
300
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
@@ -329,7 +235,7 @@ class JiTPipeline(DiffusionPipeline):
329
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
330
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
331
  model_output = -(x_pred - latents) / sigma
332
- latents = self.scheduler.step(model_output, t, latents).prev_sample
333
 
334
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
335
  if output_type == "pt":
@@ -344,3 +250,5 @@ class JiTPipeline(DiffusionPipeline):
344
  if not return_dict:
345
  return (images,)
346
  return ImagePipelineOutput(images=images)
 
 
 
1
+ """Hub custom pipeline: JiTPipeline.
2
+ Load with native Hugging Face diffusers and trust_remote_code=True.
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import inspect
8
+
 
 
 
 
 
 
 
9
  import json
 
10
  from pathlib import Path
11
+ from typing import Dict, List, Optional, Tuple, Union, Any
12
 
13
  import torch
 
14
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
 
15
  from diffusers.utils.torch_utils import randn_tensor
16
 
 
17
  RECOMMENDED_NOISE_BY_SIZE = {
18
  256: 1.0,
19
  512: 2.0,
20
  }
21
 
 
22
  class JiTPipeline(DiffusionPipeline):
23
  r"""
24
  Pipeline for image generation using JiT (Just image Transformer).
 
32
  ImageNet class id to English label mapping. Values may contain comma-separated synonyms.
33
  """
34
 
35
+ @staticmethod
36
+ def prepare_extra_step_kwargs(
37
+ scheduler,
38
+ generator=None,
39
+ eta: float | None = None,
40
+ ):
41
+ kwargs = {}
42
+ step_params = set(inspect.signature(scheduler.step).parameters.keys())
43
+ if "generator" in step_params:
44
+ kwargs["generator"] = generator
45
+ if eta is not None and "eta" in step_params:
46
+ kwargs["eta"] = eta
47
+ return kwargs
48
 
49
+ model_cpu_offload_seq = "transformer"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def __init__(
52
  self,
53
  transformer,
54
+ scheduler,
55
  id2label: Optional[Dict[Union[int, str], str]] = None,
56
  ):
57
  super().__init__()
58
  scheduler = scheduler or FlowMatchHeunDiscreteScheduler(shift=4.0)
59
  self.register_modules(transformer=transformer, scheduler=scheduler)
 
60
  self._id2label = self._normalize_id2label(id2label)
61
  self.labels = self._build_label2id(self._id2label)
62
+ self._labels_loaded_from_model_index = bool(self._id2label)
63
+
64
+ def _ensure_labels_loaded(self) -> None:
65
+ if self._labels_loaded_from_model_index:
66
+ return
67
+ loaded = self._read_id2label_from_model_index(getattr(self.config, "_name_or_path", None))
68
+ if loaded:
69
+ self._id2label = loaded
70
+ self.labels = self._build_label2id(self._id2label)
71
+ self._labels_loaded_from_model_index = True
72
 
73
  @staticmethod
74
  def _normalize_id2label(id2label: Optional[Dict[Union[int, str], str]]) -> Dict[int, str]:
 
77
  return {int(key): value for key, value in id2label.items()}
78
 
79
  @staticmethod
80
+ def _read_id2label_from_model_index(variant_path: Optional[str]) -> Dict[int, str]:
81
+ if not variant_path:
82
+ return {}
83
+ variant_dir = Path(variant_path).resolve()
84
+ model_index_path = variant_dir / "model_index.json"
85
  if not model_index_path.exists():
86
  return {}
87
  raw = json.loads(model_index_path.read_text(encoding="utf-8"))
 
102
 
103
  @property
104
  def id2label(self) -> Dict[int, str]:
105
+ self._ensure_labels_loaded()
106
  return self._id2label
107
 
108
  def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
109
+ self._ensure_labels_loaded()
 
 
 
 
 
 
110
  label2id = self.labels
111
  if not label2id:
112
+ raise ValueError(
113
+ "No English labels loaded. Ensure `id2label` exists in model_index.json."
114
+ )
115
 
116
  if isinstance(label, str):
117
  label = [label]
 
119
  missing = [item for item in label if item not in label2id]
120
  if missing:
121
  preview = ", ".join(list(label2id.keys())[:8])
122
+ raise ValueError(f"Unknown English label(s): {missing}. Example valid labels: {preview}, ...")
 
 
123
  return [label2id[item] for item in label]
124
 
125
  def _normalize_class_labels(
 
154
  output_type: Optional[str] = "pil",
155
  return_dict: bool = True,
156
  ) -> Union[ImagePipelineOutput, Tuple]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if num_inference_steps < 2:
158
  raise ValueError("num_inference_steps must be >= 2.")
159
+ if output_type not in {"pil", "np", "pt"}:
160
+ raise ValueError("output_type must be one of: 'pil', 'np', 'pt'.")
161
 
162
  class_label_ids = self._normalize_class_labels(class_labels)
163
  do_classifier_free_guidance = guidance_scale is not None and guidance_scale > 1.0
 
174
  f"height and width must be divisible by patch_size={patch_size}. Got {(height, width)}."
175
  )
176
  channels = int(self.transformer.config.in_channels)
177
+ null_class_val = int(
178
+ getattr(self.transformer.config, "num_classes", getattr(self.transformer.config, "num_class_embeds", 1000))
179
+ )
180
 
181
  if guidance_scale is None:
182
  guidance_scale = 1.0
183
  if noise_scale is None:
184
  noise_scale = RECOMMENDED_NOISE_BY_SIZE.get(max(height, width), 1.0)
185
 
186
+ latents = randn_tensor(
 
187
  shape=(batch_size, channels, height, width),
188
+ generator=generator,
189
+ device=self._execution_device,
190
+ dtype=self.transformer.dtype,
191
+ ) * noise_scale
 
 
192
 
193
  class_labels_t = torch.tensor(class_label_ids, device=self._execution_device, dtype=torch.long).reshape(-1)
194
  class_labels_t = class_labels_t.clamp(0, null_class_val - 1)
 
200
  class_labels_input = class_labels_t
201
 
202
  self.scheduler.set_timesteps(num_inference_steps, device=self._execution_device)
203
+ extra_step_kwargs = self.prepare_extra_step_kwargs(self.scheduler, generator=generator)
204
  for t in self.progress_bar(self.scheduler.timesteps):
205
  step_index = self.scheduler.index_for_timestep(t, self.scheduler.timesteps)
206
  sigma = self.scheduler.sigmas[step_index].to(device=latents.device, dtype=latents.dtype)
 
235
  sigma = sigma.reshape(*([1] * (latents.ndim - 1)))
236
  # JiT predicts x0; scheduler integrates in sigma space: dz/dsigma = -(x0 - z) / sigma.
237
  model_output = -(x_pred - latents) / sigma
238
+ latents = self.scheduler.step(model_output, t, latents, **extra_step_kwargs).prev_sample
239
 
240
  images_pt = ((latents.float().clamp(-1, 1) + 1.0) / 2.0).cpu()
241
  if output_type == "pt":
 
250
  if not return_dict:
251
  return (images,)
252
  return ImagePipelineOutput(images=images)
253
+
254
+ JiTPipelineOutput = ImagePipelineOutput